In [30]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from open_clip import tokenizer
from torch.utils.data import Subset
import numpy as np
import tqdm
import random
import wandb
import datetime
import open_clip

In [31]:
config = {
        
        "run_name":                     "CLIP-{}".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")),  # A readable name for this run
        "device_id":                    1,      # GPU id
        "seed":                         42,     # Random seed
        
        "learning_rate":                1e-4,
        "batch_size":                   128,
        "epochs":                       1,
        "model":                        "RN50",
        
        "temperature":                  0.07,
        
        "loss_type":                    "anchor+lunif",   # anchor, anchor+lunif
        
        "num_train_samples":            -1,            # -1 for all
        "num_test_samples":             -1,            # -1 for all
        "evaluate_every_n_batches":     200,
        "visualize_every_n_batches":    10,
    }

In [32]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """
    image_embeds: (batch_size, embed_dim)
    text_embeds: (batch_size, embed_dim)
    temperature: scalar float for scaling similarities
    returns: scalar loss (contrastive)
    """
    
    # Similarity matrix, shape (bs, bs)
    logits = image_embeds @ text_embeds.t()
    logits = logits / temperature

    # Targets are just the diagonal (i.e. 0->0, 1->1, ...)
    batch_size = image_embeds.size(0)
    target = torch.arange(batch_size, device=logits.device)

    # CE loss for image->text
    loss_i2t = F.cross_entropy(logits, target)
    # CE loss for text->image
    loss_t2i = F.cross_entropy(logits.t(), target)

    # Average the two directions
    return (loss_i2t + loss_t2i) / 2

In [33]:
def lunif_loss(x, t=2):
    # Compute pairwise distances between all embeddings
    sq_pdist = torch.pdist(x, p=2).pow(2)
    
    # Apply the uniformity loss formula
    return sq_pdist.mul(-t).exp().mean().log()

In [34]:
def compute_centroids(text_embeddings, visual_embeddings):
    """
    Computes the centroid for each pair of samples between text embeddings and visual embeddings
    by calculating the mean of the corresponding feature vectors across the two modalities.

    Parameters:
    - text_embeddings (torch.Tensor): Tensor of shape (batch_size1, feature_dim) representing text embeddings.
    - visual_embeddings (torch.Tensor): Tensor of shape (batch_size2, feature_dim) representing visual embeddings.

    Returns:
    - torch.Tensor: Tensor of shape (batch_size1, batch_size2, feature_dim) representing the centroid for each pair.
    """

    # Compute centroids by averaging text and visual embeddings
    # Expand the dimensions to allow pairwise computation
    text_expanded = text_embeddings.unsqueeze(1)  # Shape: [batch_size1, 1, feature_dim]
    visual_expanded = visual_embeddings.unsqueeze(0)  # Shape: [1, batch_size2, feature_dim]

    # Compute the centroid by averaging the embeddings
    centroids = (text_expanded + visual_expanded) / 2.0

    # Compute norms of the centroids
    centroid_norms = torch.norm(centroids, dim=-1)

    return centroid_norms, centroids


In [35]:
"""def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):
    
    # Check that the score matrix has the correct shape
    assert score_matrix.shape == (len(ids_txt),len(ids))

    if direction == 'forward': ### text-to-vision retrieval
        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            # gt_indice = ids.index(ids_txt[i][0])
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)
        v_medianR = torch.median(rank).item() +1
        v_meanR = torch.mean(rank).item() +1
 
        eval_log = {'forward_r1': round(vr_r1*100,3),
                    'forward_recall': f'{round(vr_r1*100,1)}/{round(vr_r5*100,1)}/{round(vr_r10*100,3)}',
                    'forward_ravg': round((vr_r1 + vr_r5 + vr_r10)/3 *100,3)
                   }
   
    else: ### vision-to-text retrieval
       
        indice_matrix = score_matrix.sort(dim=0,descending=True)[1].permute(1,0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices=[]
            for idx, id in enumerate(ids_txt):
                if id == ids[i]:
                    gt_indices.append(idx)

            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))
        
        rank = torch.tensor(rank).to(score_matrix)
        
        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)
        t_medianR = torch.median(rank).item() +1
        t_meanR = torch.mean(rank).item() +1

        eval_log = {
                    'backward_r1': round(tr_r1*100,3),
                    'backward_recall': f'{round(tr_r1*100,1)}/{round(tr_r5*100,1)}/{round(tr_r10*100,3)}',
                    'backward_ravg': round((tr_r1 + tr_r5 + tr_r10)/3 *100,3)
                  }
    
    return eval_log"""

"def compute_metric_ret(score_matrix, ids, ids_txt, direction='forward'):\n    \n    # Check that the score matrix has the correct shape\n    assert score_matrix.shape == (len(ids_txt),len(ids))\n\n    if direction == 'forward': ### text-to-vision retrieval\n        indice_matrix = score_matrix.sort(dim=-1,descending=True)[1].tolist()\n        rank = []\n        for i in range(len(ids_txt)):\n            # gt_indice = ids.index(ids_txt[i][0])\n            gt_indice = ids.index(ids_txt[i])\n            rank.append(indice_matrix[i].index(gt_indice))\n        \n        rank = torch.tensor(rank).to(score_matrix)\n        \n        vr_r1 = (rank < 1).sum().item() / len(ids_txt)\n        vr_r5 = (rank < 5).sum().item() / len(ids_txt)\n        vr_r10 = (rank < 10).sum().item() / len(ids_txt)\n        v_medianR = torch.median(rank).item() +1\n        v_meanR = torch.mean(rank).item() +1\n \n        eval_log = {'forward_r1': round(vr_r1*100,3),\n                    'forward_recall': f'{round(vr

In [36]:
"""def evaluate_model(model, test_loader, device):
    '''
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics.

    Args:
        model (nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).
    '''
    
    # Put model into eval mode
    model.eval()
    
    # Prepare storage
    all_image_embeds = []
    all_text_embeds  = []
    
    # IDs for retrieval
    # We'll assign each sample a unique ID. Because your `collate_fn` is
    # picking exactly one caption per image, we can treat each batch entry
    # as a 1:1 mapping of (image_i <-> text_i).
    ids_img = []
    ids_txt = []
    
    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds  = model.module.encode_text(text_tokens)

            # Move them to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds  = text_embeds.cpu()
            
            # Track
            bs = images.size(0)
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # For retrieval, we label these samples from current_index to current_index + bs - 1
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs
    
    # Concatenate everything
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # shape [N, embed_dim]
    all_text_embeds  = torch.cat(all_text_embeds, dim=0)   # shape [N, embed_dim]

    # Normalize embeddings for more stable retrieval
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds  = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    # Because we aligned IDs, this is effectively [N, N].
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Use the given function compute_metric_ret to compute retrieval metrics.
    # text->image: direction='forward'
    log_forward  = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')
    # image->text: direction='backward'
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward')

    # You can combine or print them:
    final_log = {**log_forward, **log_backward}
    print("Evaluation Results:", final_log)

    return final_log"""

'def evaluate_model(model, test_loader, device):\n    \'\'\'\n    Evaluate the (OpenCLIP) model on the given test_loader by computing\n    text-to-image and image-to-text retrieval metrics.\n\n    Args:\n        model (nn.Module): The trained (DataParallel) model.\n        test_loader (DataLoader): A DataLoader for the evaluation set.\n        device (torch.device): The device (CPU or GPU).\n    \'\'\'\n    \n    # Put model into eval mode\n    model.eval()\n    \n    # Prepare storage\n    all_image_embeds = []\n    all_text_embeds  = []\n    \n    # IDs for retrieval\n    # We\'ll assign each sample a unique ID. Because your `collate_fn` is\n    # picking exactly one caption per image, we can treat each batch entry\n    # as a 1:1 mapping of (image_i <-> text_i).\n    ids_img = []\n    ids_txt = []\n    \n    current_index = 0\n\n    # No gradient needed during evaluation\n    with torch.no_grad():\n        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):\n    

In [37]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import tqdm
from typing import List, Dict



def compute_metric_ret(score_matrix: torch.Tensor, ids: List[int], ids_txt: List[int], direction: str = 'forward') -> Dict[str, float]:
    """
    Compute retrieval metrics for either text-to-vision or vision-to-text retrieval.

    Args:
        score_matrix (torch.Tensor): Similarity matrix of shape [N_text, N_image].
        ids (List[int]): List of image IDs.
        ids_txt (List[int]): List of text IDs corresponding to images.
        direction (str): 'forward' for text-to-vision, 'backward' for vision-to-text.

    Returns:
        Dict[str, float]: Dictionary containing retrieval metrics.
    """
    assert score_matrix.shape == (len(ids_txt), len(ids)), f"Score matrix shape {score_matrix.shape} does not match (len(ids_txt), len(ids))"

    if direction == 'forward':  # Text-to-Vision Retrieval
        # Sort each row in descending order
        indice_matrix = score_matrix.sort(dim=-1, descending=True)[1].tolist()
        rank = []
        for i in range(len(ids_txt)):
            gt_indice = ids.index(ids_txt[i])
            rank.append(indice_matrix[i].index(gt_indice))

        rank = torch.tensor(rank).to(score_matrix.device)

        vr_r1 = (rank < 1).sum().item() / len(ids_txt)
        vr_r5 = (rank < 5).sum().item() / len(ids_txt)
        vr_r10 = (rank < 10).sum().item() / len(ids_txt)

        eval_log = {
            'forward_r1': round(vr_r1 * 100, 1),
            'forward_r5': round(vr_r5 * 100, 1),
            'forward_r10': round(vr_r10 * 100, 1),
            #'forward_recall': f'{round(vr_r1 * 100, 1)}/{round(vr_r5 * 100, 1)}/{round(vr_r10 * 100, 1)}',
            'forward_ravg': round((vr_r1 + vr_r5 + vr_r10) / 3 * 100, 1)
        }

    else:  # Vision-to-Text Retrieval
        # Sort each column in descending order
        indice_matrix = score_matrix.sort(dim=0, descending=True)[1].permute(1, 0).tolist()
        rank = []
        for i in range(len(ids)):
            gt_indices = [idx for idx, id_txt in enumerate(ids_txt) if id_txt == ids[i]]
            rank.append(min([indice_matrix[i].index(idx) for idx in gt_indices]))

        rank = torch.tensor(rank).to(score_matrix.device)

        tr_r1 = (rank < 1).sum().item() / len(ids)
        tr_r5 = (rank < 5).sum().item() / len(ids)
        tr_r10 = (rank < 10).sum().item() / len(ids)

        eval_log = {
            'backward_r1': round(tr_r1 * 100, 1),
            'backward_r5': round(tr_r5 * 100, 1),
            'backward_r10': round(tr_r10 * 100, 1),
            'backward_recall': f'{round(tr_r1 * 100,1)}/{round(tr_r5 * 100,1)}/{round(tr_r10 * 100,1)}',
            'backward_ravg': round((tr_r1 + tr_r5 + tr_r10) / 3 * 100, 1)
        }

    return eval_log

def compute_gap(feat_modality1: torch.Tensor, feat_modality2: torch.Tensor) -> float:
    """
    Compute the Euclidean distance between the centroids of two modalities.

    Args:
        feat_modality1 (torch.Tensor): Feature matrix of modality 1 with shape [N, D].
        feat_modality2 (torch.Tensor): Feature matrix of modality 2 with shape [N, D].

    Returns:
        float: Euclidean distance between centroids.
    """
    # Ensure features are normalized if required
    modality1_centroid = torch.mean(feat_modality1, dim=0)
    modality2_centroid = torch.mean(feat_modality2, dim=0)

    gap = modality1_centroid - modality2_centroid
    norm_gap = torch.norm(gap).item()

    return norm_gap

def compute_mean_angular_value_of_a_modality(feat_modality: torch.Tensor) -> float:
    """
    Compute the mean angular value (mean cosine similarity) of a modality.

    Args:
        feat_modality (torch.Tensor): Feature matrix with shape [N, D].

    Returns:
        float: Mean angular value.
    """
    # Compute cosine similarity matrix
    cos_sim = feat_modality @ feat_modality.T

    # Exclude diagonal elements by creating a mask
    mask = ~torch.eye(cos_sim.size(0), dtype=torch.bool, device=cos_sim.device)
    cos_sim_no_diag = cos_sim[mask]

    mean_cos_sim = cos_sim_no_diag.mean().item()

    return mean_cos_sim

def uniformity(features_modality1: torch.Tensor, features_modality2: torch.Tensor) -> float:
    """
    Calculate the uniformity metric for two modalities based on their features.

    Args:
        features_modality1 (torch.Tensor): Feature matrix of modality 1 with shape [N, D].
        features_modality2 (torch.Tensor): Feature matrix of modality 2 with shape [N, D].

    Returns:
        float: Uniformity metric (-W2).
    """
    # Concatenate the features of the two modalities
    Z = torch.cat([features_modality1, features_modality2], dim=0)  # Shape: [2 * N, D]

    # Compute the sample mean \mu_hat and covariance \Sigma
    mu_hat = torch.mean(Z, dim=0)  # Shape: [D]
    Sigma = torch.cov(Z.T)  # Shape: [D, D]

    # Calculate the trace and square root of the covariance matrix
    trace_Sigma = torch.trace(Sigma)  # Scalar
    sqrt_Sigma = torch.linalg.matrix_power(Sigma, 1 // 2)  # Matrix square root of Sigma, shape: [D, D]
    trace_sqrt_Sigma = torch.trace(sqrt_Sigma)  # Scalar

    # Dimensionality of the features
    m = Z.shape[1]

    # Compute the quadratic Wasserstein distance W2
    W2 = torch.sqrt(
        torch.norm(mu_hat)**2 + 1 + trace_Sigma - (2 / torch.sqrt(torch.tensor(m, dtype=Sigma.dtype))) * trace_sqrt_Sigma
    )

    # Return the uniformity metric (-W2)
    return -W2.item()

def mean_distance_of_true_pairs(features_modality1: torch.Tensor, features_modality2: torch.Tensor) -> float:
    """
    Compute the mean cosine similarity of true pairs between two modalities.

    Args:
        features_modality1 (torch.Tensor): Normalized feature matrix of modality 1 with shape [N, D].
        features_modality2 (torch.Tensor): Normalized feature matrix of modality 2 with shape [N, D].

    Returns:
        float: Mean cosine similarity of true pairs.
    """
    # Compute cosine similarity matrix
    cosine_sim = torch.matmul(features_modality1, features_modality2.T)

    # Extract diagonal elements (true pairs)
    cosine_sim_diag = torch.diag(cosine_sim)

    # Compute mean cosine similarity of true pairs
    cosine_tv_mean = torch.mean(cosine_sim_diag).item()

    return cosine_tv_mean

def evaluate_model(model: torch.nn.Module, test_loader: DataLoader, device: torch.device) -> Dict[str, float]:
    """
    Evaluate the (OpenCLIP) model on the given test_loader by computing
    text-to-image and image-to-text retrieval metrics, along with additional metrics.

    Args:
        model (torch.nn.Module): The trained (DataParallel) model.
        test_loader (DataLoader): A DataLoader for the evaluation set.
        device (torch.device): The device (CPU or GPU).

    Returns:
        Dict[str, float]: Dictionary containing all evaluation metrics.
    """
    # Put model into eval mode
    model.eval()

    # Prepare storage for embeddings
    all_image_embeds = []
    all_text_embeds = []

    # IDs for retrieval
    ids_img = []
    ids_txt = []

    current_index = 0

    # No gradient needed during evaluation
    with torch.no_grad():
        for images, captions_list in tqdm.tqdm(test_loader, desc="Evaluating"):
            # Move images to device
            images = images.to(device)

            # Tokenize captions
            text_tokens = tokenizer.tokenize(captions_list)
            text_tokens = text_tokens.to(device)

            # Extract embeddings using the .module references in DataParallel
            image_embeds = model.module.encode_image(images)
            text_embeds = model.module.encode_text(text_tokens)

            # Move embeddings to CPU for later concatenation
            image_embeds = image_embeds.cpu()
            text_embeds = text_embeds.cpu()

            # Store embeddings
            all_image_embeds.append(image_embeds)
            all_text_embeds.append(text_embeds)

            # Assign unique IDs
            bs = images.size(0)
            sample_ids = list(range(current_index, current_index + bs))
            ids_img.extend(sample_ids)
            ids_txt.extend(sample_ids)
            current_index += bs

    # Concatenate all embeddings
    all_image_embeds = torch.cat(all_image_embeds, dim=0)  # Shape: [N, D]
    all_text_embeds = torch.cat(all_text_embeds, dim=0)    # Shape: [N, D]

    # Normalize embeddings for more stable retrieval and metric computations
    all_image_embeds = F.normalize(all_image_embeds, dim=-1)
    all_text_embeds = F.normalize(all_text_embeds, dim=-1)

    # Compute pairwise similarity: [N_text, N_image]
    similarity_matrix = all_text_embeds @ all_image_embeds.t()

    # Compute retrieval metrics
    log_forward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='forward')   # Text-to-Vision
    log_backward = compute_metric_ret(similarity_matrix, ids_img, ids_txt, direction='backward') # Vision-to-Text

    # Compute additional metrics
    gap = compute_gap(all_image_embeds, all_text_embeds)
    mean_ang_image = compute_mean_angular_value_of_a_modality(all_image_embeds)
    mean_ang_text = compute_mean_angular_value_of_a_modality(all_text_embeds)
    uniformity_metric = uniformity(all_image_embeds, all_text_embeds)
    mean_cos_true_pairs = mean_distance_of_true_pairs(all_image_embeds, all_text_embeds)

    # Combine all metrics into final_log
    final_log = {
        **log_forward,
        **log_backward,
        'gap': round(gap, 4),
        'mean_angular_value_image': round(mean_ang_image, 4), # round to 4 decimal places
        'mean_angular_value_text': round(mean_ang_text, 4),
        'uniformity': round(uniformity_metric, 4),
        'mean_cosine_similarity_true_pairs': round(mean_cos_true_pairs, 4)
    }

    print("Evaluation Results:", final_log)

    return final_log

# Example usage (assuming you have a trained model, test_loader, and device defined)
# final_metrics = evaluate_model(model, test_loader, device)


In [38]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt


def visualize_embeddings(text_embeddings, vision_embeddings, 
                         sample_size=1000, method='pca', 
                         title="Embeddings Visualization",
                         save_path=None):
    """
    Visualizes text and vision embeddings in 2D or 3D using PCA, t-SNE, or UMAP.

    Args:
        text_embeddings (torch.Tensor): 
            Shape [N, D] containing text embeddings.
        vision_embeddings (torch.Tensor):
            Shape [N, D] containing vision/image embeddings.
        sample_size (int): 
            If the embeddings contain more than 'sample_size' samples, 
            randomly pick this many for faster plotting. Set -1 to use all.
        method (str): 
            "pca", "tsne", or "umap".
        title (str): 
            Title for the plot.
        save_path (str, optional): 
            If provided, saves the plot to this path instead of showing it.
    """
    # Detach from graph and bring to CPU if the tensors require grad
    text_np = text_embeddings.detach().cpu().numpy()
    vision_np = vision_embeddings.detach().cpu().numpy()

    # Optionally downsample for quicker plotting
    if sample_size != -1:
        n_text = text_np.shape[0]
        n_vision = vision_np.shape[0]

        n_samples = min(n_text, n_vision)

        if n_samples > sample_size:
            indices = np.random.choice(n_samples, size=sample_size, replace=False)
            text_np = text_np[indices]
            vision_np = vision_np[indices]

    # Combine for joint dimensionality reduction
    all_data = np.concatenate([text_np, vision_np], axis=0)

    # Apply dimensionality reduction
    if method.lower() == "pca":
        reducer = PCA(n_components=3)
        reduced = reducer.fit_transform(all_data)
    elif method.lower() == "tsne":
        reducer = TSNE(n_components=3, perplexity=30, max_iter=250, random_state=42)
        reduced = reducer.fit_transform(all_data)
    elif method.lower() == "umap":
        reducer = umap.UMAP(n_components=3, random_state=42, n_jobs=1)
        reduced = reducer.fit_transform(all_data)
    else:
        raise NotImplementedError("Only 'pca', 'tsne', and 'umap' are implemented.")

    # Split back into text and vision
    text_reduced = reduced[: len(text_np)]
    vision_reduced = reduced[len(text_np):]

    # Plot 3D visualization
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(text_reduced[:, 0], text_reduced[:, 1], text_reduced[:, 2], 
               c='red', alpha=0.6, label='Text')
    ax.scatter(vision_reduced[:, 0], vision_reduced[:, 1], vision_reduced[:, 2], 
               c='blue', alpha=0.6, label='Vision')

    ax.set_title(title)
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.set_zlabel("Component 3")
    ax.legend()

    if save_path is not None:
        plt.savefig(save_path, dpi=300)
        wandb.log({title: wandb.Image(plt)})
        plt.close()
    else:
        plt.show()


In [39]:
def train_model(config, train_loader, test_loader, device):

    # Create model & transforms from scratch (no pretrained weights) #TODO: Use the tokenizer from the chosen model, not the default one
    model, preprocess, _ = open_clip.create_model_and_transforms(
        config["model"],
        pretrained=None,
        device=device
    )

    # Put the model into training mode
    model.train()

    # If you want to fine-tune *everything* from scratch, ensure all parameters require grad:
    for param in model.parameters():
        param.requires_grad = True

    # Set up training parameters from the config
    lr = config["learning_rate"]
    epochs = config["epochs"]
    temperature = config["temperature"]

    # Move the model to multiple GPUs
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])  # Use 4 GPUs

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    current_batch = 0

    for epoch in range(epochs):
        for images, captions_list in tqdm.tqdm(train_loader):
            
            current_batch += 1
            
            # Move data to the primary device
            images = images.to(device)
            captions = captions_list

            # Tokenize text
            text_tokens = tokenizer.tokenize(captions)
            text_tokens = text_tokens.to(device)

            # Encode image and text
            image_embeds = model.module.encode_image(images)  # Use .module for methods inside DataParallel
            text_embeds = model.module.encode_text(text_tokens)
            
            # Normalize embeddings
            image_embeds = F.normalize(image_embeds, dim=-1)
            text_embeds  = F.normalize(text_embeds, dim=-1)
            
            # Compute loss based on the experiment type
            if config["loss_type"] == "anchor":
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
            elif config["loss_type"] == "anchor+lunif":
                lunif_img = lunif_loss(image_embeds)
                lunif_txt = lunif_loss(text_embeds)
                lunif = (lunif_img + lunif_txt) / 2
                loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature) + lunif
            elif config["loss_type"] == "lunif(50batch)+frozen(text_embed)":
                if current_batch <= 50:
                    lunif_img = lunif_loss(image_embeds)
                    lunif_txt = lunif_loss(text_embeds)
                    lunif = (lunif_img + lunif_txt) / 2
                    loss = lunif
                else: # train on anchor loss with frozen text embeddings
                    text_embeds = text_embeds.detach()
                    loss = contrastive_loss(image_embeds, text_embeds, temperature=temperature)
                    
            wandb.log({"train_loss": loss.item()})

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if current_batch % config["visualize_every_n_batches"] == 0:
                 visualize_embeddings(text_embeds, 
                     image_embeds, 
                     sample_size=1000, 
                     method='umap', 
                     title="CLIP Embeddings Visualization",
                     save_path="embeddings_plot.png")
            
            if current_batch % config["evaluate_every_n_batches"] == 0:
                print(f"[Epoch {epoch+1}/{epochs}]  Batch: {current_batch}  Loss: {loss.item():.5f}")
                print("Evaluating model...")
                test_results = evaluate_model(model, test_loader, device)
                
                wandb.log(test_results)

        print(f"[Epoch {epoch+1}/{epochs}]  Loss: {loss.item():.4f}")
    
    return model

In [40]:
def dataset_loader(config):

    # Path to train images and annotations
    train_image_dir = './data/coco/images/train2017/'                          # Path to train2017 images
    train_annotation_file = './data/coco/annotations/captions_train2017.json'  # Path to train2017 captions

    # Path to test (val) images and annotations
    test_image_dir = './data/coco/images/val2017/'                          # Path to val2017 images
    test_annotation_file = './data/coco/annotations/captions_val2017.json'  # Path to val2017 captions

    # Define the transform to be applied to the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize the image to the model's required input size
        transforms.ToTensor()
    ])

    # Create the training dataset
    train_coco = dset.CocoCaptions(
        root=train_image_dir,
        annFile=train_annotation_file,
        transform=transform
    )

    # Create the test dataset
    test_coco = dset.CocoCaptions(
        root=test_image_dir,
        annFile=test_annotation_file,
        transform=transform
    )
    
    if config["num_train_samples"] != -1:
        print(f"Subsetting the training dataset to {config['num_train_samples']} samples")
        # Subset the training dataset
        num_training_samples = config["num_train_samples"]
        subset_indices = list(range(num_training_samples))
        train_coco = Subset(train_coco, subset_indices)
    
    if config["num_test_samples"] != -1:
        print(f"Subsetting the test dataset to {config['num_test_samples']} samples")
        # Subset the test dataset
        num_test_samples = config["num_test_samples"]
        subset_indices = list(range(num_test_samples))
        test_coco = Subset(test_coco, subset_indices)

    # Every image has 5 captions at max, we need to sample one of them
    # Create collate function to sample one caption per image
    def collate_fn(batch):
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        sel_captions = []
        for list_captions in captions:
            caption = random.choice(list_captions)
            sel_captions.append(caption)
        return images, sel_captions

    # Create DataLoader
    batch_size = config["batch_size"]
    train_loader = DataLoader(train_coco, batch_size=batch_size, shuffle=True , drop_last=True, collate_fn=collate_fn, num_workers=12)
    test_loader  = DataLoader(test_coco , batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn, num_workers=12)
    
    return train_loader, test_loader

In [41]:
def set_seed(seed: int):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU random numbers
    torch.cuda.manual_seed(seed)  # PyTorch GPU random numbers for a single GPU
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random numbers for all GPUs
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior for cuDNN
    torch.backends.cudnn.benchmark = False  # Disable benchmark for deterministic behavior

In [42]:
def main(config):
    # Set the seed for reproducibility
    set_seed(config["seed"])
    
    # Finish any existing W&B runs before starting a new one
    wandb.finish()

    # Initialize your W&B run
    wandb.init(project="sparsify-clip", config=config, name=config["run_name"])
    
    # Print the config
    print("Config:", config)
    
    # Set the device
    device_id = config["device_id"]
    device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")
    
    # Load the dataset
    print("\nLoading the dataset...")
    train_loader, test_loader = dataset_loader(config)
    print("Dataset loaded.\n")
    
    # Train the model
    print("Training the model...")
    model = train_model(config, train_loader, test_loader, device)
    print("Training complete.\n")
    
    # Final evaluation of the model
    print("Final evaluation of the model...")
    final_log = evaluate_model(model, test_loader, device)
    print("Evaluation complete.\n")
    
    # Save the model and upload it to W&B
    #torch.save(model.state_dict(), config["run_name"] + ".pt")
    #wandb.save(config["run_name"] + ".pt")
    
    wandb.finish()

In [43]:
# %%prun

if __name__ == "__main__":
    
    # Minimum configuration (overwrites defaults)
    config["num_train_samples"] = -1
    config["num_test_samples"] = -1
    config["evaluate_every_n_batches"] = 30
    config["visualize_every_n_batches"] = 30
    config["epochs"] = 1
    
    
    # Baseline
    config["loss_type"] = "anchor"
    print("\nTraining Baseline model")
    main(config)
    
    
    # Anchor + Lunif (HAVE TO FINISH TESTING)
    config["loss_type"] = "anchor+lunif"
    print("\nTraining Anchor + Lunif model")
    main(config)
    
    # Lunif(50itr)+frozen(text_embed)
    # config["loss_type"] = "lunif(50batch)+frozen(text_embed)"
    # print("\nTraining Lunif(50itr)+frozen(text_embed) model")
    # main(config)


Training Baseline model


Config: {'run_name': 'CLIP-2025-01-02-21-19-36', 'device_id': 1, 'seed': 42, 'learning_rate': 0.0001, 'batch_size': 128, 'epochs': 1, 'model': 'RN50', 'temperature': 0.07, 'loss_type': 'anchor', 'num_train_samples': -1, 'num_test_samples': -1, 'evaluate_every_n_batches': 30, 'visualize_every_n_batches': 30}

Loading the dataset...
loading annotations into memory...
Done (t=0.68s)
creating index...
index created!
loading annotations into memory...
Done (t=0.50s)
creating index...
index created!
Dataset loaded.

Training the model...


  3%|▎         | 29/924 [00:16<07:50,  1.90it/s]

[Epoch 1/1]  Batch: 30  Loss: 4.66858
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
  3%|▎         | 30/924 [00:32<1:16:31,  5.14s/it]

Evaluation Results: {'forward_r1': 0.0, 'forward_r5': 0.2, 'forward_r10': 0.4, 'forward_ravg': 0.2, 'backward_r1': 0.0, 'backward_r5': 0.3, 'backward_r10': 0.4, 'backward_recall': '0.0/0.3/0.4', 'backward_ravg': 0.2, 'gap': 1.307, 'mean_angular_value_image': 0.969, 'mean_angular_value_text': 0.9454, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1086}


  6%|▋         | 59/924 [00:47<07:22,  1.95it/s]  

[Epoch 1/1]  Batch: 60  Loss: 4.67551
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.16it/s]
  6%|▋         | 60/924 [01:02<1:12:06,  5.01s/it]

Evaluation Results: {'forward_r1': 0.0, 'forward_r5': 0.2, 'forward_r10': 0.6, 'forward_ravg': 0.3, 'backward_r1': 0.1, 'backward_r5': 0.3, 'backward_r10': 0.7, 'backward_recall': '0.1/0.3/0.7', 'backward_ravg': 0.4, 'gap': 1.2163, 'mean_angular_value_image': 0.9119, 'mean_angular_value_text': 0.9344, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2074}


 10%|▉         | 89/924 [01:17<07:13,  1.93it/s]  

[Epoch 1/1]  Batch: 90  Loss: 4.52909
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 10%|▉         | 90/924 [01:32<1:09:37,  5.01s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.3, 'forward_r10': 0.8, 'forward_ravg': 0.4, 'backward_r1': 0.1, 'backward_r5': 0.5, 'backward_r10': 1.0, 'backward_recall': '0.1/0.5/1.0', 'backward_ravg': 0.5, 'gap': 1.1108, 'mean_angular_value_image': 0.8419, 'mean_angular_value_text': 0.8537, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2839}


 13%|█▎        | 119/924 [01:47<06:55,  1.94it/s] 

[Epoch 1/1]  Batch: 120  Loss: 4.49053
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 13%|█▎        | 120/924 [02:03<1:07:53,  5.07s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.5, 'forward_r10': 0.8, 'forward_ravg': 0.5, 'backward_r1': 0.1, 'backward_r5': 0.8, 'backward_r10': 1.3, 'backward_recall': '0.1/0.8/1.3', 'backward_ravg': 0.7, 'gap': 1.0501, 'mean_angular_value_image': 0.8281, 'mean_angular_value_text': 0.865, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.3561}


 16%|█▌        | 149/924 [02:18<06:44,  1.92it/s]  

[Epoch 1/1]  Batch: 150  Loss: 4.47749
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.09it/s]
 16%|█▌        | 150/924 [02:34<1:06:49,  5.18s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.6, 'forward_r10': 1.1, 'forward_ravg': 0.6, 'backward_r1': 0.1, 'backward_r5': 0.8, 'backward_r10': 1.6, 'backward_recall': '0.1/0.8/1.6', 'backward_ravg': 0.8, 'gap': 1.0092, 'mean_angular_value_image': 0.7581, 'mean_angular_value_text': 0.8325, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.3682}


 19%|█▉        | 179/924 [02:49<06:31,  1.90it/s]  

[Epoch 1/1]  Batch: 180  Loss: 4.38186
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.11it/s]
 19%|█▉        | 180/924 [03:04<1:02:56,  5.08s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.5, 'forward_r10': 1.1, 'forward_ravg': 0.6, 'backward_r1': 0.3, 'backward_r5': 0.8, 'backward_r10': 1.2, 'backward_recall': '0.3/0.8/1.2', 'backward_ravg': 0.8, 'gap': 0.9981, 'mean_angular_value_image': 0.7532, 'mean_angular_value_text': 0.8416, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.38}


 23%|██▎       | 209/924 [03:19<06:11,  1.93it/s]  

[Epoch 1/1]  Batch: 210  Loss: 4.31480
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.13it/s]
 23%|██▎       | 210/924 [03:35<1:00:14,  5.06s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.7, 'forward_r10': 1.6, 'forward_ravg': 0.8, 'backward_r1': 0.2, 'backward_r5': 0.7, 'backward_r10': 1.3, 'backward_recall': '0.2/0.7/1.3', 'backward_ravg': 0.8, 'gap': 0.9769, 'mean_angular_value_image': 0.7245, 'mean_angular_value_text': 0.803, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.3875}


 26%|██▌       | 239/924 [03:49<05:56,  1.92it/s]  

[Epoch 1/1]  Batch: 240  Loss: 4.09162
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
 26%|██▌       | 240/924 [04:05<57:55,  5.08s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.7, 'forward_r10': 1.7, 'forward_ravg': 0.8, 'backward_r1': 0.2, 'backward_r5': 0.6, 'backward_r10': 1.4, 'backward_recall': '0.2/0.6/1.4', 'backward_ravg': 0.7, 'gap': 0.9029, 'mean_angular_value_image': 0.6878, 'mean_angular_value_text': 0.7842, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4423}


 29%|██▉       | 269/924 [04:20<05:38,  1.93it/s]

[Epoch 1/1]  Batch: 270  Loss: 4.12768
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 29%|██▉       | 270/924 [04:36<55:18,  5.07s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 1.0, 'forward_r10': 1.9, 'forward_ravg': 1.0, 'backward_r1': 0.3, 'backward_r5': 1.0, 'backward_r10': 1.7, 'backward_recall': '0.3/1.0/1.7', 'backward_ravg': 1.0, 'gap': 0.9052, 'mean_angular_value_image': 0.6554, 'mean_angular_value_text': 0.788, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4235}


 32%|███▏      | 299/924 [04:50<05:25,  1.92it/s]

[Epoch 1/1]  Batch: 300  Loss: 3.98940
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.11it/s]
 32%|███▏      | 300/924 [05:06<52:34,  5.05s/it]

Evaluation Results: {'forward_r1': 0.2, 'forward_r5': 1.1, 'forward_r10': 2.3, 'forward_ravg': 1.2, 'backward_r1': 0.2, 'backward_r5': 1.2, 'backward_r10': 2.6, 'backward_recall': '0.2/1.2/2.6', 'backward_ravg': 1.3, 'gap': 0.9431, 'mean_angular_value_image': 0.6749, 'mean_angular_value_text': 0.7609, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.3931}


 36%|███▌      | 329/924 [05:21<05:07,  1.94it/s]

[Epoch 1/1]  Batch: 330  Loss: 4.12866
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 36%|███▌      | 330/924 [05:36<49:45,  5.03s/it]

Evaluation Results: {'forward_r1': 0.4, 'forward_r5': 1.2, 'forward_r10': 2.2, 'forward_ravg': 1.2, 'backward_r1': 0.4, 'backward_r5': 1.3, 'backward_r10': 2.3, 'backward_recall': '0.4/1.3/2.3', 'backward_ravg': 1.3, 'gap': 0.8817, 'mean_angular_value_image': 0.6113, 'mean_angular_value_text': 0.742, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4212}


 39%|███▉      | 359/924 [05:51<04:54,  1.92it/s]

[Epoch 1/1]  Batch: 360  Loss: 3.91214
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.07it/s]
 39%|███▉      | 360/924 [06:07<48:13,  5.13s/it]

Evaluation Results: {'forward_r1': 0.2, 'forward_r5': 1.0, 'forward_r10': 2.0, 'forward_ravg': 1.1, 'backward_r1': 0.4, 'backward_r5': 1.5, 'backward_r10': 2.6, 'backward_recall': '0.4/1.5/2.6', 'backward_ravg': 1.5, 'gap': 0.8237, 'mean_angular_value_image': 0.5679, 'mean_angular_value_text': 0.6577, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4348}


 42%|████▏     | 389/924 [06:22<04:36,  1.93it/s]

[Epoch 1/1]  Batch: 390  Loss: 4.16565
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.13it/s]
 42%|████▏     | 390/924 [06:37<44:51,  5.04s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_r5': 1.5, 'forward_r10': 2.8, 'forward_ravg': 1.6, 'backward_r1': 0.4, 'backward_r5': 1.8, 'backward_r10': 3.1, 'backward_recall': '0.4/1.8/3.1', 'backward_ravg': 1.8, 'gap': 0.8121, 'mean_angular_value_image': 0.4887, 'mean_angular_value_text': 0.6612, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4115}


 45%|████▌     | 419/924 [06:52<04:21,  1.93it/s]

[Epoch 1/1]  Batch: 420  Loss: 3.69120
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 45%|████▌     | 420/924 [07:08<42:15,  5.03s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_r5': 1.9, 'forward_r10': 3.3, 'forward_ravg': 1.8, 'backward_r1': 0.3, 'backward_r5': 1.3, 'backward_r10': 2.8, 'backward_recall': '0.3/1.3/2.8', 'backward_ravg': 1.5, 'gap': 0.7847, 'mean_angular_value_image': 0.5307, 'mean_angular_value_text': 0.678, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4594}


 49%|████▊     | 449/924 [07:22<04:06,  1.93it/s]

[Epoch 1/1]  Batch: 450  Loss: 3.80907
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 49%|████▊     | 450/924 [07:38<40:03,  5.07s/it]

Evaluation Results: {'forward_r1': 0.4, 'forward_r5': 1.8, 'forward_r10': 3.3, 'forward_ravg': 1.8, 'backward_r1': 0.5, 'backward_r5': 1.9, 'backward_r10': 3.3, 'backward_recall': '0.5/1.9/3.3', 'backward_ravg': 1.9, 'gap': 0.7714, 'mean_angular_value_image': 0.4515, 'mean_angular_value_text': 0.6655, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4399}


 52%|█████▏    | 479/924 [07:53<03:49,  1.94it/s]

[Epoch 1/1]  Batch: 480  Loss: 3.85243
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 52%|█████▏    | 480/924 [08:09<37:48,  5.11s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.1, 'forward_r10': 3.6, 'forward_ravg': 2.1, 'backward_r1': 0.5, 'backward_r5': 1.9, 'backward_r10': 3.5, 'backward_recall': '0.5/1.9/3.5', 'backward_ravg': 2.0, 'gap': 0.7914, 'mean_angular_value_image': 0.5074, 'mean_angular_value_text': 0.6824, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4443}


 55%|█████▌    | 509/924 [08:23<03:34,  1.93it/s]

[Epoch 1/1]  Batch: 510  Loss: 3.65467
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 55%|█████▌    | 510/924 [08:39<34:59,  5.07s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.4, 'forward_r10': 4.2, 'forward_ravg': 2.4, 'backward_r1': 0.7, 'backward_r5': 2.3, 'backward_r10': 4.1, 'backward_recall': '0.7/2.3/4.1', 'backward_ravg': 2.3, 'gap': 0.7504, 'mean_angular_value_image': 0.4906, 'mean_angular_value_text': 0.6248, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4585}


 58%|█████▊    | 539/924 [08:54<03:19,  1.93it/s]

[Epoch 1/1]  Batch: 540  Loss: 3.87560
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.17it/s]
 58%|█████▊    | 540/924 [09:09<32:04,  5.01s/it]

Evaluation Results: {'forward_r1': 0.4, 'forward_r5': 2.0, 'forward_r10': 3.9, 'forward_ravg': 2.1, 'backward_r1': 0.5, 'backward_r5': 2.2, 'backward_r10': 3.7, 'backward_recall': '0.5/2.2/3.7', 'backward_ravg': 2.1, 'gap': 0.7944, 'mean_angular_value_image': 0.4879, 'mean_angular_value_text': 0.6537, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4253}


 62%|██████▏   | 569/924 [09:24<03:04,  1.93it/s]

[Epoch 1/1]  Batch: 570  Loss: 3.64057
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.11it/s]
 62%|██████▏   | 570/924 [09:40<29:57,  5.08s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.4, 'forward_r10': 4.3, 'forward_ravg': 2.4, 'backward_r1': 0.5, 'backward_r5': 2.2, 'backward_r10': 4.2, 'backward_recall': '0.5/2.2/4.2', 'backward_ravg': 2.3, 'gap': 0.7345, 'mean_angular_value_image': 0.4724, 'mean_angular_value_text': 0.6745, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4727}


 65%|██████▍   | 599/924 [09:54<02:48,  1.93it/s]

[Epoch 1/1]  Batch: 600  Loss: 3.61471
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
 65%|██████▍   | 600/924 [10:10<27:35,  5.11s/it]

Evaluation Results: {'forward_r1': 0.5, 'forward_r5': 2.3, 'forward_r10': 4.4, 'forward_ravg': 2.4, 'backward_r1': 0.7, 'backward_r5': 2.7, 'backward_r10': 4.8, 'backward_recall': '0.7/2.7/4.8', 'backward_ravg': 2.7, 'gap': 0.6977, 'mean_angular_value_image': 0.3785, 'mean_angular_value_text': 0.6096, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4545}


 68%|██████▊   | 629/924 [10:25<02:32,  1.94it/s]

[Epoch 1/1]  Batch: 630  Loss: 3.70060
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.10it/s]
 68%|██████▊   | 630/924 [10:41<24:39,  5.03s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.7, 'forward_r10': 4.6, 'forward_ravg': 2.7, 'backward_r1': 0.6, 'backward_r5': 2.5, 'backward_r10': 4.7, 'backward_recall': '0.6/2.5/4.7', 'backward_ravg': 2.6, 'gap': 0.6921, 'mean_angular_value_image': 0.4382, 'mean_angular_value_text': 0.5744, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4717}


 71%|███████▏  | 659/924 [10:55<02:17,  1.93it/s]

[Epoch 1/1]  Batch: 660  Loss: 3.57458
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 71%|███████▏  | 660/924 [11:11<22:16,  5.06s/it]

Evaluation Results: {'forward_r1': 0.4, 'forward_r5': 2.4, 'forward_r10': 4.6, 'forward_ravg': 2.5, 'backward_r1': 0.6, 'backward_r5': 2.9, 'backward_r10': 5.0, 'backward_recall': '0.6/2.9/5.0', 'backward_ravg': 2.8, 'gap': 0.6786, 'mean_angular_value_image': 0.3991, 'mean_angular_value_text': 0.5743, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4657}


 75%|███████▍  | 689/924 [11:26<02:01,  1.93it/s]

[Epoch 1/1]  Batch: 690  Loss: 3.45504
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 75%|███████▍  | 690/924 [11:41<19:34,  5.02s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 3.1, 'forward_r10': 5.4, 'forward_ravg': 3.0, 'backward_r1': 0.5, 'backward_r5': 2.5, 'backward_r10': 4.6, 'backward_recall': '0.5/2.5/4.6', 'backward_ravg': 2.6, 'gap': 0.6667, 'mean_angular_value_image': 0.3623, 'mean_angular_value_text': 0.5204, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4475}


 78%|███████▊  | 719/924 [11:56<01:46,  1.93it/s]

[Epoch 1/1]  Batch: 720  Loss: 3.41312
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 78%|███████▊  | 720/924 [12:11<17:02,  5.01s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.9, 'forward_r10': 5.4, 'forward_ravg': 3.0, 'backward_r1': 0.8, 'backward_r5': 2.7, 'backward_r10': 5.0, 'backward_recall': '0.8/2.7/5.0', 'backward_ravg': 2.8, 'gap': 0.6403, 'mean_angular_value_image': 0.3483, 'mean_angular_value_text': 0.5953, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4844}


 81%|████████  | 749/924 [12:26<01:33,  1.86it/s]

[Epoch 1/1]  Batch: 750  Loss: 3.40776
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  3.91it/s]
 81%|████████  | 750/924 [12:42<15:02,  5.19s/it]

Evaluation Results: {'forward_r1': 0.8, 'forward_r5': 3.2, 'forward_r10': 5.6, 'forward_ravg': 3.2, 'backward_r1': 0.8, 'backward_r5': 3.1, 'backward_r10': 5.4, 'backward_recall': '0.8/3.1/5.4', 'backward_ravg': 3.1, 'gap': 0.5914, 'mean_angular_value_image': 0.2981, 'mean_angular_value_text': 0.5173, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.476}


 84%|████████▍ | 779/924 [12:57<01:15,  1.93it/s]

[Epoch 1/1]  Batch: 780  Loss: 3.40646
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 84%|████████▍ | 780/924 [13:12<12:02,  5.02s/it]

Evaluation Results: {'forward_r1': 0.7, 'forward_r5': 3.3, 'forward_r10': 5.7, 'forward_ravg': 3.2, 'backward_r1': 0.8, 'backward_r5': 3.3, 'backward_r10': 6.0, 'backward_recall': '0.8/3.3/6.0', 'backward_ravg': 3.4, 'gap': 0.6253, 'mean_angular_value_image': 0.3413, 'mean_angular_value_text': 0.5337, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4753}


 88%|████████▊ | 809/924 [13:27<00:59,  1.93it/s]

[Epoch 1/1]  Batch: 810  Loss: 3.51110
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 88%|████████▊ | 810/924 [13:43<09:30,  5.00s/it]

Evaluation Results: {'forward_r1': 0.9, 'forward_r5': 3.6, 'forward_r10': 6.0, 'forward_ravg': 3.5, 'backward_r1': 0.8, 'backward_r5': 3.4, 'backward_r10': 6.5, 'backward_recall': '0.8/3.4/6.5', 'backward_ravg': 3.6, 'gap': 0.6217, 'mean_angular_value_image': 0.3144, 'mean_angular_value_text': 0.5376, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4721}


 91%|█████████ | 839/924 [13:57<00:44,  1.93it/s]

[Epoch 1/1]  Batch: 840  Loss: 3.54169
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 91%|█████████ | 840/924 [14:13<06:59,  4.99s/it]

Evaluation Results: {'forward_r1': 0.8, 'forward_r5': 3.3, 'forward_r10': 6.2, 'forward_ravg': 3.5, 'backward_r1': 0.7, 'backward_r5': 3.0, 'backward_r10': 5.6, 'backward_recall': '0.7/3.0/5.6', 'backward_ravg': 3.1, 'gap': 0.64, 'mean_angular_value_image': 0.3852, 'mean_angular_value_text': 0.5128, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4719}


 94%|█████████▍| 869/924 [14:28<00:28,  1.93it/s]

[Epoch 1/1]  Batch: 870  Loss: 3.29147
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.16it/s]
 94%|█████████▍| 870/924 [14:43<04:30,  5.01s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 3.4, 'forward_r10': 5.9, 'forward_ravg': 3.4, 'backward_r1': 0.9, 'backward_r5': 3.7, 'backward_r10': 6.1, 'backward_recall': '0.9/3.7/6.1', 'backward_ravg': 3.6, 'gap': 0.5709, 'mean_angular_value_image': 0.3094, 'mean_angular_value_text': 0.4981, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4891}


 97%|█████████▋| 899/924 [14:58<00:12,  1.93it/s]

[Epoch 1/1]  Batch: 900  Loss: 3.45712
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 97%|█████████▋| 900/924 [15:13<01:59,  4.97s/it]

Evaluation Results: {'forward_r1': 1.1, 'forward_r5': 4.0, 'forward_r10': 6.9, 'forward_ravg': 4.0, 'backward_r1': 1.1, 'backward_r5': 3.9, 'backward_r10': 7.1, 'backward_recall': '1.1/3.9/7.1', 'backward_ravg': 4.1, 'gap': 0.5993, 'mean_angular_value_image': 0.3355, 'mean_angular_value_text': 0.5235, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4946}


100%|██████████| 924/924 [15:25<00:00,  1.00s/it]


[Epoch 1/1]  Loss: 3.5070
Training complete.

Final evaluation of the model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.10it/s]


Evaluation Results: {'forward_r1': 0.9, 'forward_r5': 3.4, 'forward_r10': 5.9, 'forward_ravg': 3.4, 'backward_r1': 0.8, 'backward_r5': 3.7, 'backward_r10': 6.1, 'backward_recall': '0.8/3.7/6.1', 'backward_ravg': 3.5, 'gap': 0.5453, 'mean_angular_value_image': 0.2943, 'mean_angular_value_text': 0.4847, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.4946}
Evaluation complete.



0,1
backward_r1,▁▂▂▂▂▃▂▂▃▂▄▄▄▃▄▄▅▄▄▅▅▅▄▆▆▆▆▅▇█
backward_r10,▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▄▅▆▅▆▅▆▆▇▇▆▇█
backward_r5,▁▁▁▂▂▂▂▂▂▃▃▃▄▃▄▄▅▅▅▆▅▆▅▆▆▇▇▆██
backward_ravg,▁▁▂▂▂▂▂▂▂▃▃▃▄▃▄▄▅▄▅▅▅▆▅▆▆▇▇▆▇█
forward_r1,▁▁▂▂▂▂▂▂▂▂▄▂▃▃▄▅▅▄▅▄▅▄▅▅▆▅▇▆▇█
forward_r10,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇█
forward_r5,▁▁▁▂▂▂▂▂▂▃▃▂▃▄▄▄▅▄▅▅▆▅▆▆▇▇▇▇▇█
forward_ravg,▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▅▄▅▅▆▅▆▆▇▇▇▇▇█
gap,█▇▆▆▅▅▅▄▄▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▁▂▁▂▁▁
mean_angular_value_image,█▇▇▇▆▆▅▅▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▂▁▁

0,1
backward_r1,1.1
backward_r10,7.1
backward_r5,3.9
backward_ravg,4.1
backward_recall,1.1/3.9/7.1
forward_r1,1.1
forward_r10,6.9
forward_r5,4
forward_ravg,4
gap,0.5993



Training Anchor + Lunif model


Config: {'run_name': 'CLIP-2025-01-02-21-19-36', 'device_id': 1, 'seed': 42, 'learning_rate': 0.0001, 'batch_size': 128, 'epochs': 1, 'model': 'RN50', 'temperature': 0.07, 'loss_type': 'anchor+lunif', 'num_train_samples': -1, 'num_test_samples': -1, 'evaluate_every_n_batches': 30, 'visualize_every_n_batches': 30}

Loading the dataset...
loading annotations into memory...
Done (t=0.71s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Dataset loaded.

Training the model...


  3%|▎         | 29/924 [00:16<08:00,  1.86it/s]

[Epoch 1/1]  Batch: 30  Loss: 1.26669
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.06it/s]
  3%|▎         | 30/924 [00:33<1:17:31,  5.20s/it]

Evaluation Results: {'forward_r1': 0.0, 'forward_r5': 0.2, 'forward_r10': 0.4, 'forward_ravg': 0.2, 'backward_r1': 0.1, 'backward_r5': 0.2, 'backward_r10': 0.3, 'backward_recall': '0.1/0.2/0.3', 'backward_ravg': 0.2, 'gap': 0.8923, 'mean_angular_value_image': 0.7963, 'mean_angular_value_text': 0.037, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.0257}


  6%|▋         | 59/924 [00:47<07:30,  1.92it/s]  

[Epoch 1/1]  Batch: 60  Loss: 1.44125
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
  6%|▋         | 60/924 [01:03<1:13:02,  5.07s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.6, 'forward_r10': 1.1, 'forward_ravg': 0.6, 'backward_r1': 0.1, 'backward_r5': 0.5, 'backward_r10': 0.9, 'backward_recall': '0.1/0.5/0.9', 'backward_ravg': 0.5, 'gap': 0.3031, 'mean_angular_value_image': 0.0742, 'mean_angular_value_text': 0.0262, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.0496}


 10%|▉         | 89/924 [01:18<07:13,  1.93it/s]  

[Epoch 1/1]  Batch: 90  Loss: 1.07148
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.07it/s]
 10%|▉         | 90/924 [01:33<1:10:21,  5.06s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 0.7, 'forward_r10': 1.4, 'forward_ravg': 0.7, 'backward_r1': 0.1, 'backward_r5': 0.6, 'backward_r10': 1.6, 'backward_recall': '0.1/0.6/1.6', 'backward_ravg': 0.8, 'gap': 0.2509, 'mean_angular_value_image': 0.0394, 'mean_angular_value_text': 0.025, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.0651}


 13%|█▎        | 119/924 [01:48<06:56,  1.93it/s] 

[Epoch 1/1]  Batch: 120  Loss: 0.88426
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 13%|█▎        | 120/924 [02:04<1:07:43,  5.05s/it]

Evaluation Results: {'forward_r1': 0.2, 'forward_r5': 1.0, 'forward_r10': 1.6, 'forward_ravg': 1.0, 'backward_r1': 0.2, 'backward_r5': 0.8, 'backward_r10': 1.4, 'backward_recall': '0.2/0.8/1.4', 'backward_ravg': 0.8, 'gap': 0.231, 'mean_angular_value_image': 0.033, 'mean_angular_value_text': 0.0234, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.082}


 16%|█▌        | 149/924 [02:19<06:44,  1.92it/s]  

[Epoch 1/1]  Batch: 150  Loss: 0.71253
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.03it/s]
 16%|█▌        | 150/924 [02:35<1:07:28,  5.23s/it]

Evaluation Results: {'forward_r1': 0.2, 'forward_r5': 1.0, 'forward_r10': 1.7, 'forward_ravg': 1.0, 'backward_r1': 0.2, 'backward_r5': 0.8, 'backward_r10': 1.7, 'backward_recall': '0.2/0.8/1.7', 'backward_ravg': 0.9, 'gap': 0.2299, 'mean_angular_value_image': 0.0279, 'mean_angular_value_text': 0.0214, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.0856}


 19%|█▉        | 179/924 [02:50<06:27,  1.92it/s]  

[Epoch 1/1]  Batch: 180  Loss: 0.91073
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 19%|█▉        | 180/924 [03:06<1:04:32,  5.21s/it]

Evaluation Results: {'forward_r1': 0.1, 'forward_r5': 1.1, 'forward_r10': 2.5, 'forward_ravg': 1.3, 'backward_r1': 0.4, 'backward_r5': 1.0, 'backward_r10': 1.9, 'backward_recall': '0.4/1.0/1.9', 'backward_ravg': 1.1, 'gap': 0.2135, 'mean_angular_value_image': 0.0334, 'mean_angular_value_text': 0.0145, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.0953}


 23%|██▎       | 209/924 [03:21<06:11,  1.92it/s]  

[Epoch 1/1]  Batch: 210  Loss: 0.58839
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.10it/s]
 23%|██▎       | 210/924 [03:37<1:01:05,  5.13s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_r5': 1.5, 'forward_r10': 2.8, 'forward_ravg': 1.5, 'backward_r1': 0.3, 'backward_r5': 1.1, 'backward_r10': 2.1, 'backward_recall': '0.3/1.1/2.1', 'backward_ravg': 1.2, 'gap': 0.2063, 'mean_angular_value_image': 0.0251, 'mean_angular_value_text': 0.0158, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1067}


 26%|██▌       | 239/924 [03:51<05:56,  1.92it/s]  

[Epoch 1/1]  Batch: 240  Loss: 0.41345
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.11it/s]
 26%|██▌       | 240/924 [04:07<58:47,  5.16s/it]

Evaluation Results: {'forward_r1': 0.5, 'forward_r5': 1.7, 'forward_r10': 2.9, 'forward_ravg': 1.7, 'backward_r1': 0.3, 'backward_r5': 1.5, 'backward_r10': 2.9, 'backward_recall': '0.3/1.5/2.9', 'backward_ravg': 1.6, 'gap': 0.1978, 'mean_angular_value_image': 0.026, 'mean_angular_value_text': 0.0148, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1207}


 29%|██▉       | 269/924 [04:22<05:38,  1.93it/s]

[Epoch 1/1]  Batch: 270  Loss: 0.24120
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 29%|██▉       | 270/924 [04:38<54:39,  5.01s/it]

Evaluation Results: {'forward_r1': 0.3, 'forward_r5': 1.6, 'forward_r10': 3.4, 'forward_ravg': 1.7, 'backward_r1': 0.4, 'backward_r5': 1.6, 'backward_r10': 3.6, 'backward_recall': '0.4/1.6/3.6', 'backward_ravg': 1.9, 'gap': 0.1889, 'mean_angular_value_image': 0.0204, 'mean_angular_value_text': 0.0134, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1217}


 32%|███▏      | 299/924 [04:52<05:23,  1.93it/s]

[Epoch 1/1]  Batch: 300  Loss: 0.20345
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.13it/s]
 32%|███▏      | 300/924 [05:09<54:24,  5.23s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.3, 'forward_r10': 3.8, 'forward_ravg': 2.2, 'backward_r1': 0.5, 'backward_r5': 2.4, 'backward_r10': 4.3, 'backward_recall': '0.5/2.4/4.3', 'backward_ravg': 2.4, 'gap': 0.184, 'mean_angular_value_image': 0.0229, 'mean_angular_value_text': 0.01, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1298}


 36%|███▌      | 329/924 [05:23<05:09,  1.92it/s]

[Epoch 1/1]  Batch: 330  Loss: 0.27945
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 36%|███▌      | 330/924 [05:39<49:56,  5.04s/it]

Evaluation Results: {'forward_r1': 0.4, 'forward_r5': 2.2, 'forward_r10': 4.1, 'forward_ravg': 2.3, 'backward_r1': 0.4, 'backward_r5': 2.0, 'backward_r10': 3.9, 'backward_recall': '0.4/2.0/3.9', 'backward_ravg': 2.1, 'gap': 0.2069, 'mean_angular_value_image': 0.0266, 'mean_angular_value_text': 0.0149, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1381}


 39%|███▉      | 359/924 [05:54<04:53,  1.92it/s]

[Epoch 1/1]  Batch: 360  Loss: 0.00364
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.09it/s]
 39%|███▉      | 360/924 [06:09<47:45,  5.08s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.2, 'forward_r10': 4.2, 'forward_ravg': 2.3, 'backward_r1': 0.4, 'backward_r5': 2.2, 'backward_r10': 4.0, 'backward_recall': '0.4/2.2/4.0', 'backward_ravg': 2.2, 'gap': 0.193, 'mean_angular_value_image': 0.0264, 'mean_angular_value_text': 0.0093, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1478}


 42%|████▏     | 389/924 [06:24<04:37,  1.93it/s]

[Epoch 1/1]  Batch: 390  Loss: 0.45124
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
 42%|████▏     | 390/924 [06:40<45:34,  5.12s/it]

Evaluation Results: {'forward_r1': 0.7, 'forward_r5': 2.2, 'forward_r10': 4.2, 'forward_ravg': 2.4, 'backward_r1': 0.7, 'backward_r5': 2.1, 'backward_r10': 4.2, 'backward_recall': '0.7/2.1/4.2', 'backward_ravg': 2.4, 'gap': 0.1929, 'mean_angular_value_image': 0.0227, 'mean_angular_value_text': 0.0164, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1582}


 45%|████▌     | 419/924 [06:55<04:22,  1.93it/s]

[Epoch 1/1]  Batch: 420  Loss: 0.09275
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 45%|████▌     | 420/924 [07:10<42:10,  5.02s/it]

Evaluation Results: {'forward_r1': 0.6, 'forward_r5': 2.7, 'forward_r10': 4.6, 'forward_ravg': 2.6, 'backward_r1': 0.7, 'backward_r5': 2.3, 'backward_r10': 4.5, 'backward_recall': '0.7/2.3/4.5', 'backward_ravg': 2.5, 'gap': 0.1853, 'mean_angular_value_image': 0.021, 'mean_angular_value_text': 0.0125, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1599}


 49%|████▊     | 449/924 [07:25<04:07,  1.92it/s]

[Epoch 1/1]  Batch: 450  Loss: -0.19277
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 49%|████▊     | 450/924 [07:41<39:41,  5.02s/it]

Evaluation Results: {'forward_r1': 0.8, 'forward_r5': 2.8, 'forward_r10': 5.1, 'forward_ravg': 2.9, 'backward_r1': 0.8, 'backward_r5': 3.2, 'backward_r10': 5.5, 'backward_recall': '0.8/3.2/5.5', 'backward_ravg': 3.1, 'gap': 0.1726, 'mean_angular_value_image': 0.0216, 'mean_angular_value_text': 0.0087, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1599}


 52%|█████▏    | 479/924 [07:55<03:50,  1.93it/s]

[Epoch 1/1]  Batch: 480  Loss: 0.11901
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.07it/s]
 52%|█████▏    | 480/924 [08:11<37:29,  5.07s/it]

Evaluation Results: {'forward_r1': 0.9, 'forward_r5': 3.1, 'forward_r10': 5.5, 'forward_ravg': 3.2, 'backward_r1': 0.8, 'backward_r5': 2.8, 'backward_r10': 5.1, 'backward_recall': '0.8/2.8/5.1', 'backward_ravg': 2.9, 'gap': 0.1484, 'mean_angular_value_image': 0.0184, 'mean_angular_value_text': 0.0075, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1729}


 55%|█████▌    | 509/924 [08:26<03:35,  1.93it/s]

[Epoch 1/1]  Batch: 510  Loss: -0.24350
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 55%|█████▌    | 510/924 [08:41<34:40,  5.03s/it]

Evaluation Results: {'forward_r1': 0.8, 'forward_r5': 3.5, 'forward_r10': 6.1, 'forward_ravg': 3.5, 'backward_r1': 0.6, 'backward_r5': 2.9, 'backward_r10': 5.4, 'backward_recall': '0.6/2.9/5.4', 'backward_ravg': 3.0, 'gap': 0.1777, 'mean_angular_value_image': 0.0209, 'mean_angular_value_text': 0.0098, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1733}


 58%|█████▊    | 539/924 [08:56<03:18,  1.94it/s]

[Epoch 1/1]  Batch: 540  Loss: 0.10842
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.05it/s]
 58%|█████▊    | 540/924 [09:12<32:36,  5.10s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 3.3, 'forward_r10': 5.5, 'forward_ravg': 3.3, 'backward_r1': 0.7, 'backward_r5': 3.2, 'backward_r10': 5.8, 'backward_recall': '0.7/3.2/5.8', 'backward_ravg': 3.2, 'gap': 0.1813, 'mean_angular_value_image': 0.0233, 'mean_angular_value_text': 0.0108, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1765}


 62%|██████▏   | 569/924 [09:27<03:03,  1.93it/s]

[Epoch 1/1]  Batch: 570  Loss: -0.14703
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 62%|██████▏   | 570/924 [09:42<30:07,  5.10s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 3.2, 'forward_r10': 5.9, 'forward_ravg': 3.4, 'backward_r1': 0.7, 'backward_r5': 3.2, 'backward_r10': 5.8, 'backward_recall': '0.7/3.2/5.8', 'backward_ravg': 3.3, 'gap': 0.1806, 'mean_angular_value_image': 0.0269, 'mean_angular_value_text': 0.0092, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1782}


 65%|██████▍   | 599/924 [09:57<02:48,  1.92it/s]

[Epoch 1/1]  Batch: 600  Loss: -0.09571
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.14it/s]
 65%|██████▍   | 600/924 [10:13<27:03,  5.01s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 3.6, 'forward_r10': 6.4, 'forward_ravg': 3.7, 'backward_r1': 0.7, 'backward_r5': 3.2, 'backward_r10': 6.1, 'backward_recall': '0.7/3.2/6.1', 'backward_ravg': 3.3, 'gap': 0.1682, 'mean_angular_value_image': 0.0214, 'mean_angular_value_text': 0.01, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1842}


 68%|██████▊   | 629/924 [10:27<02:33,  1.92it/s]

[Epoch 1/1]  Batch: 630  Loss: -0.19540
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.07it/s]
 68%|██████▊   | 630/924 [10:43<25:06,  5.13s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 3.7, 'forward_r10': 6.5, 'forward_ravg': 3.7, 'backward_r1': 0.8, 'backward_r5': 3.4, 'backward_r10': 6.4, 'backward_recall': '0.8/3.4/6.4', 'backward_ravg': 3.5, 'gap': 0.1772, 'mean_angular_value_image': 0.0248, 'mean_angular_value_text': 0.0114, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1893}


 71%|███████▏  | 659/924 [10:58<02:17,  1.93it/s]

[Epoch 1/1]  Batch: 660  Loss: -0.09929
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.06it/s]
 71%|███████▏  | 660/924 [11:14<22:44,  5.17s/it]

Evaluation Results: {'forward_r1': 1.0, 'forward_r5': 4.0, 'forward_r10': 6.7, 'forward_ravg': 3.9, 'backward_r1': 0.8, 'backward_r5': 3.3, 'backward_r10': 6.4, 'backward_recall': '0.8/3.3/6.4', 'backward_ravg': 3.5, 'gap': 0.1561, 'mean_angular_value_image': 0.0222, 'mean_angular_value_text': 0.0092, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.199}


 75%|███████▍  | 689/924 [11:29<02:02,  1.92it/s]

[Epoch 1/1]  Batch: 690  Loss: -0.30728
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.04it/s]
 75%|███████▍  | 690/924 [11:45<20:09,  5.17s/it]

Evaluation Results: {'forward_r1': 1.1, 'forward_r5': 4.2, 'forward_r10': 7.5, 'forward_ravg': 4.3, 'backward_r1': 1.0, 'backward_r5': 4.0, 'backward_r10': 7.1, 'backward_recall': '1.0/4.0/7.1', 'backward_ravg': 4.0, 'gap': 0.1635, 'mean_angular_value_image': 0.0209, 'mean_angular_value_text': 0.0058, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2029}


 78%|███████▊  | 719/924 [12:00<01:46,  1.93it/s]

[Epoch 1/1]  Batch: 720  Loss: -0.28695
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 78%|███████▊  | 720/924 [12:15<17:17,  5.09s/it]

Evaluation Results: {'forward_r1': 1.1, 'forward_r5': 4.2, 'forward_r10': 7.0, 'forward_ravg': 4.1, 'backward_r1': 1.1, 'backward_r5': 3.7, 'backward_r10': 6.8, 'backward_recall': '1.1/3.7/6.8', 'backward_ravg': 3.9, 'gap': 0.1588, 'mean_angular_value_image': 0.0169, 'mean_angular_value_text': 0.0099, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.1984}


 81%|████████  | 749/924 [12:30<01:30,  1.93it/s]

[Epoch 1/1]  Batch: 750  Loss: -0.31633
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.08it/s]
 81%|████████  | 750/924 [12:46<14:51,  5.13s/it]

Evaluation Results: {'forward_r1': 1.2, 'forward_r5': 4.5, 'forward_r10': 7.7, 'forward_ravg': 4.5, 'backward_r1': 1.0, 'backward_r5': 4.5, 'backward_r10': 7.5, 'backward_recall': '1.0/4.5/7.5', 'backward_ravg': 4.4, 'gap': 0.1804, 'mean_angular_value_image': 0.0234, 'mean_angular_value_text': 0.0083, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2033}


 84%|████████▍ | 779/924 [13:01<01:14,  1.94it/s]

[Epoch 1/1]  Batch: 780  Loss: -0.43263
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.15it/s]
 84%|████████▍ | 780/924 [13:16<12:05,  5.04s/it]

Evaluation Results: {'forward_r1': 1.3, 'forward_r5': 4.5, 'forward_r10': 7.8, 'forward_ravg': 4.5, 'backward_r1': 1.4, 'backward_r5': 4.6, 'backward_r10': 7.4, 'backward_recall': '1.4/4.6/7.4', 'backward_ravg': 4.5, 'gap': 0.1562, 'mean_angular_value_image': 0.0191, 'mean_angular_value_text': 0.0089, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.207}


 88%|████████▊ | 809/924 [13:31<00:59,  1.93it/s]

[Epoch 1/1]  Batch: 810  Loss: -0.26793
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.13it/s]
 88%|████████▊ | 810/924 [13:47<09:33,  5.03s/it]

Evaluation Results: {'forward_r1': 1.3, 'forward_r5': 4.7, 'forward_r10': 8.0, 'forward_ravg': 4.7, 'backward_r1': 1.3, 'backward_r5': 5.0, 'backward_r10': 8.2, 'backward_recall': '1.3/5.0/8.2', 'backward_ravg': 4.8, 'gap': 0.1576, 'mean_angular_value_image': 0.0176, 'mean_angular_value_text': 0.0079, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2082}


 91%|█████████ | 839/924 [14:01<00:43,  1.94it/s]

[Epoch 1/1]  Batch: 840  Loss: -0.24267
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.13it/s]
 91%|█████████ | 840/924 [14:17<07:04,  5.05s/it]

Evaluation Results: {'forward_r1': 1.2, 'forward_r5': 4.6, 'forward_r10': 8.3, 'forward_ravg': 4.7, 'backward_r1': 1.2, 'backward_r5': 4.7, 'backward_r10': 7.9, 'backward_recall': '1.2/4.7/7.9', 'backward_ravg': 4.6, 'gap': 0.1651, 'mean_angular_value_image': 0.0181, 'mean_angular_value_text': 0.0084, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2072}


 94%|█████████▍| 869/924 [14:32<00:28,  1.93it/s]

[Epoch 1/1]  Batch: 870  Loss: -0.46873
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 94%|█████████▍| 870/924 [14:47<04:32,  5.04s/it]

Evaluation Results: {'forward_r1': 1.2, 'forward_r5': 4.8, 'forward_r10': 8.2, 'forward_ravg': 4.7, 'backward_r1': 1.1, 'backward_r5': 4.6, 'backward_r10': 8.0, 'backward_recall': '1.1/4.6/8.0', 'backward_ravg': 4.6, 'gap': 0.1702, 'mean_angular_value_image': 0.0274, 'mean_angular_value_text': 0.0072, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2145}


 97%|█████████▋| 899/924 [15:02<00:12,  1.92it/s]

[Epoch 1/1]  Batch: 900  Loss: -0.32680
Evaluating model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.12it/s]
 97%|█████████▋| 900/924 [15:18<02:03,  5.13s/it]

Evaluation Results: {'forward_r1': 1.1, 'forward_r5': 5.1, 'forward_r10': 8.7, 'forward_ravg': 5.0, 'backward_r1': 1.3, 'backward_r5': 4.7, 'backward_r10': 8.0, 'backward_recall': '1.3/4.7/8.0', 'backward_ravg': 4.7, 'gap': 0.1992, 'mean_angular_value_image': 0.0251, 'mean_angular_value_text': 0.0113, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2085}


100%|██████████| 924/924 [15:30<00:00,  1.01s/it]


[Epoch 1/1]  Loss: -0.1194
Training complete.

Final evaluation of the model...


Evaluating: 100%|██████████| 39/39 [00:09<00:00,  4.06it/s]


Evaluation Results: {'forward_r1': 1.4, 'forward_r5': 4.8, 'forward_r10': 8.4, 'forward_ravg': 4.9, 'backward_r1': 1.1, 'backward_r5': 4.9, 'backward_r10': 8.2, 'backward_recall': '1.1/4.9/8.2', 'backward_ravg': 4.7, 'gap': 0.1592, 'mean_angular_value_image': 0.0187, 'mean_angular_value_text': 0.0075, 'uniformity': nan, 'mean_cosine_similarity_true_pairs': 0.2217}
Evaluation complete.



0,1
backward_r1,▁▁▁▂▂▃▂▂▃▃▃▃▄▄▅▅▄▄▄▄▅▅▆▆▆█▇▇▆▇
backward_r10,▁▂▂▂▂▂▃▃▄▅▄▄▄▅▆▅▆▆▆▆▆▆▇▇▇▇████
backward_r5,▁▁▂▂▂▂▂▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▇▆▇▇██▇█
backward_ravg,▁▁▂▂▂▂▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇█████
forward_r1,▁▂▂▂▂▂▃▄▃▄▃▄▅▄▅▆▅▆▆▆▆▆▇▇▇██▇▇▇
forward_r10,▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▆▅▆▆▆▆▇▇▇▇▇███
forward_r5,▁▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▅▅▆▆▆▇▇▇▇▇▇██
forward_ravg,▁▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▆▆▆▆▆▆▇▇▇▇████
gap,█▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_angular_value_image,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
backward_r1,1.3
backward_r10,8
backward_r5,4.7
backward_ravg,4.7
backward_recall,1.3/4.7/8.0
forward_r1,1.1
forward_r10,8.7
forward_r5,5.1
forward_ravg,5
gap,0.1992


In [44]:
""" %doctest_mode


def dataset_details():
    # Print dataset details
    print('Number of samples:', len(train_coco)) # 118287 images

    # Access a specific sample (4th sample here)
    img, target = train_coco[3]  # Load the 4th sample (index 3)

    # Display information about the sample
    print("Image Size:", img.size())  # Torch tensor size
    #plt.imshow(img.permute(1, 2, 0))  # Display the image
    print("Captions:", target)  # Captions for the image

for images, captions_list in train_loader:
    # images.shape is e.g. (N, 3, 224, 224)
    # captions_list has length N, but each item might be a tuple of possible captions

    plt.imshow(images[0].permute(1, 2, 0))
    plt.show()
    plt.imshow(images[1].permute(1, 2, 0))
    plt.show()

    print("Image batch size:", images.shape[0], "Shape:", images.shape)
    print("Captions list length:", len(captions_list))
    
    print("Captions list:", list(captions_list))

    print("Number of chosen captions:", len(list(captions_list[0])))
    
    captions = list(captions_list[0])

    # Then tokenize
    text_tokens = tokenizer.tokenize(captions)
    print("Text tokens shape:", text_tokens.shape)

    # Now encode
    #image_embeds = model.encode_image(images.to(device))
    #text_embeds = model.encode_text(text_tokens.to(device))

    # Should both be shape (N, D)
    #print("Image embeds shape:", image_embeds.shape)
    #print("Text  embeds shape:", text_embeds.shape)

    break  # just to test one batch
    

def collate_fn_debug(batch):
    print("Bath type:", type(batch)) # This is a list
    print("Batch size:", len(batch))
    print("Batch:", batch)
    images, captions = zip(*batch)
    
    print("Images type:", type(images))
    print("Images size:", len(images))
    print("Images:", images)
    
    print("Captions type:", type(captions))
    print("Captions size:", len(captions))
    print("Captions:", captions) # This is a tuple of lists, each list contains 5 captions for each image
    
    # Select one caption per image
    sel_captions = []
    for list_captions in captions:
        #print("List Captions:", list_captions)
        caption = random.choice(list_captions)
        sel_captions.append(caption)
    
    print("Selected Captions:", sel_captions)    



for images, captions_list in train_loader:
    break

# DONE: ensure that each tuple of captions has the same length, or the data loader will fail (defalut is collate(samples, collate_fn_map=collate_fn_map) from error message)

 """

' %doctest_mode\n\n\ndef dataset_details():\n    # Print dataset details\n    print(\'Number of samples:\', len(train_coco)) # 118287 images\n\n    # Access a specific sample (4th sample here)\n    img, target = train_coco[3]  # Load the 4th sample (index 3)\n\n    # Display information about the sample\n    print("Image Size:", img.size())  # Torch tensor size\n    #plt.imshow(img.permute(1, 2, 0))  # Display the image\n    print("Captions:", target)  # Captions for the image\n\nfor images, captions_list in train_loader:\n    # images.shape is e.g. (N, 3, 224, 224)\n    # captions_list has length N, but each item might be a tuple of possible captions\n\n    plt.imshow(images[0].permute(1, 2, 0))\n    plt.show()\n    plt.imshow(images[1].permute(1, 2, 0))\n    plt.show()\n\n    print("Image batch size:", images.shape[0], "Shape:", images.shape)\n    print("Captions list length:", len(captions_list))\n    \n    print("Captions list:", list(captions_list))\n\n    print("Number of chosen 