## 1. Imports and Setup

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Configuration
DATA_DIR = '../data'
IMAGES_DIR = os.path.join(DATA_DIR, 'Images')
CAPTIONS_FILE = os.path.join(DATA_DIR, 'captions.txt')
MODELS_DIR = '../models'

## 2. Load Model

In [None]:
# Model architecture (same as training)
class EncoderCNN(nn.Module):
    """CNN Encoder using pre-trained ResNet50."""
    
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size)
        
    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        features = self.bn(features)
        return features


class DecoderRNN(nn.Module):
    """LSTM Decoder for caption generation."""
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2, dropout=0.5):
        super(DecoderRNN, self).__init__()
        
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, 
                           batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, features, captions):
        embeddings = self.embedding(captions[:, :-1])
        features = features.unsqueeze(1)
        embeddings = torch.cat([features, embeddings], dim=1)
        lstm_out, _ = self.lstm(embeddings)
        lstm_out = self.dropout(lstm_out)
        outputs = self.fc(lstm_out)
        return outputs
    
    def generate(self, features, max_length=30, temperature=1.0):
        """Generate caption for inference."""
        generated = []
        h = torch.zeros(self.num_layers, 1, self.hidden_size).to(features.device)
        c = torch.zeros(self.num_layers, 1, self.hidden_size).to(features.device)
        x = features.unsqueeze(1)
        
        for _ in range(max_length):
            lstm_out, (h, c) = self.lstm(x, (h, c))
            output = self.fc(lstm_out.squeeze(1))
            output = output / temperature
            predicted = output.argmax(dim=1)
            generated.append(predicted.item())
            
            if predicted.item() == 2:  # <END> token
                break
            
            x = self.embedding(predicted).unsqueeze(1)
        
        return generated


class ImageCaptioningModel(nn.Module):
    """Complete Image Captioning Model."""
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs
    
    def generate_caption_indices(self, image, max_length=30, temperature=1.0):
        """Generate caption indices for a single image."""
        self.eval()
        with torch.no_grad():
            features = self.encoder(image)
            caption_indices = self.decoder.generate(features, max_length, temperature)
        return caption_indices

In [None]:
# Vocabulary class for decoding
class Vocabulary:
    """Vocabulary class for mapping words to indices and vice versa."""
    
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.pad_token = '<PAD>'
        self.start_token = '<START>'
        self.end_token = '<END>'
        self.unk_token = '<UNK>'
    
    def load(self, vocab_data):
        """Load vocabulary from saved data."""
        self.word2idx = vocab_data['word2idx']
        self.idx2word = {int(k): v for k, v in vocab_data['idx2word'].items()}
    
    def decode(self, indices):
        """Convert list of indices back to caption string."""
        words = []
        for idx in indices:
            word = self.idx2word.get(idx, self.unk_token)
            if word == self.end_token:
                break
            if word not in [self.start_token, self.pad_token]:
                words.append(word)
        return ' '.join(words)
    
    def __len__(self):
        return len(self.word2idx)

In [None]:
# Load vocabulary
with open(os.path.join(MODELS_DIR, 'vocab.pkl'), 'rb') as f:
    vocab_data = pickle.load(f)

vocab = Vocabulary()
vocab.load(vocab_data)
print(f'Vocabulary loaded: {len(vocab)} words')

In [None]:
# Load model
checkpoint = torch.load(os.path.join(MODELS_DIR, 'best_model.pth'), map_location=device)

# Get model parameters from final_model.pth which has the config
model_config = torch.load(os.path.join(MODELS_DIR, 'final_model.pth'), map_location=device)

model = ImageCaptioningModel(
    embed_size=model_config['embed_size'],
    hidden_size=model_config['hidden_size'],
    vocab_size=model_config['vocab_size'],
    num_layers=model_config['num_layers']
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Model loaded from epoch {checkpoint['epoch']} with val_loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Image transform (same as training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

## 3. Generate Caption Function

In [None]:
def generate_caption(image_path: str, model: any) -> str:
    """
    Takes a path to an image and returns a generated caption string.
    
    Args:
        image_path: Path to the image file
        model: The trained ImageCaptioningModel
    
    Returns:
        str: Generated caption for the image
    """
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Generate caption indices
    model.eval()
    with torch.no_grad():
        caption_indices = model.generate_caption_indices(image_tensor)
    
    # Decode to string
    caption = vocab.decode(caption_indices)
    
    return caption

In [None]:
# Test the function
test_images = os.listdir(IMAGES_DIR)[:3]
for img_name in test_images:
    img_path = os.path.join(IMAGES_DIR, img_name)
    caption = generate_caption(img_path, model)
    print(f'{img_name}: {caption}')

## 4. Demonstration

In [None]:
# Load validation images and ground truth captions
with open(os.path.join(MODELS_DIR, 'val_images.pkl'), 'rb') as f:
    val_images = pickle.load(f)

# Load captions dataframe for ground truth
df = pd.read_csv(CAPTIONS_FILE)

print(f'Validation images: {len(val_images)}')

In [None]:
def display_prediction(image_name, show_all_captions=True):
    """
    Display an image with its generated caption and ground truth captions.
    """
    img_path = os.path.join(IMAGES_DIR, image_name)
    
    # Generate caption
    generated = generate_caption(img_path, model)
    
    # Get ground truth captions
    ground_truth = df[df['image'] == image_name]['caption'].tolist()
    
    # Display
    fig, ax = plt.subplots(figsize=(10, 10))
    image = Image.open(img_path)
    ax.imshow(image)
    ax.axis('off')
    
    title = f'Generated: {generated}\n\n'
    if show_all_captions:
        title += 'Ground Truth:\n'
        for i, gt in enumerate(ground_truth, 1):
            title += f'{i}. {gt}\n'
    else:
        title += f'Ground Truth: {ground_truth[0]}'
    
    ax.set_title(title, fontsize=10, wrap=True)
    plt.tight_layout()
    plt.show()
    
    return generated, ground_truth

In [None]:
# Demonstrate on random validation images
np.random.seed(42)
sample_images = np.random.choice(val_images, size=10, replace=False)

results = []
for img_name in sample_images:
    gen, gt = display_prediction(img_name)
    results.append({'image': img_name, 'generated': gen, 'ground_truth': gt})

## 5. Analysis

### Successful Captions

We analyze cases where the model generates semantically accurate captions that capture the main subjects and actions in the image.

In [None]:
def compute_word_overlap(generated, ground_truths):
    """
    Compute the word overlap between generated caption and ground truths.
    Returns the best overlap ratio among all ground truth captions.
    """
    gen_words = set(generated.lower().split())
    
    best_overlap = 0
    for gt in ground_truths:
        gt_words = set(gt.lower().split())
        if len(gen_words) > 0:
            overlap = len(gen_words & gt_words) / len(gen_words)
            best_overlap = max(best_overlap, overlap)
    
    return best_overlap

In [None]:
# Evaluate all validation images and find successes/failures
all_results = []

print('Evaluating validation set...')
for img_name in val_images[:100]:  # Evaluate first 100 for speed
    img_path = os.path.join(IMAGES_DIR, img_name)
    generated = generate_caption(img_path, model)
    ground_truth = df[df['image'] == img_name]['caption'].tolist()
    overlap = compute_word_overlap(generated, ground_truth)
    
    all_results.append({
        'image': img_name,
        'generated': generated,
        'ground_truth': ground_truth,
        'overlap': overlap
    })

# Sort by overlap score
all_results_sorted = sorted(all_results, key=lambda x: x['overlap'], reverse=True)

print(f'Average word overlap: {np.mean([r["overlap"] for r in all_results]):.2%}')

In [None]:
# Display successful captions (high overlap)
print('=== SUCCESSFUL CAPTIONS (High Overlap) ===')
print()

for result in all_results_sorted[:5]:
    print(f"Image: {result['image']}")
    print(f"Generated: {result['generated']}")
    print(f"Ground Truth (1 of {len(result['ground_truth'])}): {result['ground_truth'][0]}")
    print(f"Word Overlap: {result['overlap']:.2%}")
    print('-' * 50)

In [None]:
# Visualize top 3 successful predictions
print('Top 3 Successful Predictions:')
for result in all_results_sorted[:3]:
    display_prediction(result['image'])

### Failure Cases

We analyze cases where the model generates captions that don't accurately describe the image content. Common failure modes include:
- Hallucinating objects not present in the image
- Missing key subjects or actions
- Generating generic or repetitive descriptions

In [None]:
# Display failure cases (low overlap)
print('=== FAILURE CASES (Low Overlap) ===')
print()

for result in all_results_sorted[-5:]:
    print(f"Image: {result['image']}")
    print(f"Generated: {result['generated']}")
    print(f"Ground Truth (1 of {len(result['ground_truth'])}): {result['ground_truth'][0]}")
    print(f"Word Overlap: {result['overlap']:.2%}")
    print('-' * 50)

In [None]:
# Visualize bottom 3 (failure) predictions
print('Bottom 3 Failure Cases:')
for result in all_results_sorted[-3:]:
    display_prediction(result['image'])

In [None]:
# Analysis summary
overlaps = [r['overlap'] for r in all_results]

print('=== SUMMARY ===' )
print(f'Total images evaluated: {len(all_results)}')
print(f'Average word overlap: {np.mean(overlaps):.2%}')
print(f'Median word overlap: {np.median(overlaps):.2%}')
print(f'High quality (>50% overlap): {sum(1 for o in overlaps if o > 0.5)} images')
print(f'Medium quality (25-50% overlap): {sum(1 for o in overlaps if 0.25 <= o <= 0.5)} images')
print(f'Low quality (<25% overlap): {sum(1 for o in overlaps if o < 0.25)} images')

In [None]:
# Plot overlap distribution
plt.figure(figsize=(10, 5))
plt.hist(overlaps, bins=20, edgecolor='black', alpha=0.7)
plt.xlabel('Word Overlap Score')
plt.ylabel('Number of Images')
plt.title('Distribution of Caption Quality (Word Overlap with Ground Truth)')
plt.axvline(np.mean(overlaps), color='red', linestyle='--', label=f'Mean: {np.mean(overlaps):.2%}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### Common Failure Patterns

Based on the analysis above, common failure patterns include:

1. **Object Misidentification**: The model sometimes confuses similar objects (e.g., dog vs. cat)
2. **Action Confusion**: The model may describe the wrong action taking place
3. **Generic Captions**: In ambiguous images, the model may fall back to generic descriptions
4. **Missing Context**: Important contextual elements like settings or backgrounds may be ignored

These issues are typical for image captioning models and can be improved with:
- More training data
- Attention mechanisms
- Larger encoder networks
- Beam search decoding

In [None]:
# Final demonstration: caption any image
def caption_image(image_path):
    """
    Generate and display caption for any image.
    """
    caption = generate_caption(image_path, model)
    
    fig, ax = plt.subplots(figsize=(8, 8))
    image = Image.open(image_path)
    ax.imshow(image)
    ax.axis('off')
    ax.set_title(f'Caption: {caption}', fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return caption

# Example usage:
# caption_image('path/to/your/image.jpg')