In [9]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from open_clip import tokenizer
from torch.utils.data import Subset
import numpy as np
import tqdm
import random
import wandb
import datetime
import open_clip

In [10]:
config = {
        
        "run_name":                     "CLIP-{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")),  # A readable name for this run
        "device_id":                    1,  # GPU id
        
        "learning_rate":                1e-4,
        "batch_size":                   128,
        "epochs":                       1,
        "model":                        "RN50",
        
        "temperature":                  0.07,
        
        "loss_type":                    "anchor+lunif",
        
        "num_train_samples":            -1,            # -1 for all
        "num_test_samples":             -1,            # -1 for all
        "evaluate_every_n_batches":     50,
    }

In [11]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    # Normalize embeddings (optional, but typical in CLIP-like models)
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds  = F.normalize(text_embeds, dim=-1)
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2

In [12]:
def lunif_loss(x, t=2):
    
    x = F.normalize(x, dim=-1)
        
    # Compute pairwise distances between all embeddings
    sq_pdist = torch.pdist(x, p=2).pow(2)
    
    # Apply the uniformity loss formula
    return sq_pdist.mul(-t).exp().mean().log()

In [13]:
def compute_centroids(text_embeddings, visual_embeddings):
    """
    Computes the centroid for each pair of samples between text embeddings and visual embeddings
    by calculating the mean of the corresponding feature vectors across the two modalities.

    Parameters:
    - text_embeddings (torch.Tensor): Tensor of shape (batch_size1, feature_dim) representing text embeddings.
    - visual_embeddings (torch.Tensor): Tensor of shape (batch_size2, feature_dim) representing visual embeddings.

    Returns:
    - torch.Tensor: Tensor of shape (batch_size1, batch_size2, feature_dim) representing the centroid for each pair.
    """

    # Get batch sizes
    batch_size1 = text_embeddings.shape[0]   # For text embeddings
    batch_size2 = visual_embeddings.shape[0]  # For visual embeddings

    # Compute centroids by averaging text and visual embeddings
    # Expand the dimensions to allow pairwise computation
    text_expanded = text_embeddings.unsqueeze(1)  # Shape: [batch_size1, 1, feature_dim]
    visual_expanded = visual_embeddings.unsqueeze(0)  # Shape: [1, batch_size2, feature_dim]

    # Compute the centroid by averaging the embeddings
    centroids = (text_expanded + visual_expanded) / 2.0

    # Compute norms of the centroids
    centroid_norms = torch.norm(centroids, dim=-1)

    return centroid_norms, centroids


In [14]:
def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):
    
    print(len(ids_txt),len(ids))
    print(score_matrix.shape)
    assert score_matrix.shape == (len(ids_txt),len(ids))

    if direction == 'forward': ### text-to-vision retrieval
        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            # gt_indice = ids.index(ids_txt[i][0])
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)
        v_medianR = torch.median(rank).item() +1
        v_meanR = torch.mean(rank).item() +1
 
        eval_log = {'forward_r1': round(vr_r1*100,3),
                    'forward_recall': f'{round(vr_r1*100,1)}/{round(vr_r5*100,1)}/{round(vr_r10*100,3)}',
                    'forward_ravg': round((vr_r1 + vr_r5 + vr_r10)/3 *100,3)
                   }
   
    else: ### vision-to-text retrieval
       
        indice_matrix = score_matrix.sort(dim=0,descending=True)[1].permute(1,0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices=[]
            for idx, id in enumerate(ids_txt):
                if id == ids[i]:
                    gt_indices.append(idx)

            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)
        t_medianR = torch.median(rank).item() +1
        t_meanR = torch.mean(rank).item() +1

        eval_log = {
                    'backward_r1': round(tr_r1*100,3),
                    'backward_recall': f'{round(tr_r1*100,1)}/{round(tr_r5*100,1)}/{round(tr_r10*100,3)}',
                    'backward_ravg': round((tr_r1 + tr_r5 + tr_r10)/3 *100,3)
                  }
    

    return eval_log

In [15]:
def evaluate_model(model, test_loader, device):
    """
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics.

    Args:
        model (nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).
    """
    
    # Put model into eval mode
    model.eval()
    
    # Prepare storage
    all_image_embeds = []
    all_text_embeds  = []
    
    # IDs for retrieval
    # We'll assign each sample a unique ID. Because your `collate_fn` is
    # picking exactly one caption per image, we can treat each batch entry
    # as a 1:1 mapping of (image_i <-> text_i).
    ids_img = []
    ids_txt = []
    
    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds  = model.module.encode_text(text_tokens)

            # Move them to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds  = text_embeds.cpu()
            
            # Track
            bs = images.size(0)
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # For retrieval, we label these samples from current_index to current_index + bs - 1
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs
    
    # Concatenate everything
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # shape [N, embed_dim]
    all_text_embeds  = torch.cat(all_text_embeds, dim=0)   # shape [N, embed_dim]

    # Normalize embeddings for more stable retrieval
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds  = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    # Because we aligned IDs, this is effectively [N, N].
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Use the given function compute_metric_ret to compute retrieval metrics.
    # text->image: direction='forward'
    log_forward  = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')
    # image->text: direction='backward'
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward')

    # You can combine or print them:
    final_log = {**log_forward, **log_backward}
    print("Evaluation Results:", final_log)

    return final_log

In [16]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # In case you want a 3D version
import numpy as np
from sklearn.decomposition import PCA


def visualize_embeddings(text_embeddings, vision_embeddings, 
                         sample_size=1000, method='pca', 
                         title="Embeddings Visualization",
                         save_path=None):
    """
    Visualizes text and vision embeddings in 2D using PCA or (optionally) t-SNE.

    Args:
        text_embeddings (torch.Tensor): 
            Shape [N, D] containing text embeddings.
        vision_embeddings (torch.Tensor):
            Shape [N, D] containing vision/image embeddings.
        sample_size (int): 
            If the embeddings contain more than 'sample_size' samples, 
            randomly pick this many for faster plotting. Set -1 to use all.
        method (str): 
            "pca" or "tsne". Currently only "pca" is implemented below for simplicity.
        title (str): 
            Title for the plot.
        save_path (str, optional): 
            If provided, saves the plot to this path instead of showing it.
    """
    # Detach from graph and bring to CPU if the tensors require grad
    text_np = text_embeddings.detach().cpu().numpy()
    vision_np = vision_embeddings.detach().cpu().numpy()
    
    # Optionally downsample for quicker plotting
    if sample_size != -1:
        n_text = text_np.shape[0]
        n_vision = vision_np.shape[0]
        
        # We assume text_np and vision_np have the same length. If not, adjust accordingly.
        # Use the minimum just in case.
        n_samples = min(n_text, n_vision)
        
        if n_samples > sample_size:
            indices = np.random.choice(n_samples, size=sample_size, replace=False)
            text_np = text_np[indices]
            vision_np = vision_np[indices]
    
    # Combine for joint dimensionality reduction
    all_data = np.concatenate([text_np, vision_np], axis=0)
    
    # Apply dimensionality reduction
    # Currently implementing PCA; you can implement t-SNE if you prefer
    if method.lower() == "pca":
        reducer = PCA(n_components=2)
        reduced = reducer.fit_transform(all_data)
    else:
        raise NotImplementedError("Only 'pca' is implemented in this example.")
    
    # Split back into text and vision
    text_reduced = reduced[: len(text_np)]
    vision_reduced = reduced[len(text_np) : ]

    # Plot
    plt.figure(figsize=(8, 6))
    plt.scatter(text_reduced[:, 0], text_reduced[:, 1], c='red', alpha=0.6, label='Text')
    plt.scatter(vision_reduced[:, 0], vision_reduced[:, 1], c='blue', alpha=0.6, label='Vision')
    
    plt.title(title)
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")
    plt.legend()
    
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plot saved to {save_path}")
    else:
        plt.show()

ModuleNotFoundError: No module named 'sklearn'

In [8]:
def train_model(config, train_loader, test_loader, device):
    
    model_name = config["model"]   


    # Create model & transforms from scratch (no pretrained weights)
    model, preprocess, _ = open_clip.create_model_and_transforms(
        model_name,
        pretrained=None,
        device=device
    )

    # Put the model into training mode
    model.train()

    # If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
    for param in model.parameters():
        param.requires_grad = True


    # Example config
    lr = config["learning_rate"]
    epochs = config["epochs"]
    temperature = config["temperature"]

    # Move the model to multiple GPUs
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    current_batch = 0

    for epoch in range(epochs):
        for images, captions_list in tqdm.tqdm(train_loader):
            
            current_batch += 1
            
            # Move data to the primary device
            images = images.to(device)
            captions = captions_list

            # Tokenize text
            text_tokens = tokenizer.tokenize(captions)
            text_tokens = text_tokens.to(device)

            # Encode image and text
            image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
            text_embeds = model.module.encode_text(text_tokens)
            
            visualize_embeddings(text_embeds, 
                     vision_embeds, 
                     sample_size=1000, 
                     method='pca', 
                     title="CLIP Embeddings Visualization",
                     save_path="embeddings_plot.png")
            
            if config["loss_type"] == "anchor":
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
            elif config["loss_type"] == "anchor+lunif":
                lunif_img = lunif_loss(image_embeds)
                lunif_txt = lunif_loss(text_embeds)
                lunif = (lunif_img + lunif_txt) / 2
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature) + lunif / 2
                
            wandb.log({"train_loss": loss.item()})

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if current_batch % config["evaluate_every_n_batches"] == 0:
                print(f"[Epoch {epoch+1}/{epochs}]  Batch: {current_batch}  Loss: {loss.item():.5f}")
                print("Evaluating model...")
                test_results = evaluate_model(model, test_loader, device)
                
                wandb.log(test_results)

        print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")
    
    return model

In [9]:
def dataset_loader(config):

    # Path to train images and annotations
    train_image_dir = './coco/images/train2017/'                          # Path to train2017 images
    train_annotation_file = './coco/annotations/captions_train2017.json'  # Path to train2017 captions

    # Path to test (val) images and annotations
    test_image_dir = './coco/images/val2017/'                          # Path to val2017 images
    test_annotation_file = './coco/annotations/captions_val2017.json'  # Path to val2017 captions

    # Define the transform to be applied to the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # or whatever size your model expects
        transforms.ToTensor()
    ])

    # Create the training dataset
    train_coco = dset.CocoCaptions(
        root=train_image_dir,
        annFile=train_annotation_file,
        transform=transform
    )

    # Create the test dataset
    test_coco = dset.CocoCaptions(
        root=test_image_dir,
        annFile=test_annotation_file,
        transform=transform
    )
    
    if config["num_train_samples"] != -1:
        print(f"Subsetting the training dataset to {config['num_train_samples']} samples")

        # Subset the training dataset
        num_training_samples = config["num_train_samples"]
        subset_indices = list(range(num_training_samples))
        train_coco = Subset(train_coco, subset_indices)
    
    if config["num_test_samples"] != -1:
        print(f"Subsetting the test dataset to {config['num_test_samples']} samples")

        # Subset the test dataset
        num_test_samples = config["num_test_samples"]
        subset_indices = list(range(num_test_samples))
        test_coco = Subset(test_coco, subset_indices)

    # Every image has 5 captions at max, we need to sample one of them
    # Create collate function to sample one caption per image
    def collate_fn(batch):
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        sel_captions = []
        for list_captions in captions:
            caption = random.choice(list_captions)
            sel_captions.append(caption)
        return images, sel_captions

    # Create DataLoader
    batch_size = config["batch_size"]
    train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True , drop_last=True, collate_fn=collate_fn, num_workers=12)
    test_loader  = DataLoader(test_coco , batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn, num_workers=12)
    
    return train_loader, test_loader

In [None]:
def set_seed(seed: int):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU random numbers
    torch.cuda.manual_seed(seed)  # PyTorch GPU random numbers for a single GPU
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random numbers for all GPUs
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior for cuDNN
    torch.backends.cudnn.benchmark = False  # Disable benchmark for deterministic behavior

In [None]:
# %%prun

if __name__ == "__main__":
    
    set_seed(42)
    
    # Finish any existing W&B runs before starting a new one
    wandb.finish()

    # Initialize your W&B run
    wandb.init(project="sparsify-clip", config=config, name=config["run_name"])
    
    # Print the config
    print("Config:", config)
    
    # Set the device
    device_id = config["device_id"]
    device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")

    # Load the dataset
    train_loader, test_loader = dataset_loader(config)
    
    # Train the model
    model = train_model(config, train_loader, test_loader, device)
    
    # Final evaluation of the model
    final_log = evaluate_model(model, test_loader, device)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mnoostale[0m ([33mnoostale-organization[0m). Use [1m`wandb login --relogin`[0m to force relogin


Config: {'run_name': 'CLIP-2025-01-01-22-36-48', 'device_id': 1, 'learning_rate': 0.0001, 'batch_size': 128, 'epochs': 1, 'model': 'RN50', 'temperature': 0.07, 'loss_type': 'anchor+lunif', 'num_train_samples': -1, 'num_test_samples': -1, 'evaluate_every_n_batches': 50}
loading annotations into memory...
Done (t=0.64s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


  5%|▌         | 49/924 [00:27<07:40,  1.90it/s]

[Epoch 1/1]  Batch: 50  Loss: 2.97053
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.42it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


  5%|▌         | 50/924 [00:41<1:06:59,  4.60s/it]

Evaluation Results: {'forward_r1': 0.12, 'forward_recall': '0.1/0.4/0.921', 'forward_ravg': 0.487, 'backward_r1': 0.14, 'backward_recall': '0.1/0.7/1.082', 'backward_ravg': 0.628}


 11%|█         | 99/924 [01:06<07:01,  1.96it/s]  

[Epoch 1/1]  Batch: 100  Loss: 2.65684
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.39it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 11%|█         | 100/924 [01:20<1:04:19,  4.68s/it]

Evaluation Results: {'forward_r1': 0.26, 'forward_recall': '0.3/1.0/1.963', 'forward_ravg': 1.075, 'backward_r1': 0.16, 'backward_recall': '0.2/0.7/1.342', 'backward_ravg': 0.741}


 16%|█▌        | 149/924 [01:45<06:48,  1.90it/s]  

[Epoch 1/1]  Batch: 150  Loss: 2.52721
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.42it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 16%|█▌        | 150/924 [02:00<59:55,  4.64s/it]

Evaluation Results: {'forward_r1': 0.22, 'forward_recall': '0.2/1.4/2.384', 'forward_ravg': 1.329, 'backward_r1': 0.341, 'backward_recall': '0.3/1.4/2.704', 'backward_ravg': 1.476}


 22%|██▏       | 199/924 [02:24<06:13,  1.94it/s]

[Epoch 1/1]  Batch: 200  Loss: 2.47699
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.37it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 22%|██▏       | 200/924 [02:39<56:52,  4.71s/it]

Evaluation Results: {'forward_r1': 0.521, 'forward_recall': '0.5/1.9/3.405', 'forward_ravg': 1.956, 'backward_r1': 0.3, 'backward_recall': '0.3/1.4/2.704', 'backward_ravg': 1.469}


 27%|██▋       | 249/924 [03:04<05:50,  1.93it/s]

[Epoch 1/1]  Batch: 250  Loss: 2.02027
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.39it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 27%|██▋       | 250/924 [03:19<52:45,  4.70s/it]

Evaluation Results: {'forward_r1': 0.521, 'forward_recall': '0.5/2.1/3.646', 'forward_ravg': 2.097, 'backward_r1': 0.561, 'backward_recall': '0.6/1.7/3.045', 'backward_ravg': 1.756}


 32%|███▏      | 299/924 [03:44<05:24,  1.93it/s]

[Epoch 1/1]  Batch: 300  Loss: 2.05095
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.40it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 32%|███▏      | 300/924 [03:58<48:26,  4.66s/it]

Evaluation Results: {'forward_r1': 0.661, 'forward_recall': '0.7/2.6/4.627', 'forward_ravg': 2.624, 'backward_r1': 0.501, 'backward_recall': '0.5/2.4/4.367', 'backward_ravg': 2.411}


 38%|███▊      | 349/924 [04:23<04:57,  1.93it/s]

[Epoch 1/1]  Batch: 350  Loss: 2.01733
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.38it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 38%|███▊      | 350/924 [04:37<44:22,  4.64s/it]

Evaluation Results: {'forward_r1': 0.521, 'forward_recall': '0.5/2.8/4.888', 'forward_ravg': 2.751, 'backward_r1': 0.601, 'backward_recall': '0.6/2.5/4.708', 'backward_ravg': 2.611}


 43%|████▎     | 399/924 [05:02<04:33,  1.92it/s]

[Epoch 1/1]  Batch: 400  Loss: 1.77863
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.36it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 43%|████▎     | 400/924 [05:17<41:21,  4.74s/it]

Evaluation Results: {'forward_r1': 0.561, 'forward_recall': '0.6/2.8/5.389', 'forward_ravg': 2.918, 'backward_r1': 0.741, 'backward_recall': '0.7/3.0/5.349', 'backward_ravg': 3.025}


 49%|████▊     | 449/924 [05:42<04:06,  1.93it/s]

[Epoch 1/1]  Batch: 450  Loss: 1.92751
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.38it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 49%|████▊     | 450/924 [05:57<37:34,  4.76s/it]

Evaluation Results: {'forward_r1': 0.721, 'forward_recall': '0.7/3.5/5.97', 'forward_ravg': 3.392, 'backward_r1': 0.741, 'backward_recall': '0.7/2.9/5.228', 'backward_ravg': 2.965}


 54%|█████▍    | 499/924 [06:22<03:41,  1.92it/s]

[Epoch 1/1]  Batch: 500  Loss: 1.96207
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.41it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 54%|█████▍    | 500/924 [06:36<33:10,  4.70s/it]

Evaluation Results: {'forward_r1': 0.861, 'forward_recall': '0.9/3.6/6.591', 'forward_ravg': 3.679, 'backward_r1': 0.761, 'backward_recall': '0.8/3.2/5.849', 'backward_ravg': 3.272}


 59%|█████▉    | 549/924 [07:01<03:15,  1.92it/s]

[Epoch 1/1]  Batch: 550  Loss: 1.76249
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.36it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 60%|█████▉    | 550/924 [07:16<29:33,  4.74s/it]

Evaluation Results: {'forward_r1': 0.801, 'forward_recall': '0.8/3.4/6.21', 'forward_ravg': 3.472, 'backward_r1': 0.861, 'backward_recall': '0.9/3.2/5.609', 'backward_ravg': 3.239}


 65%|██████▍   | 599/924 [07:41<02:48,  1.93it/s]

[Epoch 1/1]  Batch: 600  Loss: 1.68925
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.45it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 65%|██████▍   | 600/924 [07:55<24:40,  4.57s/it]

Evaluation Results: {'forward_r1': 1.042, 'forward_recall': '1.0/4.3/7.352', 'forward_ravg': 4.24, 'backward_r1': 1.062, 'backward_recall': '1.1/4.0/7.011', 'backward_ravg': 4.026}


 70%|███████   | 649/924 [08:20<02:21,  1.95it/s]

[Epoch 1/1]  Batch: 650  Loss: 1.57810
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.40it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 70%|███████   | 650/924 [08:34<21:14,  4.65s/it]

Evaluation Results: {'forward_r1': 1.262, 'forward_recall': '1.3/4.5/7.772', 'forward_ravg': 4.507, 'backward_r1': 1.182, 'backward_recall': '1.2/4.4/7.412', 'backward_ravg': 4.347}


 76%|███████▌  | 699/924 [08:59<01:56,  1.93it/s]

[Epoch 1/1]  Batch: 700  Loss: 1.56632
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.42it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 76%|███████▌  | 700/924 [09:13<17:06,  4.58s/it]

Evaluation Results: {'forward_r1': 1.162, 'forward_recall': '1.2/4.7/7.752', 'forward_ravg': 4.534, 'backward_r1': 1.102, 'backward_recall': '1.1/3.6/6.951', 'backward_ravg': 3.873}


 81%|████████  | 749/924 [09:38<01:30,  1.93it/s]

[Epoch 1/1]  Batch: 750  Loss: 1.35949
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.42it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 81%|████████  | 750/924 [09:53<13:33,  4.68s/it]

Evaluation Results: {'forward_r1': 1.262, 'forward_recall': '1.3/5.0/8.754', 'forward_ravg': 5.015, 'backward_r1': 1.122, 'backward_recall': '1.1/4.6/8.474', 'backward_ravg': 4.728}


 86%|████████▋ | 799/924 [10:18<01:04,  1.92it/s]

[Epoch 1/1]  Batch: 800  Loss: 1.28336
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.40it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 87%|████████▋ | 800/924 [10:32<09:48,  4.74s/it]

Evaluation Results: {'forward_r1': 1.402, 'forward_recall': '1.4/5.3/9.315', 'forward_ravg': 5.355, 'backward_r1': 1.262, 'backward_recall': '1.3/5.0/8.373', 'backward_ravg': 4.868}


 92%|█████████▏| 849/924 [10:58<00:38,  1.93it/s]

[Epoch 1/1]  Batch: 850  Loss: 1.39391
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.37it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 92%|█████████▏| 850/924 [11:12<05:47,  4.70s/it]

Evaluation Results: {'forward_r1': 1.382, 'forward_recall': '1.4/5.3/9.555', 'forward_ravg': 5.422, 'backward_r1': 1.302, 'backward_recall': '1.3/5.0/8.694', 'backward_ravg': 5.001}


 97%|█████████▋| 899/924 [11:37<00:12,  1.93it/s]

[Epoch 1/1]  Batch: 900  Loss: 1.47090
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:08<00:00,  4.43it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 97%|█████████▋| 900/924 [11:51<01:51,  4.65s/it]

Evaluation Results: {'forward_r1': 1.242, 'forward_recall': '1.2/5.4/9.515', 'forward_ravg': 5.402, 'backward_r1': 1.522, 'backward_recall': '1.5/5.7/9.696', 'backward_ravg': 5.656}


100%|██████████| 924/924 [12:04<00:00,  1.28it/s]


[Epoch 1/1]  Loss: 1.1218


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.27it/s]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])
Evaluation Results: {'forward_r1': 1.502, 'forward_recall': '1.5/6.3/10.156', 'forward_ravg': 5.996, 'backward_r1': 1.482, 'backward_recall': '1.5/5.5/9.696', 'backward_ravg': 5.556}


In [11]:
""" %doctest_mode


def dataset_details():
    # Print dataset details
    print('Number of samples:', len(train_coco)) # 118287 images

    # Access a specific sample (4th sample here)
    img, target = train_coco[3]  # Load the 4th sample (index 3)

    # Display information about the sample
    print("Image Size:", img.size())  # Torch tensor size
    #plt.imshow(img.permute(1, 2, 0))  # Display the image
    print("Captions:", target)  # Captions for the image

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)

 """

' %doctest_mode\n\n\ndef dataset_details():\n    # Print dataset details\n    print(\'Number of samples:\', len(train_coco)) # 118287 images\n\n    # Access a specific sample (4th sample here)\n    img, target = train_coco[3]  # Load the 4th sample (index 3)\n\n    # Display information about the sample\n    print("Image Size:", img.size())  # Torch tensor size\n    #plt.imshow(img.permute(1, 2, 0))  # Display the image\n    print("Captions:", target)  # Captions for the image\n\nfor images, captions_list in train_loader:\n    # images.shape is e.g. (N, 3, 224, 224)\n    # captions_list has length N, but each item might be a tuple of possible captions\n\n    plt.imshow(images[0].permute(1, 2, 0))\n    plt.show()\n    plt.imshow(images[1].permute(1, 2, 0))\n    plt.show()\n\n    print("Image batch size:", images.shape[0], "Shape:", images.shape)\n    print("Captions list length:", len(captions_list))\n    \n    print("Captions list:", list(captions_list))\n\n    print("Number of chosen 