# MS COCO 2014 Dataset Preparation for CLIP Fine-tuning

This notebook prepares the COCO 2014 dataset for fine-tuning CLIP models.

## Features:
- Automatic Kaggle dataset download
- CLIP-specific image preprocessing (224x224, normalized)
- Text embedding caching for efficient training
- PyTorch Dataset implementation
- Verification and visualization

## Dataset Info:
- Training images: ~82,783
- Validation images: ~40,504
- Multiple captions per image
- Source: COCO 2014 from Kaggle

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers>=4.30.0 torch>=2.0.0 torchvision>=0.15.0
!pip install -q pillow kaggle pycocotools matplotlib tqdm

print("✓ All dependencies installed successfully!")

## 2. Import Libraries

In [None]:
# Standard library imports
import os
import json
import random
from pathlib import Path
from collections import defaultdict

# ML and Image processing
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

# Transformers for CLIP
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPTextModel

# Utilities
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 3. Kaggle Authentication Setup

**Instructions:**
1. Download your `kaggle.json` from [Kaggle Account Settings](https://www.kaggle.com/settings)
2. Upload it using the file upload button in Colab (left sidebar → Files → Upload)
3. Run the cell below

In [None]:
# Setup Kaggle authentication
kaggle_json_path = 'kaggle.json'

# Check if kaggle.json exists in current directory
if not os.path.exists(kaggle_json_path):
    print("⚠️  kaggle.json not found!")
    print("\nPlease:")
    print("1. Download kaggle.json from https://www.kaggle.com/settings")
    print("2. Upload it to this Colab environment using the file browser")
    print("3. Re-run this cell")
else:
    # Create .kaggle directory in home
    kaggle_dir = Path.home() / '.kaggle'
    kaggle_dir.mkdir(exist_ok=True)
    
    # Copy kaggle.json to ~/.kaggle/
    target_path = kaggle_dir / 'kaggle.json'
    !cp {kaggle_json_path} {target_path}
    
    # Set proper permissions
    !chmod 600 {target_path}
    
    print("✓ Kaggle authentication configured successfully!")
    print(f"  Credentials saved to: {target_path}")
    
    # Verify authentication
    !kaggle --version

## 4. Download COCO 2014 Dataset

This will download ~13GB of data. Download time depends on your connection speed.

In [None]:
# Define paths
DATASET_DIR = Path('/content/coco2014')
DATASET_DIR.mkdir(exist_ok=True)

# Download dataset from Kaggle
print("Downloading COCO 2014 dataset from Kaggle...")
print("This may take 10-20 minutes depending on connection speed.\n")

!kaggle datasets download -d jeffaudi/coco-2014-dataset-for-yolov3 -p {DATASET_DIR} --unzip

print("\n" + "="*60)
print("Dataset download complete!")
print("="*60)

# Verify directory structure
expected_dirs = ['train2014', 'val2014', 'annotations']
print("\nVerifying directory structure:")

for dir_name in expected_dirs:
    dir_path = DATASET_DIR / dir_name
    if dir_path.exists():
        if dir_name == 'annotations':
            files = list(dir_path.glob('*.json'))
            print(f"  ✓ {dir_name}/: {len(files)} JSON files")
        else:
            files = list(dir_path.glob('*.jpg'))
            print(f"  ✓ {dir_name}/: {len(files):,} images")
    else:
        print(f"  ✗ {dir_name}/ NOT FOUND")

# Check for caption files
annotations_dir = DATASET_DIR / 'annotations'
caption_files = {
    'train': annotations_dir / 'captions_train2014.json',
    'val': annotations_dir / 'captions_val2014.json'
}

print("\nCaption files:")
for split, path in caption_files.items():
    if path.exists():
        # Load and print statistics
        with open(path, 'r') as f:
            data = json.load(f)
        print(f"  ✓ {path.name}")
        print(f"    - Images: {len(data['images']):,}")
        print(f"    - Captions: {len(data['annotations']):,}")
    else:
        print(f"  ✗ {path.name} NOT FOUND")

print("\n" + "="*60)

## 5. Load CLIP Model Components

Loading pre-trained CLIP model from HuggingFace: `openai/clip-vit-base-patch32`

In [None]:
# Load CLIP components
MODEL_NAME = 'openai/clip-vit-base-patch32'

print(f"Loading CLIP model: {MODEL_NAME}...")

# Load tokenizer and text model
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME)
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME)

# Move to GPU and set to eval mode
text_encoder = text_encoder.to(device)
text_encoder.eval()

print(f"✓ CLIP text encoder loaded successfully!")
print(f"  Model parameters: {sum(p.numel() for p in text_encoder.parameters()):,}")
print(f"  Text embedding dimension: {text_encoder.config.hidden_size}")
print(f"  Device: {device}")

## 6. Define Image Transforms

CLIP requires specific normalization values:
- Input size: 224×224
- Mean: [0.48145466, 0.4578275, 0.40821073]
- Std: [0.26862954, 0.26130258, 0.27577711]

In [None]:
# CLIP normalization constants
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
IMAGE_SIZE = 224

# Define image transformation pipeline
image_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
])

# Define inverse transform for visualization
def denormalize(tensor):
    """Denormalize image tensor for visualization"""
    mean = torch.tensor(CLIP_MEAN).view(3, 1, 1)
    std = torch.tensor(CLIP_STD).view(3, 1, 1)
    return tensor * std + mean

print("✓ Image transforms configured:")
print(f"  Input size: {IMAGE_SIZE}×{IMAGE_SIZE}")
print(f"  Normalization mean: {CLIP_MEAN}")
print(f"  Normalization std: {CLIP_STD}")

## 7. Encode and Cache Training Captions

This cell encodes all training captions using CLIP's text encoder and caches them to disk.
This saves GPU memory and computation time during training.

In [None]:
def encode_and_cache_captions(split='train', batch_size=64):
    """
    Encode captions using CLIP text encoder and cache to disk.
    
    Args:
        split: 'train' or 'val'
        batch_size: Number of captions to process at once
    """
    # Load captions JSON
    caption_file = DATASET_DIR / 'annotations' / f'captions_{split}2014.json'
    print(f"Loading captions from {caption_file.name}...")
    
    with open(caption_file, 'r') as f:
        coco_data = json.load(f)
    
    # Organize captions by image_id
    print("Organizing captions by image_id...")
    image_to_captions = defaultdict(list)
    
    for annotation in coco_data['annotations']:
        image_id = annotation['image_id']
        caption = annotation['caption']
        image_to_captions[image_id].append(caption)
    
    print(f"Found {len(image_to_captions)} unique images with captions")
    
    # Prepare data for encoding
    cache_data = []
    
    # Process each image's captions
    print(f"\nEncoding captions (batch_size={batch_size})...")
    
    image_ids = list(image_to_captions.keys())
    
    with torch.no_grad():
        for img_id in tqdm(image_ids, desc=f"Encoding {split} captions"):
            captions = image_to_captions[img_id]
            
            # Tokenize all captions for this image
            inputs = tokenizer(
                captions,
                padding=True,
                truncation=True,
                max_length=77,
                return_tensors='pt'
            )
            
            # Move to device and encode
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = text_encoder(**inputs)
            
            # Get pooled output (CLS token representation)
            embeddings = outputs.pooler_output.cpu()
            
            # Store data
            cache_data.append({
                'image_id': img_id,
                'embeddings': embeddings,  # Shape: [num_captions, 512]
                'captions': captions
            })
            
            # Periodically clear CUDA cache
            if len(cache_data) % 1000 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Save cache to disk
    cache_file = DATASET_DIR / f'{split}_text_embeddings.pt'
    print(f"\nSaving cache to {cache_file}...")
    
    torch.save({
        'data': cache_data,
        'model_name': MODEL_NAME,
        'embedding_dim': text_encoder.config.hidden_size
    }, cache_file)
    
    # Print statistics
    cache_size_mb = cache_file.stat().st_size / (1024 * 1024)
    total_captions = sum(len(item['captions']) for item in cache_data)
    
    print(f"\n{'='*60}")
    print(f"Cache created successfully!")
    print(f"{'='*60}")
    print(f"  File: {cache_file.name}")
    print(f"  Size: {cache_size_mb:.2f} MB")
    print(f"  Images: {len(cache_data):,}")
    print(f"  Total captions: {total_captions:,}")
    print(f"  Avg captions/image: {total_captions/len(cache_data):.2f}")
    print(f"  Embedding dimension: {text_encoder.config.hidden_size}")
    print(f"{'='*60}\n")
    
    return cache_file

# Encode training captions
train_cache_file = encode_and_cache_captions(split='train', batch_size=64)

## 8. Encode and Cache Validation Captions

In [None]:
# Encode validation captions
val_cache_file = encode_and_cache_captions(split='val', batch_size=64)

## 9. Define PyTorch Dataset Class

Custom Dataset class that:
- Loads images on-the-fly (memory efficient)
- Retrieves pre-computed text embeddings from cache
- Handles multiple captions per image

In [None]:
class COCOClipDataset(Dataset):
    """
    COCO Dataset for CLIP fine-tuning.
    
    Returns:
        image: Preprocessed image tensor [3, 224, 224]
        text_embedding: Pre-computed CLIP text embedding [512]
        caption: Original caption text (for reference)
        image_id: COCO image ID
    """
    
    def __init__(self, split='train', transform=None, return_all_captions=False):
        """
        Args:
            split: 'train' or 'val'
            transform: Image transforms to apply
            return_all_captions: If True, return all captions. If False, randomly select one.
        """
        self.split = split
        self.transform = transform or image_transforms
        self.return_all_captions = return_all_captions
        
        # Set paths
        self.image_dir = DATASET_DIR / f'{split}2014'
        self.cache_file = DATASET_DIR / f'{split}_text_embeddings.pt'
        
        # Load cached embeddings
        print(f"Loading cached embeddings from {self.cache_file.name}...")
        cache = torch.load(self.cache_file)
        self.cache_data = cache['data']
        self.embedding_dim = cache['embedding_dim']
        
        # Build index: image_id -> cache index
        self.image_id_to_idx = {
            item['image_id']: idx 
            for idx, item in enumerate(self.cache_data)
        }
        
        print(f"  ✓ Loaded {len(self.cache_data)} images")
        print(f"  ✓ Embedding dimension: {self.embedding_dim}")
        
    def __len__(self):
        return len(self.cache_data)
    
    def __getitem__(self, idx):
        # Get cached data
        item = self.cache_data[idx]
        image_id = item['image_id']
        embeddings = item['embeddings']  # [num_captions, 512]
        captions = item['captions']
        
        # Load image
        image_filename = f'COCO_{self.split}2014_{image_id:012d}.jpg'
        image_path = self.image_dir / image_filename
        
        try:
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a black image as fallback
            image = torch.zeros(3, 224, 224)
        
        # Select caption(s)
        if self.return_all_captions:
            # Return all captions and embeddings
            return {
                'image': image,
                'text_embeddings': embeddings,
                'captions': captions,
                'image_id': image_id
            }
        else:
            # Randomly select one caption
            caption_idx = random.randint(0, len(captions) - 1)
            return {
                'image': image,
                'text_embedding': embeddings[caption_idx],
                'caption': captions[caption_idx],
                'image_id': image_id
            }

print("✓ COCOClipDataset class defined successfully!")

## 10. Create Dataset Instances

In [None]:
# Create train and val datasets
print("Creating dataset instances...\n")

train_dataset = COCOClipDataset(split='train', transform=image_transforms)
print()
val_dataset = COCOClipDataset(split='val', transform=image_transforms)

print(f"\n{'='*60}")
print("Dataset Summary")
print(f"{'='*60}")
print(f"Training set size: {len(train_dataset):,} images")
print(f"Validation set size: {len(val_dataset):,} images")
print(f"Total: {len(train_dataset) + len(val_dataset):,} images")
print(f"{'='*60}")

# Test loading a single sample
print("\nTesting dataset loading...")
sample = train_dataset[0]
print(f"  ✓ Image shape: {sample['image'].shape}")
print(f"  ✓ Text embedding shape: {sample['text_embedding'].shape}")
print(f"  ✓ Caption: \"{sample['caption']}\"")
print(f"  ✓ Image ID: {sample['image_id']}")

## 11. Visualize Random Image-Caption Pairs

Verify that images and captions are correctly loaded and preprocessed.

In [None]:
def visualize_samples(dataset, num_samples=6, figsize=(15, 10)):
    """
    Visualize random samples from the dataset.
    
    Args:
        dataset: COCOClipDataset instance
        num_samples: Number of samples to display
        figsize: Figure size
    """
    # Select random indices
    indices = random.sample(range(len(dataset)), num_samples)
    
    # Create subplot grid
    rows = (num_samples + 2) // 3
    cols = min(3, num_samples)
    
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    if num_samples == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for idx, ax in zip(indices, axes):
        # Get sample
        sample = dataset[idx]
        image = sample['image']
        caption = sample['caption']
        image_id = sample['image_id']
        
        # Denormalize image for display
        image_display = denormalize(image)
        image_display = torch.clamp(image_display, 0, 1)
        image_display = image_display.permute(1, 2, 0).numpy()
        
        # Display image
        ax.imshow(image_display)
        ax.axis('off')
        
        # Add caption as title
        wrapped_caption = '\n'.join(
            [caption[i:i+40] for i in range(0, len(caption), 40)]
        )
        ax.set_title(f"ID: {image_id}\n{wrapped_caption}", 
                     fontsize=9, pad=10)
    
    # Hide extra subplots
    for ax in axes[num_samples:]:
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nDisplayed {num_samples} random samples")
    print(f"  Image tensor shape: {sample['image'].shape}")
    print(f"  Text embedding shape: {sample['text_embedding'].shape}")

# Visualize training samples
print("Training Set Samples:")
print("="*60)
visualize_samples(train_dataset, num_samples=6)

print("\n" + "="*60)
print("Validation Set Samples:")
print("="*60)
visualize_samples(val_dataset, num_samples=6)

## 12. Dataset Statistics and Sanity Checks

In [None]:
def dataset_statistics(dataset, name='Dataset', num_check=100):
    """
    Compute and display dataset statistics.
    
    Args:
        dataset: COCOClipDataset instance
        name: Dataset name for display
        num_check: Number of samples to check for integrity
    """
    print(f"\n{'='*60}")
    print(f"{name} Statistics")
    print(f"{'='*60}")
    
    # Basic stats
    print(f"Total samples: {len(dataset):,}")
    
    # Check a subset for integrity
    print(f"\nChecking {num_check} random samples for integrity...")
    
    indices = random.sample(range(len(dataset)), min(num_check, len(dataset)))
    
    valid_count = 0
    image_shapes = []
    embedding_shapes = []
    caption_lengths = []
    
    for idx in tqdm(indices, desc="Validating"):
        try:
            sample = dataset[idx]
            
            # Check shapes
            img_shape = sample['image'].shape
            emb_shape = sample['text_embedding'].shape
            cap_len = len(sample['caption'])
            
            image_shapes.append(img_shape)
            embedding_shapes.append(emb_shape)
            caption_lengths.append(cap_len)
            
            # Check for expected shapes
            assert img_shape == (3, 224, 224), f"Invalid image shape: {img_shape}"
            assert emb_shape == (512,), f"Invalid embedding shape: {emb_shape}"
            assert cap_len > 0, "Empty caption"
            
            valid_count += 1
            
        except Exception as e:
            print(f"  ✗ Error at index {idx}: {e}")
    
    print(f"\nIntegrity Check Results:")
    print(f"  Valid samples: {valid_count}/{num_check}")
    print(f"  Success rate: {100*valid_count/num_check:.2f}%")
    
    if valid_count > 0:
        print(f"\nShape Statistics:")
        print(f"  Image shape: {image_shapes[0]} (all samples)")
        print(f"  Text embedding shape: {embedding_shapes[0]} (all samples)")
        print(f"  Caption length range: {min(caption_lengths)}-{max(caption_lengths)} chars")
        print(f"  Average caption length: {sum(caption_lengths)/len(caption_lengths):.1f} chars")
    
    print(f"{'='*60}")

# Run statistics on both datasets
dataset_statistics(train_dataset, name='Training Set', num_check=100)
dataset_statistics(val_dataset, name='Validation Set', num_check=100)

# Final summary
print(f"\n{'='*60}")
print("Dataset Preparation Complete!")
print(f"{'='*60}")
print("✓ Images downloaded and organized")
print("✓ Text embeddings cached")
print("✓ PyTorch datasets created")
print("✓ Data integrity verified")
print("\nNext step: Run train_clip.ipynb to train the model!")
print(f"{'='*60}")