# Gradio Demo Application for Indic-CLIP

> Provides an interactive interface for testing the Indic-CLIP model using Gradio, deployable to Hugging Face Spaces.

## Setup and Imports

In [None]:
#| hide
# This cell is primarily for running in Colab or similar environments.
# Make sure the project library and its dependencies are installed.

# !pip install -q gradio torch torchvision torchaudio fastai transformers timm sentencepiece Pillow imagehash scikit-learn indic-nlp-library

# Mount Google Drive (Optional, adjust PROJECT_ROOT if not using Drive)
import sys
from pathlib import Path
import os

if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_ROOT = Path('/content/drive/MyDrive/Indic-Clip') # Adjust path if needed
    # Add project root to Python path
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    # Change current directory (optional, but can help with relative paths)
    # os.chdir(PROJECT_ROOT)
    print(f"Running in Colab. Project path added: {PROJECT_ROOT}")
else:
    # Assume standard nbdev structure if not in Colab
    # Find the project root assuming 'nbs' is the current directory's parent
    if Path.cwd().name == 'nbs':
        PROJECT_ROOT = Path.cwd().parent
    else:
        PROJECT_ROOT = Path.cwd() # Assume current dir is root if not in nbs
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"Running locally. Project root assumed: {PROJECT_ROOT}")

# Ensure PROJECT_ROOT is correctly defined before proceeding
if 'PROJECT_ROOT' not in locals():
   PROJECT_ROOT = Path(".").resolve()
   print(f"Warning: PROJECT_ROOT not set by environment checks, defaulting to {PROJECT_ROOT}")

In [None]:
import gradio as gr
import torch
import numpy as np
from PIL import Image
import os
import random
import logging
from pathlib import Path
from typing import List, Tuple, Dict

# Project specific imports
from indic_clip.inference import (
    load_indic_clip_model,
    extract_image_features,
    extract_text_features,
    compute_similarity
)
from indic_clip.core import (
    get_logger,
    setup_logging,
    CHECKPOINT_PATH,
    TOKENIZER_PATH,
    PRETRAINED_TOKENIZER_NAME, # Default text model name
    DEFAULT_EMBED_DIM, # Default embedding dim
    DEFAULT_IMAGE_SIZE,
    HINDI_RAW_PATH # For locating sample images
)
from indic_clip.data.tokenization import IndicBERTTokenizer
# Import IndicCLIP class directly for type hinting, even though we load via helper
from indic_clip.model.clip import IndicCLIP

# Setup logging
setup_logging()
logger = get_logger("indic_clip_app")

## Configuration and Model Loading

In [None]:
# --- Configuration ---
# Point to the specific checkpoint you want to use for the demo
# This should ideally be the best performing model from your training runs.
CHECKPOINT_FILENAME = 'best_valid_loss.pth' # Or 'best_recall.pth', etc.
CHECKPOINT_FILE_PATH = CHECKPOINT_PATH / CHECKPOINT_FILENAME
TOKENIZER_DIR_PATH = TOKENIZER_PATH
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = DEFAULT_IMAGE_SIZE # Use image size consistent with training
TOP_K = 5 # Number of results to show for retrieval

# --- !!! IMPORTANT: Model Configuration !!! ---
# This configuration MUST match the parameters used to train the checkpoint being loaded.
# If the checkpoint was saved using `learn.save` without `with_opt=False`,
# the state dict might be nested. `load_indic_clip_model` tries to handle this,
# but providing the correct instantiation config is crucial.
# Using the config from 10_training.ipynb example run:
MODEL_CONFIGURATION = {
    'embed_dim': 512, # From training example
    'vision_model_name': 'resnet50', # From training example
    'vision_pretrained': False, # Pretrained flag doesn't matter for loading weights
    'text_model_name': PRETRAINED_TOKENIZER_NAME, # Default, assumed from training
    'text_pretrained': False,
    # The tokenizer instance will be loaded separately below and passed during load_model
    'tokenizer': None # Placeholder, will be loaded next
}

# --- Load Tokenizer ---
try:
    tokenizer = IndicBERTTokenizer.load_tokenizer(TOKENIZER_DIR_PATH)
    MODEL_CONFIGURATION['tokenizer'] = tokenizer # Add loaded tokenizer to config
    logger.info(f"Tokenizer loaded successfully from {TOKENIZER_DIR_PATH}")
except FileNotFoundError:
    logger.error(f"Tokenizer directory not found at {TOKENIZER_DIR_PATH}. Cannot start app.")
    tokenizer = None # Ensure tokenizer is None if loading failed
except Exception as e:
    logger.error(f"Error loading tokenizer: {e}", exc_info=True)
    tokenizer = None

# --- Load Model ---
model: IndicCLIP = None
if tokenizer is not None:
    try:
        if not CHECKPOINT_FILE_PATH.exists():
            logger.error(f"Checkpoint file not found: {CHECKPOINT_FILE_PATH}. Cannot load model.")
        else:
            model = load_indic_clip_model(
                checkpoint_path=CHECKPOINT_FILE_PATH,
                model_config=MODEL_CONFIGURATION,
                device=DEVICE
            )
            logger.info(f"Model loaded successfully from {CHECKPOINT_FILE_PATH} to device {DEVICE}.")
    except FileNotFoundError as e:
        logger.error(f"Error: {e}")
    except Exception as e:
        logger.error(f"Error loading model: {e}", exc_info=True)

if model is None:
    logger.critical("Model could not be loaded. The application cannot function.")
    # Optional: raise an exception or exit if model loading is critical
    # raise RuntimeError("Failed to load IndicCLIP model.")

## Sample Gallery Data (for Demo)

In [None]:
# For simplicity, we define a small, fixed gallery here.
# A real application would load these from a file or database.

# --- Text Gallery for Image-to-Text Retrieval ---
TEXT_GALLERY = [
    "एक लड़का फुटबॉल खेल रहा है।",
    "समुद्र तट पर सूर्यास्त।",
    "एक बिल्ली सोफे पर सो रही है।",
    "पारंपरिक साड़ी पहने एक महिला।",
    "एक मंदिर का प्रवेश द्वार।",
    "देवता गणेश की मूर्ति।",
    "एक व्यस्त भारतीय बाज़ार।",
    "लैपटॉप पर काम करता हुआ व्यक्ति।",
    "मेज़ पर रखी किताबों का ढेर।",
    "एक लाल रंग की स्पोर्ट्स कार।"
]

# --- Image Gallery for Text-to-Image Retrieval ---
# Assumes these images exist relative to the project root or Colab environment
# Ideally, copy these samples into a 'data/samples' directory
SAMPLE_IMAGE_DIR = PROJECT_ROOT / 'data' / 'samples'
IMAGE_GALLERY_FILES = [
    SAMPLE_IMAGE_DIR / 'cat.jpg', # Needs to exist
    SAMPLE_IMAGE_DIR / 'dog_park.jpg',
    SAMPLE_IMAGE_DIR / 'sunset_beach.jpg',
    SAMPLE_IMAGE_DIR / 'woman_saree.jpg',
    SAMPLE_IMAGE_DIR / 'temple.jpg',
    SAMPLE_IMAGE_DIR / 'ganesh.jpg',
    SAMPLE_IMAGE_DIR / 'market.jpg',
    SAMPLE_IMAGE_DIR / 'laptop.jpg',
    SAMPLE_IMAGE_DIR / 'books.jpg',
    SAMPLE_IMAGE_DIR / 'car.jpg'
]

# Ensure sample image directory exists
SAMPLE_IMAGE_DIR.mkdir(parents=True, exist_ok=True)

# Filter out images that don't exist
valid_image_gallery_files = [f for f in IMAGE_GALLERY_FILES if f.exists()]
if len(valid_image_gallery_files) < len(IMAGE_GALLERY_FILES):
    logger.warning(f"Missing some sample images. Found {len(valid_image_gallery_files)} out of {len(IMAGE_GALLERY_FILES)} expected in {SAMPLE_IMAGE_DIR}")
    IMAGE_GALLERY_FILES = valid_image_gallery_files

# Pre-encode gallery features (optional, but improves demo speed)
text_gallery_features: torch.Tensor = None
image_gallery_features: torch.Tensor = None

if model is not None and tokenizer is not None:
    try:
        logger.info("Pre-encoding text gallery features...")
        text_gallery_features = extract_text_features(model, tokenizer, TEXT_GALLERY, device=DEVICE)
        logger.info(f"Encoded {len(TEXT_GALLERY)} text gallery items.")
    except Exception as e:
        logger.error(f"Failed to pre-encode text gallery: {e}", exc_info=True)

    if valid_image_gallery_files:
        try:
            logger.info("Pre-encoding image gallery features...")
            # Pass file paths to extract_image_features
            image_gallery_features = extract_image_features(model, valid_image_gallery_files, img_size=IMAGE_SIZE, device=DEVICE)
            logger.info(f"Encoded {len(valid_image_gallery_files)} image gallery items.")
        except Exception as e:
            logger.error(f"Failed to pre-encode image gallery: {e}", exc_info=True)
else:
    logger.warning("Model or tokenizer not loaded, cannot pre-encode gallery features.")


## Gradio Interface Functions

In [None]:
def predict_text_from_image(image_input: Image.Image) -> str:
    """Gradio interface function for Image-to-Text retrieval."""
    if model is None or tokenizer is None or text_gallery_features is None:
        return "Error: Model, tokenizer, or text gallery features not loaded."
    if image_input is None:
        return "Error: Please upload an image."

    try:
        # 1. Encode the input image
        # extract_image_features handles PIL image input directly
        img_feat = extract_image_features(model, image_input, img_size=IMAGE_SIZE, device=DEVICE)

        # 2. Compute similarity with pre-encoded text gallery features
        # compute_similarity expects normalized features
        similarity = compute_similarity(model, img_feat, text_gallery_features)

        # 3. Get top K results
        scores, indices = torch.topk(similarity.squeeze(0), k=min(TOP_K, len(TEXT_GALLERY)), dim=-1)

        # 4. Format results
        results = "\n".join([f"{scores[i].item():.4f}: {TEXT_GALLERY[indices[i].item()]}" for i in range(len(indices))])
        return results

    except Exception as e:
        logger.error(f"Error in predict_text_from_image: {e}", exc_info=True)
        return f"An error occurred: {e}"

def predict_image_from_text(text_input: str) -> List[Tuple[str, str]]:
    """Gradio interface function for Text-to-Image retrieval."""
    if model is None or tokenizer is None or image_gallery_features is None or not valid_image_gallery_files:
        # Gradio expects a list for Gallery output, even on error
        return [("error.png", "Error: Model, tokenizer, or image gallery features not loaded.")]
    if not text_input or not text_input.strip():
        return [("error.png", "Error: Please enter text.")]

    try:
        # 1. Encode the input text
        txt_feat = extract_text_features(model, tokenizer, text_input, device=DEVICE)

        # 2. Compute similarity with pre-encoded image gallery features
        similarity = compute_similarity(model, image_gallery_features, txt_feat) # Note order for T2I

        # 3. Get top K results
        scores, indices = torch.topk(similarity.squeeze(-1), k=min(TOP_K, len(valid_image_gallery_files)), dim=0)

        # 4. Format results for Gradio Gallery (List of tuples: (image_path, caption))
        results = []
        for i in range(len(indices)):
            img_index = indices[i].item()
            img_path = valid_image_gallery_files[img_index]
            score = scores[i].item()
            # Use filename as caption, or potentially retrieve original text if mapped
            caption = f"{score:.4f}: {img_path.name}"
            # Gradio needs string paths for local files
            results.append((str(img_path), caption))
        return results

    except Exception as e:
        logger.error(f"Error in predict_image_from_text: {e}", exc_info=True)
        return [("error.png", f"An error occurred: {e}")]

def predict_zero_shot(image_input: Image.Image, candidate_labels_text: str) -> Dict[str, float]:
    """Gradio interface function for Zero-Shot Classification."""
    if model is None or tokenizer is None:
        return {"Error": 1.0, "Message": "Model or tokenizer not loaded."}
    if image_input is None:
        return {"Error": 1.0, "Message": "Please upload an image."}
    if not candidate_labels_text or not candidate_labels_text.strip():
        return {"Error": 1.0, "Message": "Please enter candidate labels."}

    try:
        # Parse labels
        class_names = [label.strip() for label in candidate_labels_text.split(',') if label.strip()]
        if not class_names:
            return {"Error": 1.0, "Message": "Invalid label format. Enter comma-separated labels."}

        # Use default English templates for simplicity in this demo
        # A more advanced demo could detect language or allow template selection
        templates = DEFAULT_PROMPT_TEMPLATES_EN

        # Encode image
        img_feat = extract_image_features(model, image_input, img_size=IMAGE_SIZE, device=DEVICE)

        # Generate and encode text prompts for all classes
        all_prompts = []
        for template in templates:
            for classname in class_names:
                all_prompts.append(template.format(classname))

        # Use extract_text_features which handles batching internally
        text_embeddings = extract_text_features(model, tokenizer, all_prompts, device=DEVICE)

        # Average embeddings if multiple templates were used
        if len(templates) > 1:
            num_classes = len(class_names)
            text_embeddings = text_embeddings.view(len(templates), num_classes, -1).mean(dim=0)

        # Normalize final text embeddings (should already be normalized by extract_text_features, but safe to redo)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        # Compute similarity (image features should also be normalized)
        similarity = compute_similarity(model, img_feat, text_embeddings).squeeze()

        # Apply softmax to get probabilities
        probs = F.softmax(similarity, dim=-1)

        # Format results for Gradio Label output
        results = {class_names[i]: probs[i].item() for i in range(len(class_names))}
        return results

    except Exception as e:
        logger.error(f"Error in predict_zero_shot: {e}", exc_info=True)
        return {"Error": 1.0, "Message": f"An error occurred: {e}"}

## Gradio Interface Definition

In [None]:
css = """
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
input[type='range'] { accent-color: black; }
.dark input[type='range'] { accent-color: #dfdqdq; }
.container { max-width: 1100px; margin: auto; padding-top: 1.5rem; }
#gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; }
#gallery>div>.h-full { min-height: 20rem; }
.details:hover { text-decoration: underline; }
.feedback { font-size: 0.8rem; margin-bottom: 5px; }
.feedback textarea { font-size: 0.8rem; }
.feedback button { margin: 0; }
.gradio-container { max-width: 1140px !important; }
"""

block = gr.Blocks(css=css, theme=gr.themes.Default())

with block:
    gr.Markdown(
        """
        <div style="text-align: center; max-width: 1000px; margin: 20px auto;">
        <h1 style="font-weight: 900; font-size: 3rem;">
            Indic-CLIP
        </h1>
        <p style="margin-bottom: 10px; font-size: 94%">
            Multimodal Vision-Language Model for Indic Languages (Hindi/Sanskrit)
         </p>
         <p>Provide an image or text to retrieve corresponding matches, or perform zero-shot classification.</p>
         <p>Note: This demo uses a small, fixed gallery for retrieval. Model trained on Flickr8k-Hindi (example).</p>
       </div>
        """
    )
    with gr.Tabs():
        with gr.TabItem("🖼️ Image-to-Text Retrieval"):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(type="pil", label="Input Image")
                    submit_i2t = gr.Button("Retrieve Text", variant="primary")
                with gr.Column():
                    output_text = gr.Textbox(lines=TOP_K, label=f"Top {TOP_K} Text Matches (Score: Text)")
            gr.Examples(
                examples=[os.path.join(SAMPLE_IMAGE_DIR, fn.name) for fn in IMAGE_GALLERY_FILES[:min(3, len(IMAGE_GALLERY_FILES))]],
                inputs=input_image,
                label="Sample Images"
            )

        with gr.TabItem("📝 Text-to-Image Retrieval"):
            with gr.Row():
                with gr.Column():
                    input_text = gr.Textbox(label="Input Text (Hindi/Sanskrit)")
                    submit_t2i = gr.Button("Retrieve Images", variant="primary")
                with gr.Column():
                    # Gallery output expects list of (image_path, caption) tuples
                    output_gallery = gr.Gallery(label=f"Top {TOP_K} Image Matches (Score: Filename)", show_label=True).style(columns=TOP_K, height="auto", object_fit="contain")
            gr.Examples(
                examples=TEXT_GALLERY[:min(3, len(TEXT_GALLERY))],
                inputs=input_text,
                label="Sample Texts"
            )

        with gr.TabItem("🏷️ Zero-Shot Classification"):
            with gr.Row():
                with gr.Column():
                    input_image_zs = gr.Image(type="pil", label="Input Image")
                with gr.Column():
                    candidate_labels = gr.Textbox(label="Candidate Labels (Comma-separated)", placeholder="e.g., बिल्ली, कुत्ता, पक्षी, कार")
                    submit_zs = gr.Button("Classify Image", variant="primary")
                    output_labels = gr.Label(num_top_classes=3, label="Classification Results")
            gr.Examples(
                examples=[
                    [os.path.join(SAMPLE_IMAGE_DIR, IMAGE_GALLERY_FILES[0].name), "बिल्ली, कुत्ता, पक्षी"], # Cat example
                    [os.path.join(SAMPLE_IMAGE_DIR, IMAGE_GALLERY_FILES[3].name), "साड़ी, कुर्ता, पोशाक"], # Saree example
                    [os.path.join(SAMPLE_IMAGE_DIR, IMAGE_GALLERY_FILES[4].name), "मंदिर, मस्जिद, चर्च"], # Temple example
                ],
                inputs=[input_image_zs, candidate_labels],
                outputs=output_labels,
                label="Sample Images and Labels"
            )

    # Define button click actions
    submit_i2t.click(
        predict_text_from_image,
        inputs=[input_image],
        outputs=[output_text]
    )
    submit_t2i.click(
        predict_image_from_text,
        inputs=[input_text],
        outputs=[output_gallery]
    )
    submit_zs.click(
        predict_zero_shot,
        inputs=[input_image_zs, candidate_labels],
        outputs=[output_labels]
    )

    # Launch the interface
    # block.launch(debug=True) # Use debug=True for testing

## Launch App

In [None]:
#| hide
# The launch command should ideally be the last cell executed
# or run in the app.py script when deploying to Spaces.
if __name__ == '__main__' and 'google.colab' not in sys.modules:
    print("Launching Gradio interface...")
    # Check if model loaded before launching
    if model is not None:
        block.launch(share=False) # share=True to create public link
    else:
        print("ERROR: Model not loaded. Cannot launch Gradio app.")