In [9]:
config = {
        "learning_rate": 1e-4,
        "batch_size": 128,
        "epochs": 1,
        "model": "RN50",
        
        "temperature": 0.07,
        
        "num_train_samples": 50000,
        "num_test_samples": 5000,
    }

In [10]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from open_clip import tokenizer
from torch.utils.data import Subset
import tqdm
import random
import wandb
import datetime
import open_clip

wandb.finish()

# Initialize your W&B run
wandb.init(
    project="sparsify-clip",  # The name of your project
    config=config,
    name="CLIP-{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")),  # A readable name for this run
)

# Path to train images and annotations
train_image_dir = './coco/images/train2017/'                          # Path to train2017 images
train_annotation_file = './coco/annotations/captions_train2017.json'  # Path to train2017 captions

# Path to test (val) images and annotations
test_image_dir = './coco/images/val2017/'                          # Path to val2017 images
test_annotation_file = './coco/annotations/captions_val2017.json'  # Path to val2017 captions

# Define the transform to be applied to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # or whatever size your model expects
    transforms.ToTensor()
])

# Create the training dataset
train_coco = dset.CocoCaptions(
    root=train_image_dir,
    annFile=train_annotation_file,
    transform=transform
)

# Create the test dataset
test_coco = dset.CocoCaptions(
    root=test_image_dir,
    annFile=test_annotation_file,
    transform=transform
)

# Subset the training dataset
num_training_samples = config["num_train_samples"]
subset_indices = list(range(num_training_samples))
train_coco = Subset(train_coco, subset_indices)

# Subset the test dataset
num_test_samples = config["num_test_samples"]
subset_indices = list(range(num_test_samples))
test_coco = Subset(test_coco, subset_indices)

# Every image has 5 captions at max, we need to sample one of them
# Create collate function to sample one caption per image
def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    sel_captions = []
    for list_captions in captions:
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    return images, sel_captions

# Create DataLoader
batch_size = config["batch_size"]
train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True , drop_last=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_coco , batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)

0,1
backward_r1,▁▁▁▁
backward_ravg,▁▁█▃
forward_r1,▁▁▁▁
forward_ravg,▁▁▁█
train_loss,▄▄█▁▆▄▃▃▃▅▇▇▇▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▆▆▆▅▇▇▇▇▇▇▇▆

0,1
backward_r1,0.02
backward_ravg,0.127
backward_recall,0.0/0.1/0.22
forward_r1,0.02
forward_ravg,0.113
forward_recall,0.0/0.1/0.22
train_loss,1.29571


loading annotations into memory...
Done (t=0.62s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [11]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    # Normalize embeddings (optional, but typical in CLIP-like models)
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds  = F.normalize(text_embeds, dim=-1)
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2

In [12]:
def lunif_loss(x, t=2):
    
    x = F.normalize(x, dim=-1)
        
    # Compute pairwise distances between all embeddings
    sq_pdist = torch.pdist(x, p=2).pow(2)
    
    # Apply the uniformity loss formula
    return sq_pdist.mul(-t).exp().mean().log()

In [13]:
def compute_centroids(text_embeddings, visual_embeddings):
    """
    Computes the centroid for each pair of samples between text embeddings and visual embeddings
    by calculating the mean of the corresponding feature vectors across the two modalities.

    Parameters:
    - text_embeddings (torch.Tensor): Tensor of shape (batch_size1, feature_dim) representing text embeddings.
    - visual_embeddings (torch.Tensor): Tensor of shape (batch_size2, feature_dim) representing visual embeddings.

    Returns:
    - torch.Tensor: Tensor of shape (batch_size1, batch_size2, feature_dim) representing the centroid for each pair.
    """

    # Get batch sizes
    batch_size1 = text_embeddings.shape[0]   # For text embeddings
    batch_size2 = visual_embeddings.shape[0]  # For visual embeddings

    # Compute centroids by averaging text and visual embeddings
    # Expand the dimensions to allow pairwise computation
    text_expanded = text_embeddings.unsqueeze(1)  # Shape: [batch_size1, 1, feature_dim]
    visual_expanded = visual_embeddings.unsqueeze(0)  # Shape: [1, batch_size2, feature_dim]

    # Compute the centroid by averaging the embeddings
    centroids = (text_expanded + visual_expanded) / 2.0

    # Compute norms of the centroids
    centroid_norms = torch.norm(centroids, dim=-1)

    return centroid_norms, centroids


In [14]:
def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):
    
    print(len(ids_txt),len(ids))
    print(score_matrix.shape)
    assert score_matrix.shape == (len(ids_txt),len(ids))

    if direction == 'forward': ### text-to-vision retrieval
        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            # gt_indice = ids.index(ids_txt[i][0])
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)
        v_medianR = torch.median(rank).item() +1
        v_meanR = torch.mean(rank).item() +1
 
        eval_log = {'forward_r1': round(vr_r1*100,3),
                    'forward_recall': f'{round(vr_r1*100,1)}/{round(vr_r5*100,1)}/{round(vr_r10*100,3)}',
                    'forward_ravg': round((vr_r1 + vr_r5 + vr_r10)/3 *100,3)
                   }
   
    else: ### vision-to-text retrieval
       
        indice_matrix = score_matrix.sort(dim=0,descending=True)[1].permute(1,0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices=[]
            for idx, id in enumerate(ids_txt):
                if id == ids[i]:
                    gt_indices.append(idx)

            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)
        t_medianR = torch.median(rank).item() +1
        t_meanR = torch.mean(rank).item() +1

        eval_log = {
                    'backward_r1': round(tr_r1*100,3),
                    'backward_recall': f'{round(tr_r1*100,1)}/{round(tr_r5*100,1)}/{round(tr_r10*100,3)}',
                    'backward_ravg': round((tr_r1 + tr_r5 + tr_r10)/3 *100,3)
                  }
    

    return eval_log

In [15]:
import numpy as np

def evaluate_model(model, test_loader, device):
    """
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics.

    Args:
        model (nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).
    """

    # Put model into eval mode
    model.eval()
    
    # Prepare storage
    all_image_embeds = []
    all_text_embeds  = []
    
    # IDs for retrieval
    # We'll assign each sample a unique ID. Because your `collate_fn` is
    # picking exactly one caption per image, we can treat each batch entry
    # as a 1:1 mapping of (image_i <-> text_i).
    ids_img = []
    ids_txt = []
    
    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds  = model.module.encode_text(text_tokens)

            # Move them to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds  = text_embeds.cpu()
            
            # Track
            bs = images.size(0)
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # For retrieval, we label these samples from current_index to current_index + bs - 1
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs
    
    # Concatenate everything
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # shape [N, embed_dim]
    all_text_embeds  = torch.cat(all_text_embeds, dim=0)   # shape [N, embed_dim]

    # Normalize embeddings for more stable retrieval
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds  = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    # Because we aligned IDs, this is effectively [N, N].
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Use the given function compute_metric_ret to compute retrieval metrics.
    # text->image: direction='forward'
    log_forward  = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')
    # image->text: direction='backward'
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward')

    # You can combine or print them:
    final_log = {**log_forward, **log_backward}
    print("Evaluation Results:", final_log)

    return final_log

In [16]:
# %%prun to profile and see where the time is spent


#model_name = "ViT-B-32"        # Example architecture
model_name = config["model"]    # Example architecture
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"

# Create model & transforms from scratch (no pretrained weights)
model, preprocess, _ = open_clip.create_model_and_transforms(
    model_name,
    pretrained=None,
    device=device
)

# Put the model into training mode
model.train()

# If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
for param in model.parameters():
    param.requires_grad = True


# Example config
lr = config["learning_rate"]
epochs = config["epochs"]
temperature = config["temperature"]

# Move the model to multiple GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

optimizer = optim.AdamW(model.parameters(), lr=lr)

current_batch = 0

for epoch in range(epochs):
    for images, captions_list in tqdm.tqdm(train_loader):
        
        current_batch += 1
        
        # Move data to the primary device
        images = images.to(device)
        captions = captions_list

        # Tokenize text
        text_tokens = tokenizer.tokenize(captions)
        text_tokens = text_tokens.to(device)

        # Encode image and text
        image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
        text_embeds = model.module.encode_text(text_tokens)
        
        
        # Coompute the lunif loss
        lunif_img = lunif_loss(image_embeds)
        lunif_txt = lunif_loss(text_embeds)
        lunif = (lunif_img + lunif_txt) / 2
        

        # Compute the contrastive loss
        loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature) + lunif / 2
        wandb.log({"train_loss": loss.item()})

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if current_batch % 20 == 0:
            print(f"[Epoch {epoch+1}/{epochs}]  Batch: {current_batch}  Loss: {loss.item():.5f}")
            test_results = evaluate_model(model, test_loader, device)
            
            wandb.log(test_results)

    print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")


  5%|▍         | 19/390 [00:22<07:10,  1.16s/it]

[Epoch 1/1]  Batch: 20  Loss: 3.48998


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


  5%|▌         | 20/390 [01:11<1:35:13, 15.44s/it]

Evaluation Results: {'forward_r1': 0.04, 'forward_recall': '0.0/0.1/0.14', 'forward_ravg': 0.08, 'backward_r1': 0.02, 'backward_recall': '0.0/0.1/0.22', 'backward_ravg': 0.114}


 10%|█         | 39/390 [01:35<07:33,  1.29s/it]  

[Epoch 1/1]  Batch: 40  Loss: 3.90224


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.12s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 10%|█         | 40/390 [02:24<1:31:41, 15.72s/it]

Evaluation Results: {'forward_r1': 0.02, 'forward_recall': '0.0/0.3/0.521', 'forward_ravg': 0.267, 'backward_r1': 0.04, 'backward_recall': '0.0/0.2/0.501', 'backward_ravg': 0.254}


 15%|█▌        | 59/390 [02:47<06:27,  1.17s/it]  

[Epoch 1/1]  Batch: 60  Loss: 3.51593


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.12s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 15%|█▌        | 60/390 [03:36<1:25:21, 15.52s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_recall': '0.1/0.3/0.601', 'forward_ravg': 0.347, 'backward_r1': 0.08, 'backward_recall': '0.1/0.3/0.541', 'backward_ravg': 0.3}


 20%|██        | 79/390 [03:57<05:59,  1.16s/it]  

[Epoch 1/1]  Batch: 80  Loss: 3.45159


Evaluating: 100%|██████████| 39/39 [00:44<00:00,  1.13s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 21%|██        | 80/390 [04:47<1:20:46, 15.63s/it]

Evaluation Results: {'forward_r1': 0.08, 'forward_recall': '0.1/0.6/0.942', 'forward_ravg': 0.534, 'backward_r1': 0.12, 'backward_recall': '0.1/0.4/0.701', 'backward_ravg': 0.407}


 25%|██▌       | 99/390 [05:08<05:35,  1.15s/it]  

[Epoch 1/1]  Batch: 100  Loss: 3.31263


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 26%|██▌       | 100/390 [05:57<1:14:33, 15.43s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_recall': '0.1/0.6/0.921', 'forward_ravg': 0.528, 'backward_r1': 0.06, 'backward_recall': '0.1/0.4/0.781', 'backward_ravg': 0.414}


 31%|███       | 119/390 [06:19<05:13,  1.16s/it]  

[Epoch 1/1]  Batch: 120  Loss: 3.08761


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 31%|███       | 120/390 [07:08<1:09:05, 15.36s/it]

Evaluation Results: {'forward_r1': 0.02, 'forward_recall': '0.0/0.7/1.182', 'forward_ravg': 0.641, 'backward_r1': 0.1, 'backward_recall': '0.1/0.6/0.921', 'backward_ravg': 0.534}


 36%|███▌      | 139/390 [07:30<04:54,  1.17s/it]  

[Epoch 1/1]  Batch: 140  Loss: 3.07114


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.12s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 36%|███▌      | 140/390 [08:19<1:05:01, 15.61s/it]

Evaluation Results: {'forward_r1': 0.18, 'forward_recall': '0.2/0.7/1.142', 'forward_ravg': 0.674, 'backward_r1': 0.18, 'backward_recall': '0.2/0.5/0.841', 'backward_ravg': 0.507}


 41%|████      | 159/390 [08:41<04:31,  1.18s/it]  

[Epoch 1/1]  Batch: 160  Loss: 2.97673


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.13s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 41%|████      | 160/390 [09:30<59:43, 15.58s/it]

Evaluation Results: {'forward_r1': 0.08, 'forward_recall': '0.1/0.6/1.242', 'forward_ravg': 0.648, 'backward_r1': 0.04, 'backward_recall': '0.0/0.7/1.242', 'backward_ravg': 0.654}


 46%|████▌     | 179/390 [09:52<04:03,  1.16s/it]

[Epoch 1/1]  Batch: 180  Loss: 2.93536


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 46%|████▌     | 180/390 [10:41<54:17, 15.51s/it]

Evaluation Results: {'forward_r1': 0.12, 'forward_recall': '0.1/0.7/1.382', 'forward_ravg': 0.741, 'backward_r1': 0.12, 'backward_recall': '0.1/0.5/1.062', 'backward_ravg': 0.554}


 51%|█████     | 199/390 [11:03<03:44,  1.17s/it]

[Epoch 1/1]  Batch: 200  Loss: 3.06189


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 51%|█████▏    | 200/390 [11:52<48:53, 15.44s/it]

Evaluation Results: {'forward_r1': 0.14, 'forward_recall': '0.1/0.9/1.663', 'forward_ravg': 0.901, 'backward_r1': 0.2, 'backward_recall': '0.2/0.5/1.082', 'backward_ravg': 0.601}


 56%|█████▌    | 219/390 [12:14<03:17,  1.15s/it]

[Epoch 1/1]  Batch: 220  Loss: 2.74465


Evaluating: 100%|██████████| 39/39 [00:45<00:00,  1.17s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 56%|█████▋    | 220/390 [13:05<45:46, 16.15s/it]

Evaluation Results: {'forward_r1': 0.16, 'forward_recall': '0.2/0.8/1.402', 'forward_ravg': 0.781, 'backward_r1': 0.2, 'backward_recall': '0.2/0.9/1.663', 'backward_ravg': 0.915}


 61%|██████▏   | 239/390 [13:27<02:59,  1.19s/it]

[Epoch 1/1]  Batch: 240  Loss: 2.77369


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.12s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 62%|██████▏   | 240/390 [14:16<38:46, 15.51s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_recall': '0.3/1.0/1.683', 'forward_ravg': 0.982, 'backward_r1': 0.18, 'backward_recall': '0.2/0.8/1.623', 'backward_ravg': 0.868}


 66%|██████▋   | 259/390 [14:38<02:36,  1.19s/it]

[Epoch 1/1]  Batch: 260  Loss: 2.67248


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 67%|██████▋   | 260/390 [15:27<33:30, 15.47s/it]

Evaluation Results: {'forward_r1': 0.18, 'forward_recall': '0.2/1.0/1.803', 'forward_ravg': 1.002, 'backward_r1': 0.26, 'backward_recall': '0.3/0.9/1.863', 'backward_ravg': 1.022}


 72%|███████▏  | 279/390 [15:49<02:08,  1.16s/it]

[Epoch 1/1]  Batch: 280  Loss: 2.79094


Evaluating: 100%|██████████| 39/39 [00:44<00:00,  1.15s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 72%|███████▏  | 280/390 [16:39<28:57, 15.79s/it]

Evaluation Results: {'forward_r1': 0.16, 'forward_recall': '0.2/1.1/1.903', 'forward_ravg': 1.042, 'backward_r1': 0.26, 'backward_recall': '0.3/0.9/1.863', 'backward_ravg': 1.015}


 77%|███████▋  | 299/390 [17:01<01:45,  1.16s/it]

[Epoch 1/1]  Batch: 300  Loss: 2.48632


Evaluating: 100%|██████████| 39/39 [00:44<00:00,  1.14s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 77%|███████▋  | 300/390 [17:51<23:34, 15.72s/it]

Evaluation Results: {'forward_r1': 0.321, 'forward_recall': '0.3/1.3/2.444', 'forward_ravg': 1.349, 'backward_r1': 0.321, 'backward_recall': '0.3/1.0/1.843', 'backward_ravg': 1.062}


 82%|████████▏ | 319/390 [18:13<01:24,  1.19s/it]

[Epoch 1/1]  Batch: 320  Loss: 2.51419


Evaluating: 100%|██████████| 39/39 [00:44<00:00,  1.14s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 82%|████████▏ | 320/390 [19:03<18:25, 15.79s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_recall': '0.3/1.4/2.143', 'forward_ravg': 1.289, 'backward_r1': 0.22, 'backward_recall': '0.2/1.4/2.284', 'backward_ravg': 1.309}


 87%|████████▋ | 339/390 [20:07<04:00,  4.71s/it]

[Epoch 1/1]  Batch: 340  Loss: 2.48775


Evaluating: 100%|██████████| 39/39 [02:05<00:00,  3.21s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 87%|████████▋ | 340/390 [22:22<36:24, 43.69s/it]

Evaluation Results: {'forward_r1': 0.28, 'forward_recall': '0.3/1.4/2.544', 'forward_ravg': 1.409, 'backward_r1': 0.3, 'backward_recall': '0.3/1.4/2.544', 'backward_ravg': 1.409}


 92%|█████████▏| 359/390 [22:48<00:42,  1.36s/it]

[Epoch 1/1]  Batch: 360  Loss: 2.44322


Evaluating: 100%|██████████| 39/39 [01:21<00:00,  2.10s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 92%|█████████▏| 360/390 [24:16<13:35, 27.19s/it]

Evaluation Results: {'forward_r1': 0.341, 'forward_recall': '0.3/1.7/3.285', 'forward_ravg': 1.776, 'backward_r1': 0.361, 'backward_recall': '0.4/1.6/2.945', 'backward_ravg': 1.636}


 97%|█████████▋| 379/390 [24:38<00:12,  1.17s/it]

[Epoch 1/1]  Batch: 380  Loss: 2.36717


Evaluating: 100%|██████████| 39/39 [00:44<00:00,  1.15s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])


 97%|█████████▋| 380/390 [25:28<02:38, 15.81s/it]

Evaluation Results: {'forward_r1': 0.401, 'forward_recall': '0.4/1.7/3.285', 'forward_ravg': 1.81, 'backward_r1': 0.361, 'backward_recall': '0.4/1.6/3.025', 'backward_ravg': 1.649}


100%|██████████| 390/390 [25:39<00:00,  3.95s/it]


[Epoch 1/1]  Loss: 2.2905


In [17]:
final_log = evaluate_model(model, test_loader, device)

# 3) final_log now contains your forward/backward R@1, R@5, R@10 metrics


Evaluating: 100%|██████████| 39/39 [00:43<00:00,  1.12s/it]


4992 4992
torch.Size([4992, 4992])
4992 4992
torch.Size([4992, 4992])
Evaluation Results: {'forward_r1': 0.421, 'forward_recall': '0.4/1.6/2.945', 'forward_ravg': 1.669, 'backward_r1': 0.341, 'backward_recall': '0.3/1.2/2.444', 'backward_ravg': 1.329}


In [18]:
""" %doctest_mode


def dataset_details():
    # Print dataset details
    print('Number of samples:', len(train_coco)) # 118287 images

    # Access a specific sample (4th sample here)
    img, target = train_coco[3]  # Load the 4th sample (index 3)

    # Display information about the sample
    print("Image Size:", img.size())  # Torch tensor size
    #plt.imshow(img.permute(1, 2, 0))  # Display the image
    print("Captions:", target)  # Captions for the image

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)

 """

' %doctest_mode\n\n\ndef dataset_details():\n    # Print dataset details\n    print(\'Number of samples:\', len(train_coco)) # 118287 images\n\n    # Access a specific sample (4th sample here)\n    img, target = train_coco[3]  # Load the 4th sample (index 3)\n\n    # Display information about the sample\n    print("Image Size:", img.size())  # Torch tensor size\n    #plt.imshow(img.permute(1, 2, 0))  # Display the image\n    print("Captions:", target)  # Captions for the image\n\nfor images, captions_list in train_loader:\n    # images.shape is e.g. (N, 3, 224, 224)\n    # captions_list has length N, but each item might be a tuple of possible captions\n\n    plt.imshow(images[0].permute(1, 2, 0))\n    plt.show()\n    plt.imshow(images[1].permute(1, 2, 0))\n    plt.show()\n\n    print("Image batch size:", images.shape[0], "Shape:", images.shape)\n    print("Captions list length:", len(captions_list))\n    \n    print("Captions list:", list(captions_list))\n\n    print("Number of chosen 