# Image Captioning with Vision Transformer - Training Notebook

This notebook trains an image captioning model using Vision Transformer (ViT) encoder and Transformer decoder.

## Setup Instructions:
1. **Mount Google Drive** (run the cell below)
2. **Upload your data** to Google Drive:
   - Create folder: `/content/drive/MyDrive/image_captioning/`
   - Upload images to: `/content/drive/MyDrive/image_captioning/data/images/`
   - Upload captions CSV to: `/content/drive/MyDrive/image_captioning/data/captions.csv`
3. **Run the training cell** below

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install -q transformers torch torchvision pandas pillow tqdm

## Import dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
from collections import Counter
import numpy as np
from tqdm import tqdm
import json
from transformers import ViTModel, ViTConfig
import math

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Vocabulary class

In [None]:
class Vocabulary:
    """Vocabulary class for caption tokenization"""

    def __init__(self, freq_threshold=5):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            for word in sentence.lower().split():
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = text.lower().split()
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

## Dataset class

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5, split='train', split_ratio=0.8):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        
        unique_images = self.df['image'].unique()
        np.random.seed(42)
        np.random.shuffle(unique_images)
        
        train_size = int(len(unique_images) * split_ratio * 0.9)
        val_size = int(len(unique_images) * split_ratio * 0.1)
        
        if split == 'train':
            selected_images = unique_images[:train_size]
        elif split == 'val':
            selected_images = unique_images[train_size:train_size + val_size]
        else:
            selected_images = unique_images[train_size + val_size:]
        
        self.df = self.df[self.df['image'].isin(selected_images)]
        self.imgs = self.df['image'].values
        self.captions = self.df['caption'].values
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img_path = os.path.join(self.root_dir, img_id)
        img = Image.open(img_path).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
        
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)


In [None]:
class CaptionCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
        return imgs, targets

In [None]:
def get_transforms(image_size=224):
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

## Model architecture

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [None]:
class VisionTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, pretrained=True):
        super(VisionTransformerEncoder, self).__init__()
        
        if pretrained:
            self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        else:
            config = ViTConfig(
                hidden_size=embed_dim,
                num_hidden_layers=12,
                num_attention_heads=12,
                intermediate_size=3072,
                image_size=224,
                patch_size=16
            )
            self.vit = ViTModel(config)
        
        self.embed_dim = embed_dim

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        features = outputs.last_hidden_state
        return features

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, num_layers=6, num_heads=8,
                 forward_expansion=4, dropout=0.1, max_length=100):
        super(TransformerDecoder, self).__init__()
        
        self.embed_dim = embed_dim
        self.word_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim, max_length, dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * forward_expansion,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, captions, encoder_out, tgt_mask=None, tgt_padding_mask=None):
        embeddings = self.word_embedding(captions)
        embeddings = self.positional_encoding(embeddings)
        
        decoder_out = self.transformer_decoder(
            tgt=embeddings,
            memory=encoder_out,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )
        
        predictions = self.fc_out(decoder_out)
        return predictions

    def generate_square_subsequent_mask(self, sz, device):
        mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

In [None]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, num_decoder_layers=6,
                 num_heads=8, forward_expansion=4, dropout=0.1,
                 max_length=100, pretrained_vit=True):
        super(ImageCaptioningModel, self).__init__()
        
        self.encoder = VisionTransformerEncoder(embed_dim=embed_dim, pretrained=pretrained_vit)
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            num_layers=num_decoder_layers,
            num_heads=num_heads,
            forward_expansion=forward_expansion,
            dropout=dropout,
            max_length=max_length
        )

    def forward(self, images, captions, tgt_padding_mask=None):
        encoder_out = self.encoder(images)
        caption_len = captions.size(1)
        tgt_mask = self.decoder.generate_square_subsequent_mask(caption_len, captions.device)
        predictions = self.decoder(captions, encoder_out, tgt_mask, tgt_padding_mask)
        return predictions

## Trainer class

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, vocab, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.vocab = vocab
        self.config = config
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        self.criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
        
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config['learning_rate'],
            betas=(0.9, 0.98),
            eps=1e-9
        )
        
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=3
        )
        
        self.start_epoch = 0
        self.best_val_loss = float('inf')
        self.train_losses = []
        self.val_losses = []

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0.0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}/{self.config["num_epochs"]}')
        
        for batch_idx, (images, captions) in enumerate(pbar):
            images = images.to(self.device)
            captions = captions.to(self.device)
            
            caption_input = captions[:, :-1]
            caption_target = captions[:, 1:]
            
            padding_mask = (caption_input == self.vocab.stoi["<PAD>"])
            
            predictions = self.model(images, caption_input, padding_mask)
            
            predictions = predictions.reshape(-1, predictions.shape[2])
            caption_target = caption_target.reshape(-1)
            
            loss = self.criterion(predictions, caption_target)
            
            self.optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
        
        avg_loss = epoch_loss / len(self.train_loader)
        return avg_loss

    def validate(self):
        self.model.eval()
        epoch_loss = 0.0
        
        with torch.no_grad():
            for images, captions in tqdm(self.val_loader, desc='Validating'):
                images = images.to(self.device)
                captions = captions.to(self.device)
                
                caption_input = captions[:, :-1]
                caption_target = captions[:, 1:]
                
                padding_mask = (caption_input == self.vocab.stoi["<PAD>"])
                
                predictions = self.model(images, caption_input, padding_mask)
                
                predictions = predictions.reshape(-1, predictions.shape[2])
                caption_target = caption_target.reshape(-1)
                
                loss = self.criterion(predictions, caption_target)
                epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(self.val_loader)
        return avg_loss

    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss,
            'config': self.config
        }
        
        checkpoint_path = os.path.join(self.config['checkpoint_dir'], 'latest_checkpoint.pth')
        torch.save(checkpoint, checkpoint_path)
        
        if is_best:
            best_path = os.path.join(self.config['checkpoint_dir'], 'best_model.pth')
            torch.save(checkpoint, best_path)
            print(f'Best model saved with validation loss: {self.best_val_loss:.4f}')

    def train(self):
        print(f'Training on device: {self.device}')
        print(f'Number of training samples: {len(self.train_loader.dataset)}')
        print(f'Number of validation samples: {len(self.val_loader.dataset)}')
        print(f'Vocabulary size: {len(self.vocab)}')
        
        for epoch in range(self.start_epoch, self.config['num_epochs']):
            train_loss = self.train_epoch(epoch)
            self.train_losses.append(train_loss)
            
            val_loss = self.validate()
            self.val_losses.append(val_loss)
            
            self.scheduler.step(val_loss)
            
            print(f'\nEpoch {epoch + 1}/{self.config["num_epochs"]}:')
            print(f'Train Loss: {train_loss:.4f}')
            print(f'Val Loss: {val_loss:.4f}')
            print(f'Learning Rate: {self.optimizer.param_groups[0]["lr"]:.6f}')
            
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
            
            self.save_checkpoint(epoch, is_best)
            
            history = {
                'train_losses': self.train_losses,
                'val_losses': self.val_losses
            }
            history_path = os.path.join(self.config['checkpoint_dir'], 'training_history.json')
            with open(history_path, 'w') as f:
                json.dump(history, f, indent=4)
        
        print('\nTraining completed!')
        print(f'Best validation loss: {self.best_val_loss:.4f}')

## Training

In [None]:
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80 + "\n")

# CONFIGURATION - Modify these paths for your setup
config = {
    'data_dir': '/content/drive/MyDrive/image_captioning/data/images',
    'captions_file': '/content/drive/MyDrive/image_captioning/data/captions.csv',
    'checkpoint_dir': '/content/drive/MyDrive/image_captioning/checkpoints',
    'batch_size': 32,
    'num_workers': 2,
    'image_size': 224,
    'embed_dim': 768,
    'num_decoder_layers': 6,
    'num_heads': 8,
    'forward_expansion': 4,
    'dropout': 0.1,
    'max_length': 100,
    'learning_rate': 3e-4,
    'num_epochs': 30,
    'pretrained_vit': True,
    'resume': False
}

# Create checkpoint directory
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Load dataset
print('Loading dataset...')
train_transform, val_transform = get_transforms(config['image_size'])

train_dataset = FlickrDataset(
    root_dir=config['data_dir'],
    captions_file=config['captions_file'],
    transform=train_transform,
    split='train'
)

val_dataset = FlickrDataset(
    root_dir=config['data_dir'],
    captions_file=config['captions_file'],
    transform=val_transform,
    split='val'
)

val_dataset.vocab = train_dataset.vocab
vocab = train_dataset.vocab

# Create data loaders
pad_idx = vocab.stoi["<PAD>"]

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    shuffle=True,
    collate_fn=CaptionCollate(pad_idx=pad_idx),
    pin_memory=True
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    shuffle=False,
    collate_fn=CaptionCollate(pad_idx=pad_idx),
    pin_memory=True
)

# Save vocabulary
vocab_path = os.path.join(config['checkpoint_dir'], 'vocab.pth')
torch.save(vocab, vocab_path)
print(f'Vocabulary saved to {vocab_path}')

# Create model
print('Creating model...')
model = ImageCaptioningModel(
    vocab_size=len(vocab),
    embed_dim=config['embed_dim'],
    num_decoder_layers=config['num_decoder_layers'],
    num_heads=config['num_heads'],
    forward_expansion=config['forward_expansion'],
    dropout=config['dropout'],
    max_length=config['max_length'],
    pretrained_vit=config['pretrained_vit']
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

# Create trainer and start training
trainer = Trainer(model, train_loader, val_loader, vocab, config)
trainer.train()

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)

## After Training

Your trained model will be saved in:
- Best model: `/content/drive/MyDrive/image_captioning/checkpoints/best_model.pth`
- Latest checkpoint: `/content/drive/MyDrive/image_captioning/checkpoints/latest_checkpoint.pth`
- Vocabulary: `/content/drive/MyDrive/image_captioning/checkpoints/vocab.pth`
- Training history: `/content/drive/MyDrive/image_captioning/checkpoints/training_history.json`