In [42]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from open_clip import tokenizer
from torch.utils.data import Subset
import numpy as np
import tqdm
import random
import wandb
import datetime
import open_clip

In [43]:
config = {
        
        "run_name":                     "CLIP-{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")),  # A readable name for this run
        "device_id":                    1,      # GPU id
        "seed":                         42,     # Random seed
        
        "learning_rate":                1e-4,
        "batch_size":                   128,
        "epochs":                       1,
        "model":                        "RN50",
        
        "temperature":                  0.07,
        
        "loss_type":                    "anchor+lunif",   # anchor, anchor+lunif
        
        "num_train_samples":            -1,            # -1 for all
        "num_test_samples":             -1,            # -1 for all
        "evaluate_every_n_batches":     200,
        "visualize_every_n_batches":    10,
    }

In [44]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2

In [45]:
def lunif_loss(x, t=2):
    # Compute pairwise distances between all embeddings
    sq_pdist = torch.pdist(x, p=2).pow(2)
    
    # Apply the uniformity loss formula
    return sq_pdist.mul(-t).exp().mean().log()

In [46]:
def compute_centroids(text_embeddings, visual_embeddings):
    """
    Computes the centroid for each pair of samples between text embeddings and visual embeddings
    by calculating the mean of the corresponding feature vectors across the two modalities.

    Parameters:
    - text_embeddings (torch.Tensor): Tensor of shape (batch_size1, feature_dim) representing text embeddings.
    - visual_embeddings (torch.Tensor): Tensor of shape (batch_size2, feature_dim) representing visual embeddings.

    Returns:
    - torch.Tensor: Tensor of shape (batch_size1, batch_size2, feature_dim) representing the centroid for each pair.
    """

    # Compute centroids by averaging text and visual embeddings
    # Expand the dimensions to allow pairwise computation
    text_expanded = text_embeddings.unsqueeze(1)  # Shape: [batch_size1, 1, feature_dim]
    visual_expanded = visual_embeddings.unsqueeze(0)  # Shape: [1, batch_size2, feature_dim]

    # Compute the centroid by averaging the embeddings
    centroids = (text_expanded + visual_expanded) / 2.0

    # Compute norms of the centroids
    centroid_norms = torch.norm(centroids, dim=-1)

    return centroid_norms, centroids


In [47]:
def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):
    
    # Check that the score matrix has the correct shape
    assert score_matrix.shape == (len(ids_txt),len(ids))

    if direction == 'forward': ### text-to-vision retrieval
        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            # gt_indice = ids.index(ids_txt[i][0])
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)
        v_medianR = torch.median(rank).item() +1
        v_meanR = torch.mean(rank).item() +1
 
        eval_log = {'forward_r1': round(vr_r1*100,3),
                    'forward_recall': f'{round(vr_r1*100,1)}/{round(vr_r5*100,1)}/{round(vr_r10*100,3)}',
                    'forward_ravg': round((vr_r1 + vr_r5 + vr_r10)/3 *100,3)
                   }
   
    else: ### vision-to-text retrieval
       
        indice_matrix = score_matrix.sort(dim=0,descending=True)[1].permute(1,0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices=[]
            for idx, id in enumerate(ids_txt):
                if id == ids[i]:
                    gt_indices.append(idx)

            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)
        t_medianR = torch.median(rank).item() +1
        t_meanR = torch.mean(rank).item() +1

        eval_log = {
                    'backward_r1': round(tr_r1*100,3),
                    'backward_recall': f'{round(tr_r1*100,1)}/{round(tr_r5*100,1)}/{round(tr_r10*100,3)}',
                    'backward_ravg': round((tr_r1 + tr_r5 + tr_r10)/3 *100,3)
                  }
    
    return eval_log

In [48]:
def evaluate_model(model, test_loader, device):
    """
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics.

    Args:
        model (nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).
    """
    
    # Put model into eval mode
    model.eval()
    
    # Prepare storage
    all_image_embeds = []
    all_text_embeds  = []
    
    # IDs for retrieval
    # We'll assign each sample a unique ID. Because your `collate_fn` is
    # picking exactly one caption per image, we can treat each batch entry
    # as a 1:1 mapping of (image_i <-> text_i).
    ids_img = []
    ids_txt = []
    
    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds  = model.module.encode_text(text_tokens)

            # Move them to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds  = text_embeds.cpu()
            
            # Track
            bs = images.size(0)
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # For retrieval, we label these samples from current_index to current_index + bs - 1
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs
    
    # Concatenate everything
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # shape [N, embed_dim]
    all_text_embeds  = torch.cat(all_text_embeds, dim=0)   # shape [N, embed_dim]

    # Normalize embeddings for more stable retrieval
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds  = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    # Because we aligned IDs, this is effectively [N, N].
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Use the given function compute_metric_ret to compute retrieval metrics.
    # text->image: direction='forward'
    log_forward  = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')
    # image->text: direction='backward'
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward')

    # You can combine or print them:
    final_log = {**log_forward, **log_backward}
    print("Evaluation Results:", final_log)

    return final_log

In [49]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt


def visualize_embeddings(text_embeddings, vision_embeddings, 
                         sample_size=1000, method='pca', 
                         title="Embeddings Visualization",
                         save_path=None):
    """
    Visualizes text and vision embeddings in 2D or 3D using PCA, t-SNE, or UMAP.

    Args:
        text_embeddings (torch.Tensor): 
            Shape [N, D] containing text embeddings.
        vision_embeddings (torch.Tensor):
            Shape [N, D] containing vision/image embeddings.
        sample_size (int): 
            If the embeddings contain more than 'sample_size' samples, 
            randomly pick this many for faster plotting. Set -1 to use all.
        method (str): 
            "pca", "tsne", or "umap".
        title (str): 
            Title for the plot.
        save_path (str, optional): 
            If provided, saves the plot to this path instead of showing it.
    """
    # Detach from graph and bring to CPU if the tensors require grad
    text_np = text_embeddings.detach().cpu().numpy()
    vision_np = vision_embeddings.detach().cpu().numpy()

    # Optionally downsample for quicker plotting
    if sample_size != -1:
        n_text = text_np.shape[0]
        n_vision = vision_np.shape[0]

        n_samples = min(n_text, n_vision)

        if n_samples > sample_size:
            indices = np.random.choice(n_samples, size=sample_size, replace=False)
            text_np = text_np[indices]
            vision_np = vision_np[indices]

    # Combine for joint dimensionality reduction
    all_data = np.concatenate([text_np, vision_np], axis=0)

    # Apply dimensionality reduction
    if method.lower() == "pca":
        reducer = PCA(n_components=3)
        reduced = reducer.fit_transform(all_data)
    elif method.lower() == "tsne":
        reducer = TSNE(n_components=3, perplexity=30, max_iter=250, random_state=42)
        reduced = reducer.fit_transform(all_data)
    elif method.lower() == "umap":
        reducer = umap.UMAP(n_components=3, random_state=42, n_jobs=1)
        reduced = reducer.fit_transform(all_data)
    else:
        raise NotImplementedError("Only 'pca', 'tsne', and 'umap' are implemented.")

    # Split back into text and vision
    text_reduced = reduced[: len(text_np)]
    vision_reduced = reduced[len(text_np):]

    # Plot 3D visualization
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(text_reduced[:, 0], text_reduced[:, 1], text_reduced[:, 2], 
               c='red', alpha=0.6, label='Text')
    ax.scatter(vision_reduced[:, 0], vision_reduced[:, 1], vision_reduced[:, 2], 
               c='blue', alpha=0.6, label='Vision')

    ax.set_title(title)
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.set_zlabel("Component 3")
    ax.legend()

    if save_path is not None:
        plt.savefig(save_path, dpi=300)
        wandb.log({title: wandb.Image(plt)})
        plt.close()
    else:
        plt.show()


In [50]:
def train_model(config, train_loader, test_loader, device):

    # Create model & transforms from scratch (no pretrained weights) #TODO: Use the tokenizer from the chosen model, not the default one
    model, preprocess, _ = open_clip.create_model_and_transforms(
        config["model"],
        pretrained=None,
        device=device
    )

    # Put the model into training mode
    model.train()

    # If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
    for param in model.parameters():
        param.requires_grad = True

    # Set up training parameters from the config
    lr = config["learning_rate"]
    epochs = config["epochs"]
    temperature = config["temperature"]

    # Move the model to multiple GPUs
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    current_batch = 0

    for epoch in range(epochs):
        for images, captions_list in tqdm.tqdm(train_loader):
            
            current_batch += 1
            
            # Move data to the primary device
            images = images.to(device)
            captions = captions_list

            # Tokenize text
            text_tokens = tokenizer.tokenize(captions)
            text_tokens = text_tokens.to(device)

            # Encode image and text
            image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
            text_embeds = model.module.encode_text(text_tokens)
            
            # Normalize embeddings
            image_embeds = F.normalize(image_embeds, dim=-1)
            text_embeds  = F.normalize(text_embeds, dim=-1)
            
            # Compute loss based on the experiment type
            if config["loss_type"] == "anchor":
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
            elif config["loss_type"] == "anchor+lunif":
                lunif_img = lunif_loss(image_embeds)
                lunif_txt = lunif_loss(text_embeds)
                lunif = (lunif_img + lunif_txt) / 2
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature) + lunif
            elif config["loss_type"] == "lunif(50batch)+frozen(text_embed)":
                if current_batch <= 50:
                    lunif_img = lunif_loss(image_embeds)
                    lunif_txt = lunif_loss(text_embeds)
                    lunif = (lunif_img + lunif_txt) / 2
                    loss = lunif
                else: # train on anchor loss with frozen text embeddings
                    text_embeds = text_embeds.detach()
                    loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
                    
            wandb.log({"train_loss": loss.item()})

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if current_batch % config["visualize_every_n_batches"] == 0:
                 visualize_embeddings(text_embeds, 
                     image_embeds, 
                     sample_size=1000, 
                     method='umap', 
                     title="CLIP Embeddings Visualization",
                     save_path="embeddings_plot.png")
            
            if current_batch % config["evaluate_every_n_batches"] == 0:
                print(f"[Epoch {epoch+1}/{epochs}]  Batch: {current_batch}  Loss: {loss.item():.5f}")
                print("Evaluating model...")
                test_results = evaluate_model(model, test_loader, device)
                
                wandb.log(test_results)

        print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")
    
    return model

In [51]:
def dataset_loader(config):

    # Path to train images and annotations
    train_image_dir = './data/coco/images/train2017/'                          # Path to train2017 images
    train_annotation_file = './data/coco/annotations/captions_train2017.json'  # Path to train2017 captions

    # Path to test (val) images and annotations
    test_image_dir = './data/coco/images/val2017/'                          # Path to val2017 images
    test_annotation_file = './data/coco/annotations/captions_val2017.json'  # Path to val2017 captions

    # Define the transform to be applied to the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize the image to the model's required input size
        transforms.ToTensor()
    ])

    # Create the training dataset
    train_coco = dset.CocoCaptions(
        root=train_image_dir,
        annFile=train_annotation_file,
        transform=transform
    )

    # Create the test dataset
    test_coco = dset.CocoCaptions(
        root=test_image_dir,
        annFile=test_annotation_file,
        transform=transform
    )
    
    if config["num_train_samples"] != -1:
        print(f"Subsetting the training dataset to {config['num_train_samples']} samples")
        # Subset the training dataset
        num_training_samples = config["num_train_samples"]
        subset_indices = list(range(num_training_samples))
        train_coco = Subset(train_coco, subset_indices)
    
    if config["num_test_samples"] != -1:
        print(f"Subsetting the test dataset to {config['num_test_samples']} samples")
        # Subset the test dataset
        num_test_samples = config["num_test_samples"]
        subset_indices = list(range(num_test_samples))
        test_coco = Subset(test_coco, subset_indices)

    # Every image has 5 captions at max, we need to sample one of them
    # Create collate function to sample one caption per image
    def collate_fn(batch):
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        sel_captions = []
        for list_captions in captions:
            caption = random.choice(list_captions)
            sel_captions.append(caption)
        return images, sel_captions

    # Create DataLoader
    batch_size = config["batch_size"]
    train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True , drop_last=True, collate_fn=collate_fn, num_workers=12)
    test_loader  = DataLoader(test_coco , batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn, num_workers=12)
    
    return train_loader, test_loader

In [52]:
def set_seed(seed: int):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU random numbers
    torch.cuda.manual_seed(seed)  # PyTorch GPU random numbers for a single GPU
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random numbers for all GPUs
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior for cuDNN
    torch.backends.cudnn.benchmark = False  # Disable benchmark for deterministic behavior

In [53]:
def main(config):
    # Set the seed for reproducibility
    set_seed(config["seed"])
    
    # Finish any existing W&B runs before starting a new one
    wandb.finish()

    # Initialize your W&B run
    wandb.init(project="sparsify-clip", config=config, name=config["run_name"])
    
    # Print the config
    print("Config:", config)
    
    # Set the device
    device_id = config["device_id"]
    device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")
    
    # Load the dataset
    print("\nLoading the dataset...")
    train_loader, test_loader = dataset_loader(config)
    print("Dataset loaded.\n")
    
    # Train the model
    print("Training the model...")
    model = train_model(config, train_loader, test_loader, device)
    print("Training complete.\n")
    
    # Final evaluation of the model
    print("Final evaluation of the model...")
    final_log = evaluate_model(model, test_loader, device)
    print("Evaluation complete.\n")
    
    # Save the model and upload it to W&B
    #torch.save(model.state_dict(), config["run_name"] + ".pt")
    #wandb.save(config["run_name"] + ".pt")
    
    wandb.finish()

In [None]:
# %%prun

if __name__ == "__main__":
    
    # Minimum configuration (overwrites defaults)
    config["num_train_samples"] = -1
    config["num_test_samples"] = -1
    config["evaluate_every_n_batches"] = 50
    config["visualize_every_n_batches"] = 10
    config["epochs"] = 3
    
    
    # Baseline
    config["loss_type"] = "anchor"
    print("\nTraining Baseline model")
    main(config)
    
    
    # Anchor + Lunif
    config["loss_type"] = "anchor+lunif"
    print("\nTraining Anchor + Lunif model")
    main(config)
    
    # Lunif(50itr)+frozen(text_embed)
    # config["loss_type"] = "lunif(50batch)+frozen(text_embed)"
    # print("\nTraining Lunif(50itr)+frozen(text_embed) model")
    # main(config)


Training Baseline model


Config: {'run_name': 'CLIP-2025-01-02-17-52-04', 'device_id': 1, 'seed': 42, 'learning_rate': 0.0001, 'batch_size': 128, 'epochs': 3, 'model': 'RN50', 'temperature': 0.07, 'loss_type': 'anchor', 'num_train_samples': -1, 'num_test_samples': -1, 'evaluate_every_n_batches': 50, 'visualize_every_n_batches': 10}

Loading the dataset...
loading annotations into memory...
Done (t=0.67s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Dataset loaded.

Training the model...


  5%|▌         | 49/924 [00:31<07:49,  1.86it/s]

[Epoch 1/3]  Batch: 50  Loss: 4.59527
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.22it/s]
  5%|▌         | 50/924 [00:46<1:14:53,  5.14s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_recall': '0.1/0.4/0.721', 'forward_ravg': 0.401, 'backward_r1': 0.06, 'backward_recall': '0.1/0.3/0.681', 'backward_ravg': 0.341}


 11%|█         | 99/924 [01:15<07:12,  1.91it/s]  

[Epoch 1/3]  Batch: 100  Loss: 4.40957
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
 11%|█         | 100/924 [01:32<1:15:11,  5.48s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_recall': '0.1/0.5/0.942', 'forward_ravg': 0.514, 'backward_r1': 0.08, 'backward_recall': '0.1/0.5/1.042', 'backward_ravg': 0.554}


 12%|█▏        | 109/924 [01:36<09:38,  1.41it/s]  

In [41]:
""" %doctest_mode


def dataset_details():
    # Print dataset details
    print('Number of samples:', len(train_coco)) # 118287 images

    # Access a specific sample (4th sample here)
    img, target = train_coco[3]  # Load the 4th sample (index 3)

    # Display information about the sample
    print("Image Size:", img.size())  # Torch tensor size
    #plt.imshow(img.permute(1, 2, 0))  # Display the image
    print("Captions:", target)  # Captions for the image

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)

 """

' %doctest_mode\n\n\ndef dataset_details():\n    # Print dataset details\n    print(\'Number of samples:\', len(train_coco)) # 118287 images\n\n    # Access a specific sample (4th sample here)\n    img, target = train_coco[3]  # Load the 4th sample (index 3)\n\n    # Display information about the sample\n    print("Image Size:", img.size())  # Torch tensor size\n    #plt.imshow(img.permute(1, 2, 0))  # Display the image\n    print("Captions:", target)  # Captions for the image\n\nfor images, captions_list in train_loader:\n    # images.shape is e.g. (N, 3, 224, 224)\n    # captions_list has length N, but each item might be a tuple of possible captions\n\n    plt.imshow(images[0].permute(1, 2, 0))\n    plt.show()\n    plt.imshow(images[1].permute(1, 2, 0))\n    plt.show()\n\n    print("Image batch size:", images.shape[0], "Shape:", images.shape)\n    print("Captions list length:", len(captions_list))\n    \n    print("Captions list:", list(captions_list))\n\n    print("Number of chosen 