In [None]:
!pip install ndlinear

In [9]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from transformers import ViTModel, ViTConfig
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
from PIL import Image
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import time
from tqdm import tqdm
import torchvision.transforms as T
import glob
from ndlinear import NdLinear

In [10]:
class ImageTextDataset(Dataset):
    def __init__(self, data_dir, captions_file, tokenizer, transform, max_length=64):
        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_length = max_length
        self.data = self._load_captions(captions_file)
        print(f"Loaded {len(self.data)} image-caption pairs")

    def _load_captions(self, filepath):
        data = []
        with open(filepath, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                
                # Split only on first comma to separate image name and caption
                parts = line.split(',', 1)
                if len(parts) == 2:
                    img_name, caption = parts[0].strip(), parts[1].strip()
                    
                    # Remove quotes if present
                    if caption.startswith('"') and caption.endswith('"'):
                        caption = caption[1:-1]
                    
                    data.append((img_name, caption))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, caption = self.data[idx]
        img_path = os.path.join(self.data_dir, img_name)
        
        # Handle image loading errors gracefully
        try:
            image = Image.open(img_path).convert("RGB")
            image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image (black)
            image = torch.zeros(3, 224, 224)
        
        # Encode the caption
        encoded = self.tokenizer(
            caption, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt"
        )
        
        return {
            "image": image,
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "caption": caption,  # Store original caption for evaluation
            "image_name": img_name  # Store image name for evaluation
        }

In [11]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim, proj_dim):
        super().__init__()
        # self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        # self.projection = nn.Linear(embed_dim, proj_dim)
        # self.layer_norm = nn.LayerNorm(proj_dim)
        self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.projection = NdLinear(input_dims=(embed_dim,), hidden_size=(proj_dim,))
        self.layer_norm = nn.LayerNorm(proj_dim)

    def forward(self, input_ids, attention_mask):
        x = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        x = x[:, 0, :]  # B, T[cls], E
        x = self.projection(x)
        return self.layer_norm(x)

In [12]:
class ImageEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224', proj_dim=256):
        super().__init__()
        self.model = ViTModel.from_pretrained(model_name)
        embed_dim = self.model.config.hidden_size

        for param in self.model.parameters():
            param.requires_grad = True

        # self.projection = nn.Linear(embed_dim, proj_dim)
        # self.layer_norm = nn.LayerNorm(proj_dim)
        self.projection = NdLinear(input_dims=(embed_dim,), hidden_size=(proj_dim,))
        self.layer_norm = nn.LayerNorm(proj_dim)

    def forward(self, x):
        outputs = self.model(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0]  # CLS token
        x = self.projection(cls_token)
        return self.layer_norm(x)

In [13]:
class CLIPModel(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224', proj_dim=256):
        super().__init__()
        self.image_encoder = ImageEncoder(model_name, proj_dim)
        self.text_encoder = TextEncoder(embed_dim=768, proj_dim=proj_dim)
        # Initialize temperature parameter (learnable)
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1/0.07))

    def forward(self, batch, device):
        # Move inputs to device
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Get embeddings
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(input_ids, attention_mask)
        
        # Normalize embeddings
        image_features = F.normalize(image_features, dim=1)
        text_features = F.normalize(text_features, dim=1)

        # Scaled pairwise cosine similarities [n, n]
        logits = torch.matmul(image_features, text_features.T) * torch.exp(self.temperature)
        
        # Contrastive loss
        labels = torch.arange(image_features.size(0)).to(device)
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.T, labels)
        loss = (loss_i2t + loss_t2i) / 2

        return loss, logits, image_features, text_features

    def encode_image(self, image, device):
        image = image.to(device)
        features = self.image_encoder(image)
        return F.normalize(features, dim=1)
    
    def encode_text(self, input_ids, attention_mask, device):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        features = self.text_encoder(input_ids, attention_mask)
        return F.normalize(features, dim=1)

In [14]:
def train_epoch(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    start_time = time.time()
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    for batch in progress_bar:
        loss, _, _, _ = model(batch, device)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})
    
    avg_loss = total_loss / len(dataloader)
    elapsed = time.time() - start_time
    
    print(f"Epoch {epoch+1} completed in {elapsed:.2f}s - Avg Loss: {avg_loss:.4f}")
    return avg_loss

In [15]:
def evaluate(model, dataloader, device):
    model.eval()
    i2t_correct_1 = 0
    i2t_correct_5 = 0
    t2i_correct_1 = 0
    t2i_correct_5 = 0
    total = 0
    
    all_image_features = []
    all_text_features = []
    all_captions = []
    all_image_names = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            loss, logits, image_features, text_features = model(batch, device)
            
            # Store features and metadata for analysis
            all_image_features.append(image_features)
            all_text_features.append(text_features)
            all_captions.extend(batch["caption"])
            all_image_names.extend(batch["image_name"])
            
            # Image-to-text retrieval (find the right caption for each image)
            i2t_similarity = logits
            i2t_sorted_indices = i2t_similarity.argsort(dim=1, descending=True)
            
            # Text-to-image retrieval (find the right image for each caption)
            t2i_similarity = logits.T
            t2i_sorted_indices = t2i_similarity.argsort(dim=1, descending=True)
            
            # Calculate batch accuracy
            batch_size = image_features.size(0)
            targets = torch.arange(batch_size).to(device)
            
            # Top-1 and Top-5 accuracy for image-to-text
            i2t_correct_1 += (i2t_sorted_indices[:, 0] == targets).sum().item()
            for k in range(min(5, batch_size)):
                i2t_correct_5 += (i2t_sorted_indices[:, k] == targets).sum().item()
            
            # Top-1 and Top-5 accuracy for text-to-image
            t2i_correct_1 += (t2i_sorted_indices[:, 0] == targets).sum().item()
            for k in range(min(5, batch_size)):
                t2i_correct_5 += (t2i_sorted_indices[:, k] == targets).sum().item()
            
            total += batch_size
    
    # Combine features for full dataset analysis
    all_image_features = torch.cat(all_image_features, dim=0)
    all_text_features = torch.cat(all_text_features, dim=0)
    
    # Calculate metrics
    i2t_top1_acc = i2t_correct_1 / total
    i2t_top5_acc = i2t_correct_5 / (total * min(5, total))
    t2i_top1_acc = t2i_correct_1 / total
    t2i_top5_acc = t2i_correct_5 / (total * min(5, total))
    
    results = {
        "i2t_top1": i2t_top1_acc,
        "i2t_top5": i2t_top5_acc,
        "t2i_top1": t2i_top1_acc,
        "t2i_top5": t2i_top5_acc,
        "image_features": all_image_features.cpu(),
        "text_features": all_text_features.cpu(),
        "captions": all_captions,
        "image_names": all_image_names
    }
    
    return results

In [16]:
def zero_shot_prediction(model, image_path, text_candidates, tokenizer, transform, device):
    """
    Perform zero-shot prediction using the trained CLIP model.
    
    Args:
        model: Trained CLIP model
        image_path: Path to the query image
        text_candidates: List of textual descriptions to match against
        tokenizer: Text tokenizer
        transform: Image transformation pipeline
        device: Device to run inference on
    
    Returns:
        List of (text, similarity score) pairs sorted by score
    """
    model.eval()
    
    # Load and process image
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    
    # Process text candidates
    text_tokens = tokenizer(
        text_candidates,
        padding='max_length',
        truncation=True,
        max_length=64,
        return_tensors="pt"
    )
    
    with torch.no_grad():
        # Encode image
        image_features = model.encode_image(image, device)
        
        # Encode text candidates
        text_features = model.encode_text(
            text_tokens["input_ids"], 
            text_tokens["attention_mask"],
            device
        )
        
        # Calculate similarities
        similarities = (image_features @ text_features.T).squeeze(0)
        
    # Sort by similarity
    similarities = similarities.cpu().numpy()
    sorted_indices = similarities.argsort()[::-1]
    
    # Return sorted text-similarity pairs
    results = [(text_candidates[i], similarities[i]) for i in sorted_indices]
    return results

In [17]:
def train_and_evaluate(model, train_loader, val_loader, device, save_dir, epochs=20, save_interval=5):
    """
    Train and evaluate the CLIP model
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Track metrics
    train_losses = []
    val_metrics = []
    
    for epoch in range(epochs):
        # Train for one epoch
        avg_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        train_losses.append(avg_loss)
        
        # Evaluate on validation set
        val_results = evaluate(model, val_loader, device)
        val_metrics.append({
            "epoch": epoch + 1,
            "i2t_top1": val_results["i2t_top1"],
            "i2t_top5": val_results["i2t_top5"],
            "t2i_top1": val_results["t2i_top1"],
            "t2i_top5": val_results["t2i_top5"]
        })
        
        print(f"Validation metrics:")
        print(f"  Image-to-Text: Top-1: {val_results['i2t_top1']:.4f}, Top-5: {val_results['i2t_top5']:.4f}")
        print(f"  Text-to-Image: Top-1: {val_results['t2i_top1']:.4f}, Top-5: {val_results['t2i_top5']:.4f}")
        
        # Save checkpoint at intervals
        if (epoch + 1) % save_interval == 0:
            checkpoint_path = os.path.join(save_dir, f"clip_model_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'val_metrics': val_metrics[-1]
            }, checkpoint_path)
            print(f"Model checkpoint saved to {checkpoint_path}")
    
    # Save final model
    final_path = os.path.join(save_dir, "clip_model_final.pt")
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        'val_metrics': val_metrics[-1]
    }, final_path)
    print(f"Final model saved to {final_path}")
    
    return {
        "train_losses": train_losses,
        "val_metrics": val_metrics
    }

In [18]:
def visualize_results(train_results, save_dir):
    """
    Visualize training progress and metrics
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Plot training loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_results["train_losses"])
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'training_loss.png'))
    
    # Plot validation metrics
    plt.figure(figsize=(12, 6))
    epochs = [m["epoch"] for m in train_results["val_metrics"]]
    plt.plot(epochs, [m["i2t_top1"] for m in train_results["val_metrics"]], label='I2T Top-1')
    plt.plot(epochs, [m["i2t_top5"] for m in train_results["val_metrics"]], label='I2T Top-5')
    plt.plot(epochs, [m["t2i_top1"] for m in train_results["val_metrics"]], label='T2I Top-1')
    plt.plot(epochs, [m["t2i_top5"] for m in train_results["val_metrics"]], label='T2I Top-5')
    plt.title('Validation Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'validation_metrics.png'))
    
    plt.close('all')

In [19]:
def main():
    
    # Configuration
    data_dir = "/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images"  # Update with your path
    captions_file = "/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/captions.txt"  # Update with your path
    save_dir = "./model_checkpoints"
    results_dir = "./results"
    
    # Model and training parameters
    proj_dim = 256
    batch_size = 32
    num_epochs = 15
    save_interval = 5
    random_seed = 42
    
    # Set random seeds for reproducibility
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize tokenizer and transforms
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Load dataset
    print("Loading dataset...")
    dataset = ImageTextDataset(
        data_dir=data_dir,
        captions_file=captions_file,
        tokenizer=tokenizer,
        transform=transform
    )
    
    # Split dataset into train and validation sets
    train_indices, val_indices = train_test_split(
        list(range(len(dataset))), 
        test_size=0.1,
        random_state=random_seed
    )
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=1,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )
    
    # Initialize model
    print("Initializing CLIP model...")
    model = CLIPModel(
        model_name='google/vit-base-patch16-224',
        proj_dim=proj_dim
    ).to(device)
    
    # Train and evaluate model
    print("Starting training...")
    train_results = train_and_evaluate(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        save_dir=save_dir,
        epochs=num_epochs,
        save_interval=save_interval
    )
    
    # Visualize results
    print("Visualizing results...")
    visualize_results(train_results, results_dir)
    
    # Perform zero-shot prediction example
    print("\nPerforming zero-shot prediction example:")
    # Use the first image from validation set as example
    example_batch = next(iter(val_loader))
    example_img_name = example_batch["image_name"][0]
    example_img_path = os.path.join(data_dir, example_img_name)
    
    # Create some sample text candidates (including the actual caption)
    actual_caption = example_batch["caption"][0]
    text_candidates = [
        actual_caption,
        "A dog running on the beach",
        "A cat sitting on a window sill",
        "A person hiking in the mountains",
        "Children playing in a park"
    ]
    
    print(f"Query image: {example_img_name}")
    print(f"Actual caption: {actual_caption}")
    
    results = zero_shot_prediction(
        model=model,
        image_path=example_img_path,
        text_candidates=text_candidates,
        tokenizer=tokenizer,
        transform=transform,
        device=device
    )
    
    print("Zero-shot predictions (sorted by similarity):")
    for text, score in results:
        print(f"  Score: {score:.4f} - Text: {text}")
    
    print("\nTraining and evaluation completed!")


In [19]:
if __name__ == "__main__":
    main()

Using device: cuda
Loading dataset...
Loaded 40456 image-caption pairs
Training set size: 36410
Validation set size: 4046
Initializing CLIP model...


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training...


Epoch 1:  27%|██▋       | 306/1138 [02:01<05:55,  2.34it/s, loss=0.639]

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 1: 100%|██████████| 1138/1138 [07:41<00:00,  2.46it/s, loss=0.134]


Epoch 1 completed in 461.94s - Avg Loss: 0.6242


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.50it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.8997, Top-5: 0.1991
  Text-to-Image: Top-1: 0.8982, Top-5: 0.1991


Epoch 2:  70%|███████   | 802/1138 [05:09<02:06,  2.65it/s, loss=0.247] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 2: 100%|██████████| 1138/1138 [07:20<00:00,  2.59it/s, loss=0.119] 


Epoch 2 completed in 440.02s - Avg Loss: 0.2147


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.46it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9100, Top-5: 0.1991
  Text-to-Image: Top-1: 0.9021, Top-5: 0.1991


Epoch 3:  82%|████████▏ | 931/1138 [06:00<01:23,  2.49it/s, loss=0.0457]

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 3: 100%|██████████| 1138/1138 [07:20<00:00,  2.58it/s, loss=0.205] 


Epoch 3 completed in 440.33s - Avg Loss: 0.1659


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.45it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9142, Top-5: 0.1992
  Text-to-Image: Top-1: 0.9071, Top-5: 0.1989


Epoch 4:  39%|███▊      | 439/1138 [02:50<04:24,  2.64it/s, loss=0.0919]

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 4: 100%|██████████| 1138/1138 [07:20<00:00,  2.58it/s, loss=0.198] 


Epoch 4 completed in 440.42s - Avg Loss: 0.1386


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.46it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9135, Top-5: 0.1990
  Text-to-Image: Top-1: 0.9093, Top-5: 0.1990


Epoch 5:   8%|▊         | 92/1138 [00:35<06:44,  2.59it/s, loss=0.182] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'

Epoch 5:   8%|▊         | 93/1138 [00:36<06:39,  2.62it/s, loss=0.105]




Epoch 5: 100%|██████████| 1138/1138 [07:20<00:00,  2.58it/s, loss=0.12]  


Epoch 5 completed in 440.40s - Avg Loss: 0.1357


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.45it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9152, Top-5: 0.1993
  Text-to-Image: Top-1: 0.9105, Top-5: 0.1991
Model checkpoint saved to ./model_checkpoints/clip_model_epoch_5.pt


Epoch 6:  88%|████████▊ | 1002/1138 [06:28<00:55,  2.46it/s, loss=0.187] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 6: 100%|██████████| 1138/1138 [07:20<00:00,  2.58it/s, loss=0.0977]


Epoch 6 completed in 440.82s - Avg Loss: 0.1157


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.45it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9187, Top-5: 0.1989
  Text-to-Image: Top-1: 0.9197, Top-5: 0.1991


Epoch 7:   7%|▋         | 77/1138 [00:29<06:44,  2.62it/s, loss=0.0937] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 7: 100%|██████████| 1138/1138 [07:21<00:00,  2.57it/s, loss=0.233]  


Epoch 7 completed in 441.98s - Avg Loss: 0.1057


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.46it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9086, Top-5: 0.1990
  Text-to-Image: Top-1: 0.9152, Top-5: 0.1989


Epoch 8:   0%|          | 0/1138 [00:00<?, ?it/s]

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 8: 100%|██████████| 1138/1138 [07:21<00:00,  2.58it/s, loss=0.043] 


Epoch 8 completed in 441.73s - Avg Loss: 0.1062


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.43it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9187, Top-5: 0.1990
  Text-to-Image: Top-1: 0.9194, Top-5: 0.1990


Epoch 9:  19%|█▊        | 211/1138 [01:21<05:52,  2.63it/s, loss=0.0695]

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 9: 100%|██████████| 1138/1138 [07:21<00:00,  2.58it/s, loss=0.0709]


Epoch 9 completed in 441.81s - Avg Loss: 0.0974


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.44it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9147, Top-5: 0.1992
  Text-to-Image: Top-1: 0.9182, Top-5: 0.1990


Epoch 10:  79%|███████▊  | 895/1138 [05:47<01:33,  2.60it/s, loss=0.115]  

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 10: 100%|██████████| 1138/1138 [07:21<00:00,  2.57it/s, loss=0.0242] 


Epoch 10 completed in 441.97s - Avg Loss: 0.0974


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.44it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9207, Top-5: 0.1994
  Text-to-Image: Top-1: 0.9207, Top-5: 0.1992
Model checkpoint saved to ./model_checkpoints/clip_model_epoch_10.pt


Epoch 11:  31%|███       | 352/1138 [02:16<05:01,  2.60it/s, loss=0.0462] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 11: 100%|██████████| 1138/1138 [07:21<00:00,  2.58it/s, loss=0.0844] 


Epoch 11 completed in 441.47s - Avg Loss: 0.0846


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.45it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9219, Top-5: 0.1994
  Text-to-Image: Top-1: 0.9197, Top-5: 0.1991


Epoch 12:  32%|███▏      | 361/1138 [02:19<04:57,  2.61it/s, loss=0.189]  

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 12: 100%|██████████| 1138/1138 [07:20<00:00,  2.58it/s, loss=0.026]  


Epoch 12 completed in 440.70s - Avg Loss: 0.0832


Evaluating: 100%|██████████| 127/127 [00:36<00:00,  3.46it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9145, Top-5: 0.1989
  Text-to-Image: Top-1: 0.9152, Top-5: 0.1990


Epoch 13:  14%|█▍        | 160/1138 [01:01<06:09,  2.65it/s, loss=0.102]  

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 13: 100%|██████████| 1138/1138 [07:21<00:00,  2.58it/s, loss=0.0588] 


Epoch 13 completed in 441.87s - Avg Loss: 0.0843


Evaluating: 100%|██████████| 127/127 [00:37<00:00,  3.36it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9184, Top-5: 0.1992
  Text-to-Image: Top-1: 0.9286, Top-5: 0.1992


Epoch 14:  36%|███▌      | 409/1138 [02:38<04:56,  2.46it/s, loss=0.0549] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 14: 100%|██████████| 1138/1138 [07:22<00:00,  2.57it/s, loss=0.103]  


Epoch 14 completed in 442.19s - Avg Loss: 0.0799


Evaluating: 100%|██████████| 127/127 [00:38<00:00,  3.33it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9204, Top-5: 0.1990
  Text-to-Image: Top-1: 0.9157, Top-5: 0.1990


Epoch 15:  84%|████████▎ | 951/1138 [06:09<01:11,  2.60it/s, loss=0.0536] 

Error loading image /home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image: [Errno 2] No such file or directory: '/home/b.gandhi/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images/image'


Epoch 15: 100%|██████████| 1138/1138 [07:21<00:00,  2.58it/s, loss=0.02]   


Epoch 15 completed in 441.51s - Avg Loss: 0.0748


Evaluating: 100%|██████████| 127/127 [00:37<00:00,  3.35it/s]


Validation metrics:
  Image-to-Text: Top-1: 0.9236, Top-5: 0.1990
  Text-to-Image: Top-1: 0.9256, Top-5: 0.1989
Model checkpoint saved to ./model_checkpoints/clip_model_epoch_15.pt
Final model saved to ./model_checkpoints/clip_model_final.pt
Visualizing results...

Performing zero-shot prediction example:
Query image: 3606093421_eddd46c2c7.jpg
Actual caption: two men in an orange raft boat
Zero-shot predictions (sorted by similarity):
  Score: 0.6272 - Text: two men in an orange raft boat
  Score: 0.0257 - Text: A dog running on the beach
  Score: 0.0018 - Text: A person hiking in the mountains
  Score: -0.0535 - Text: A cat sitting on a window sill
  Score: -0.3343 - Text: Children playing in a park

Training and evaluation completed!


In [20]:
import random
def test_random_image(checkpoint_path, image_dir):
    """
    Test the CLIP model on a random image from the specified directory
    """
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model and load checkpoint
    model = CLIPModel(proj_dim=256).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"Model loaded from {checkpoint_path}")
    
    # Initialize tokenizer and image transform
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Define some test classes for zero-shot classification
    classes = [
        "dog", "cat", "bird", "person", "car", "mountain", "beach", 
        "building", "flower", "tree", "food", "sunset", "garden", "trees", "city"
    ]

    # Define some test captions
    captions = [
        "A dog running in a field",
        "A cat sleeping on a couch",
        "A person hiking in the mountains",
        "A beautiful sunset over the ocean",
        "Children playing in a park",
        "A car driving on a highway",
        "A building in a city skyline",
        "Flowers in a garden",
        "A plate of delicious food",
        "A dog sitting on the beach",
        "A view of garden and skyscrapers"
    ]
    image_paths = glob.glob(image_dir + "*.jpeg")
    for img_path in image_paths:
        image = Image.open(img_path).convert("RGB")
        image_tensor = transform(image).unsqueeze(0)

        
        # Test zero-shot classification
        print("\nPerforming zero-shot classification...")
        with torch.no_grad():
            # Encode image
            image_features = model.encode_image(image_tensor, device)

            # Create text prompts using a template
            text_prompts = [f"a photo of a {cls}" for cls in classes]

            # Tokenize text
            text_tokens = tokenizer(
                text_prompts,
                padding='max_length',
                truncation=True,
                max_length=64,
                return_tensors="pt"
            )

            # Encode text
            text_features = model.encode_text(
                text_tokens["input_ids"], 
                text_tokens["attention_mask"],
                device
            )

            # Calculate similarities
            similarities = (image_features @ text_features.T).squeeze(0).cpu().numpy()

            # Sort by similarity
            sorted_indices = similarities.argsort()[::-1]

            # Print results
            print("Top 5 classifications:")
            for i in sorted_indices[:5]:
                print(f"  {classes[i]}: {similarities[i]:.4f}")

        # Test caption matching
        print("\nPerforming caption matching...")
        with torch.no_grad():
            # Tokenize captions
            caption_tokens = tokenizer(
                captions,
                padding='max_length',
                truncation=True,
                max_length=64,
                return_tensors="pt"
            )

            # Encode captions
            caption_features = model.encode_text(
                caption_tokens["input_ids"], 
                caption_tokens["attention_mask"],
                device
            )

            # Calculate similarities
            similarities = (image_features @ caption_features.T).squeeze(0).cpu().numpy()

            # Sort by similarity
            sorted_indices = similarities.argsort()[::-1]

            # Print results
            print("Top 3 matching captions:")
            for i in sorted_indices[:3]:
                print(f"  Score: {similarities[i]:.4f} - {captions[i]}")

In [21]:
checkpoint_path = "./model_checkpoints/clip_model_final.pt"

In [25]:
image_dir = "../CLIP/test_images/"

In [26]:
test_random_image(checkpoint_path, image_dir)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using device: cuda
Model loaded from ./model_checkpoints/clip_model_final.pt

Performing zero-shot classification...
Top 5 classifications:
  beach: 0.4875
  sunset: 0.4549
  dog: 0.3377
  food: 0.3081
  flower: 0.2660

Performing caption matching...
Top 3 matching captions:
  Score: 0.5803 - A dog sitting on the beach
  Score: 0.5247 - A beautiful sunset over the ocean
  Score: 0.2511 - A dog running in a field

Performing zero-shot classification...
Top 5 classifications:
  city: 0.5318
  garden: 0.4198
  bird: 0.3731
  flower: 0.3682
  tree: 0.3492

Performing caption matching...
Top 3 matching captions:
  Score: 0.7126 - A view of garden and skyscrapers
  Score: 0.6368 - A building in a city skyline
  Score: 0.3620 - Flowers in a garden
