In [None]:
# ! pip install -q flash-attn --no-build-isolation

In [4]:
# Hyperparameters
hyperparams = {
    "dataset": "Flickr8k",  # Options: 'Flickr8k', 'Flickr30k'
    "image_dir": "./flickr_data/Flickr8k_Dataset/Images",  # Path to images
    "captions_file": "./flickr_data/Flickr8k_Dataset/captions.txt",  # Path to captions
    "vocab_size": 5000,  # Maximum vocabulary size
    "embed_size": 256,  # Embedding size (optional if not using separate embeddings)
    "hidden_size": 512,  # Hidden size for decoder (not directly used with Hugging Face models)
    "batch_size": 32,  # Batch size
    "num_epochs": 10,  # Number of training epochs
    "learning_rate": 5e-5,  # Learning rate
    "weight_decay": 1e-4,  # Weight decay for optimizer
    "max_length": 50,  # Maximum caption length for generation
    "num_beams": 5,  # Number of beams for beam search
    "save_dir": "models/",  # Directory to save models and plots
    "seed": 42,  # Random seed for reproducibility
}

In [12]:
import os
import random
import time
import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms
from PIL import Image

from transformers import (
    VisionEncoderDecoderModel,
    GPT2Tokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
)

import nltk
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
from nltk.translate.meteor_score import single_meteor_score

from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
import sys

notebook_dir = Path(os.getcwd()).resolve()  # Get the current working directory
project_root = notebook_dir.parents[1]  # Adjust the number to go up to the project root
sys.path.append(str(project_root))

print(f"Project root: {project_root}")

# Download necessary NLTK data
nltk.download("punkt")


# Set random seeds for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    # Some additional settings for full reproducibility (optional)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(hyperparams["seed"])

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Project root: /Users/ivankoh/Personal/image-captioning-project
Using device: cpu


[nltk_data] Downloading package punkt to /Users/ivankoh/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [14]:
from data.dataset import *
from data.preprocessing import *


dataset = "Flickr8k"  # Change to "Flickr30k" if needed

# Paths
dataset_dir = f"../../flickr_data/{dataset}_Dataset/Images"
captions_file = f"../../flickr_data/{dataset}_Dataset/captions.txt"
image_dir = dataset_dir

train_losses = []
val_losses = []
bleu_scores = []
meteor_scores = []
cider_scores = []

# Load captions
caption_df = pd.read_csv(captions_file).dropna().drop_duplicates()
print(f"Total captions loaded: {len(caption_df)}")

# Build vocabulary
word2idx, idx2word, image_captions = build_vocabulary(caption_df, vocab_size=10000)
print(f"Vocabulary size: {len(word2idx)}")

# Convert captions to sequences
captions_seqs, max_length = convert_captions_to_sequences(image_captions, word2idx)
print(f"Maximum caption length: {max_length}")

# Get data transformations
train_transform = get_transform(train=True)
val_transform = get_transform(train=False)

# Split data into training and validation sets
image_names = list(image_captions.keys())
train_images, val_images, _ = get_splits(image_names, test_size=0.2)
print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")

# Create datasets and data loaders
train_dataset = FlickrDataset(
    image_dir, train_images, captions_seqs, transform=train_transform
)
val_dataset = FlickrDataset(
    image_dir, val_images, captions_seqs, transform=val_transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2,
)
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Total captions loaded: 40445
Vocabulary size: 8921
Maximum caption length: 40
Training samples: 6472
Validation samples: 1457
Number of training batches: 1011
Number of validation batches: 228
Using device: cpu


In [15]:
# Initialize GPT2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Add special tokens if not present
special_tokens_dict = {
    "bos_token": "<start>",
    "eos_token": "<end>",
    "pad_token": "<pad>",
}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
if num_added_tokens > 0:
    print(f"Added {num_added_tokens} special tokens to the tokenizer.")

# Initialize VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224", "gpt2"
)

# Resize token embeddings to accommodate the new special tokens
model.decoder.resize_token_embeddings(len(tokenizer))

# Set special tokens for the model
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# Set generation parameters
model.config.max_length = hyperparams["max_length"]
model.config.num_beams = hyperparams["num_beams"]

# Move model to device
model.to(device)

# Print model summary
print(model)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Added 3 special tokens to the tokenizer.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(i

In [16]:
def evaluate_model(model, data_loader, criterion, device, tokenizer):
    """
    Evaluate the VisionEncoderDecoderModel on the validation set.
    Args:
        model: VisionEncoderDecoderModel.
        data_loader: DataLoader for the validation set.
        criterion: Loss function.
        device: Computation device (CPU or GPU).
        tokenizer: Tokenizer used for encoding captions.
    Returns:
        average_loss: Average validation loss.
    """
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient computation for evaluation
        for batch in tqdm(data_loader, desc="Evaluating"):
            images, captions, lengths = batch
            images = images.to(device)
            captions = captions.to(device)

            # Forward pass
            outputs = model(pixel_values=images, labels=captions)
            loss = outputs.loss
            total_loss += loss.item()
            total_samples += 1

    # Calculate average loss
    average_loss = total_loss / total_samples
    return average_loss


def calculate_bleu_score(
    model,
    image_dir,
    image_ids,
    image2captions,
    transform,
    tokenizer,
    device,
):
    """
    Calculate BLEU score for the generated captions.
    Args:
        model: VisionEncoderDecoderModel.
        image_dir: Directory containing images.
        image_ids: List of image IDs.
        image2captions: Dictionary mapping image IDs to reference captions.
        transform: Preprocessing transformation for images.
        tokenizer: Tokenizer used for encoding/decoding captions.
        device: Computation device (CPU or GPU).
    Returns:
        bleu_score: Corpus BLEU score for generated captions.
    """
    model.eval()
    references = []
    hypotheses = []
    smoothie = SmoothingFunction().method4

    with torch.no_grad():
        for img_id in tqdm(image_ids, desc="Calculating BLEU"):
            img_path = os.path.join(image_dir, img_id)
            image = Image.open(img_path).convert("RGB")
            image = transform(image).unsqueeze(0).to(device)

            # Generate caption with greedy search (num_beams=1)
            generated_ids = model.generate(
                pixel_values=image,
                max_length=hyperparams["max_length"],
                num_beams=1,  # Greedy search
                do_sample=False,
            )
            generated_caption = tokenizer.decode(
                generated_ids[0], skip_special_tokens=True
            )

            # Prepare hypothesis
            hypothesis = word_tokenize(generated_caption.lower())
            hypotheses.append(hypothesis)

            # Prepare references
            ref_captions = image2captions.get(img_id, [])
            refs = [word_tokenize(" ".join(ref).lower()) for ref in ref_captions]
            references.append(refs)

    # Compute corpus BLEU score
    bleu_score = corpus_bleu(references, hypotheses, smoothing_function=smoothie)
    return bleu_score


def calculate_meteor_score(
    model,
    image_dir,
    image_ids,
    image2captions,
    transform,
    tokenizer,
    device,
):
    """
    Calculate METEOR score for the generated captions.
    Args:
        model: VisionEncoderDecoderModel.
        image_dir: Directory containing images.
        image_ids: List of image IDs.
        image2captions: Dictionary mapping image IDs to reference captions.
        transform: Preprocessing transformation for images.
        tokenizer: Tokenizer used for encoding/decoding captions.
        device: Computation device (CPU or GPU).
    Returns:
        average_meteor: Average METEOR score.
    """
    model.eval()
    meteor_scores = []

    with torch.no_grad():
        for img_id in tqdm(image_ids, desc="Calculating METEOR"):
            img_path = os.path.join(image_dir, img_id)
            image = Image.open(img_path).convert("RGB")
            image = transform(image).unsqueeze(0).to(device)

            # Generate caption with greedy search (num_beams=1)
            generated_ids = model.generate(
                pixel_values=image,
                max_length=hyperparams["max_length"],
                num_beams=1,  # Greedy search
                do_sample=False,
            )
            generated_caption = tokenizer.decode(
                generated_ids[0], skip_special_tokens=True
            )

            # Prepare hypothesis
            hypothesis = " ".join(word_tokenize(generated_caption.lower()))

            # Prepare references
            ref_captions = [
                " ".join(ref).lower() for ref in image2captions.get(img_id, [])
            ]

            # Calculate METEOR score for the current image
            score = single_meteor_score(ref_captions, hypothesis)
            meteor_scores.append(score)

    # Compute average METEOR score
    average_meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0
    return average_meteor


def calculate_cider_score(
    model,
    image_dir,
    image_ids,
    image2captions,
    transform,
    tokenizer,
    device,
):
    """
    Calculate CIDEr score for the generated captions.
    Args:
        model: VisionEncoderDecoderModel.
        image_dir: Directory containing images.
        image_ids: List of image IDs.
        image2captions: Dictionary mapping image IDs to reference captions.
        transform: Preprocessing transformation for images.
        tokenizer: Tokenizer used for encoding/decoding captions.
        device: Computation device (CPU or GPU).
    Returns:
        cider_score: CIDEr score for generated captions.
    """
    model.eval()
    gts = {}  # Ground truth captions
    res = {}  # Generated captions
    tokenizer_cider = PTBTokenizer()  # Tokenizer for captions

    with torch.no_grad():
        for img_id in tqdm(image_ids, desc="Calculating CIDEr"):
            img_path = os.path.join(image_dir, img_id)
            image = Image.open(img_path).convert("RGB")
            image = transform(image).unsqueeze(0).to(device)

            # Generate caption with greedy search (num_beams=1)
            generated_ids = model.generate(
                pixel_values=image,
                max_length=hyperparams["max_length"],
                num_beams=1,  # Greedy search
                do_sample=False,
            )
            generated_caption = tokenizer.decode(
                generated_ids[0], skip_special_tokens=True
            )

            # Prepare generated caption
            sampled_caption = " ".join(word_tokenize(generated_caption.lower()))

            # Prepare references
            references = [
                " ".join(ref).lower() for ref in image2captions.get(img_id, [])
            ]

            # Update dictionaries with tokenized captions
            gts[img_id] = [{"caption": ref} for ref in references]
            res[img_id] = [{"caption": sampled_caption}]

    # Tokenize captions
    gts = tokenizer_cider.tokenize(gts)
    res = tokenizer_cider.tokenize(res)

    # Compute CIDEr score
    cider_scorer = Cider()
    cider_score, _ = cider_scorer.compute_score(gts, res)
    return cider_score

In [None]:
# Initialize metrics storage
train_losses = []
val_losses = []
bleu_scores = []
meteor_scores = []
cider_scores = []

# Define optimizer
optimizer = AdamW(
    model.parameters(),
    lr=hyperparams["learning_rate"],
    weight_decay=hyperparams["weight_decay"],
)

# Total training steps
num_epochs = hyperparams["num_epochs"]
total_steps = num_epochs * len(train_loader)


# Initialize the learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",  # We want to minimize the validation loss
    factor=0.5,  # Factor by which the learning rate will be reduced
    patience=2,  # Number of epochs with no improvement after which learning rate will be reduced
    verbose=True,  # Print a message when the learning rate is updated
)

# Prepare image to captions mapping for evaluation
val_image2captions = prepare_image2captions(val_images, captions_seqs, idx2word)
# Loss function
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Training loop
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    total_train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in pbar:
        images, captions, lengths = batch
        images = images.to(device)
        captions = captions.to(device)

        # Forward pass
        outputs = model(pixel_values=images, labels=captions)
        loss = outputs.loss
        total_train_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Update parameters
        optimizer.step()

        pbar.set_postfix({"loss": loss.item()})

    # Calculate average training loss for the epoch
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    val_loss = evaluate_model(model, val_loader, criterion, device, tokenizer)
    scheduler.step(val_loss)
    val_losses.append(val_loss)

    # Calculate evaluation metrics
    bleu = calculate_bleu_score(
        model,
        hyperparams["image_dir"],
        val_images,
        val_image2captions,
        get_transform(train=False),
        tokenizer,
        device,
    )
    bleu_scores.append(bleu)

    meteor = calculate_meteor_score(
        model,
        hyperparams["image_dir"],
        val_images,
        val_image2captions,
        get_transform(train=False),
        tokenizer,
        device,
    )
    meteor_scores.append(meteor)

    cider = calculate_cider_score(
        model,
        hyperparams["image_dir"],
        val_images,
        val_image2captions,
        get_transform(train=False),
        tokenizer,
        device,
    )
    cider_scores.append(cider)

    # Print epoch summary
    epoch_duration = time.time() - start_time
    print(
        f"\nEpoch [{epoch+1}/{num_epochs}] completed in {epoch_duration:.2f}s"
        f"\nTraining Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}"
        f"\nBLEU Score: {bleu:.4f}, METEOR Score: {meteor:.4f}, CIDEr Score: {cider:.4f}\n"
    )

In [None]:
# Create save directory if it doesn't exist
os.makedirs(hyperparams["save_dir"], exist_ok=True)

# Plot Training and Validation Loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss")
plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
loss_plot_path = os.path.join(hyperparams["save_dir"], "loss_plot.png")
plt.savefig(loss_plot_path)
plt.show()
print(f"Loss plot saved to {loss_plot_path}")

# Plot Evaluation Metrics
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), bleu_scores, label="BLEU Score")
plt.plot(range(1, num_epochs + 1), meteor_scores, label="METEOR Score")
plt.plot(range(1, num_epochs + 1), cider_scores, label="CIDEr Score")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.title("Evaluation Metrics over Epochs")
plt.legend()
plt.grid(True)
metrics_plot_path = os.path.join(hyperparams["save_dir"], "metrics_plot.png")
plt.savefig(metrics_plot_path)
plt.show()
print(f"Metrics plot saved to {metrics_plot_path}")

# Cell 9: Saving the Model

# Save the trained model's state dictionary
model_save_path = os.path.join(
    hyperparams["save_dir"], f"vision_encoder_decoder_{hyperparams['dataset']}.pth"
)
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

# Save the tokenizer
tokenizer_save_path = os.path.join(hyperparams["save_dir"], "tokenizer")
tokenizer.save_pretrained(tokenizer_save_path)
print(f"Tokenizer saved to {tokenizer_save_path}")

In [None]:
dataset = "Flickr8k"

captions_file_path = f"{project_root}/flickr_data/{dataset}_Dataset/captions.txt"
image_dir = f"{project_root}/flickr_data/{dataset}_Dataset/Images"

# Load captions
caption_df = pd.read_csv(captions_file_path).dropna().drop_duplicates()

# Build vocabulary
word2idx, idx2word, image_captions = build_vocabulary(caption_df, vocab_size=5000)

# Convert captions to sequences
captions_seqs, max_length = convert_captions_to_sequences(image_captions, word2idx)

# Get data transformations
test_transform = get_transform(train=False)

# Split data into training, validation, and test sets
image_names = list(image_captions.keys())
_, _, test_images = get_splits(image_names, test_size=0.2)

# Prepare image to captions mapping for ground truth captions
test_image2captions = prepare_image2captions(test_images, captions_seqs, idx2word)

# Create test dataset and data loader
test_dataset = FlickrDataset(
    image_dir, test_images, captions_seqs, transform=test_transform, mode='test'
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,  # Process one image at a time
    shuffle=False,
    collate_fn=collate_fn, 
    num_workers=2,
)

# Function to generate and display captions for a given number of test images
def generate_captions(model, test_loader, image2captions, transform, tokenizer, device, num_images=6):
    """
    Generate captions for test images using greedy search and display them alongside ground truth captions.
    Args:
        model: VisionEncoderDecoderModel.
        test_loader: DataLoader for the test set.
        image2captions: Dictionary mapping image IDs to reference captions.
        transform: Preprocessing transformation for images.
        tokenizer: Tokenizer used for encoding/decoding captions.
        device: Computation device (CPU or GPU).
        num_images (int): Number of test images to process.
    """
    model.eval()

    with torch.no_grad():
        for i, (images, captions, image_ids) in enumerate(test_loader):
            if i >= num_images:
                break  # Stop after processing 'num_images' images

            images = images.to(device)

            # Generate caption with greedy search (num_beams=1)
            generated_ids = model.generate(
                pixel_values=images,
                max_length=hyperparams['max_length'],
                num_beams=1,  # Greedy search
                do_sample=False
            )
            generated_caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            # Prepare hypothesis
            hypothesis = word_tokenize(generated_caption.lower())

            # Retrieve ground truth captions
            image_name = image_ids[0]
            gt_captions = image2captions.get(image_name, [])

            # Remove unknown tokens and punctuation (optional)
            # Here, assuming <unk>, <start>, <end> are already handled in preprocessing

            # Print generated and ground truth captions
            print(f"Image ID: {image_name}")
            print(f"Generated Caption: {generated_caption}")
            print("Ground Truth Captions:")
            for gt_caption in gt_captions:
                gt_caption_str = ' '.join(gt_caption)
                print(f"- {gt_caption_str}")
            print('-' * 50)

# Choose the number of images to generate captions for
num_test_images = 6

# Generate captions
generate_captions(
    model,
    test_loader,
    test_image2captions,
    get_transform(train=False),
    tokenizer,
    device,
    num_images=num_test_images
)