In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        os.path.join(dirname, filename)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [32]:
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import spacy
from tqdm import tqdm

import warnings
warnings.simplefilter('ignore')

# Enable deterministic CuDNN for consistent behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load English language model for tokenization
spacy_eng = spacy.load("en_core_web_sm")

Using device: cuda


# **1- Vocabulary & Text Processing**

In [33]:
class Vocabulary:
    """Vocabulary class to handle text tokenization and numericalization"""
    
    def __init__(self, freq_threshold):
        # Initialize special tokens
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
    
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        """Tokenize English text using spaCy"""
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentences):
        """Build vocabulary from sentences based on frequency threshold"""
        idx = 4  # Start index after special tokens
        frequency = {}
        
        for sentence in sentences:
            for word in self.tokenizer_eng(sentence):
                if word not in frequency:
                    frequency[word] = 1
                else:
                    frequency[word] += 1
                
                # Add word to vocabulary if it reaches frequency threshold
                if frequency[word] == self.freq_threshold:
                    self.itos[idx] = word
                    self.stoi[word] = idx
                    idx += 1
    
    def numericalize(self, sentence):
        """Convert sentence to numerical tokens"""
        tokenized_text = self.tokenizer_eng(sentence)
        return [self.stoi[word] if word in self.stoi else self.stoi["<UNK>"] 
                for word in tokenized_text]

# **2- Dataset Preparation**

In [None]:
class FlickrDataset(Dataset):
    """Dataset class for Flickr8k images and captions"""
    
    def __init__(self, root_dir, caption_path, freq_threshold=5, transform=None):
        self.root_dir = root_dir
        self.df = pd.read_csv(caption_path)
        self.transform = transform
        self.captions = self.df['caption']
        self.images = self.df['image']
        
        # Initialize and build vocabulary
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        """Return image and numericalized caption for a given index"""
        caption = self.captions[index]
        img_name = self.images[index]
        img = Image.open(os.path.join(self.root_dir, img_name)).convert("RGB")
        
        # Apply image transformations if specified
        if self.transform:
            img = self.transform(img)
        
        # Convert caption to numerical tokens with SOS and EOS
        numerical_caption = [self.vocab.stoi["<SOS>"]]
        numerical_caption += self.vocab.numericalize(caption)
        numerical_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numerical_caption)

In [34]:
class MyCollate:
    """Collate function to handle padding in variable-length captions"""
    
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        # Stack images
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        
        # Pad captions and transpose for LSTM (seq_len, batch_size)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        
        return imgs, targets

# **3- CNN Encoder Architecture**

In [35]:
class EncoderCNN(nn.Module):
    """CNN encoder using ResNet-50 pretrained model"""
    
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        # Load pretrained ResNet-50
        self.cnn = models.resnet50(pretrained=True)
        
        # Freeze all parameters
        for param in self.cnn.parameters():
            param.requires_grad_(False)
        
        # Replace final fully connected layer
        in_features = self.cnn.fc.in_features
        self.cnn.fc = nn.Linear(in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images):
        """Extract image features"""
        features = self.dropout(self.relu(self.cnn(images)))
        return features.unsqueeze(1)

# **4- Attention Mechanism**

In [36]:
class Attention(nn.Module):
    """Attention mechanism for decoder"""
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        """Calculate attention weights and context vector"""
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        
        # Compute attention scores
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        
        # Compute context vector
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        
        return context, alpha

# **5- LSTM Decoder**

In [37]:
class DecoderRNN(nn.Module):
    """LSTM decoder with attention mechanism"""
    
    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim, encoder_dim=256, dropout=0.5):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(encoder_dim, hidden_size, attention_dim)
        
        # LSTM takes concatenated [embedded word + context vector]
        self.lstm = nn.LSTM(embed_size + encoder_dim, hidden_size, batch_first=False)
        
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.encoder_dim = encoder_dim
        self.hidden_size = hidden_size
    
    def forward(self, features, captions):
        """Generate caption predictions"""
        batch_size = features.size(0)
        
        # Embed captions (seq_len, batch_size) -> (seq_len, batch_size, embed_size)
        embeddings = self.dropout(self.embed(captions))
        
        # Initialize hidden state and cell state
        h = torch.zeros(1, batch_size, self.hidden_size).to(device)
        c = torch.zeros(1, batch_size, self.hidden_size).to(device)
        
        # Prepare output tensor (seq_len, batch_size, vocab_size)
        seq_len = captions.size(0)
        outputs = torch.zeros(seq_len, batch_size, self.fc.out_features).to(device)
        
        # Process each time step
        for t in range(seq_len):
            # Get context vector using attention
            context, alpha = self.attention(features, h.squeeze(0))
            
            # Combine embedded word and context
            lstm_input = torch.cat([embeddings[t], context], dim=1).unsqueeze(0)
            
            # LSTM step
            lstm_out, (h, c) = self.lstm(lstm_input, (h, c))
            
            # Predict next word
            output = self.fc(self.dropout(lstm_out.squeeze(0)))
            outputs[t] = output
        
        return outputs

# **6- Model Integration**

In [38]:
class CNNtoRNN(nn.Module):
    """Complete image captioning model combining CNN encoder and RNN decoder"""
    
    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, attention_dim, embed_size)
    
    def forward(self, images, captions):
        """Forward pass through the entire model"""
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# **7- Training Configuration**

In [39]:
def get_loader(root_folder, annotation_file, transform, batch_size=16, num_workers=4, shuffle=True):
    """Create data loader for the dataset"""
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        collate_fn=MyCollate(pad_idx),
        pin_memory=True
    )
    return loader, dataset

In [40]:
# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to ResNet input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize
])

In [41]:
# Prepare data loader
loader, dataset = get_loader(
    root_folder="/kaggle/input/flickr8kimagescaptions/flickr8k/images",
    annotation_file="/kaggle/input/flickr8kimagescaptions/flickr8k/captions.txt",
    transform=transform
)

In [42]:
# Model hyperparameters
vocab_size = len(dataset.vocab)
embed_size = 256
hidden_size = 512
attention_dim = 256
learning_rate = 3e-4

In [43]:
# Initialize model, loss, and optimizer
model = CNNtoRNN(embed_size, hidden_size, vocab_size, attention_dim).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [44]:
from torch.cuda.amp import GradScaler, autocast

def train_model(model, loader, optimizer, criterion, num_epochs=10):
    """Training loop with mixed-precision training"""
    model.train()
    scaler = GradScaler()  # For mixed-precision training
    
    for epoch in range(num_epochs):
        loop = tqdm(loader, total=len(loader), leave=True)
        
        for images, captions in loop:
            images = images.to(device)
            captions = captions.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with autocast():
                # Predict next words (excluding EOS)
                outputs = model(images, captions[:-1])
                
                # Calculate loss (excluding SOS)
                loss = criterion(
                    outputs.reshape(-1, outputs.shape[2]), 
                    captions[1:].reshape(-1)
                )
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # Update progress bar
            loop.set_postfix(loss=loss.item())

In [45]:
# Train the model
num_epochs = 10
train_model(model, loader, optimizer, criterion, num_epochs)

100%|██████████| 2529/2529 [04:02<00:00, 10.45it/s, loss=3.44]
100%|██████████| 2529/2529 [04:00<00:00, 10.53it/s, loss=4.21]
100%|██████████| 2529/2529 [04:01<00:00, 10.49it/s, loss=2.77]
100%|██████████| 2529/2529 [04:02<00:00, 10.43it/s, loss=3.06]
100%|██████████| 2529/2529 [04:02<00:00, 10.42it/s, loss=2.4] 
100%|██████████| 2529/2529 [04:03<00:00, 10.41it/s, loss=2.92]
100%|██████████| 2529/2529 [04:01<00:00, 10.45it/s, loss=3.06]
100%|██████████| 2529/2529 [04:01<00:00, 10.47it/s, loss=2.59]
100%|██████████| 2529/2529 [04:02<00:00, 10.42it/s, loss=2.24]
100%|██████████| 2529/2529 [04:02<00:00, 10.43it/s, loss=2.62]
