# Inference API

> Provides high-level functions for loading a trained Indic-CLIP model and performing inference (feature extraction, similarity computation).

In [None]:
#| default_exp inference

## Colab Setup

In [None]:
#| hide
# Mount Google Drive (Optional, but recommended for persistent storage)
from pathlib import Path
import sys
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
    # Define PROJECT_ROOT for Colab
    PROJECT_ROOT = Path('/content/drive/MyDrive/Indic-Clip') # Adjust path if needed
    if not PROJECT_ROOT.exists():
        print(f"Warning: Project directory not found at {PROJECT_ROOT}. Please ensure it exists.")
    else:
        # Add project root to sys.path
        if str(PROJECT_ROOT) not in sys.path:
            sys.path.insert(0, str(PROJECT_ROOT))
            print(f"Added {PROJECT_ROOT} to sys.path")
        # Change current working directory
        os.chdir(PROJECT_ROOT)
        print(f"Changed working directory to: {os.getcwd()}")

except ModuleNotFoundError:
    print("Not running in Colab, skipping Drive mount.")
    # Define PROJECT_ROOT for local execution (adjust if needed)
    PROJECT_ROOT = Path.cwd()
    if PROJECT_ROOT.name == 'nbs': PROJECT_ROOT = PROJECT_ROOT.parent
    print(f"Running locally. Project root assumed: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
         sys.path.insert(0, str(PROJECT_ROOT))
         print(f"Added {PROJECT_ROOT} to sys.path")
except Exception as e:
    print(f"An error occurred during Colab setup: {e}")
    PROJECT_ROOT = Path('.').resolve()
    print(f"Defaulting project root to current dir: {PROJECT_ROOT}")

In [None]:
#| hide
# Install requirements if needed (especially in Colab)
# !pip install -qr requirements.txt

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import logging
from pathlib import Path
from PIL import Image
from typing import List, Union, Dict, Optional

from fastcore.all import *
from fastai.vision.all import (
    PILImage, Resize, ToTensor, Normalize, imagenet_stats, Learner, DataLoaders, Datasets, TfmdLists
)

# --- Project Imports ---
try:
    from indic_clip.core import (
        get_logger, setup_logging, DEFAULT_IMAGE_SIZE, DEFAULT_EMBED_DIM, PRETRAINED_TOKENIZER_NAME, TOKENIZER_PATH,
        CHECKPOINT_PATH
    )
    from indic_clip.model.clip import IndicCLIP
    from indic_clip.data.tokenization import IndicBERTTokenizer
    # Import other necessary components if needed, e.g., specific DataLoaders
except ModuleNotFoundError:
    print("Could not import project modules in 13_inference.ipynb. Using fallbacks.")
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    def get_logger(name): return logging.getLogger(name)
    def setup_logging(): pass
    DEFAULT_IMAGE_SIZE = 224
    DEFAULT_EMBED_DIM = 768
    PRETRAINED_TOKENIZER_NAME = "ai4bharat/indic-bert"
    TOKENIZER_PATH = Path('./models/tokenizer')
    CHECKPOINT_PATH = Path('./models/checkpoints')
    # Dummy classes
    class IndicCLIP(torch.nn.Module):
        def __init__(self, *args, **kwargs): super().__init__()
        def encode_image(self, x): return torch.randn(x.shape[0], kwargs.get('embed_dim', 768))
        def encode_text(self, *args, **kwargs): return torch.randn(args[0].shape[0], kwargs.get('embed_dim', 768))
        def eval(self): pass
        def to(self, device): return self
        logit_scale = torch.nn.Parameter(torch.tensor(1.0))

    class IndicBERTTokenizer:
        @classmethod
        def load_tokenizer(cls, *args, **kwargs): return cls()
        def tokenize(self, texts):
            if isinstance(texts, str): texts = [texts]
            bs = len(texts)
            seq_len = 32
            return {'input_ids': torch.randint(0, 1000, (bs, seq_len)),
                    'attention_mask': torch.ones((bs, seq_len), dtype=torch.long)}

logger = get_logger(__name__)

## Core Inference Functions

In [None]:
#| export
def load_indic_clip_model(
    checkpoint_path: Union[str, Path],
    model_config: Optional[Dict] = None, # Optional: pass model config if not saved with ckpt
    device: Optional[Union[str, torch.device]] = None
) -> IndicCLIP:
    """
    Loads a trained IndicCLIP model from a checkpoint file.

    Args:
        checkpoint_path (Union[str, Path]): Path to the .pth checkpoint file.
        model_config (Optional[Dict]): Dictionary containing model configuration args 
                                      (embed_dim, vision_model_name, text_model_name, etc.) 
                                      needed to instantiate the model structure. 
                                      If None, attempts to infer from defaults or common patterns.
        device (Optional[Union[str, torch.device]]): Device to load the model onto ('cuda', 'cpu', etc.). 
                                                      Defaults to CUDA if available, else CPU.

    Returns:
        IndicCLIP: The loaded model instance in evaluation mode.

    Raises:
        FileNotFoundError: If the checkpoint file does not exist.
        TypeError: If model_config is needed but not provided.
        RuntimeError: If there's an issue loading the state dict.
    """
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Loading model onto device: {device}")

    # --- Instantiate Model Structure ---
    # Best practice: Save model config with checkpoint. If not, require it here.
    if model_config is None:
        # Attempt to use defaults - this might be incorrect!
        logger.warning("Model config not provided. Attempting to load using default parameters. This may fail or lead to unexpected behavior if the checkpoint was saved with different settings.")
        model_config = {
            'embed_dim': DEFAULT_EMBED_DIM,
            'vision_model_name': 'vit_base_patch16_224', # Make sure this matches your training!
            'vision_pretrained': False, # Pretrained flag doesn't matter much for loading weights
            'text_model_name': PRETRAINED_TOKENIZER_NAME, # Make sure this matches!
            'text_pretrained': False,
            # Tokenizer instance needed if custom tokens were added, load separately!
            'tokenizer': IndicBERTTokenizer.load_tokenizer(TOKENIZER_PATH) # Assume default path
        }
        # If tokenizer is required and fails to load, raise an error
        if not model_config['tokenizer']:
             raise ValueError("Could not load default tokenizer required for model instantiation.")
    
    # Ensure tokenizer is present in config for IndicCLIP constructor
    if 'tokenizer' not in model_config or model_config['tokenizer'] is None:
         logger.warning("Loading model without providing a tokenizer instance. Trying to load default.")
         try:
             model_config['tokenizer'] = IndicBERTTokenizer.load_tokenizer(TOKENIZER_PATH)
             if not model_config['tokenizer']:
                 raise ValueError("Failed to load default tokenizer.")
         except Exception as e:
             raise ValueError(f"Could not load default tokenizer required for model instantiation from {TOKENIZER_PATH}. Please provide it in model_config. Error: {e}")
             
    try:
        model = IndicCLIP(**model_config)
        logger.info(f"Instantiated IndicCLIP model structure with config: {model_config}")
    except Exception as e:
        logger.error(f"Failed to instantiate model with provided config: {e}", exc_info=True)
        raise TypeError(f"Could not instantiate IndicCLIP model. Check model_config. Error: {e}")

    # --- Load State Dict --- 
    try:
        # Map location ensures model loads correctly regardless of where it was saved
        state_dict = torch.load(checkpoint_path, map_location='cpu') 
        
        # Handle potential keys mismatch (e.g., if saved directly or via Learner.save)
        if 'model' in state_dict:
            state_dict = state_dict['model']
        elif 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']

        # Remove potential DDP prefix 'module.'
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        # Load the weights
        model.load_state_dict(state_dict)
        logger.info(f"Successfully loaded model weights from {checkpoint_path}")

    except Exception as e:
        logger.error(f"Error loading state dict from {checkpoint_path}: {e}", exc_info=True)
        raise RuntimeError(f"Failed to load state dict. Ensure checkpoint is valid and config matches. Error: {e}")

    model.eval() # Set model to evaluation mode
    model.to(device) # Move model to the target device
    logger.info("Model set to evaluation mode.")
    
    return model

In [None]:
#| export
def _get_image_transform(img_size: int = DEFAULT_IMAGE_SIZE):
    """Creates a standard image transform pipeline for inference."""
    # Use transforms similar to validation set
    return TfmCompose([
        Resize(img_size, method='squish'), # Or 'pad'
        ToTensor(),
        Normalize.from_stats(*imagenet_stats)
    ])

In [None]:
#| export
@torch.no_grad()
def extract_image_features(
    model: IndicCLIP,
    image_input: Union[str, Path, PIL.Image.Image, torch.Tensor, List[Union[str, Path, PIL.Image.Image, torch.Tensor]]],
    img_size: int = DEFAULT_IMAGE_SIZE,
    batch_size: int = 32,
    device: Optional[Union[str, torch.device]] = None
) -> torch.Tensor:
    """
    Extracts normalized features for one or more images using the IndicCLIP model.

    Args:
        model (IndicCLIP): The loaded IndicCLIP model instance.
        image_input: A single image (path, PIL Image, or preprocessed Tensor) or a list of images.
        img_size (int): The image size the model expects.
        batch_size (int): Batch size for processing multiple images.
        device (Optional[Union[str, torch.device]]): Device to use for inference. If None, uses model's device.

    Returns:
        torch.Tensor: A tensor containing the normalized image features (N, embed_dim).
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    else:
        model.to(device)

    transform = _get_image_transform(img_size)
    
    if not isinstance(image_input, list):
        image_input = [image_input]

    all_features = []
    for i in range(0, len(image_input), batch_size):
        batch_inputs = image_input[i:i+batch_size]
        processed_batch = []
        for item in batch_inputs:
            if isinstance(item, (str, Path)):
                try:
                    img = PILImage.create(item)
                except Exception as e:
                    logger.error(f"Failed to load image: {item}. Error: {e}")
                    continue # Skip this image
            elif isinstance(item, Image.Image):
                img = item
            elif isinstance(item, torch.Tensor):
                 # Assume tensor is already processed, just move to device
                 # Warning: This assumes the input tensor matches model expectations!
                 processed_batch.append(item.to(device)) 
                 continue # Skip transform for tensors
            else:
                logger.warning(f"Unsupported image input type: {type(item)}. Skipping.")
                continue
            
            # Apply transforms if not already a tensor
            try:
                processed_img = transform(img)
                processed_batch.append(processed_img)
            except Exception as e:
                 logger.error(f"Failed to transform image: {item}. Error: {e}")

        if not processed_batch:
            continue # Skip if batch is empty after errors
            
        # Stack processed images into a batch tensor
        batch_tensor = torch.stack(processed_batch).to(device)
        
        # Extract features
        features = model.encode_image(batch_tensor)
        all_features.append(features.cpu()) # Move features to CPU to accumulate

    if not all_features:
        logger.error("No image features could be extracted.")
        # Return an empty tensor with the correct embedding dimension
        embed_dim = model.visual_projection.out_features # Get embed_dim from model
        return torch.empty((0, embed_dim))
        
    return torch.cat(all_features)

In [None]:
#| export
@torch.no_grad()
def extract_text_features(
    model: IndicCLIP,
    tokenizer: IndicBERTTokenizer,
    text_input: Union[str, List[str]],
    batch_size: int = 32,
    device: Optional[Union[str, torch.device]] = None
) -> torch.Tensor:
    """
    Extracts normalized features for one or more text strings using the IndicCLIP model.

    Args:
        model (IndicCLIP): The loaded IndicCLIP model instance.
        tokenizer (IndicBERTTokenizer): The tokenizer instance compatible with the model.
        text_input (Union[str, List[str]]): A single text string or a list of strings.
        batch_size (int): Batch size for processing multiple texts.
        device (Optional[Union[str, torch.device]]): Device to use for inference. If None, uses model's device.

    Returns:
        torch.Tensor: A tensor containing the normalized text features (N, embed_dim).
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    else:
        model.to(device)

    if isinstance(text_input, str):
        text_input = [text_input]

    all_features = []
    for i in range(0, len(text_input), batch_size):
        batch_texts = text_input[i:i+batch_size]
        
        try:
            # Tokenize the batch
            inputs = tokenizer.tokenize(batch_texts)
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            
            # Extract features
            features = model.encode_text(input_ids=input_ids, attention_mask=attention_mask)
            all_features.append(features.cpu()) # Move features to CPU
        except Exception as e:
            logger.error(f"Error processing text batch starting with '{batch_texts[0][:50]}...': {e}")
            continue # Skip batch on error

    if not all_features:
        logger.error("No text features could be extracted.")
        embed_dim = model.text_projection.out_features # Get embed_dim from model
        return torch.empty((0, embed_dim))

    return torch.cat(all_features)

In [None]:
#| export
@torch.no_grad()
def compute_similarity(
    model: IndicCLIP,
    image_features: torch.Tensor,
    text_features: torch.Tensor
) -> torch.Tensor:
    """
    Computes the similarity matrix between image and text features using the model's logit scale.

    Args:
        model (IndicCLIP): The loaded IndicCLIP model instance (used for logit_scale).
        image_features (torch.Tensor): Normalized image features (N, embed_dim).
        text_features (torch.Tensor): Normalized text features (M, embed_dim).

    Returns:
        torch.Tensor: A similarity matrix of shape (N, M).
    """
    model.eval()
    device = next(model.parameters()).device # Use model's device
    
    # Ensure features are on the correct device
    image_features = image_features.to(device)
    text_features = text_features.to(device)
    
    # Get the exponentiated logit scale from the model
    logit_scale = model.logit_scale.exp()
    
    # Compute similarity (dot product of normalized features, scaled by temperature)
    # Features should already be normalized by the extraction functions
    similarity = logit_scale * image_features @ text_features.t()
    
    return similarity

## Example Usage

In [None]:
#| eval: false 
# --- Example: Load Model and Run Inference ---
if __name__ == '__main__':
    print("--- Running Inference API Example ---")
    
    # --- Configuration ---
    # !! IMPORTANT: Replace with the *actual* path to your trained checkpoint !!
    # checkpoint_name = 'best_recall_interactive.pth' # From 10_training example
    checkpoint_name = 'epoch_0.pth' # Default name if SaveModelCallback wasn't configured fully
    checkpoint_file = CHECKPOINT_PATH / checkpoint_name
    
    # !! IMPORTANT: Provide the config used during training if it wasn't saved with the checkpoint !!
    #             (If Learner.save was used, config is often not saved)
    # Example: Use the smaller ResNet18 config from the 10_training notebook test
    model_configuration = {
        'embed_dim': 512, # Example: ResNet18 from test
        'vision_model_name': 'resnet18', # Example: ResNet18 from test
        'vision_pretrained': False, # Doesn't matter for loading state_dict
        'text_model_name': PRETRAINED_TOKENIZER_NAME,
        'text_pretrained': False,
        'tokenizer': IndicBERTTokenizer.load_tokenizer(TOKENIZER_PATH) # Load the tokenizer used for training
    }
    
    # Create dummy checkpoint if it doesn't exist for demonstration
    if not checkpoint_file.exists():
        print(f"Warning: Checkpoint {checkpoint_file} not found. Creating a dummy model and saving it.")
        ensure_dir(checkpoint_file.parent)
        dummy_model = IndicCLIP(**model_configuration) 
        torch.save(dummy_model.state_dict(), checkpoint_file)
        print("Dummy checkpoint created.")

    # --- Load Model and Tokenizer ---
    try:
        loaded_model = load_indic_clip_model(checkpoint_file, model_config=model_configuration)
        loaded_tokenizer = IndicBERTTokenizer.load_tokenizer(TOKENIZER_PATH)
        print("Model and tokenizer loaded successfully.")

        # --- Prepare Sample Data ---
        # Use a real image path if available, otherwise create dummy
        sample_image_path = PROJECT_ROOT / 'sample_image.jpg' # Create this image or use a real one
        if not sample_image_path.exists():
             print(f"Creating dummy image at: {sample_image_path}")
             dummy_pil_img = Image.new('RGB', (250, 300), color='cyan')
             dummy_pil_img.save(sample_image_path)

        sample_texts = [
            "एक बिल्ली सोफे पर सो रही है।", # A cat sleeping on a sofa
            "समुद्र तट पर सूर्यास्त।",    # Sunset on a beach
            "साड़ी पहने एक महिला।"       # A woman wearing a saree
        ]

        # --- Extract Features ---
        print("\nExtracting image features...")
        image_features = extract_image_features(loaded_model, sample_image_path)
        print(f"Image features shape: {image_features.shape}")

        print("\nExtracting text features...")
        text_features = extract_text_features(loaded_model, loaded_tokenizer, sample_texts)
        print(f"Text features shape: {text_features.shape}")

        # --- Compute Similarity ---
        print("\nComputing similarity...")
        similarity_matrix = compute_similarity(loaded_model, image_features, text_features)
        print(f"Similarity matrix shape: {similarity_matrix.shape}")
        print("Similarity scores:")
        print(similarity_matrix.cpu().numpy())

        # Find best matching text for the image
        best_match_idx = similarity_matrix.argmax().item()
        print(f"\nBest matching text for the image: '{sample_texts[best_match_idx]}' (Score: {similarity_matrix[0, best_match_idx]:.4f})")

    except FileNotFoundError as e:
        print(f"\nError: {e}")
        print("Please ensure the correct checkpoint and tokenizer paths are set.")
    except Exception as e:
        print(f"\nAn unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()
        
    # Optional: Clean up dummy image
    # if 'dummy_pil_img' in locals() and sample_image_path.exists():
    #    os.remove(sample_image_path)
    #    print(f"Cleaned up dummy image: {sample_image_path}")

    print("\n--- Inference API Example Finished ---")

In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # Run this in terminal to export