# Initial Evaluation Run

> Load the trained Indic-CLIP model and evaluate its performance, focusing on cross-modal retrieval metrics on the validation set.

## Setup and Imports

In [None]:
#| hide
# Add project root to sys.path to allow importing project modules
from pathlib import Path
import sys
import os

# Determine project root based on environment (colab vs local)
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    if not Path('/content/drive').exists(): drive.mount('/content/drive')
    if Path('/content/drive/MyDrive').exists():
       project_parent = '/content/drive/MyDrive/Indic-Clip' # Adjust if your path differs
       print(f"Assuming Project Directory in Google Drive: {project_parent}")
    else:
        project_parent = '/content/indic-clip' # Adjust if cloned to /content
        print(f"Assuming Project Directory in Colab /content: {project_parent}")
else:
    # Try to find the project root assuming script is in 'nbs' or similar
    current_path = Path.cwd()
    if current_path.name == 'nbs':
         project_parent = str(current_path.parent)
    else:
         # Assume current dir is project root if not in 'nbs'
         project_parent = str(current_path) 
    print(f"Assuming Project Directory (Local): {project_parent}")

project_path = Path(project_parent)
if project_path.exists() and str(project_path) not in sys.path:
    sys.path.insert(0, str(project_path))
    print(f"Added {project_path} to sys.path")
    # Change working directory to project root for consistency
    try:
        os.chdir(project_path)
        print(f"Changed working directory to: {os.getcwd()}")
    except Exception as e:
        print(f"Warning: Could not change directory to {project_path}. Error: {e}")
else:
    print(f"Project path {project_path} not found or already in sys.path.")

# Verify import after path adjustment
try:
    import indic_clip.core
    print("Imported indic_clip.core successfully.")
except ModuleNotFoundError:
    print("ERROR: Still cannot find indic_clip.core. Ensure project structure and path are correct.")


In [None]:
#| hide
# Install requirements if needed (e.g., in Colab)
# !pip install -qr requirements.txt
# !pip install scikit-learn # For accuracy metrics if used

In [None]:
import torch
import pandas as pd
import numpy as np
import logging
from pathlib import Path
from tqdm.notebook import tqdm

# Project specific imports
from indic_clip.core import (
    get_logger, setup_logging, CHECKPOINT_PATH, TOKENIZER_PATH,
    PROCESSED_DATA_PATH, DEFAULT_IMAGE_SIZE, DEFAULT_EMBED_DIM, PRETRAINED_TOKENIZER_NAME
)
from indic_clip.data.creation import IndicCLIPDataBlock, get_indic_clip_items
from indic_clip.data.tokenization import IndicBERTTokenizer
from indic_clip.model.clip import IndicCLIP # Needed for type hints in load_indic_clip_model
from indic_clip.inference import load_indic_clip_model #, extract_image_features, extract_text_features, compute_similarity
from indic_clip.evaluation.metrics import calculate_retrieval_metrics #, calculate_zeroshot_accuracy
# from indic_clip.evaluation.benchmarks import load_benchmark_data, create_zeroshot_dataloader, DEFAULT_PROMPT_TEMPLATES_HI, DEFAULT_ZS_CATEGORIES_HI # For ZS evaluation

setup_logging()
logger = get_logger(__name__)

## Configuration

In [None]:
# --- Evaluation Configuration ---
checkpoint_name = 'best_valid_loss.pth' # <<< Name of the checkpoint file to evaluate (e.g., from training)
model_vision_backbone = 'resnet50'      # <<< Vision backbone used during training
model_text_backbone = PRETRAINED_TOKENIZER_NAME # <<< Text backbone used during training
model_embed_dim = 512                     # <<< Embedding dimension used during training

# Data configuration (should match training validation split)
processed_data_path = PROCESSED_DATA_PATH / 'filtered_data.jsonl'
tokenizer_path = TOKENIZER_PATH
img_size = DEFAULT_IMAGE_SIZE
max_seq_len = 128
valid_pct = 0.25 # <<< IMPORTANT: Must match the valid_pct used during training!
seed = 42      # <<< IMPORTANT: Must match the seed used during training!
batch_size = 64 # Batch size for evaluation inference (adjust based on GPU memory)
num_workers = 4

# Retrieval metric config
k_values = [1, 5, 10]

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Construct full checkpoint path
checkpoint_file = CHECKPOINT_PATH / checkpoint_name
logger.info(f"Attempting to load checkpoint: {checkpoint_file}")

## Load Model and Tokenizer

In [None]:
model = None
tokenizer = None
try:
    # Load tokenizer first (needed for model config)
    tokenizer = IndicBERTTokenizer.load_tokenizer(tokenizer_path, max_length=max_seq_len)
    logger.info(f"Tokenizer loaded from {tokenizer_path}")

    # Prepare model configuration (required by load_indic_clip_model)
    model_config = {
        'embed_dim': model_embed_dim,
        'vision_model_name': model_vision_backbone,
        'vision_pretrained': False, # Pretrained flag doesn't affect loading state dict
        'text_model_name': model_text_backbone,
        'text_pretrained': False,
        'tokenizer': tokenizer # Pass the loaded tokenizer instance
    }

    # Load the model
    model = load_indic_clip_model(
        checkpoint_path=checkpoint_file,
        model_config=model_config,
        device=device
    )
    logger.info("Model loaded successfully.")

except FileNotFoundError as e:
    logger.error(f"Error: {e}. Please ensure the checkpoint file exists and the path is correct.")
except Exception as e:
    logger.error(f"An unexpected error occurred loading model/tokenizer: {e}", exc_info=True)

## Prepare Validation DataLoader

In [None]:
valid_dl = None
if model and tokenizer: # Proceed only if model and tokenizer loaded successfully
    try:
        logger.info("Loading data items for validation set...")
        items_df = get_indic_clip_items(data_path=processed_data_path)

        if not items_df.empty:
            logger.info("Instantiating DataBlock for validation...")
            # Use augmentations=False for validation
            eval_dblock = IndicCLIPDataBlock(
                tokenizer_name_or_path=model_text_backbone,
                tokenizer_save_path=tokenizer_path,
                max_length=max_seq_len,
                img_size=img_size,
                valid_pct=valid_pct, # Use the same split as training
                seed=seed,           # Use the same seed as training
                batch_size=batch_size,
                num_workers=num_workers,
                use_augmentations=False # NO augmentations for validation
            )

            logger.info("Creating DataLoaders to get validation set...")
            dls = eval_dblock.get_dataloaders(items_df, shuffle_train=False) # Shuffle=False for deterministic validation set
            valid_dl = dls.valid # Extract the validation dataloader
            logger.info(f"Validation DataLoader created with {len(valid_dl.dataset)} items and {len(valid_dl)} batches.")
        else:
            logger.error("Failed to load items dataframe. Cannot create DataLoader.")

    except Exception as e:
        logger.error(f"An error occurred preparing the DataLoader: {e}", exc_info=True)
else:
    logger.warning("Model or Tokenizer not loaded. Skipping DataLoader creation.")

## Extract Features from Validation Set

In [None]:
all_image_features = []
all_text_features = []
val_image_features = None
val_text_features = None

if model and valid_dl:
    logger.info("Extracting features from validation set...")
    model.eval() # Ensure model is in eval mode
    with torch.no_grad():
        for batch in tqdm(valid_dl, desc="Extracting Features"):
            # Batch structure from IndicCLIPDataBlock: (img_tensor, (text_ids, text_mask)), _dummy_cat_
            img_batch, txt_tuple_batch, _ = batch
            img_batch = img_batch.to(device)
            txt_ids_batch = txt_tuple_batch[0].to(device)
            txt_mask_batch = txt_tuple_batch[1].to(device)

            try:
                # Use model's specific encoding methods
                img_feat = model.encode_image(img_batch)
                txt_feat = model.encode_text(txt_ids_batch, txt_mask_batch)

                all_image_features.append(img_feat.cpu()) # Store on CPU
                all_text_features.append(txt_feat.cpu())

            except Exception as e:
                logger.error(f"Error extracting features for a batch: {e}", exc_info=True)
                # Decide whether to skip batch or halt

    if all_image_features and all_text_features:
        # Concatenate all features into single tensors
        val_image_features = torch.cat(all_image_features)
        val_text_features = torch.cat(all_text_features)
        logger.info(f"Feature extraction complete. Image features shape: {val_image_features.shape}, Text features shape: {val_text_features.shape}")
    else:
        logger.error("Feature extraction failed or resulted in empty lists.")
else:
    logger.warning("Model or DataLoader not available. Skipping feature extraction.")

## Calculate and Report Retrieval Metrics

In [None]:
retrieval_results = None
if val_image_features is not None and val_text_features is not None and model is not None:
    logger.info("Calculating retrieval metrics...")
    try:
        # Get the logit scale from the model
        with torch.no_grad():
            logit_scale = model.logit_scale.exp().cpu() # Use the trained logit scale

        # Move features to the correct device for calculation
        val_image_features = val_image_features.to(device)
        val_text_features = val_text_features.to(device)

        retrieval_results = calculate_retrieval_metrics(
            image_features=val_image_features,
            text_features=val_text_features,
            logit_scale=logit_scale.to(device),
            k_values=k_values
        )

        print("\n--- Retrieval Metrics (Validation Set) ---")
        if retrieval_results:
            for key, value in retrieval_results.items():
                print(f"  {key}: {value:.4f}")
        else:
            print("  No retrieval results calculated.")
        print("------------------------------------------")
        # Suggest saving to Results.md
        print("\nConsider adding these results to Results.md")

    except Exception as e:
        logger.error(f"Error calculating retrieval metrics: {e}", exc_info=True)

else:
    logger.warning("Features or model not available. Skipping retrieval metric calculation.")

## Zero-Shot Classification (Placeholder)

In [None]:
# --- Zero-Shot Evaluation (Example/Placeholder) ---
# This requires benchmark data to be prepared in the format expected by
# load_benchmark_data and create_zeroshot_dataloader (from 12_evaluation_benchmarks.ipynb)

# benchmark_name_zs = 'flickr30k_hi' # Or another benchmark name
# df_zs = load_benchmark_data(benchmark_name_zs)

# if df_zs is not None and model and tokenizer:
#     logger.info(f"Preparing Zero-Shot DataLoader for benchmark: {benchmark_name_zs}")
#     # Assuming benchmark images are relative to BENCHMARK_DATA_PATH / benchmark_name_zs
#     zs_dl = create_zeroshot_dataloader(
#         df=df_zs,
#         benchmark_base_path=(BENCHMARK_DATA_PATH / benchmark_name_zs),
#         # Ensure label_col matches the column with integer class indices in your benchmark CSV/JSONL
#         label_col='label_idx', # <<< Adjust if necessary
#         batch_size=batch_size,
#         num_workers=num_workers
#     )
#
#     if zs_dl:
#         logger.info(f"Extracting image features for Zero-Shot benchmark...")
#         all_zs_image_features = []
#         all_zs_labels = []
#         with torch.no_grad():
#             for batch in tqdm(zs_dl, desc=f"ZS Features ({benchmark_name_zs})"):
#                 img_batch, lbl_batch = batch
#                 img_batch = img_batch.to(device)
#                 try:
#                     img_feat = model.encode_image(img_batch)
#                     all_zs_image_features.append(img_feat.cpu())
#                     all_zs_labels.append(lbl_batch.cpu()) # Labels are already on CPU from DataLoader
#                 except Exception as e:
#                     logger.error(f"Error extracting ZS features for a batch: {e}", exc_info=True)
#
#         if all_zs_image_features:
#             zs_image_features = torch.cat(all_zs_image_features).to(device)
#             zs_image_labels = torch.cat(all_zs_labels).numpy()
#
#             # Define class names and templates relevant to the benchmark
#             # zs_class_names = # Load appropriate class names for the benchmark
#             # zs_templates = DEFAULT_PROMPT_TEMPLATES_HI # Or other relevant templates
#
#             # logger.info("Calculating Zero-Shot accuracy...")
#             # accuracy = calculate_zeroshot_accuracy(
#             #     image_features=zs_image_features,
#             #     image_labels=zs_image_labels,
#             #     class_names=zs_class_names,
#             #     templates=zs_templates,
#             #     model=model,
#             #     tokenizer=tokenizer
#             # )
#             # print(f"\n--- Zero-Shot Accuracy ({benchmark_name_zs}) ---")
#             # print(f"  Top-1 Accuracy: {accuracy:.4f}")
#             # print("-----------------------------------")
#             print("\nZero-Shot evaluation code is placeholder. Uncomment and adapt with actual class names and benchmark data.")
#         else:
#              logger.error(f"Failed to extract features for Zero-Shot benchmark {benchmark_name_zs}.")
#     else:
#         logger.warning(f"Could not create Zero-Shot DataLoader for {benchmark_name_zs}.")
# else:
#     logger.warning(f"Could not load benchmark data for {benchmark_name_zs} or model/tokenizer not available. Skipping Zero-Shot evaluation.")

logger.info("Evaluation script finished.")