In [1]:
!pip install -q timm pandas tqdm albumentations opencv-python scikit-learn transformers torch torchvision torchaudio

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [6]:
import os
import cv2
import numpy as np
import pandas as pd
import itertools # For optimizer params later
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt # For potential plotting, not explicitly used in core logic yet
import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
from collections import defaultdict # For metric calculation later


  from tqdm.autonotebook import tqdm
  check_for_updates()


### Block 1: Flickr8k Data Preparation & Initial Setup

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
zip_path1 = '/content/drive/MyDrive/Flickr8k.zip'

!ls -lh "$zip_path1"

!unzip -q "$zip_path1" -d "/content/sample_data/Flickr8k"

-rw------- 1 root root 1.1G Jun  1 06:35 /content/drive/MyDrive/Flickr8k.zip


In [4]:
# --- Configuration for Flickr8k Paths ---
# IMPORTANT: Update these paths to where your Flickr8k dataset is located if they are different
FLICKR8K_IMAGES_DIR_PATH = "/content/sample_data/Flickr8k/Images"
FLICKR8K_TOKEN_FILE_PATH = "/content/sample_data/Flickr8k/captions.txt"

print(f"Using Flickr8k Images Path: {FLICKR8K_IMAGES_DIR_PATH}")
print(f"Using Flickr8k Captions File: {FLICKR8K_TOKEN_FILE_PATH}")

Using Flickr8k Images Path: /content/sample_data/Flickr8k/Images
Using Flickr8k Captions File: /content/sample_data/Flickr8k/captions.txt


In [10]:
# Read the Flickr8k captions file
#    Assuming each line is in the format:
#        image_name.jpg, A caption describing the image...
# image,caption
#1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .
#    Use names=['image','caption'] since there is no header row.
df = pd.read_csv(
    FLICKR8K_TOKEN_FILE_PATH,
    sep=",",
    names=["image", "caption"],
    header=None,
    engine="python",
)

# 2. Strip any leading/trailing whitespace from caption text
df["caption"] = df["caption"].str.strip()

# 3. Assign a caption_number to each caption (0–4) per image.
#    This assumes that each image appears exactly 5 times, in sequence.
#    If your file isn’t guaranteed to be grouped by image, you can still use groupby().
df["caption_number"] = df.groupby("image").cumcount()

# 4. Create an integer ID for each image.
#    factorize() assigns a unique integer ID (0, 1, 2, …) to each distinct image filename.
df["id"] = df["image"].factorize()[0]

# 5. Reorder columns if you want them in the same order as your original snippet:
df = df[["image", "caption_number", "caption", "id"]]

# 6. Save to CSV
# Save to a new CSV (in the current directory)
PROCESSED_CAPTIONS_CSV_PATH = "/content/sample_data/Flickr8k/flickr8k_captions.csv"
df.to_csv(PROCESSED_CAPTIONS_CSV_PATH, index=False)

# 7. (Optional) Peek at the first few rows
print(df.head())

                       image  caption_number  \
0                      image               0   
1  1000268201_693b08cb0e.jpg               0   
2  1000268201_693b08cb0e.jpg               1   
3  1000268201_693b08cb0e.jpg               2   
4  1000268201_693b08cb0e.jpg               3   

                                             caption  id  
0                                            caption   0  
1  A child in a pink dress is climbing up a set o...   1  
2              A girl going into a wooden building .   1  
3   A little girl climbing into a wooden playhouse .   1  
4  A little girl climbing the stairs to her playh...   1  


In [11]:
print("\nProcessed Flickr8k DataFrame head:")
print(df.head())
print(f"Processed captions saved to: {os.path.abspath(PROCESSED_CAPTIONS_CSV_PATH)}")



Processed Flickr8k DataFrame head:
                       image  caption_number  \
0                      image               0   
1  1000268201_693b08cb0e.jpg               0   
2  1000268201_693b08cb0e.jpg               1   
3  1000268201_693b08cb0e.jpg               2   
4  1000268201_693b08cb0e.jpg               3   

                                             caption  id  
0                                            caption   0  
1  A child in a pink dress is climbing up a set o...   1  
2              A girl going into a wooden building .   1  
3   A little girl climbing into a wooden playhouse .   1  
4  A little girl climbing the stairs to her playh...   1  
Processed captions saved to: /content/sample_data/Flickr8k/flickr8k_captions.csv


### Block 2: CFG, Utility Classes, and Dataset Class

In [12]:
class CFG:
    debug = False # Set to True for smaller dataset runs and fewer epochs
    image_path = FLICKR8K_IMAGES_DIR_PATH # From Block 1
    captions_path = "." # Directory where processed_captions_file is saved
    processed_captions_file = PROCESSED_CAPTIONS_CSV_PATH # From Block 1

    batch_size = 32
    num_workers = 2 # Adjust based on your environment (0 often safest for notebooks)
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3 # For projection heads in optimizer setup
    patience = 2 # For ReduceLROnPlateau
    factor = 0.5 # For ReduceLROnPlateau
    epochs = 3 if debug else 10 # Reduced for faster runs, adjust as needed
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_name = 'resnet50'
    image_embedding = 2048 # Output dim of ResNet50 before projection
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768   # Output dim of DistilBert before projection
    text_tokenizer = "distilbert-base-uncased" # Tokenizer model name
    max_length = 200 # Max sequence length for captions

    pretrained = True # Use pretrained weights for image and text encoders
    trainable = True # Fine-tune image and text encoders
    temperature = 1.0 # Initial value for temperature (used as fixed in this CLIPModel)

    size = 224 # Image size

    num_projection_layers = 1 # Informational, as ProjectionHead structure is fixed
    projection_dim = 256 # Dimension of the joint embedding space
    dropout = 0.1

cfg = CFG() # Create an instance of the configuration

In [13]:
#UTILS
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        if self.count > 0:
            self.avg = self.sum / self.count
        else:
            self.avg = 0


    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"] # Returns the LR of the first param group

In [14]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            self.captions, padding=True, truncation=True, max_length=cfg.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }
        image_file_path = os.path.join(cfg.image_path, self.image_filenames[idx])
        try:
            image = cv2.imread(image_file_path)
            if image is None:
                # print(f"Warning: Failed to read image {image_file_path} at index {idx}. Using placeholder.")
                image = np.zeros((cfg.size, cfg.size, 3), dtype=np.uint8) # Placeholder

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            transformed = self.transforms(image=image)
            image_tensor = transformed['image']

            if not isinstance(image_tensor, torch.Tensor):
                 image_tensor = torch.from_numpy(image_tensor)

            if image_tensor.ndim == 3 and image_tensor.shape[0] != 3 and image_tensor.shape[2] in [1, 3, 4]: # HWC check (common if ToTensorV2 not last)
                image_tensor = image_tensor.permute(2, 0, 1)

            item['image'] = image_tensor.float()
            item['caption_text'] = self.captions[idx]
        except Exception as e:
            print(f"Error processing item {idx}, image {image_file_path}: {e}")
            item['image'] = torch.zeros((3, cfg.size, cfg.size)).float() # Fallback image
            item['caption_text'] = "error placeholder caption"
             # Ensure tokenized keys still exist even if error
            if 'input_ids' not in item:
                 temp_tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
                 dummy_tokens = temp_tokenizer(item['caption_text'],padding="max_length",truncation=True,max_length=cfg.max_length, return_tensors="pt")
                 item.update({key: values[0] for key, values in dummy_tokens.items()})
        return item

    def __len__(self):
        return len(self.captions)

def get_transforms(mode="train"):
    # Using standard ImageNet mean/std
    return A.Compose(
        [
            A.Resize(cfg.size, cfg.size, always_apply=True),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, always_apply=True),
            # If not using A.ToTensorV2(), manual conversion and permute needed in Dataset as currently done.
            # ToTensorV2(), # If used, it should be the last transform. It also converts HWC -> CHW.
        ]
    )

print(f"\nConfiguration instance `cfg` created. Device: {cfg.device}")
print(f"CFG Image path: {cfg.image_path}")
print(f"CFG Processed captions file: {os.path.join(cfg.captions_path, cfg.processed_captions_file)}")


Configuration instance `cfg` created. Device: cuda
CFG Image path: /content/sample_data/Flickr8k/Images
CFG Processed captions file: /content/sample_data/Flickr8k/flickr8k_captions.csv


### Block 3: Model Architecture and Loss Function
  

In [15]:
class ImageEncoder(nn.Module):
    def __init__(
        self, model_name=cfg.model_name, pretrained=cfg.pretrained, trainable=cfg.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained=pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [16]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=cfg.text_encoder_model, pretrained=cfg.pretrained, trainable=cfg.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            # For non-pretrained, still use the config of the specified model to match dimensions etc.
            config = DistilBertConfig.from_pretrained(model_name)
            self.model = DistilBertModel(config=config)

        for p in self.model.parameters():
            p.requires_grad = trainable
        self.target_token_idx = 0 # Index of CLS token

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :] # CLS token embedding

In [17]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim, # Input embedding dimension from the encoder
        projection_dim=cfg.projection_dim, # Output dimension (joint embedding space)
        dropout=cfg.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x_gelu = self.gelu(projected)
        x_fc = self.fc(x_gelu)
        x_dropout = self.dropout(x_fc)
        x_residual = x_dropout + projected # Residual connection
        x_norm = self.layer_norm(x_residual)
        return x_norm

In [18]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=cfg.temperature, # Fixed temperature from CFG
        image_embedding_dim=cfg.image_embedding, # Output dim of image_encoder
        text_embedding_dim=cfg.text_embedding,   # Output dim of text_encoder
    ):
        super().__init__()
        self.image_encoder = ImageEncoder() # Uses defaults from cfg
        self.text_encoder = TextEncoder()   # Uses defaults from cfg
        self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim)
        self.temperature = temperature # Fixed temperature, not a learnable nn.Parameter here

    def encode_image(self, image_tensor):
        image_features = self.image_encoder(image_tensor)
        image_embeddings = self.image_projection(image_features)
        return F.normalize(image_embeddings, p=2, dim=-1)

    def encode_text(self, input_ids, attention_mask):
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = self.text_projection(text_features)
        return F.normalize(text_embeddings, p=2, dim=-1)

    def forward(self, batch):
        image_embeddings = self.encode_image(batch["image"])
        text_embeddings = self.encode_text(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Calculating the Loss using the user's specified formulation
        logits = (text_embeddings @ image_embeddings.T) / self.temperature

        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        # Targets based on similarities, scaled by temperature
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2.0 * self.temperature, dim=-1
        )

        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none') # Assumes targets is symmetric or handled by .T
        loss =  (images_loss + texts_loss) / 2.0
        return loss.mean()

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(dim=1) # Sum over the class dimension
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

print("\nModel architecture classes (ImageEncoder, TextEncoder, ProjectionHead, CLIPModel) and loss function defined.")


Model architecture classes (ImageEncoder, TextEncoder, ProjectionHead, CLIPModel) and loss function defined.


### Block 4: Training Loop and Data Loaders Setup

In [19]:
def make_train_valid_dfs():
    captions_file_to_read = os.path.join(cfg.captions_path, cfg.processed_captions_file)
    print(f"Reading captions for train/valid split from: {captions_file_to_read}")
    try:
        dataframe = pd.read_csv(captions_file_to_read)
        if dataframe.empty:
            print("Error: Loaded dataframe for splitting is empty.")
            return pd.DataFrame(), pd.DataFrame()
    except FileNotFoundError:
        print(f"Error: Processed captions file not found at {captions_file_to_read}. Cannot create train/valid splits.")
        return pd.DataFrame(), pd.DataFrame()

    if "id" not in dataframe.columns or dataframe["id"].isnull().all():
        print("Error: 'id' column is missing or all null in the dataframe. Cannot split data by image ID.")
        return pd.DataFrame(), pd.DataFrame()

    valid_id_values = dataframe["id"].dropna().astype(int)
    if valid_id_values.empty:
        print("Error: No valid 'id' values found for splitting.")
        return pd.DataFrame(), pd.DataFrame()

    unique_image_ids_in_df = sorted(valid_id_values.unique())

    if cfg.debug:
        debug_max_unique_ids = min(20, len(unique_image_ids_in_df)) # e.g., 20 unique images for debug
        image_ids_to_split = unique_image_ids_in_df[:debug_max_unique_ids]
        print(f"Debug mode: Using {len(image_ids_to_split)} unique image IDs for train/valid split.")
    else:
        image_ids_to_split = unique_image_ids_in_df

    if not image_ids_to_split: # Check if list is empty
        print("No image IDs available for splitting after filtering (or in debug mode).")
        return pd.DataFrame(), pd.DataFrame()

    np.random.seed(42)
    num_valid_samples = int(0.2 * len(image_ids_to_split))
    if len(image_ids_to_split) > 0 and num_valid_samples == 0 : # Ensure at least 1 validation ID if possible
        num_valid_samples = 1
    if num_valid_samples > len(image_ids_to_split): # If 0.2*len is > len (only if len<5)
        num_valid_samples = max(0, len(image_ids_to_split) -1 ) # Ensure at least 1 train ID if possible


    valid_image_ids = np.random.choice(
        image_ids_to_split, size=num_valid_samples, replace=False
    )
    train_image_ids = [id_ for id_ in image_ids_to_split if id_ not in valid_image_ids]

    train_dataframe = dataframe[dataframe["id"].isin(train_image_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_image_ids)].reset_index(drop=True)

    print(f"Train DataFrame shape: {train_dataframe.shape} ({len(train_image_ids)} unique images)")
    print(f"Valid DataFrame shape: {valid_dataframe.shape} ({len(valid_image_ids)} unique images)")
    return train_dataframe, valid_dataframe


In [20]:
def build_loaders(dataframe, tokenizer, mode):
    if dataframe.empty or not all(col in dataframe.columns for col in ["image", "caption"]):
        print(f"Warning: DataFrame for mode '{mode}' is empty or missing key columns. Returning None for DataLoader.")
        return None

    transforms = get_transforms(mode=mode) # mode is "train" or "valid"/"test"
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )

    if len(dataset) == 0:
        print(f"Warning: Dataset for mode '{mode}' is empty after initialization. Returning None for DataLoader.")
        return None

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=True if mode == "train" else False,
        pin_memory=True if cfg.device == "cuda" else False
    )
    return dataloader


In [23]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, scheduler_step_mode):
    loss_meter = AvgMeter()
    model.train() # Ensure model is in training mode
    tqdm_object = tqdm(train_loader, total=len(train_loader), desc=f"Training Epoch")
    for batch in tqdm_object:
        batch_on_device = {}
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch_on_device[k] = v.to(cfg.device)

        loss = model(batch_on_device)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if scheduler_step_mode == "batch":
            if lr_scheduler is not None:
                 lr_scheduler.step()

        count = batch_on_device["image"].size(0)
        loss_meter.update(loss.item(), count)
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()
    model.eval() # Ensure model is in evaluation mode
    tqdm_object = tqdm(valid_loader, total=len(valid_loader), desc=f"Validation Epoch")
    with torch.no_grad(): # No gradients needed for validation
        for batch in tqdm_object:
            batch_on_device = {}
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    batch_on_device[k] = v.to(cfg.device)

            loss = model(batch_on_device)

            count = batch_on_device["image"].size(0)
            loss_meter.update(loss.item(), count)
            tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter

In [24]:
# --- Main Training Script Execution ---
print("\n--- Starting CLIP Model Training Process ---")

# 1. Create train and valid DataFrames
# `df` (the full dataframe) should be available from Block 1
if 'df' not in locals() or df.empty:
    print("Base DataFrame 'df' is not available or empty. Attempting to reload...")
    try:
        df = pd.read_csv(os.path.join(cfg.captions_path, cfg.processed_captions_file))
        if df.empty: raise FileNotFoundError # Trigger except if reloaded df is empty
    except FileNotFoundError:
        print(f"Fatal: Could not load base DataFrame from {os.path.join(cfg.captions_path, cfg.processed_captions_file)}. Halting.")
        # exit() # Or handle more gracefully depending on execution environment


--- Starting CLIP Model Training Process ---


In [25]:
if not df.empty:
    train_df, valid_df = make_train_valid_dfs()

    if train_df.empty or valid_df.empty:
        print("Train or Valid DataFrame is empty after split. Halting training.")
    else:
        # 2. Initialize Tokenizer
        tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
        print("Tokenizer initialized.")

        # 3. Build DataLoaders
        train_loader = build_loaders(train_df, tokenizer, mode="train")
        valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

        if train_loader is None or valid_loader is None:
            print("Failed to build DataLoaders. Halting training.")
        else:
            print("DataLoaders built successfully.")
            # 4. Initialize Model
            model = CLIPModel().to(cfg.device)
            print(f"CLIPModel initialized on {cfg.device}.")

            # 5. Setup Optimizer
            params = [
                {"params": model.image_encoder.parameters(), "lr": cfg.image_encoder_lr, "weight_decay": 0.0}, # Explicitly no wd for encoders here
                {"params": model.text_encoder.parameters(), "lr": cfg.text_encoder_lr, "weight_decay": 0.0},  # Explicitly no wd for encoders here
                {"params": itertools.chain(
                    model.image_projection.parameters(), model.text_projection.parameters()
                ), "lr": cfg.head_lr, "weight_decay": cfg.weight_decay} # wd only for projection heads
            ]
            optimizer = torch.optim.AdamW(params) # AdamW will apply its own default wd if not set per group or here
            print("Optimizer AdamW initialized.")

            # 6. Setup LR Scheduler
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="min", patience=cfg.patience, factor=cfg.factor, verbose=True
            )
            print("LR Scheduler (ReduceLROnPlateau) initialized.")

            scheduler_step_mode = "epoch" # For ReduceLROnPlateau

            best_loss = float('inf')
            best_epoch = -1

            # 7. Training Loop
            for epoch in range(cfg.epochs):
                print(f"\nEpoch: {epoch + 1}/{cfg.epochs}")

                train_loss_meter = train_epoch(model, train_loader, optimizer, lr_scheduler, scheduler_step_mode)
                print(f"Epoch {epoch+1} Train Loss: {train_loss_meter.avg:.4f}, LR: {get_lr(optimizer):.1e}")

                valid_loss_meter = valid_epoch(model, valid_loader)
                print(f"Epoch {epoch+1} Valid Loss: {valid_loss_meter.avg:.4f}")

                if valid_loss_meter.avg < best_loss:
                    best_loss = valid_loss_meter.avg
                    best_epoch = epoch + 1
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_loss': best_loss,
                        'cfg': vars(cfg) # Save config for reference
                        }, "/content/sample_data/Flickr8k/best_model_checkpoint.pt")
                    print(f"Saved Best Model at Epoch {epoch+1}! Validation Loss: {best_loss:.4f}")

                if scheduler_step_mode == "epoch":
                    lr_scheduler.step(valid_loss_meter.avg)

            print("\n--- Training Complete ---")
            print(f"Best Validation Loss: {best_loss:.4f} achieved at Epoch {best_epoch}")
            print(f"Best model saved to: best_model_checkpoint.pt")
else:
    print("Base DataFrame 'df' is empty. Training process cannot start.")


Reading captions for train/valid split from: /content/sample_data/Flickr8k/flickr8k_captions.csv
Train DataFrame shape: (32366, 4) (6474 unique images)
Valid DataFrame shape: (8090, 4) (1618 unique images)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Tokenizer initialized.


  A.Resize(cfg.size, cfg.size, always_apply=True),
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, always_apply=True),


DataLoaders built successfully.


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

CLIPModel initialized on cuda.
Optimizer AdamW initialized.
LR Scheduler (ReduceLROnPlateau) initialized.

Epoch: 1/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 1 Train Loss: 3.2941, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 1 Valid Loss: 3.3022
Saved Best Model at Epoch 1! Validation Loss: 3.3022

Epoch: 2/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 2 Train Loss: 3.2572, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 2 Valid Loss: 3.2994
Saved Best Model at Epoch 2! Validation Loss: 3.2994

Epoch: 3/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 3 Train Loss: 3.2518, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 3 Valid Loss: 3.3000

Epoch: 4/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 4 Train Loss: 3.2485, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 4 Valid Loss: 3.2980
Saved Best Model at Epoch 4! Validation Loss: 3.2980

Epoch: 5/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 5 Train Loss: 3.2499, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 5 Valid Loss: 3.2984

Epoch: 6/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 6 Train Loss: 3.2484, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 6 Valid Loss: 3.3006

Epoch: 7/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 7 Train Loss: 3.2478, LR: 1.0e-04


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 7 Valid Loss: 3.3025

Epoch: 8/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 8 Train Loss: 3.2461, LR: 5.0e-05


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 8 Valid Loss: 3.3059

Epoch: 9/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 9 Train Loss: 3.2452, LR: 5.0e-05


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 9 Valid Loss: 3.2982

Epoch: 10/10


Training Epoch:   0%|          | 0/1012 [00:00<?, ?it/s]

Epoch 10 Train Loss: 3.2458, LR: 5.0e-05


Validation Epoch:   0%|          | 0/253 [00:00<?, ?it/s]

Epoch 10 Valid Loss: 3.3009

--- Training Complete ---
Best Validation Loss: 3.2980 achieved at Epoch 4
Best model saved to: best_model_checkpoint.pt


In [26]:
!mv ./best_model_checkpoint.pt /content/sample_data/Flickr8k/.

### Block 5: Embedding Generation (After Training)

In [27]:
print("\n--- Preparing for Embedding Generation (Post-Training) ---")

MODEL_CHECKPOINT_PATH = "/content/sample_data/Flickr8k/best_model_checkpoint.pt" # Path to the saved best model

# Ensure CFG and CLIPModel class are defined from previous blocks
# cfg instance should already be defined

if os.path.exists(MODEL_CHECKPOINT_PATH):
    loaded_model_for_embedding = CLIPModel().to(cfg.device)
    try:
        checkpoint = torch.load(MODEL_CHECKPOINT_PATH, map_location=cfg.device)
        loaded_model_for_embedding.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model loaded successfully for embedding generation from {MODEL_CHECKPOINT_PATH}")
        loaded_model_for_embedding.eval()
    except Exception as e:
        print(f"Error loading model from {MODEL_CHECKPOINT_PATH}: {e}. Cannot generate embeddings.")
        loaded_model_for_embedding = None # Ensure it's None if loading failed
else:
    print(f"Model checkpoint {MODEL_CHECKPOINT_PATH} not found. Cannot generate embeddings.")
    loaded_model_for_embedding = None



--- Preparing for Embedding Generation (Post-Training) ---
Model loaded successfully for embedding generation from /content/sample_data/Flickr8k/best_model_checkpoint.pt


In [28]:
if loaded_model_for_embedding:
    # Initialize tokenizer and image transforms
    tokenizer_for_embedding = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
    image_transforms_for_embedding = get_transforms(mode="valid")

    # Prepare the DataFrame for which you want to generate embeddings.
    # Example: using the 'valid_df' created during training.
    # Or, load the full 'df' if you want embeddings for the entire dataset.
    # For consistency, let's use the 'valid_df' that was defined during the training block, if available.
    # If not, we re-create it.

    if 'valid_df' not in locals() or valid_df.empty:
        print("`valid_df` not in scope or empty from training. Re-creating a validation split for embedding example.")
        if 'df' not in locals() or df.empty:
             print("Base DataFrame 'df' not found or empty. Reloading processed captions...")
             try:
                 df = pd.read_csv(os.path.join(cfg.captions_path, cfg.processed_captions_file))
                 if df.empty: raise ValueError("Reloaded DataFrame for splitting is empty.")
             except Exception as e_reload:
                 print(f"Could not load base DataFrame: {e_reload}")
                 df = pd.DataFrame() # empty df

        if not df.empty:
            _, current_target_df_for_embeddings = make_train_valid_dfs() # This gets a new split
            # Ideally, for evaluating the *trained* model, you'd use the exact same valid_df from training.
            # If `valid_df` from the training block is still in scope, prefer using that.
            # This re-split is for a self-contained example here.
            print(f"Using a new validation split (shape: {current_target_df_for_embeddings.shape}) for embedding generation.")
        else:
            print("Base DataFrame 'df' is empty. Cannot generate splits for embedding example.")
            current_target_df_for_embeddings = pd.DataFrame()
    else:
        print("Using `valid_df` from the training phase for embedding generation.")
        current_target_df_for_embeddings = valid_df # Use the one from the training block

    if not current_target_df_for_embeddings.empty:
        print(f"Target DataFrame for embeddings (shape: {current_target_df_for_embeddings.shape})")

        def generate_image_embeddings_dict(dataframe, image_column_name, model_to_use, transforms, cfg_obj_ref):
            if dataframe.empty or image_column_name not in dataframe.columns:
                print(f"Warning: DataFrame empty or missing '{image_column_name}'. Returning empty dict for image embeddings.")
                return {}
            unique_image_filenames = sorted(list(dataframe[image_column_name].unique()))
            image_embeddings_map = {}
            model_to_use.eval() # Ensure model is in eval mode
            print(f"Generating embeddings for {len(unique_image_filenames)} unique images...")
            with torch.no_grad():
                for image_filename in tqdm(unique_image_filenames, desc="Generating Image Embeddings"):
                    image_path = os.path.join(cfg_obj_ref.image_path, image_filename)
                    try:
                        image = cv2.imread(image_path)
                        if image is None: continue
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                        transformed = transforms(image=image)
                        img_tensor = transformed['image']
                        if not isinstance(img_tensor, torch.Tensor): img_tensor = torch.from_numpy(img_tensor)
                        if img_tensor.ndim == 3 and img_tensor.shape[0] != 3: img_tensor = img_tensor.permute(2, 0, 1)
                        img_tensor = img_tensor.unsqueeze(0).float().to(cfg_obj_ref.device)
                        embedding = model_to_use.encode_image(img_tensor)
                        image_embeddings_map[image_filename] = embedding.squeeze(0).cpu()
                    except Exception as e_img: print(f"Error for image {image_path}: {e_img}")
            return image_embeddings_map

        def generate_text_embeddings_dict(dataframe, img_col, cap_num_col, cap_text_col, model_to_use, tokenizer_ref, cfg_obj_ref):
            if dataframe.empty or not all(c in dataframe.columns for c in [img_col, cap_num_col, cap_text_col]):
                print("Warning: DataFrame empty or missing columns for text embeddings. Returning empty dict.")
                return {}
            text_embeddings_map = {}
            captions_data = []
            for _, row in dataframe.iterrows():
                caption_key = f"{row[img_col]}_cap_{row[cap_num_col]}"
                captions_data.append((caption_key, row[cap_text_col]))
            model_to_use.eval() # Ensure model is in eval mode
            print(f"Generating embeddings for {len(captions_data)} captions...")
            with torch.no_grad():
                for i in tqdm(range(0, len(captions_data), cfg_obj_ref.batch_size), desc="Generating Text Embeddings"):
                    batch_data = captions_data[i:i + cfg_obj_ref.batch_size]
                    batch_keys = [item[0] for item in batch_data]
                    batch_texts = [item[1] for item in batch_data]
                    try:
                        tokens = tokenizer_ref(batch_texts, padding=True, truncation=True, max_length=cfg_obj_ref.max_length, return_tensors="pt")
                        input_ids = tokens["input_ids"].to(cfg_obj_ref.device)
                        attention_mask = tokens["attention_mask"].to(cfg_obj_ref.device)
                        batch_text_embeddings = model_to_use.encode_text(input_ids, attention_mask)
                        for key, embedding in zip(batch_keys, batch_text_embeddings):
                            text_embeddings_map[key] = embedding.cpu()
                    except Exception as e_txt: print(f"Error for text batch starting with {batch_keys[0] if batch_keys else 'N/A'}: {e_txt}")
            return text_embeddings_map

        # Generate embeddings
        generated_image_embeddings = generate_image_embeddings_dict(
            dataframe=current_target_df_for_embeddings,
            image_column_name='image',
            model_to_use=loaded_model_for_embedding,
            transforms=image_transforms_for_embedding,
            cfg_obj_ref=cfg
        )
        print(f"Generated {len(generated_image_embeddings)} unique image embeddings.")

        generated_text_embeddings = generate_text_embeddings_dict(
            dataframe=current_target_df_for_embeddings,
            img_col='image',
            cap_num_col='caption_number',
            cap_text_col='caption',
            model_to_use=loaded_model_for_embedding,
            tokenizer_ref=tokenizer_for_embedding,
            cfg_obj_ref=cfg
        )
        print(f"Generated {len(generated_text_embeddings)} text embeddings.")

        # Save these embeddings for the next step
        EMBEDDINGS_OUTPUT_DIR = "/content/sample_data/Flickr8k/final_embeddings/"
        os.makedirs(EMBEDDINGS_OUTPUT_DIR, exist_ok=True)
        IMG_EMB_PATH = os.path.join(EMBEDDINGS_OUTPUT_DIR, "final_image_embeddings.pt")
        TXT_EMB_PATH = os.path.join(EMBEDDINGS_OUTPUT_DIR, "final_text_embeddings.pt")

        if generated_image_embeddings:
            torch.save(generated_image_embeddings, IMG_EMB_PATH)
            print(f"Image embeddings saved to {IMG_EMB_PATH}")
        if generated_text_embeddings:
            torch.save(generated_text_embeddings, TXT_EMB_PATH)
            print(f"Text embeddings saved to {TXT_EMB_PATH}")
    else:
        print("Target DataFrame for embeddings is empty. Skipping embedding generation.")
else:
    print("Skipping embedding generation as trained model could not be loaded.")


Using `valid_df` from the training phase for embedding generation.
Target DataFrame for embeddings (shape: (8090, 4))
Generating embeddings for 1618 unique images...


  A.Resize(cfg.size, cfg.size, always_apply=True),
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, always_apply=True),


Generating Image Embeddings:   0%|          | 0/1618 [00:00<?, ?it/s]

Generated 1618 unique image embeddings.
Generating embeddings for 8090 captions...


Generating Text Embeddings:   0%|          | 0/253 [00:00<?, ?it/s]

Generated 8090 text embeddings.
Image embeddings saved to ./flickr8k_final_embeddings/final_image_embeddings.pt
Text embeddings saved to ./flickr8k_final_embeddings/final_text_embeddings.pt


### Block 6: Retrieval Accuracy Calculation (Using Dumped Embeddings)

In [32]:
print("\n--- Preparing for Retrieval Accuracy Calculation ---")

# Top-K values for evaluation
TOP_K_VALUES_EVAL = list(range(1, 6)) # k from 1 to 5

# --- Metric Helper Functions (should be defined if not already in scope from Block 2) ---
# Assuming AvgMeter, get_lr are defined. We need compute_top_k_metrics, etc. for retrieval.

def compute_top_k_retrieval_metrics(ranked_keys, ground_truth_keys, current_top_k_values_list):
    result = {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in current_top_k_values_list}
    ground_truth_set = set(ground_truth_keys)
    if not ground_truth_set:
        return result # All metrics remain 0
    for k_val in current_top_k_values_list:
        top_k = ranked_keys[:k_val]
        hits = len(set(top_k) & ground_truth_set)
        result[k_val]["accuracy"] = 1 if hits > 0 else 0
        result[k_val]["precision"] = hits / k_val if k_val > 0 else 0
        result[k_val]["recall"] = hits / len(ground_truth_set)
    return result

def accumulate_retrieval_metrics(overall_metrics_dict, current_query_metrics_dict):
    for k_val in current_query_metrics_dict:
        for metric_name in current_query_metrics_dict[k_val]:
            overall_metrics_dict[k_val][metric_name] += current_query_metrics_dict[k_val][metric_name]

def init_retrieval_metrics_accumulator(current_top_k_values_list):
    return {k: {"accuracy": 0, "precision": 0, "recall": 0} for k in current_top_k_values_list}

def report_retrieval_metrics(title, accumulated_metrics_dict, total_valid_queries_count, current_top_k_values_list):
    print(f"\n--- {title} Retrieval Results ({total_valid_queries_count} Queries) ---")
    if total_valid_queries_count == 0:
        print("No valid queries were processed.")
        return
    for k_val in current_top_k_values_list:
        acc = accumulated_metrics_dict[k_val]["accuracy"] / total_valid_queries_count
        prec = accumulated_metrics_dict[k_val]["precision"] / total_valid_queries_count
        rec = accumulated_metrics_dict[k_val]["recall"] / total_valid_queries_count
        print(f"Top-{k_val}: Accuracy = {acc:.4f} | Precision = {prec:.4f} | Recall = {rec:.4f}")



--- Preparing for Retrieval Accuracy Calculation ---


In [33]:
# --- Main Evaluation Function using Loaded Embeddings ---
def run_final_retrieval_evaluation(
    image_embeddings_dict_loaded,
    text_embeddings_dict_loaded,
    top_k_list,
    calc_device="cpu" # Device for cosine similarity calculation
):
    if not image_embeddings_dict_loaded or not text_embeddings_dict_loaded:
        print("Error: Loaded image or text embeddings dictionary is empty. Cannot proceed with evaluation.")
        return

    print("Preparing matrices for final evaluation...")
    image_ids_eval_list = list(image_embeddings_dict_loaded.keys())
    img_matrix_list = [image_embeddings_dict_loaded[id_].unsqueeze(0).to(calc_device) if image_embeddings_dict_loaded[id_].ndim==1 else image_embeddings_dict_loaded[id_].to(calc_device) for id_ in image_ids_eval_list]
    image_matrix_eval = torch.cat(img_matrix_list, dim=0).cpu().numpy()

    caption_keys_eval_list = list(text_embeddings_dict_loaded.keys())
    txt_matrix_list = [text_embeddings_dict_loaded[id_].unsqueeze(0).to(calc_device) if text_embeddings_dict_loaded[id_].ndim==1 else text_embeddings_dict_loaded[id_].to(calc_device) for id_ in caption_keys_eval_list]
    text_matrix_eval = torch.cat(txt_matrix_list, dim=0).cpu().numpy()

    print(f"Evaluation Image matrix shape: {image_matrix_eval.shape}")
    print(f"Evaluation Text matrix shape: {text_matrix_eval.shape}")

    gt_img_to_capkeys = defaultdict(list)
    gt_capkey_to_imgid = {}
    for cap_key in caption_keys_eval_list:
        try:
            img_id_part = cap_key.split('_cap_')[0]
            gt_img_to_capkeys[img_id_part].append(cap_key)
            gt_capkey_to_imgid[cap_key] = img_id_part
        except IndexError: continue
    print(f"Built GT map for {len(gt_img_to_capkeys)} images for evaluation.")

    # 1. Image -> Text Retrieval
    if image_matrix_eval.size > 0 and text_matrix_eval.size > 0:
        i2t_metrics = init_retrieval_metrics_accumulator(top_k_list)
        i2t_queries = 0
        for idx, img_id in enumerate(tqdm(image_ids_eval_list, desc="Eval: Img2Txt")):
            if img_id not in gt_img_to_capkeys: continue
            query_img_vec = image_matrix_eval[idx].reshape(1, -1)
            sims = cosine_similarity(query_img_vec, text_matrix_eval)[0]
            ranked_indices = np.argsort(sims)[::-1]
            ranked_caps = [caption_keys_eval_list[i] for i in ranked_indices]
            gt_caps = gt_img_to_capkeys[img_id]
            metrics = compute_top_k_retrieval_metrics(ranked_caps, gt_caps, top_k_list)
            accumulate_retrieval_metrics(i2t_metrics, metrics)
            i2t_queries += 1
        report_retrieval_metrics("Image-to-Text (Final Eval)", i2t_metrics, i2t_queries, top_k_list)

    # 2. Text -> Text Retrieval
    if text_matrix_eval.size > 0:
        t2t_metrics = init_retrieval_metrics_accumulator(top_k_list)
        t2t_queries = 0
        for idx, cap_key_query in enumerate(tqdm(caption_keys_eval_list, desc="Eval: Txt2Txt")):
            img_id_query = gt_capkey_to_imgid.get(cap_key_query)
            if not img_id_query: continue
            query_txt_vec = text_matrix_eval[idx].reshape(1, -1)
            sims = cosine_similarity(query_txt_vec, text_matrix_eval)[0]
            sims[idx] = -1e9 # Exclude self
            ranked_indices = np.argsort(sims)[::-1]
            ranked_caps = [caption_keys_eval_list[i] for i in ranked_indices]
            gt_caps = [ck for ck in gt_img_to_capkeys.get(img_id_query, []) if ck != cap_key_query]
            metrics = compute_top_k_retrieval_metrics(ranked_caps, gt_caps, top_k_list)
            accumulate_retrieval_metrics(t2t_metrics, metrics)
            t2t_queries +=1
        report_retrieval_metrics("Text-to-Text (Final Eval)", t2t_metrics, t2t_queries, top_k_list)

    # 3. Text -> Image Retrieval
    if text_matrix_eval.size > 0 and image_matrix_eval.size > 0:
        t2i_metrics = init_retrieval_metrics_accumulator(top_k_list)
        t2i_queries = 0
        for idx, cap_key_query in enumerate(tqdm(caption_keys_eval_list, desc="Eval: Txt2Img")):
            gt_img_id = gt_capkey_to_imgid.get(cap_key_query)
            if not gt_img_id or gt_img_id not in image_ids_eval_list : continue
            query_txt_vec = text_matrix_eval[idx].reshape(1, -1)
            sims = cosine_similarity(query_txt_vec, image_matrix_eval)[0]
            ranked_indices = np.argsort(sims)[::-1]
            ranked_imgs = [image_ids_eval_list[i] for i in ranked_indices]
            metrics = compute_top_k_retrieval_metrics(ranked_imgs, [gt_img_id], top_k_list)
            accumulate_retrieval_metrics(t2i_metrics, metrics)
            t2i_queries += 1
        report_retrieval_metrics("Text-to-Image (Final Eval)", t2i_metrics, t2i_queries, top_k_list)
    print("\n✅ Final Retrieval Accuracy Calculation Complete.")

In [35]:
from sklearn.metrics.pairwise import cosine_similarity
# --- Example Usage of Final Evaluation (after Block 5 saves embeddings) ---
if __name__ == '__main__': # This check ensures this part only runs when script is executed directly
    # Paths where embeddings were saved by Block 5
    SAVED_IMG_EMB_PATH = "/content/sample_data/Flickr8k/final_embeddings/final_image_embeddings.pt"
    SAVED_TXT_EMB_PATH = "/content/sample_data/Flickr8k/final_embeddings/final_text_embeddings.pt"

    print(f"\n--- Running Final Retrieval Accuracy Calculation Example ---")

    loaded_final_image_embeddings = {}
    loaded_final_text_embeddings = {}

    if os.path.exists(SAVED_IMG_EMB_PATH):
        loaded_final_image_embeddings = torch.load(SAVED_IMG_EMB_PATH, map_location='cpu')
        print(f"Loaded {len(loaded_final_image_embeddings)} final image embeddings.")
    else:
        print(f"Error: Final Image embeddings file not found at {SAVED_IMG_EMB_PATH}.")

    if os.path.exists(SAVED_TXT_EMB_PATH):
        loaded_final_text_embeddings = torch.load(SAVED_TXT_EMB_PATH, map_location='cpu')
        print(f"Loaded {len(loaded_final_text_embeddings)} final text embeddings.")
    else:
        print(f"Error: Final Text embeddings file not found at {SAVED_TXT_EMB_PATH}.")

    if loaded_final_image_embeddings and loaded_final_text_embeddings:
        run_final_retrieval_evaluation(
            image_embeddings_dict_loaded=loaded_final_image_embeddings,
            text_embeddings_dict_loaded=loaded_final_text_embeddings,
            top_k_list=TOP_K_VALUES_EVAL, # k=1 to 5
            calc_device = cfg.device # Use device from cfg for calculations if GPU available
        )
    else:
        print("Cannot run final evaluation due to missing loaded embeddings.")


--- Running Final Retrieval Accuracy Calculation Example ---
Loaded 1618 final image embeddings.
Loaded 8090 final text embeddings.
Preparing matrices for final evaluation...
Evaluation Image matrix shape: (1618, 256)
Evaluation Text matrix shape: (8090, 256)
Built GT map for 1618 images for evaluation.


Eval: Img2Txt:   0%|          | 0/1618 [00:00<?, ?it/s]


--- Image-to-Text (Final Eval) Retrieval Results (1618 Queries) ---
Top-1: Accuracy = 0.0012 | Precision = 0.0012 | Recall = 0.0002
Top-2: Accuracy = 0.0025 | Precision = 0.0012 | Recall = 0.0005
Top-3: Accuracy = 0.0043 | Precision = 0.0014 | Recall = 0.0009
Top-4: Accuracy = 0.0049 | Precision = 0.0012 | Recall = 0.0010
Top-5: Accuracy = 0.0049 | Precision = 0.0012 | Recall = 0.0012


Eval: Txt2Txt:   0%|          | 0/8090 [00:00<?, ?it/s]


--- Text-to-Text (Final Eval) Retrieval Results (8090 Queries) ---
Top-1: Accuracy = 0.1540 | Precision = 0.1540 | Recall = 0.0385
Top-2: Accuracy = 0.2058 | Precision = 0.1195 | Recall = 0.0597
Top-3: Accuracy = 0.2410 | Precision = 0.0990 | Recall = 0.0743
Top-4: Accuracy = 0.2667 | Precision = 0.0850 | Recall = 0.0850
Top-5: Accuracy = 0.2886 | Precision = 0.0748 | Recall = 0.0935


Eval: Txt2Img:   0%|          | 0/8090 [00:00<?, ?it/s]


--- Text-to-Image (Final Eval) Retrieval Results (8090 Queries) ---
Top-1: Accuracy = 0.0012 | Precision = 0.0012 | Recall = 0.0012
Top-2: Accuracy = 0.0025 | Precision = 0.0012 | Recall = 0.0025
Top-3: Accuracy = 0.0042 | Precision = 0.0014 | Recall = 0.0042
Top-4: Accuracy = 0.0051 | Precision = 0.0013 | Recall = 0.0051
Top-5: Accuracy = 0.0061 | Precision = 0.0012 | Recall = 0.0061

✅ Final Retrieval Accuracy Calculation Complete.


In [36]:
!pip freeze > "/content/sample_data/Flickr8k/requirements.txt"