In [14]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from open_clip import tokenizer
from torch.utils.data import Subset
import numpy as np
import tqdm
import random
import wandb
import datetime
import open_clip

In [15]:
config = {
        
        "run_name":                     "CLIP-{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")),  # A readable name for this run
        "device_id":                    1,      # GPU id
        "seed":                         42,     # Random seed
        
        "learning_rate":                1e-4,
        "batch_size":                   128,
        "epochs":                       1,
        "model":                        "RN50",
        
        "temperature":                  0.07,
        
        "loss_type":                    "anchor+lunif",   # anchor, anchor+lunif
        
        "num_train_samples":            -1,            # -1 for all
        "num_test_samples":             -1,            # -1 for all
        "evaluate_every_n_batches":     200,
    }

In [16]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2

In [17]:
def lunif_loss(x, t=2):
    # Compute pairwise distances between all embeddings
    sq_pdist = torch.pdist(x, p=2).pow(2)
    
    # Apply the uniformity loss formula
    return sq_pdist.mul(-t).exp().mean().log()

In [18]:
def compute_centroids(text_embeddings, visual_embeddings):
    """
    Computes the centroid for each pair of samples between text embeddings and visual embeddings
    by calculating the mean of the corresponding feature vectors across the two modalities.

    Parameters:
    - text_embeddings (torch.Tensor): Tensor of shape (batch_size1, feature_dim) representing text embeddings.
    - visual_embeddings (torch.Tensor): Tensor of shape (batch_size2, feature_dim) representing visual embeddings.

    Returns:
    - torch.Tensor: Tensor of shape (batch_size1, batch_size2, feature_dim) representing the centroid for each pair.
    """

    # Compute centroids by averaging text and visual embeddings
    # Expand the dimensions to allow pairwise computation
    text_expanded = text_embeddings.unsqueeze(1)  # Shape: [batch_size1, 1, feature_dim]
    visual_expanded = visual_embeddings.unsqueeze(0)  # Shape: [1, batch_size2, feature_dim]

    # Compute the centroid by averaging the embeddings
    centroids = (text_expanded + visual_expanded) / 2.0

    # Compute norms of the centroids
    centroid_norms = torch.norm(centroids, dim=-1)

    return centroid_norms, centroids


In [19]:
def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):
    
    # Check that the score matrix has the correct shape
    assert score_matrix.shape == (len(ids_txt),len(ids))

    if direction == 'forward': ### text-to-vision retrieval
        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            # gt_indice = ids.index(ids_txt[i][0])
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)
        v_medianR = torch.median(rank).item() +1
        v_meanR = torch.mean(rank).item() +1
 
        eval_log = {'forward_r1': round(vr_r1*100,3),
                    'forward_recall': f'{round(vr_r1*100,1)}/{round(vr_r5*100,1)}/{round(vr_r10*100,3)}',
                    'forward_ravg': round((vr_r1 + vr_r5 + vr_r10)/3 *100,3)
                   }
   
    else: ### vision-to-text retrieval
       
        indice_matrix = score_matrix.sort(dim=0,descending=True)[1].permute(1,0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices=[]
            for idx, id in enumerate(ids_txt):
                if id == ids[i]:
                    gt_indices.append(idx)

            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)
        t_medianR = torch.median(rank).item() +1
        t_meanR = torch.mean(rank).item() +1

        eval_log = {
                    'backward_r1': round(tr_r1*100,3),
                    'backward_recall': f'{round(tr_r1*100,1)}/{round(tr_r5*100,1)}/{round(tr_r10*100,3)}',
                    'backward_ravg': round((tr_r1 + tr_r5 + tr_r10)/3 *100,3)
                  }
    
    return eval_log

In [20]:
def evaluate_model(model, test_loader, device):
    """
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics.

    Args:
        model (nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).
    """
    
    # Put model into eval mode
    model.eval()
    
    # Prepare storage
    all_image_embeds = []
    all_text_embeds  = []
    
    # IDs for retrieval
    # We'll assign each sample a unique ID. Because your `collate_fn` is
    # picking exactly one caption per image, we can treat each batch entry
    # as a 1:1 mapping of (image_i <-> text_i).
    ids_img = []
    ids_txt = []
    
    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds  = model.module.encode_text(text_tokens)

            # Move them to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds  = text_embeds.cpu()
            
            # Track
            bs = images.size(0)
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # For retrieval, we label these samples from current_index to current_index + bs - 1
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs
    
    # Concatenate everything
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # shape [N, embed_dim]
    all_text_embeds  = torch.cat(all_text_embeds, dim=0)   # shape [N, embed_dim]

    # Normalize embeddings for more stable retrieval
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds  = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    # Because we aligned IDs, this is effectively [N, N].
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Use the given function compute_metric_ret to compute retrieval metrics.
    # text->image: direction='forward'
    log_forward  = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')
    # image->text: direction='backward'
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward')

    # You can combine or print them:
    final_log = {**log_forward, **log_backward}
    print("Evaluation Results:", final_log)

    return final_log

In [21]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from sklearn.decomposition import PCA

def visualize_embeddings(text_embeddings, vision_embeddings, 
                         sample_size=1000, method='pca', 
                         title="Embeddings Visualization",
                         save_path=None):
    """
    Visualizes text and vision embeddings in 2D or 3D using PCA.

    Args:
        text_embeddings (torch.Tensor): 
            Shape [N, D] containing text embeddings.
        vision_embeddings (torch.Tensor):
            Shape [N, D] containing vision/image embeddings.
        sample_size (int): 
            If the embeddings contain more than 'sample_size' samples, 
            randomly pick this many for faster plotting. Set -1 to use all.
        method (str): 
            "pca". Currently only "pca" is implemented below for simplicity.
        title (str): 
            Title for the plot.
        save_path (str, optional): 
            If provided, saves the plot to this path instead of showing it.
    """
    # Detach from graph and bring to CPU if the tensors require grad
    text_np = text_embeddings.detach().cpu().numpy()
    vision_np = vision_embeddings.detach().cpu().numpy()

    # Optionally downsample for quicker plotting
    if sample_size != -1:
        n_text = text_np.shape[0]
        n_vision = vision_np.shape[0]

        n_samples = min(n_text, n_vision)

        if n_samples > sample_size:
            indices = np.random.choice(n_samples, size=sample_size, replace=False)
            text_np = text_np[indices]
            vision_np = vision_np[indices]

    # Combine for joint dimensionality reduction
    all_data = np.concatenate([text_np, vision_np], axis=0)

    # Apply dimensionality reduction
    if method.lower() == "pca":
        reducer = PCA(n_components=3)
        reduced = reducer.fit_transform(all_data)
    else:
        raise NotImplementedError("Only 'pca' is implemented in this example.")

    # Split back into text and vision
    text_reduced = reduced[: len(text_np)]
    vision_reduced = reduced[len(text_np):]

    # Plot 3D visualization
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(text_reduced[:, 0], text_reduced[:, 1], text_reduced[:, 2], 
               c='red', alpha=0.6, label='Text')
    ax.scatter(vision_reduced[:, 0], vision_reduced[:, 1], vision_reduced[:, 2], 
               c='blue', alpha=0.6, label='Vision')

    ax.set_title(title)
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.set_zlabel("Component 3")
    ax.legend()

    if save_path is not None:
        plt.savefig(save_path, dpi=300)
        wandb.log({title: wandb.Image(plt)})
        plt.close()
    else:
        plt.show()


In [22]:
def train_model(config, train_loader, test_loader, device):

    # Create model & transforms from scratch (no pretrained weights)
    model, preprocess, _ = open_clip.create_model_and_transforms(
        config["model"],
        pretrained=None,
        device=device
    )

    # Put the model into training mode
    model.train()

    # If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
    for param in model.parameters():
        param.requires_grad = True

    # Set up training parameters from the config
    lr = config["learning_rate"]
    epochs = config["epochs"]
    temperature = config["temperature"]

    # Move the model to multiple GPUs
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    current_batch = 0

    for epoch in range(epochs):
        for images, captions_list in tqdm.tqdm(train_loader):
            
            current_batch += 1
            
            # Move data to the primary device
            images = images.to(device)
            captions = captions_list

            # Tokenize text
            text_tokens = tokenizer.tokenize(captions)
            text_tokens = text_tokens.to(device)

            # Encode image and text
            image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
            text_embeds = model.module.encode_text(text_tokens)
            
            # Normalize embeddings
            image_embeds = F.normalize(image_embeds, dim=-1)
            text_embeds  = F.normalize(text_embeds, dim=-1)
            
            
            if config["loss_type"] == "anchor":
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
            elif config["loss_type"] == "anchor+lunif":
                lunif_img = lunif_loss(image_embeds)
                lunif_txt = lunif_loss(text_embeds)
                lunif = (lunif_img + lunif_txt) / 2
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature) + lunif
            elif config["loss_type"] == "lunif(50batch)+frozen(text_embed)":
                if current_batch <= 50:
                    lunif_img = lunif_loss(image_embeds)
                    lunif_txt = lunif_loss(text_embeds)
                    lunif = (lunif_img + lunif_txt) / 2
                    loss = lunif
                else: # train on anchor loss with frozen text embeddings
                    text_embeds = text_embeds.detach()
                    loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
                    
            wandb.log({"train_loss": loss.item()})

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if current_batch % 50 == 0:
                 visualize_embeddings(text_embeds, 
                     image_embeds, 
                     sample_size=1000, 
                     method='pca', 
                     title="CLIP Embeddings Visualization",
                     save_path="embeddings_plot.png")
            
            if current_batch % config["evaluate_every_n_batches"] == 0:
                print(f"[Epoch {epoch+1}/{epochs}]  Batch: {current_batch}  Loss: {loss.item():.5f}")
                print("Evaluating model...")
                test_results = evaluate_model(model, test_loader, device)
                
                wandb.log(test_results)

        print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")
    
    return model

In [23]:
def dataset_loader(config):

    # Path to train images and annotations
    train_image_dir = './data/coco/images/train2017/'                          # Path to train2017 images
    train_annotation_file = './data/coco/annotations/captions_train2017.json'  # Path to train2017 captions

    # Path to test (val) images and annotations
    test_image_dir = './data/coco/images/val2017/'                          # Path to val2017 images
    test_annotation_file = './data/coco/annotations/captions_val2017.json'  # Path to val2017 captions

    # Define the transform to be applied to the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize the image to the model's required input size
        transforms.ToTensor()
    ])

    # Create the training dataset
    train_coco = dset.CocoCaptions(
        root=train_image_dir,
        annFile=train_annotation_file,
        transform=transform
    )

    # Create the test dataset
    test_coco = dset.CocoCaptions(
        root=test_image_dir,
        annFile=test_annotation_file,
        transform=transform
    )
    
    if config["num_train_samples"] != -1:
        print(f"Subsetting the training dataset to {config['num_train_samples']} samples")
        # Subset the training dataset
        num_training_samples = config["num_train_samples"]
        subset_indices = list(range(num_training_samples))
        train_coco = Subset(train_coco, subset_indices)
    
    if config["num_test_samples"] != -1:
        print(f"Subsetting the test dataset to {config['num_test_samples']} samples")
        # Subset the test dataset
        num_test_samples = config["num_test_samples"]
        subset_indices = list(range(num_test_samples))
        test_coco = Subset(test_coco, subset_indices)

    # Every image has 5 captions at max, we need to sample one of them
    # Create collate function to sample one caption per image
    def collate_fn(batch):
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        sel_captions = []
        for list_captions in captions:
            caption = random.choice(list_captions)
            sel_captions.append(caption)
        return images, sel_captions

    # Create DataLoader
    batch_size = config["batch_size"]
    train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True , drop_last=True, collate_fn=collate_fn, num_workers=12)
    test_loader  = DataLoader(test_coco , batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn, num_workers=12)
    
    return train_loader, test_loader

In [24]:
def set_seed(seed: int):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU random numbers
    torch.cuda.manual_seed(seed)  # PyTorch GPU random numbers for a single GPU
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random numbers for all GPUs
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior for cuDNN
    torch.backends.cudnn.benchmark = False  # Disable benchmark for deterministic behavior

In [25]:
def main(config):
    # Set the seed for reproducibility
    set_seed(config["seed"])
    
    # Finish any existing W&B runs before starting a new one
    wandb.finish()

    # Initialize your W&B run
    wandb.init(project="sparsify-clip", config=config, name=config["run_name"])
    
    # Print the config
    print("Config:", config)
    
    # Set the device
    device_id = config["device_id"]
    device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")
    
    # Load the dataset
    print("Loading the dataset...")
    train_loader, test_loader = dataset_loader(config)
    print("Dataset loaded.\n")
    
    # Train the model
    print("Training the model...")
    model = train_model(config, train_loader, test_loader, device)
    print("Training complete.\n")
    
    # Final evaluation of the model
    print("Final evaluation of the model...")
    final_log = evaluate_model(model, test_loader, device)
    print("Evaluation complete.\n")
    
    # Save the model and upload it to W&B
    #torch.save(model.state_dict(), config["run_name"] + ".pt")
    #wandb.save(config["run_name"] + ".pt")
    
    wandb.finish()

In [26]:
# %%prun

if __name__ == "__main__":
    
    config["num_train_samples"] = 30000
    config["num_test_samples"] = 3000
    config["evaluate_every_n_batches"] = 10
    
    # Baseline
    config["loss_type"] = "anchor"
    print("\nTraining Baseline model")
    main(config)
    
    
    # Anchor + Lunif
    config["loss_type"] = "anchor+lunif"
    print("\nTraining Anchor + Lunif model")
    main(config)
    
    # Lunif(200itr)+frozen(text_embed)
    config["loss_type"] = "lunif(200itr)+frozen(text_embed)"
    print("\nTraining Lunif(200itr)+frozen(text_embed) model")
    main(config)

0,1
backward_r1,▁
backward_ravg,▁
forward_r1,▁
forward_ravg,▁
train_loss,█▆▅▆▄▃▃▁▁▁▇▇▇▇▆▇▆▆▇▆

0,1
backward_r1,0.034
backward_ravg,0.17
backward_recall,0.0/0.1/0.34
forward_r1,0
forward_ravg,0.272
forward_recall,0.0/0.2/0.611
train_loss,3.21035


Config: {'run_name': 'CLIP-2025-01-02-16-51-09', 'device_id': 1, 'seed': 42, 'learning_rate': 0.0001, 'batch_size': 128, 'epochs': 1, 'model': 'RN50', 'temperature': 0.07, 'loss_type': 'anchor', 'num_train_samples': 30000, 'num_test_samples': 3000, 'evaluate_every_n_batches': 10}
loading annotations into memory...
Done (t=0.67s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
Subsetting the training dataset to 30000 samples
Subsetting the test dataset to 3000 samples


  4%|▍         | 9/234 [00:06<02:04,  1.81it/s]

[Epoch 1/1]  Batch: 10  Loss: 4.85013
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.74it/s]
  4%|▍         | 10/234 [00:15<11:43,  3.14s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.1/0.374', 'forward_ravg': 0.181, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.34', 'backward_ravg': 0.181}


  8%|▊         | 19/234 [00:19<02:10,  1.64it/s]

[Epoch 1/1]  Batch: 20  Loss: 4.85205
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.77it/s]
  9%|▊         | 20/234 [00:28<11:03,  3.10s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.1/0.272', 'forward_ravg': 0.147, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.34', 'backward_ravg': 0.181}


 12%|█▏        | 29/234 [00:32<02:05,  1.63it/s]

[Epoch 1/1]  Batch: 30  Loss: 4.85204
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.64it/s]
 13%|█▎        | 30/234 [00:41<10:56,  3.22s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.2/0.306', 'forward_ravg': 0.181, 'backward_r1': 0.068, 'backward_recall': '0.1/0.2/0.374', 'backward_ravg': 0.204}


 17%|█▋        | 39/234 [00:46<02:00,  1.62it/s]

[Epoch 1/1]  Batch: 40  Loss: 4.85198
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.68it/s]
 17%|█▋        | 40/234 [00:55<10:12,  3.16s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.2/0.34', 'forward_ravg': 0.192, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.374', 'backward_ravg': 0.192}


 21%|██        | 49/234 [00:59<01:53,  1.63it/s]

[Epoch 1/1]  Batch: 50  Loss: 4.85181
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.70it/s]
 21%|██▏       | 50/234 [01:10<10:53,  3.55s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.2/0.476', 'forward_ravg': 0.26, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.374', 'backward_ravg': 0.204}


 25%|██▌       | 59/234 [01:14<01:51,  1.57it/s]

[Epoch 1/1]  Batch: 60  Loss: 4.85167
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.50it/s]
 26%|██▌       | 60/234 [01:23<09:29,  3.27s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.2/0.34', 'forward_ravg': 0.192, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.408', 'backward_ravg': 0.204}


 29%|██▉       | 69/234 [01:28<01:42,  1.60it/s]

[Epoch 1/1]  Batch: 70  Loss: 4.85151
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.70it/s]
 30%|██▉       | 70/234 [01:37<08:33,  3.13s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.2/0.34', 'forward_ravg': 0.192, 'backward_r1': 0.034, 'backward_recall': '0.0/0.3/0.408', 'backward_ravg': 0.238}


 34%|███▍      | 79/234 [01:41<01:34,  1.63it/s]

[Epoch 1/1]  Batch: 80  Loss: 4.85134
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.72it/s]
 34%|███▍      | 80/234 [01:50<08:07,  3.17s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.3/0.442', 'forward_ravg': 0.26, 'backward_r1': 0.068, 'backward_recall': '0.1/0.2/0.374', 'backward_ravg': 0.204}


 38%|███▊      | 89/234 [01:55<01:29,  1.61it/s]

[Epoch 1/1]  Batch: 90  Loss: 4.85132
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.48it/s]
 38%|███▊      | 90/234 [02:04<07:56,  3.31s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.2/0.645', 'forward_ravg': 0.317, 'backward_r1': 0.034, 'backward_recall': '0.0/0.3/0.408', 'backward_ravg': 0.238}


 42%|████▏     | 99/234 [02:08<01:24,  1.60it/s]

[Epoch 1/1]  Batch: 100  Loss: 4.84986
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.72it/s]
 43%|████▎     | 100/234 [02:18<07:31,  3.37s/it]

Evaluation Results: {'forward_r1': 0.136, 'forward_recall': '0.1/0.3/0.51', 'forward_ravg': 0.328, 'backward_r1': 0.068, 'backward_recall': '0.1/0.2/0.442', 'backward_ravg': 0.226}


 47%|████▋     | 109/234 [02:23<01:18,  1.59it/s]

[Epoch 1/1]  Batch: 110  Loss: 4.83614
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.71it/s]
 47%|████▋     | 110/234 [02:32<06:27,  3.13s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.3/0.645', 'forward_ravg': 0.34, 'backward_r1': 0.0, 'backward_recall': '0.0/0.3/0.442', 'backward_ravg': 0.26}


 51%|█████     | 119/234 [02:36<01:10,  1.62it/s]

[Epoch 1/1]  Batch: 120  Loss: 4.84486
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.73it/s]
 51%|█████▏    | 120/234 [02:45<05:53,  3.10s/it]

Evaluation Results: {'forward_r1': 0.136, 'forward_recall': '0.1/0.4/0.747', 'forward_ravg': 0.419, 'backward_r1': 0.068, 'backward_recall': '0.1/0.3/0.51', 'backward_ravg': 0.283}


 55%|█████▌    | 129/234 [02:49<01:04,  1.62it/s]

[Epoch 1/1]  Batch: 130  Loss: 4.85269
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.65it/s]
 56%|█████▌    | 130/234 [02:58<05:29,  3.17s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.2/0.34', 'forward_ravg': 0.181, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.34', 'backward_ravg': 0.181}


 59%|█████▉    | 139/234 [03:03<00:58,  1.62it/s]

[Epoch 1/1]  Batch: 140  Loss: 4.85047
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.73it/s]
 60%|█████▉    | 140/234 [03:12<04:55,  3.14s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.2/0.34', 'forward_ravg': 0.181, 'backward_r1': 0.034, 'backward_recall': '0.0/0.2/0.34', 'backward_ravg': 0.181}


 64%|██████▎   | 149/234 [03:16<00:52,  1.63it/s]

[Epoch 1/1]  Batch: 150  Loss: 4.85056
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.75it/s]
 64%|██████▍   | 150/234 [03:26<04:41,  3.35s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.2/0.408', 'forward_ravg': 0.238, 'backward_r1': 0.034, 'backward_recall': '0.0/0.1/0.34', 'backward_ravg': 0.17}


 68%|██████▊   | 159/234 [03:30<00:46,  1.60it/s]

[Epoch 1/1]  Batch: 160  Loss: 4.85086
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.72it/s]
 68%|██████▊   | 160/234 [03:39<03:50,  3.11s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.3/0.611', 'forward_ravg': 0.317, 'backward_r1': 0.034, 'backward_recall': '0.0/0.1/0.408', 'backward_ravg': 0.192}


 72%|███████▏  | 169/234 [03:44<00:40,  1.62it/s]

[Epoch 1/1]  Batch: 170  Loss: 4.84861
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.79it/s]
 73%|███████▎  | 170/234 [03:52<03:19,  3.12s/it]

Evaluation Results: {'forward_r1': 0.034, 'forward_recall': '0.0/0.2/0.34', 'forward_ravg': 0.192, 'backward_r1': 0.068, 'backward_recall': '0.1/0.2/0.306', 'backward_ravg': 0.181}


 76%|███████▋  | 179/234 [03:57<00:33,  1.62it/s]

[Epoch 1/1]  Batch: 180  Loss: 4.84469
Evaluating model...


Evaluating: 100%|██████████| 23/23 [00:06<00:00,  3.70it/s]
 77%|███████▋  | 180/234 [04:06<02:48,  3.12s/it]

Evaluation Results: {'forward_r1': 0.068, 'forward_recall': '0.1/0.3/0.442', 'forward_ravg': 0.272, 'backward_r1': 0.102, 'backward_recall': '0.1/0.4/0.51', 'backward_ravg': 0.328}


 81%|████████  | 189/234 [04:10<00:27,  1.62it/s]

[Epoch 1/1]  Batch: 190  Loss: 4.83611
Evaluating model...


Evaluating:   0%|          | 0/23 [00:01<?, ?it/s]
 81%|████████  | 189/234 [04:12<01:00,  1.34s/it]


KeyboardInterrupt: 

In [None]:
""" %doctest_mode


def dataset_details():
    # Print dataset details
    print('Number of samples:', len(train_coco)) # 118287 images

    # Access a specific sample (4th sample here)
    img, target = train_coco[3]  # Load the 4th sample (index 3)

    # Display information about the sample
    print("Image Size:", img.size())  # Torch tensor size
    #plt.imshow(img.permute(1, 2, 0))  # Display the image
    print("Captions:", target)  # Captions for the image

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)

 """