In [None]:
#| default_exp evaluation.metrics

# Evaluation Metrics

> Functions to calculate cross-modal retrieval (R@k, MR) and zero-shot classification accuracy for Indic-CLIP.

## Colab Setup

In [None]:
#| hide
# Mount Google Drive (Optional, but recommended for persistent storage)
from pathlib import Path
import sys
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
    # Define PROJECT_ROOT for Colab
    PROJECT_ROOT = Path('/content/drive/MyDrive/Indic-Clip') # Adjust path if needed
    if not PROJECT_ROOT.exists():
        print(f"Warning: Project directory not found at {PROJECT_ROOT}. Please ensure it exists.")
    else:
        # Add project root to sys.path
        if str(PROJECT_ROOT) not in sys.path:
            sys.path.insert(0, str(PROJECT_ROOT))
            print(f"Added {PROJECT_ROOT} to sys.path")
        # Change current working directory
        os.chdir(PROJECT_ROOT)
        print(f"Changed working directory to: {os.getcwd()}")

except ModuleNotFoundError:
    print("Not running in Colab, skipping Drive mount.")
    # Define PROJECT_ROOT for local execution (adjust if needed)
    PROJECT_ROOT = Path.cwd()
    if PROJECT_ROOT.name == 'nbs': PROJECT_ROOT = PROJECT_ROOT.parent
    print(f"Running locally. Project root assumed: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
         sys.path.insert(0, str(PROJECT_ROOT))
         print(f"Added {PROJECT_ROOT} to sys.path")
except Exception as e:
    print(f"An error occurred during Colab setup: {e}")
    PROJECT_ROOT = Path('.').resolve()
    print(f"Defaulting project root to current dir: {PROJECT_ROOT}")

In [None]:
#| hide
# Install requirements if needed (especially in Colab)
# !pip install -qr requirements.txt
# !pip install scikit-learn # Ensure sklearn is installed

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import torch.nn.functional as F
import numpy as np
import logging
from typing import List, Dict, Callable
from fastcore.all import *
from sklearn.metrics import accuracy_score

# --- Project Imports ---
try:
    from indic_clip.core import get_logger, setup_logging
    from indic_clip.model.clip import IndicCLIP # Needed for type hinting and model methods
    from indic_clip.data.tokenization import IndicBERTTokenizer # Needed for type hinting
except ModuleNotFoundError:
    print("Could not import project modules in 11_evaluation_metrics.ipynb. Ensure path is correct.")
    # Define dummy logger if core not found
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    def get_logger(name): return logging.getLogger(name)
    def setup_logging(): pass
    # Define dummy classes for type hints if needed
    class IndicCLIP(torch.nn.Module): pass
    class IndicBERTTokenizer: pass

logger = get_logger(__name__)

## Cross-Modal Retrieval Metrics

In [None]:
#| export
def calculate_retrieval_metrics(
    image_features: torch.Tensor, 
    text_features: torch.Tensor, 
    logit_scale: torch.Tensor, 
    k_values: List[int] = [1, 5, 10]
) -> Dict[str, float]:
    """
    Calculates Recall@k (R@k) and Mean Recall (MR) for image-to-text and text-to-image retrieval.

    Assumes image_features and text_features are L2 normalized and correspond to paired data
    along the batch dimension (i.e., image_i matches text_i).

    Args:
        image_features (torch.Tensor): Tensor of normalized image features (B, EmbeddingDim).
        text_features (torch.Tensor): Tensor of normalized text features (B, EmbeddingDim).
        logit_scale (torch.Tensor): The temperature scaling factor (scalar tensor, already exponentiated).
        k_values (List[int]): List of k values for which to calculate Recall@k.

    Returns:
        Dict[str, float]: A dictionary containing calculated metrics like 'i2t_r@1', 't2i_r@5', 'mean_recall'.
    """
    metrics = {}
    device = image_features.device

    # Calculate similarity matrix
    similarity = logit_scale * image_features @ text_features.t()
    num_samples = similarity.shape[0]
    
    # --- Image-to-Text Retrieval ---
    # Find the top k text indices for each image query
    i2t_indices = similarity.topk(max(k_values), dim=1)[1]
    # Create ground truth labels (image i should match text i)
    gt_labels = torch.arange(num_samples, device=device).view(-1, 1)
    
    # Check if the ground truth label is within the top k predictions
    i2t_correct = i2t_indices == gt_labels

    # Calculate R@k for image-to-text
    all_recalls = []
    for k in k_values:
        recall_k = i2t_correct[:, :k].any(dim=1).float().mean().item()
        metrics[f'i2t_r@{k}'] = recall_k
        all_recalls.append(recall_k)
        logger.debug(f"i2t R@{k}: {recall_k:.4f}")

    # --- Text-to-Image Retrieval ---
    # Find the top k image indices for each text query (using transpose)
    t2i_indices = similarity.t().topk(max(k_values), dim=1)[1]
    # Ground truth labels are the same (text i should match image i)
    
    # Check if the ground truth label is within the top k predictions
    t2i_correct = t2i_indices == gt_labels
    
    # Calculate R@k for text-to-image
    for k in k_values:
        recall_k = t2i_correct[:, :k].any(dim=1).float().mean().item()
        metrics[f't2i_r@{k}'] = recall_k
        all_recalls.append(recall_k)
        logger.debug(f"t2i R@{k}: {recall_k:.4f}")

    # --- Mean Recall (MR) ---
    metrics['mean_recall'] = np.mean(all_recalls) if all_recalls else 0.0
    logger.debug(f"Mean Recall (MR): {metrics['mean_recall']:.4f}")

    return metrics

## Zero-Shot Classification Accuracy

In [None]:
#| export
def calculate_zeroshot_accuracy(
    image_features: torch.Tensor,
    image_labels: List[int] | np.ndarray | torch.Tensor,
    class_names: List[str],
    templates: List[str] | Callable[[str], str],
    model: IndicCLIP,
    tokenizer: IndicBERTTokenizer,
) -> float:
    """
    Calculates the Top-1 Zero-Shot Classification accuracy.

    Args:
        image_features (torch.Tensor): Tensor of normalized image features (B, EmbeddingDim).
        image_labels (List[int] | np.ndarray | torch.Tensor): Ground truth labels for the images (indices corresponding to class_names).
        class_names (List[str]): List of class names.
        templates (List[str] | Callable[[str], str]): List of prompt templates (e.g., "a photo of {}") 
                                                       or a function that takes a class name and returns a prompt.
        model (IndicCLIP): The trained IndicCLIP model instance.
        tokenizer (IndicBERTTokenizer): The tokenizer instance used by the model.

    Returns:
        float: The Top-1 zero-shot classification accuracy.
    """
    device = image_features.device
    num_classes = len(class_names)
    model.eval() # Ensure model is in evaluation mode

    logger.info(f"Starting zero-shot classification for {num_classes} classes.")

    # 1. Generate and encode text prompts for all classes
    all_prompts = []
    if callable(templates):
        for classname in class_names:
            all_prompts.append(templates(classname))
    else: # templates is a list of strings
        for template in templates:
            for classname in class_names:
                all_prompts.append(template.format(classname))
    
    logger.debug(f"Generated {len(all_prompts)} prompts.")

    with torch.no_grad(): # No need for gradients during text encoding
        # Tokenize prompts - assumes tokenizer handles batching
        tokenized_prompts = tokenizer.tokenize(all_prompts)
        # Move tensors within the tokenized output to the model's device
        tokenized_prompts = {k: v.to(device) for k, v in tokenized_prompts.items() if isinstance(v, torch.Tensor)}
        
        # Encode text prompts
        # Note: model.encode_text expects a dictionary like {'input_ids': ..., 'attention_mask': ...}
        # Ensure tokenizer output format matches this expectation.
        text_embeddings = model.encode_text(tokenized_prompts)
        
        # Normalize text embeddings
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        # Average embeddings if multiple templates were used
        if isinstance(templates, list) and len(templates) > 1:
            num_templates = len(templates)
            text_embeddings = text_embeddings.view(num_templates, num_classes, -1).mean(dim=0)
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1) # Re-normalize after averaging
        
        logger.debug(f"Encoded text embeddings shape: {text_embeddings.shape}")
            
        # 2. Calculate Similarity and Predict
        # Image features should already be normalized
        # Calculate cosine similarity (dot product of normalized features)
        similarity = image_features @ text_embeddings.t() # Shape: [B, NumClasses]
        logger.debug(f"Calculated similarity shape: {similarity.shape}")

        # Get predictions (index of the class with highest similarity)
        predictions = similarity.argmax(dim=1).cpu().numpy()
        logger.debug(f"Predictions shape: {predictions.shape}")

    # 3. Calculate Accuracy
    # Ensure labels are in numpy format
    if isinstance(image_labels, torch.Tensor):
        image_labels = image_labels.cpu().numpy()
    elif isinstance(image_labels, list):
        image_labels = np.array(image_labels)
        
    logger.debug(f"Ground truth labels shape: {image_labels.shape}")

    accuracy = accuracy_score(image_labels, predictions)
    logger.info(f"Zero-shot Top-1 Accuracy: {accuracy:.4f}")

    return accuracy

## Example Usage & Testing

In [None]:
#| eval: false
# Example usage for calculate_retrieval_metrics
if __name__ == '__main__':
    print("--- Testing Retrieval Metrics ---")
    B = 4 # Batch size
    D = 512 # Embedding dimension
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create dummy normalized features
    img_feat = F.normalize(torch.randn(B, D, device=device), dim=-1)
    txt_feat = F.normalize(torch.randn(B, D, device=device), dim=-1)
    logit_scale = torch.ones([], device=device) * 100 # Use a high scale for testing
    k_vals = [1, 2]

    # Make diagonal elements have higher similarity for testing
    for i in range(B):
        txt_feat[i] = img_feat[i] * 0.8 + F.normalize(torch.randn(D, device=device), dim=-1) * 0.2
    txt_feat = F.normalize(txt_feat, dim=-1)

    metrics = calculate_retrieval_metrics(img_feat, txt_feat, logit_scale, k_values=k_vals)
    print(f"Calculated Retrieval Metrics: {metrics}")
    assert 'i2t_r@1' in metrics
    assert 't2i_r@1' in metrics
    assert 'mean_recall' in metrics

    print("\n--- Testing Zero-Shot Classification (Requires Mock Model/Tokenizer) ---")

    # Define Mock classes for testing if modules weren't imported
    class MockTokenizer:
        def tokenize(self, texts):
            # Simple tokenization: just create dummy tensors
            max_len = 10
            ids = torch.randint(0, 1000, (len(texts), max_len))
            mask = torch.ones_like(ids)
            return {'input_ids': ids, 'attention_mask': mask}

    class MockIndicCLIP(torch.nn.Module):
        def __init__(self, embed_dim):
            super().__init__()
            self.embed_dim = embed_dim
            # Dummy parameters to allow moving to device
            self.dummy_param = torch.nn.Parameter(torch.empty(1))

        def encode_text(self, tokenized_input):
            bs = tokenized_input['input_ids'].shape[0]
            # Return random normalized embeddings
            return F.normalize(torch.randn(bs, self.embed_dim, device=self.dummy_param.device), dim=-1)
        
        def eval(self):
            pass # Mock eval mode

    # Setup dummy data for zero-shot
    B_zs = 4 # Batch size for zero-shot test
    img_feat_zs = F.normalize(torch.randn(B_zs, D, device=device), dim=-1)
    img_labels_zs = [0, 1, 0, 2] # Example labels (indices)
    class_names_zs = ['cat', 'dog', 'bird']
    templates_zs = ["a photo of a {}"]
    mock_model = MockIndicCLIP(D).to(device)
    mock_tokenizer = MockTokenizer()

    accuracy = calculate_zeroshot_accuracy(
        img_feat_zs,
        img_labels_zs,
        class_names_zs,
        templates_zs,
        mock_model,
        mock_tokenizer
    )
    print(f"Calculated Zero-Shot Accuracy (dummy): {accuracy:.4f}")
    assert isinstance(accuracy, float) and 0.0 <= accuracy <= 1.0

    print("\nTests completed.")

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # Run this in terminal to export