In [1]:
# LightVision: Long-Text-Aware Offline Lightweight Text-to-Image Retrieval
# Refactored and Organized Notebook

# ============================================================================
# CELL 1: Project Setup and Environment Check
# ============================================================================

print("="*60)
print("LIGHTVISION: MOBILE-FRIENDLY IMAGE RETRIEVAL SYSTEM")
print("="*60)
print("Setting up environment and checking dependencies...")

import sys
import os
import torch

# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA Version: {torch.version.cuda}")

# Project directories
BASE_DATA_DESTINATION = os.path.join(os.getcwd(), "data")
FLICKR8K_IMAGES_FOLDER_NAME = "Images"
CAPTIONS_JSON_FILENAME = "all_captions.json"
CHECKPOINT_DIR = os.path.join(os.getcwd(), "checkpoints")
DEVICE = device

print(f"Base data directory: {BASE_DATA_DESTINATION}")
print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print("Environment setup complete!")

LIGHTVISION: MOBILE-FRIENDLY IMAGE RETRIEVAL SYSTEM
Setting up environment and checking dependencies...
Device: cpu
Base data directory: /Users/erenyavuz/Desktop/KU/25 Spring/COMP447/Project/Repo/FlightVision/data
Checkpoint directory: /Users/erenyavuz/Desktop/KU/25 Spring/COMP447/Project/Repo/FlightVision/checkpoints
Environment setup complete!


In [2]:
# ============================================================================
# CELL 2: Install and Import Dependencies
# ============================================================================

# Install required packages (uncomment if needed)
# !pip install torch torchvision timm open-clip-torch datasets clip-benchmark

# Core imports
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
import random
import time

# Add LightVision to path
sys.path.append('/content/LightVision')

# MobileCLIP imports
import mobileclip

print("All dependencies imported successfully!")



All dependencies imported successfully!


In [None]:
# ============================================================================
# CELL 3: Configuration and Parameters
# ============================================================================

class Config:
    """Centralized configuration class"""
    
    # Model configuration
    MODEL_NAME = 'mobileclip_s0'
    DEVICE = DEVICE
    EMBEDDING_DIM = 512
    
    # Training configuration
    BATCH_SIZE = 128
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 1
    NUM_WORKERS = 0
    
    # Data configuration
    BASE_DATA_DIR = BASE_DATA_DESTINATION
    IMAGES_DIR = os.path.join(BASE_DATA_DIR, FLICKR8K_IMAGES_FOLDER_NAME)
    CAPTIONS_FILE = os.path.join(BASE_DATA_DIR, CAPTIONS_JSON_FILENAME)
    NEW_CAPTIONS_FILE = os.path.join(BASE_DATA_DIR, "new_file.json")
    
    # Checkpoint configuration
    CHECKPOINT_DIR = CHECKPOINT_DIR
    BASE_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'mobileclip_s0.pt')
    FINETUNED_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'mobileclip_finetuned_epoch1_last.pt')
    
    # Positional embedding configuration
    LAMBDA2 = 4  # Interpolation parameter
    
    # Loss configuration
    TEMPERATURE = 0.07
    PCA_DIM = 32

# Print configuration
config = Config()
print("Configuration loaded:")
for attr in dir(config):
    if not attr.startswith('_'):
        print(f"  {attr}: {getattr(config, attr)}")


In [None]:
# ============================================================================
# CELL 4: Download Base Model and Setup Directories
# ============================================================================

# Create necessary directories
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config.BASE_DATA_DIR, exist_ok=True)

# Download base model if not exists
if not os.path.exists(config.BASE_MODEL_PATH):
    print("Downloading base MobileCLIP model...")
    !wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt -P {config.CHECKPOINT_DIR}
    print("Base model downloaded successfully!")
else:
    print("Base model already exists.")


In [None]:
# ============================================================================
# CELL 5: Model Loading and Setup Functions
# ============================================================================

def load_base_model(model_name=config.MODEL_NAME, checkpoint_path=None):
    """Load the base MobileCLIP model"""
    print(f"Loading {model_name} model...")
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading from checkpoint: {checkpoint_path}")
        model, _, preprocess = mobileclip.create_model_and_transforms(
            model_name, pretrained=checkpoint_path
        )
    else:
        print("Loading base pretrained model...")
        model, _, preprocess = mobileclip.create_model_and_transforms(
            model_name, pretrained=config.BASE_MODEL_PATH
        )
    
    model.to(config.DEVICE)
    tokenizer = mobileclip.get_tokenizer(model_name)
    
    print(f"Model loaded successfully on {config.DEVICE}")
    return model, preprocess, tokenizer

def apply_positional_embedding_interpolation(model, lambda2=config.LAMBDA2):
    """Apply Knowledge-Preserving Stretching for positional embeddings"""
    print(f"Applying positional embedding interpolation with λ={lambda2}")
    
    pos_embed = model.text_encoder.get_positional_embedding().pos_embed.pos_embed
    if pos_embed is None:
        raise ValueError("Positional embedding not found in text encoder.")
    
    max_pos, embed_dim = pos_embed.shape[2], pos_embed.shape[3]
    modified_pos_embed = torch.zeros((1, 1, max_pos, embed_dim), device=pos_embed.device)
    
    for pos in range(max_pos):
        if pos <= 20:
            # Preserve first 20 positions (most informative)
            modified_pos_embed[:, :, pos, :] = pos_embed[:, :, pos, :]
        else:
            # Interpolate remaining positions
            lower_idx = pos // lambda2
            upper_idx = min(lower_idx + 1, max_pos - 1)
            alpha = (pos % lambda2) / lambda2
            modified_pos_embed[:, :, pos, :] = (
                (1 - alpha) * pos_embed[:, :, lower_idx, :] + 
                alpha * pos_embed[:, :, upper_idx, :]
            )
    
    # Update model with new positional embeddings
    modified_pos_embed = torch.nn.Parameter(modified_pos_embed, requires_grad=False)
    model.text_encoder.get_positional_embedding().pos_embed.pos_embed = modified_pos_embed
    
    print("Positional embedding interpolation applied successfully")
    return model

def print_model_info(model):
    """Print model information"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nModel Information:")
    print(f"  Device: {config.DEVICE}")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")


In [None]:
# ============================================================================
# CELL 6: Load and Configure Base Model
# ============================================================================

# Load base model
model, preprocess, tokenizer = load_base_model()

# Apply positional embedding interpolation
model = apply_positional_embedding_interpolation(model)

# Print model information
print_model_info(model)

# Set model to evaluation mode initially
model.eval()

print("Base model loaded and configured successfully!")


In [None]:
# ============================================================================
# CELL 7: Test Base Model Performance
# ============================================================================

def test_model_inference(model, preprocess, tokenizer, test_texts=None, test_image_path=None):
    """Test model inference capabilities"""
    print("Testing model inference...")
    
    if test_texts is None:
        test_texts = [
            "The lemon on the left is yellow and the eggplant on the right is purple.",
            "The lemon on the left is purple and the eggplant on the right is yellow.",
            "The lemon on the right is yellow and the eggplant on the left is purple.",
            "The lemon on the right is purple and the eggplant on the left is yellow"
        ]
    
    # Test text encoding
    text_tokens = tokenizer(test_texts).to(config.DEVICE)
    
    with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = model.encode_text(text_tokens)
        text_features = F.normalize(text_features, dim=-1)
        
    print(f"Text encoding successful: {text_features.shape}")
    
    # Test image encoding if image provided
    if test_image_path and os.path.exists(test_image_path):
        image = preprocess(Image.open(test_image_path).convert('RGB')).unsqueeze(0)
        image = image.to(config.DEVICE)
        
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(image.half())
            image_features = F.normalize(image_features, dim=-1)
            
            # Calculate similarities
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        
        print(f"Image encoding successful: {image_features.shape}")
        print(f"Text probabilities: {text_probs}")
        
        return text_probs
    
    return text_features

# Test the base model
print("Testing base model performance...")
test_model_inference(model, preprocess, tokenizer)


In [None]:
# ============================================================================
# UPDATED CELL 8: Data Loading and Dataset Classes with Train/Test Split
# ============================================================================

class Flickr8kCaptionedDataset(Dataset):
    """Dataset class for Flickr8k with custom captions"""
    
    def __init__(self, image_dir, captions_file, preprocess_fn, pull_from_json=True):
        self.image_dir = image_dir
        self.preprocess_fn = preprocess_fn
        self.samples = []
        
        print(f"Loading dataset from {captions_file}...")
        
        if pull_from_json:
            self._load_from_json(captions_file)
        else:
            self._load_from_txt(captions_file)
        
        print(f"Loaded {len(self.samples)} samples")
    
    def _load_from_json(self, captions_file):
        """Load captions from JSON file"""
        with open(captions_file, 'r') as f:
            captions_data = json.load(f)
        
        for image_name, captions in captions_data.items():
            image_path = os.path.join(self.image_dir, image_name)
            if not os.path.exists(image_path):
                continue
            
            if isinstance(captions, dict):
                # Handle both LLaVA format and standard format
                if "long_caption" in captions and "short_caption" in captions:
                    self.samples.append((
                        image_name, 
                        captions["short_caption"], 
                        captions["long_caption"]
                    ))
                elif "long_detailed" in captions and "short_caption" in captions:
                    # Handle LLaVA format directly
                    self.samples.append((
                        image_name, 
                        captions["short_caption"], 
                        captions["long_detailed"]
                    ))
            else:
                self.samples.append((image_name, captions, "standard"))
    
    def _load_from_txt(self, captions_file):
        """Load captions from text file"""
        with open(captions_file, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
            
            if '#' in line:
                parts = line.split('#', 1)
            else:
                parts = line.split(',', 1)
            
            if len(parts) == 2:
                image_name, caption = parts
                image_path = os.path.join(self.image_dir, image_name.strip())
                if os.path.exists(image_path):
                    self.samples.append((image_name.strip(), caption.strip(), "standard"))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_name, caption, caption_type = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        
        try:
            image = Image.open(image_path).convert('RGB')
            image = self.preprocess_fn(image)
            return image, caption, caption_type
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return self.__getitem__(random.randint(0, len(self) - 1))

def create_train_test_split(llava_captions_file, images_dir, output_dir, test_ratio=0.125):
    """
    Create train/test split from LLaVA captions
    
    Args:
        llava_captions_file: Path to LLaVA generated captions
        images_dir: Directory containing images
        output_dir: Output directory for split files
        test_ratio: Ratio for test set (0.125 = 1k out of 8k)
    
    Returns:
        Dictionary with file paths
    """
    
    print("="*50)
    print("CREATING TRAIN/TEST SPLIT")
    print("="*50)
    
    # Load LLaVA captions
    print(f"Loading captions from: {llava_captions_file}")
    with open(llava_captions_file, 'r') as f:
        llava_data = json.load(f)
    
    print(f"Found {len(llava_data)} image-caption pairs")
    
    # Clean and filter existing images
    cleaned_data = {}
    for image_name, captions in llava_data.items():
        image_path = os.path.join(images_dir, image_name)
        if not os.path.exists(image_path):
            continue
        
        # Standardize format
        if isinstance(captions, dict):
            # Convert LLaVA format to standard format
            if 'short_caption' in captions and 'long_detailed' in captions:
                cleaned_data[image_name] = {
                    'short_caption': captions['short_caption'].strip(),
                    'long_caption': captions['long_detailed'].strip()  # Rename for consistency
                }
            elif 'short_caption' in captions and 'long_caption' in captions:
                cleaned_data[image_name] = {
                    'short_caption': captions['short_caption'].strip(),
                    'long_caption': captions['long_caption'].strip()
                }
    
    print(f"Cleaned data: {len(cleaned_data)} valid image-caption pairs")
    
    # Analyze caption statistics
    short_lengths = [len(caps['short_caption'].split()) for caps in cleaned_data.values()]
    long_lengths = [len(caps['long_caption'].split()) for caps in cleaned_data.values()]
    
    print(f"Caption statistics:")
    print(f"  Short captions: avg {sum(short_lengths)/len(short_lengths):.1f} words (range: {min(short_lengths)}-{max(short_lengths)})")
    print(f"  Long captions: avg {sum(long_lengths)/len(long_lengths):.1f} words (range: {min(long_lengths)}-{max(long_lengths)})")
    
    # Create random split
    random.seed(42)  # For reproducible splits
    all_images = list(cleaned_data.keys())
    random.shuffle(all_images)
    
    # Calculate split sizes
    total_images = len(all_images)
    test_size = int(total_images * test_ratio)
    train_size = total_images - test_size
    
    print(f"\nSplit configuration:")
    print(f"  Total images: {total_images}")
    print(f"  Training: {train_size} images ({(train_size/total_images)*100:.1f}%)")
    print(f"  Testing: {test_size} images ({(test_size/total_images)*100:.1f}%)")
    
    # Create splits
    test_images = all_images[:test_size]
    train_images = all_images[test_size:]
    
    train_data = {img: cleaned_data[img] for img in train_images}
    test_data = {img: cleaned_data[img] for img in test_images}
    
    # Save split files
    os.makedirs(output_dir, exist_ok=True)
    
    train_file = os.path.join(output_dir, "train_captions.json")
    test_file = os.path.join(output_dir, "test_captions.json")
    
    with open(train_file, 'w') as f:
        json.dump(train_data, f, indent=2)
    
    with open(test_file, 'w') as f:
        json.dump(test_data, f, indent=2)
    
    # Save metadata
    metadata = {
        'split_info': {
            'seed': 42,
            'total_images': total_images,
            'train_size': train_size,
            'test_size': test_size,
            'test_ratio': test_ratio
        },
        'caption_stats': {
            'short_caption_avg_length': sum(short_lengths)/len(short_lengths),
            'long_caption_avg_length': sum(long_lengths)/len(long_lengths)
        },
        'files': {
            'original_llava_captions': os.path.basename(llava_captions_file),
            'train_captions': 'train_captions.json',
            'test_captions': 'test_captions.json'
        }
    }
    
    metadata_file = os.path.join(output_dir, "split_metadata.json")
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"\nFiles created:")
    print(f"  Training data: {train_file}")
    print(f"  Test data: {test_file}")
    print(f"  Metadata: {metadata_file}")
    
    return {
        'train_file': train_file,
        'test_file': test_file,
        'metadata_file': metadata_file,
        'train_size': train_size,
        'test_size': test_size
    }

def check_data_availability():
    """Check what data is available and create splits if needed"""
    print("Checking data availability...")
    
    # Check images
    images_exist = os.path.exists(config.IMAGES_DIR)
    if images_exist:
        image_files = [f for f in os.listdir(config.IMAGES_DIR) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        image_count = len(image_files)
    else:
        image_count = 0
    
    # Check for different caption file formats
    llava_captions = os.path.join(config.BASE_DATA_DIR, "captions_database.json")
    train_captions = os.path.join(config.BASE_DATA_DIR, "train_captions.json")
    test_captions = os.path.join(config.BASE_DATA_DIR, "test_captions.json")
    
    files_status = {
        'images_dir': images_exist,
        'image_count': image_count,
        'llava_captions': os.path.exists(llava_captions),
        'train_captions': os.path.exists(train_captions),
        'test_captions': os.path.exists(test_captions)
    }
    
    print(f"Data Status:")
    print(f"  Images directory: {config.IMAGES_DIR} ({'✓' if images_exist else '✗'})")
    print(f"  Image count: {image_count}")
    print(f"  LLaVA captions: {'✓' if files_status['llava_captions'] else '✗'}")
    print(f"  Training split: {'✓' if files_status['train_captions'] else '✗'}")
    print(f"  Test split: {'✓' if files_status['test_captions'] else '✗'}")
    
    # Create splits if LLaVA data exists but splits don't
    if (files_status['llava_captions'] and 
        not (files_status['train_captions'] and files_status['test_captions'])):
        
        print(f"\nLLaVA captions found but splits missing. Creating train/test split...")
        split_result = create_train_test_split(
            llava_captions_file=llava_captions,
            images_dir=config.IMAGES_DIR,
            output_dir=config.BASE_DATA_DIR,
            test_ratio=0.125  # 1k test, 7k train
        )
        
        # Update status
        files_status['train_captions'] = True
        files_status['test_captions'] = True
        files_status['split_result'] = split_result
    
    return files_status

# Check data availability and create splits if needed
data_status = check_data_availability()

In [None]:
# ============================================================================
# CELL 9: Loss Functions and Training Utilities
# ============================================================================

def PCA(input_tensor, PCA_dim=config.PCA_DIM):
    """Apply PCA for dimensionality reduction"""
    # Calculate mean
    mean = torch.mean(input_tensor, dim=0)
    # Center the data
    X_centered = input_tensor - mean.unsqueeze(0)
    X_centered = X_centered.float()
    
    # Use SVD for numerical stability
    U, S, Vt = torch.linalg.svd(X_centered, full_matrices=False)
    principal_components = Vt.T[:, :PCA_dim]
    
    # Transform and reconstruct
    X_transformed = torch.mm(X_centered, principal_components)
    X_reversed = torch.mm(X_transformed, principal_components.T)
    X_reversed += mean
    
    return X_reversed

def single_loss(image_embeds, text_embeds, temperature=config.TEMPERATURE):
    """Standard contrastive loss function"""
    # Normalize embeddings
    image_embeds = F.normalize(image_embeds, dim=1)
    text_embeds = F.normalize(text_embeds, dim=1)
    
    # Compute similarity matrix
    logits = torch.matmul(image_embeds, text_embeds.T) / temperature
    
    # Labels are the positions of the positive pairs
    labels = torch.arange(logits.size(0), device=logits.device)
    
    # Compute loss in both directions
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

def long_clip_loss(image_embedding, long_embedding, short_embedding):
    """LongCLIP-style loss with PCA-based dual alignment"""
    # Normalize features
    image_features_long = F.normalize(image_embedding, dim=1)
    text_features_long = F.normalize(long_embedding, dim=1)
    text_features_short = F.normalize(short_embedding, dim=1)
    
    # Apply PCA to get compressed image features
    image_features_short = PCA(image_features_long, config.PCA_DIM)
    image_features_short = F.normalize(image_features_short, dim=1)
    
    # Calculate similarity matrices
    sim_i2tl = torch.matmul(image_features_long, text_features_long.T)
    sim_tl2i = sim_i2tl.T
    sim_i2ts = torch.matmul(image_features_short, text_features_short.T)
    sim_ts2i = sim_i2ts.T
    
    # Apply temperature scaling
    if hasattr(model, 'logit_scale'):
        logit_scale = model.logit_scale.exp()
        sim_i2tl = logit_scale * sim_i2tl
        sim_tl2i = logit_scale * sim_tl2i
        sim_i2ts = logit_scale * sim_i2ts
        sim_ts2i = logit_scale * sim_ts2i
    
    # Create targets
    bs = image_embedding.size(0)
    targets = torch.arange(bs, device=image_embedding.device)
    
    # Calculate losses
    loss_itcl = (
        F.cross_entropy(sim_i2tl, targets, label_smoothing=0.1) +
        F.cross_entropy(sim_tl2i, targets, label_smoothing=0.1)
    ) / 2
    
    loss_itcs = (
        F.cross_entropy(sim_i2ts, targets, label_smoothing=0.1) +
        F.cross_entropy(sim_ts2i, targets, label_smoothing=0.1)
    ) / 2
    
    total_loss = (loss_itcl + loss_itcs) / 2
    return total_loss

print("Loss functions defined successfully!")


In [None]:
# ============================================================================
# CELL 10: Training Function
# ============================================================================

def train_model(config, images_dir, captions_file,
                long_clip_loss_fn=None, single_loss_fn=None):
    """Train the model with given parameters"""
    
    print(f"Starting training with {config.NUM_EPOCHS} epochs...")
    
    # Check data availability
    if not os.path.exists(images_dir):
        raise FileNotFoundError(f"Images directory not found: {images_dir}")
    if not os.path.exists(captions_file):
        print(f"Captions file not found: {captions_file}")
        pull_from_json = False
    
    # Create dataset and dataloader
    dataset = Flickr8kCaptionedDataset(images_dir, captions_file, preprocess, pull_from_json)
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, 
                          num_workers=config.NUM_WORKERS, drop_last=True)
    
    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    
    # Training loop
    model.train()
    
    for epoch in range(config.NUM_EPOCHS):
        total_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")
        
        for batch_idx, (images, captions, caption_types) in enumerate(progress_bar):
            images = images.to(config.DEVICE)
            tokenized_captions = tokenizer(captions).to(config.DEVICE)
            
            # Forward pass
            with torch.cuda.amp.autocast():
                image_features = model.encode_image(images)
                text_features = model.encode_text(tokenized_captions)
                
                # Choose loss function
                if (long_clip_loss_fn is not None and 
                    any(ct != 'standard' for ct in caption_types)):
                    # Use long_clip_loss if available and appropriate
                    loss = long_clip_loss_fn(image_features, text_features, text_features)
                else:
                    # Use single loss
                    if single_loss_fn is None:
                        raise ValueError("Single loss function must be provided")
                    loss = single_loss_fn(image_features, text_features)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update progress
            total_loss += loss.item()
            avg_loss = total_loss / (batch_idx + 1)
            progress_bar.set_postfix(loss=f"{avg_loss:.4f}")
        
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS}, Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        checkpoint_path = os.path.join(config.CHECKPOINT_DIR, f"mobileclip_finetuned_epoch{epoch+1}_last.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")
    
    return model

print("Training function ready!")


In [None]:
# ============================================================================
# UPDATED CELL 11: Run Training with Proper Train Split
# ============================================================================

# Check if we have proper training data
if data_status['train_captions'] and data_status['image_count'] > 0:
    print("="*60)
    print("TRAINING WITH PROPER TRAIN/TEST SPLIT")
    print("="*60)
    
    # Use training split for training
    train_captions_file = os.path.join(config.BASE_DATA_DIR, "train_captions.json")
    
    print(f"Training with: {train_captions_file}")
    
    # Load training metadata to show split info
    metadata_file = os.path.join(config.BASE_DATA_DIR, "split_metadata.json")
    if os.path.exists(metadata_file):
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        
        print(f"Training set size: {metadata['split_info']['train_size']} images")
        print(f"Test set size: {metadata['split_info']['test_size']} images")
        print(f"Test ratio: {metadata['split_info']['test_ratio']:.1%}")
    
    # Train the model on training split only
    trained_model = train_model(
        config=config,
        images_dir=config.IMAGES_DIR,
        captions_file=train_captions_file,  # Use training split
        pull_from_json=True,
        long_clip_loss_fn=long_clip_loss,
        single_loss_fn=single_loss
    )
    
    print("✓ Training completed on training split!")
    
else:
    print("❌ Training data not available. Please ensure you have:")
    print("1. LLaVA captions file (captions_database.json)")
    print("2. Images in the data/Images/ directory")
    if not data_status['llava_captions']:
        print("\nTo generate LLaVA captions, run your caption generation script first.")


In [None]:
# ============================================================================
# CELL 12: Load Trained Model (if available)
# ============================================================================

def load_trained_model(checkpoint_path=config.FINETUNED_MODEL_PATH):
    """Load the trained model from checkpoint"""
    if os.path.exists(checkpoint_path):
        print(f"Loading trained model from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=config.DEVICE)
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Trained model loaded successfully!")
        return True
    else:
        print(f"No trained model found at: {checkpoint_path}")
        print("Using base model for testing...")
        return False

# Try to load trained model
model_loaded = load_trained_model()
model.eval()

In [None]:
# ============================================================================
# UPDATED CELL 13: Model Testing on Test Set (Not Training Set)
# ============================================================================

def evaluate_on_test_set(model, preprocess, tokenizer, sample_size=5):
    """Evaluate model on the reserved test set"""
    
    print("="*60)
    print("EVALUATION ON TEST SET")
    print("="*60)
    
    test_captions_file = os.path.join(config.BASE_DATA_DIR, "test_captions.json")
    
    if not os.path.exists(test_captions_file):
        print("❌ Test set not found. Please ensure train/test split was created.")
        return None
    
    # Load test data
    with open(test_captions_file, 'r') as f:
        test_data = json.load(f)
    
    print(f"Test set contains {len(test_data)} images")
    print(f"Running detailed evaluation on {sample_size} sample images...")
    
    # Sample images for evaluation
    test_images = list(test_data.keys())
    random.seed(42)  # For reproducible sampling
    sample_images = random.sample(test_images, min(sample_size, len(test_images)))
    
    results = []
    model.eval()
    
    for i, image_name in enumerate(sample_images): # Iterate through the sampled images
        try:
            image_path = os.path.join(config.IMAGES_DIR, image_name) # Construct the full image path
            image = preprocess(Image.open(image_path).convert('RGB')).unsqueeze(0).to(config.DEVICE) # Open, convert to RGB, preprocess, add batch dimension, and move image to device
            
            captions = test_data[image_name] # Retrieve captions for the current image
            short_caption = captions['short_caption'] # Extract the short caption
            long_caption = captions['long_caption'] # Extract the long caption
            
            # Test both caption types
            test_texts = [short_caption, long_caption] # Combine short and long captions into a list
            text_tokens = tokenizer(test_texts).to(config.DEVICE) # Tokenize the captions and move them to the device
            
            with torch.no_grad(), torch.cuda.amp.autocast(): # Disable gradient calculation and enable mixed precision for inference
                image_features = model.encode_image(image) # Encode the image into a feature vector
                text_features = model.encode_text(text_tokens) # Encode the text into a feature vector
                
                image_features = F.normalize(image_features, dim=-1) # Normalize the image features
                text_features = F.normalize(text_features, dim=-1) # Normalize the text features
                
                #TODO: Here the improvement is calculated as the difference in cosine similarity between the long and short captions,
                # but the same number of tokens considered for both captions and rest is truncated. Longer captions must be captured with the proposed method for the improvement to be meaningful. 
                similarities = (100.0 * image_features @ text_features.T).squeeze() # Calculate cosine similarity between image and text features
            
            result = {
                'image_name': image_name, # Store the image name
                'short_caption': short_caption, # Store the short caption
                'long_caption': long_caption, # Store the long caption
                'short_similarity': float(similarities[0]), # Store the similarity score for the short caption
                'long_similarity': float(similarities[1]), # Store the similarity score for the long caption
                'short_words': len(short_caption.split()), # Count the number of words in the short caption
                'long_words': len(long_caption.split()), # Count the number of words in the long caption
                'improvement': float(similarities[1] - similarities[0]) # Calculate the improvement in similarity from short to long caption
            }
            
            results.append(result) # Append the result to the list of results
            
            print(f"\nTest {i+1}/{len(sample_images)}: {image_name}") # Print the test image number and name
            print(f"  Short caption ({result['short_words']} words): {result['short_similarity']:.4f}") # Print the short caption similarity and word count
            print(f"  Long caption ({result['long_words']} words): {result['long_similarity']:.4f}") # Print the long caption similarity and word count
            print(f"  Improvement: {result['improvement']:+.4f}") # Print the improvement in similarity
            
        except Exception as e:
            print(f"Error processing {image_name}: {e}") # Handle any errors during processing
    
    # Calculate aggregate statistics
    if results:
        print(f"\n" + "="*50)
        print("TEST SET EVALUATION SUMMARY")
        print("="*50)
        
        short_similarities = [r['short_similarity'] for r in results]
        long_similarities = [r['long_similarity'] for r in results]
        improvements = [r['improvement'] for r in results]
        short_word_counts = [r['short_words'] for r in results]
        long_word_counts = [r['long_words'] for r in results]
        
        print(f"Results based on {len(results)} test images:")
        print(f"\nShort Captions:")
        print(f"  Average similarity: {sum(short_similarities)/len(short_similarities):.4f}")
        print(f"  Average length: {sum(short_word_counts)/len(short_word_counts):.1f} words")
        
        print(f"\nLong Captions:")
        print(f"  Average similarity: {sum(long_similarities)/len(long_similarities):.4f}")
        print(f"  Average length: {sum(long_word_counts)/len(long_word_counts):.1f} words")
        
        print(f"\nLong vs Short Analysis:")
        avg_improvement = sum(improvements) / len(improvements)
        positive_improvements = sum(1 for imp in improvements if imp > 0)
        print(f"  Average improvement: {avg_improvement:+.4f}")
        print(f"  Long captions better: {positive_improvements}/{len(improvements)} ({positive_improvements/len(improvements)*100:.1f}%)")
        
        if avg_improvement > 0:
            print(f"  ✓ Model shows improvement on longer, detailed captions")
        else:
            print(f"  ⚠ Model performance decreased on longer captions")
    
    return results

def compare_base_vs_finetuned_on_test():
    """Compare base model vs fine-tuned model on test set"""
    
    print("="*60)
    print("BASE vs FINE-TUNED COMPARISON ON TEST SET")
    print("="*60)
    
    test_captions_file = os.path.join(config.BASE_DATA_DIR, "test_captions.json")
    if not os.path.exists(test_captions_file):
        print("❌ Test set not found.")
        return None
    
    # Test base model
    print("1. Evaluating Base Model...")
    base_model, base_preprocess, base_tokenizer = load_base_model(checkpoint_path=None)
    # Nothing is applied to the base model, so we just load it
    base_results = evaluate_on_test_set(base_model, base_preprocess, base_tokenizer, sample_size=10)
    
    # Test fine-tuned model if available
    if os.path.exists(config.FINETUNED_MODEL_PATH):
        print(f"\n2. Evaluating Fine-tuned Model...")
        ft_model, ft_preprocess, ft_tokenizer = load_base_model(checkpoint_path=config.FINETUNED_MODEL_PATH)
        # No positional embedding interpolation is applied to the fine-tuned model, because it was already done during training
        # ft_model = apply_positional_embedding_interpolation(ft_model)
        ft_results = evaluate_on_test_set(ft_model, ft_preprocess, ft_tokenizer, sample_size=10)
        
        # Compare results
        if base_results and ft_results:
            print(f"\n" + "="*50)
            print("COMPARISON SUMMARY")
            print("="*50)
            
            base_short_avg = sum(r['short_similarity'] for r in base_results) / len(base_results)
            base_long_avg = sum(r['long_similarity'] for r in base_results) / len(base_results)
            ft_short_avg = sum(r['short_similarity'] for r in ft_results) / len(ft_results)
            ft_long_avg = sum(r['long_similarity'] for r in ft_results) / len(ft_results)
            
            short_improvement = ((ft_short_avg - base_short_avg) / base_short_avg * 100) if base_short_avg > 0 else 0
            long_improvement = ((ft_long_avg - base_long_avg) / base_long_avg * 100) if base_long_avg > 0 else 0
            
            print(f"Short Caption Performance:")
            print(f"  Base model: {base_short_avg:.4f}")
            print(f"  Fine-tuned: {ft_short_avg:.4f}")
            print(f"  Improvement: {short_improvement:+.2f}%")
            
            print(f"\nLong Caption Performance:")
            print(f"  Base model: {base_long_avg:.4f}")
            print(f"  Fine-tuned: {ft_long_avg:.4f}")
            print(f"  Improvement: {long_improvement:+.2f}%")
            
            print(f"\nKey Findings:")
            if long_improvement > short_improvement:
                print(f"  ✓ Fine-tuning particularly helps with long captions (+{long_improvement-short_improvement:.2f}% extra benefit)")
            if ft_long_avg > ft_short_avg:
                print(f"  ✓ Fine-tuned model handles detailed descriptions better")
            
        return {'base': base_results, 'finetuned': ft_results}
    else:
        print(f"❌ Fine-tuned model not found at: {config.FINETUNED_MODEL_PATH}")
        return {'base': base_results}

# Run evaluation on test set (not training set!)
print("Test set evaluation functions ready.")
print("Run evaluation with:")
print("  test_results = evaluate_on_test_set(model, preprocess, tokenizer)")
print("  comparison = compare_base_vs_finetuned_on_test()")

In [None]:
# ============================================================================
# CELL 14: Performance Comparison Summary
# ============================================================================

def print_performance_summary():
    """Print a summary of model performance"""
    print("="*60)
    print("LIGHTVISION PERFORMANCE SUMMARY")
    print("="*60)
    
    print(f"Model Configuration:")
    print(f"  Base Model: {config.MODEL_NAME}")
    print(f"  Device: {config.DEVICE}")
    print(f"  Positional Embedding Interpolation: λ={config.LAMBDA2}")
    print(f"  PCA Dimension: {config.PCA_DIM}")
    
    if model_loaded:
        print(f"  Status: ✓ Fine-tuned model loaded")
        print(f"  Checkpoint: {config.FINETUNED_MODEL_PATH}")
    else:
        print(f"  Status: ⚠ Base model (no fine-tuning)")
    
    print(f"\nData Configuration:")
    _, image_count, caption_files = check_data_availability()
    print(f"  Images available: {image_count}")
    print(f"  Caption files: {sum(caption_files.values())} available")
    
    print(f"\nNext Steps:")
    if not model_loaded:
        print("  1. ⚠ Train the model using available data")
        print("  2. Run comprehensive evaluation")
        print("  3. Set up retrieval framework")
    else:
        print("  1. ✓ Model is ready for deployment")
        print("  2. Run comprehensive evaluation")
        print("  3. Set up retrieval framework with FAISS")
    
    print(f"\nFor comprehensive evaluation and retrieval setup:")
    print(f"  - Use the evaluation_system.py module")
    print(f"  - Use the retrieval_framework.py module")

# Print summary
print_performance_summary()


In [None]:
# ============================================================================
# CELL 15: Quick Interactive Test
# ============================================================================

def interactive_test():
    """Run an interactive test session"""
    print("="*60)
    print("INTERACTIVE MODEL TEST")
    print("="*60)
    print("Enter text queries to test the model (type 'quit' to exit)")
    
    while True:
        query = input("\nEnter your query: ").strip()
        
        if query.lower() in ['quit', 'exit', 'q']:
            break
        
        if not query:
            continue
        
        try:
            # Encode the query
            text_tokens = tokenizer([query]).to(config.DEVICE)
            
            with torch.no_grad():
                text_features = model.encode_text(text_tokens)
                text_features = F.normalize(text_features, dim=-1)
            
            print(f"Query encoded successfully!")
            print(f"Text feature shape: {text_features.shape}")
            print(f"Feature norm: {text_features.norm():.4f}")
            
            # If you have a test image, you can compare similarities here
            
        except Exception as e:
            print(f"Error processing query: {e}")

# Uncomment the line below to run interactive test
# interactive_test()

print("Interactive test function ready. Uncomment the line above to run it.")

In [None]:
# ============================================================================
# CELL 16: Export Model and Prepare for Deployment
# ============================================================================

def export_model_for_deployment():
    """Prepare model for deployment in retrieval framework"""
    print("Preparing model for deployment...")
    
    # Save model in a format ready for retrieval framework
    deployment_config = {
        'model_name': config.MODEL_NAME,
        'checkpoint_path': config.FINETUNED_MODEL_PATH if model_loaded else None,
        'device': str(config.DEVICE),
        'embedding_dim': config.EMBEDDING_DIM,
        'positional_interpolation': {
            'applied': True,
            'lambda': config.LAMBDA2
        },
        'training_config': {
            'batch_size': config.BATCH_SIZE,
            'learning_rate': config.LEARNING_RATE,
            'epochs_trained': config.NUM_EPOCHS if model_loaded else 0
        }
    }
    
    # Save deployment configuration
    config_path = os.path.join(config.CHECKPOINT_DIR, "deployment_config.json")
    with open(config_path, 'w') as f:
        json.dump(deployment_config, f, indent=2)
    
    print(f"Deployment configuration saved to: {config_path}")
    
    # Save current model state for retrieval framework
    if model_loaded:
        retrieval_model_path = os.path.join(config.CHECKPOINT_DIR, "model_for_retrieval.pt")
        torch.save(model.state_dict(), retrieval_model_path)
        print(f"Model state saved for retrieval: {retrieval_model_path}")
    
    return deployment_config

# Export model
deployment_config = export_model_for_deployment()


In [None]:
# ============================================================================
# CELL 17: Notebook Summary and Next Steps
# ============================================================================

def print_notebook_summary():
    """Print a comprehensive summary of what was accomplished"""
    print("="*80)
    print("LIGHTVISION NOTEBOOK EXECUTION SUMMARY")
    print("="*80)
    
    print("✓ COMPLETED TASKS:")
    print("  1. ✓ Environment setup and dependency checking")
    print("  2. ✓ Base MobileCLIP model loading")
    print("  3. ✓ Positional embedding interpolation applied")
    print("  4. ✓ Loss functions implemented (single_loss, long_clip_loss)")
    print("  5. ✓ Training framework set up")
    
    if model_loaded:
        print("  6. ✓ Model fine-tuning completed")
        print("  7. ✓ Trained model loaded and tested")
    else:
        print("  6. ⚠ Model fine-tuning skipped (no training data)")
        print("  7. ⚠ Using base model for testing")
    
    print("  8. ✓ Model testing functions implemented")
    print("  9. ✓ Deployment configuration prepared")
    
    print(f"\n📊 MODEL CONFIGURATION:")
    print(f"  • Base Model: {config.MODEL_NAME}")
    print(f"  • Device: {config.DEVICE}")
    print(f"  • Embedding Dimension: {config.EMBEDDING_DIM}")
    print(f"  • Positional Interpolation: λ={config.LAMBDA2}")
    print(f"  • PCA Dimension: {config.PCA_DIM}")
    
    print(f"\n📁 GENERATED FILES:")
    print(f"  • Base model: {config.BASE_MODEL_PATH}")
    if model_loaded:
        print(f"  • Fine-tuned model: {config.FINETUNED_MODEL_PATH}")
    print(f"  • Deployment config: {os.path.join(config.CHECKPOINT_DIR, 'deployment_config.json')}")
    
    print(f"\n🚀 NEXT STEPS:")
    print(f"  1. Run comprehensive evaluation:")
    print(f"     python evaluation_system.py --mode full")
    print(f"  2. Set up retrieval framework:")
    print(f"     python retrieval_framework.py")
    print(f"  3. Compare model performance:")
    print(f"     python example_usage.py --example 2")
    print(f"  4. Deploy to production environment")
    
    print(f"\n🔬 RESEARCH CONTRIBUTIONS:")
    print(f"  • Lightweight mobile-compatible CLIP model")
    print(f"  • Extended context length handling (77+ tokens)")
    print(f"  • Principal Component Matching for multi-granularity alignment")
    print(f"  • Knowledge-preserving positional embedding interpolation")
    
    print(f"\n📈 EXPECTED IMPROVEMENTS:")
    print(f"  • Better handling of detailed, long-form queries")
    print(f"  • Improved semantic understanding beyond 20 tokens")
    print(f"  • Maintained efficiency for mobile deployment")
    print(f"  • Enhanced retrieval accuracy for complex descriptions")

# Print final summary
print_notebook_summary()

In [None]:
# ============================================================================
# CELL 18: Cleanup and Memory Management
# ============================================================================

def cleanup_memory():
    """Clean up GPU memory and prepare for next steps"""
    import gc
    
    print("Cleaning up memory...")
    
    # Clear cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Garbage collection
    gc.collect()
    
    # Print memory usage
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        memory_reserved = torch.cuda.memory_reserved() / 1024**3   # GB
        print(f"GPU Memory - Allocated: {memory_allocated:.2f} GB, Reserved: {memory_reserved:.2f} GB")
    
    print("Memory cleanup completed!")

# Optional cleanup
# cleanup_memory()

print("="*80)
print("LIGHTVISION NOTEBOOK EXECUTION COMPLETED SUCCESSFULLY!")
print("="*80)
print("The model is now ready for deployment in the retrieval framework.")
print("Check the generated files in the checkpoints/ directory.")
print("Run the evaluation and retrieval modules for comprehensive testing.")