In [25]:
import sys
sys.path.append("..")  

In [26]:
import json
import pickle
from pathlib import Path

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torchvision.models as models

from utils.dataloader import get_transforms, load_split_ids, build_caption_dataset
from utils.caption_dataset import CaptionDataset

In [27]:
# Load vocabulary
with open("../data/processed/word2idx.json", "r") as f:
    word2idx = json.load(f)

# Load image-caption sequences (already tokenized and cleaned)
with open("../data/processed/image_caption_seqs.pkl", "rb") as f:
    image_caption_seqs = pickle.load(f)

In [28]:
# Load official splits
train_ids = load_split_ids("../data/Flickr8k_text/Flickr_8k.trainImages.txt")
val_ids   = load_split_ids("../data/Flickr8k_text/Flickr_8k.devImages.txt")
test_ids  = load_split_ids("../data/Flickr8k_text/Flickr_8k.testImages.txt")

In [29]:
# Set image folder path
image_folder = "../data/Flicker8k_Dataset"

# Define transforms
transform_train = get_transforms("train")
transform_val   = get_transforms("val")

# Build datasets using shared util function
train_dataset = build_caption_dataset(train_ids, image_caption_seqs, word2idx, image_folder, transform_train)
val_dataset   = build_caption_dataset(val_ids, image_caption_seqs, word2idx, image_folder, transform_val)
test_dataset  = build_caption_dataset(test_ids, image_caption_seqs, word2idx, image_folder, transform_val)

In [30]:
class EncoderCNN(nn.Module):
    def __init__(self, encoded_image_size=14, fine_tune=True):
        super(EncoderCNN, self).__init__()
        self.enc_image_size = encoded_image_size

        resnet = models.resnet101(pretrained=True)
        # Remove linear and pooling layers (classifier)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # Adaptive pooling to ensure fixed size output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune(fine_tune)

    def forward(self, images):
        out = self.resnet(images)               # (batch_size, 2048, H, W)
        out = self.adaptive_pool(out)           # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)           # (batch_size, enc_size, enc_size, 2048)
        return out

    def fine_tune(self, fine_tune=True):
        for p in self.resnet.parameters():
            p.requires_grad = False

        # Unfreeze layer4 and onwards for fine-tuning
        if fine_tune:
            for c in list(self.resnet.children())[5:]:
                for p in c.parameters():
                    p.requires_grad = True

In [31]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, hidden_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(hidden_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        """
        encoder_out: (batch_size, num_pixels, encoder_dim)
        decoder_hidden: (batch_size, hidden_dim)
        """
        att1 = self.encoder_att(encoder_out)       # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)  # (batch_size, 1, attention_dim)
        att = self.full_att(self.relu(att1 + att2)).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)                  # (batch_size, num_pixels)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return context, alpha

In [32]:
class DecoderRNNWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, hidden_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderRNNWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.attention = Attention(encoder_dim, hidden_dim, attention_dim)

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(p=dropout)

        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, hidden_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, hidden_dim)
        self.init_c = nn.Linear(encoder_dim, hidden_dim)

        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, encoder_out, captions):
        """
        encoder_out: (batch_size, enc_size, enc_size, encoder_dim)
        captions: (batch_size, max_len)
        """
        batch_size = captions.size(0)
        vocab_size = self.vocab_size
        max_len = captions.size(1)

        # Flatten spatial features
        encoder_out = encoder_out.view(batch_size, -1, self.encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        embeddings = self.embedding(captions)  # (batch_size, max_len, embed_dim)

        h, c = self.init_hidden_state(encoder_out.mean(dim=1))  # init with mean-pooled image

        outputs = torch.zeros(batch_size, max_len, vocab_size).to(captions.device)

        for t in range(max_len):
            context, _ = self.attention(encoder_out, h)
            lstm_input = torch.cat([embeddings[:, t, :], context], dim=1)
            h, c = self.decode_step(lstm_input, (h, c))
            preds = self.fc(self.dropout_layer(h))
            outputs[:, t, :] = preds

        return outputs

    def init_hidden_state(self, mean_encoder_out):
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [33]:
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm

def train_model(model, train_dataset, val_dataset, word2idx, device='cuda', 
                batch_size=32, epochs=20, patience=3, lr=1e-4):
    from torch.utils.data import DataLoader
    import torch.nn as nn
    import numpy as np
    import torch

    pad_idx = word2idx['<pad>']
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=2, factor=0.5
)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_losses = []

        tqdm_train = tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]")
        for images, captions, _ in tqdm_train:
            images, captions = images.to(device), captions.to(device)

            optimizer.zero_grad()
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            tqdm_train.set_postfix(loss=np.mean(train_losses))

        avg_train_loss = np.mean(train_losses)

        # Validation
        model.eval()
        val_losses = []
        references = []
        hypotheses = []

        tqdm_val = tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]")
        with torch.no_grad():
            for images, captions, lengths in tqdm_val:
                images, captions = images.to(device), captions.to(device)
                outputs = model(images, captions[:, :-1])
                loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
                val_losses.append(loss.item())

                # BLEU prep
                preds = torch.argmax(outputs, dim=2)
                for ref, pred in zip(captions, preds):
                    ref_tokens = [w for w in ref.tolist() if w not in {pad_idx, word2idx['<start>'], word2idx['<end>']}]
                    pred_tokens = [w for w in pred.tolist() if w not in {pad_idx, word2idx['<start>'], word2idx['<end>']}]
                    references.append([ref_tokens])
                    hypotheses.append(pred_tokens)

        avg_val_loss = np.mean(val_losses)
        scheduler.step(avg_val_loss)

        bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

        print(f"\nEpoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
        print(f"BLEU-1 = {bleu1:.4f}, BLEU-2 = {bleu2:.4f}, BLEU-3 = {bleu3:.4f}, BLEU-4 = {bleu4:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    print("Training complete.")
    return model

In [34]:
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [35]:
class CaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(CaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        encoder_out = self.encoder(images)
        outputs = self.decoder(encoder_out, captions)
        return outputs

In [None]:
# Step 1: Set hyperparameters
embed_dim = 256
hidden_dim = 512
attention_dim = 256
dropout = 0.5
vocab_size = len(word2idx)

# Step 2: Instantiate encoder and decoder
encoder = EncoderCNN(encoded_image_size=14, fine_tune=True)
decoder = DecoderRNNWithAttention(
    attention_dim=attention_dim,
    embed_dim=embed_dim,
    hidden_dim=hidden_dim,
    vocab_size=vocab_size,
    encoder_dim=2048,
    dropout=dropout
)

# Step 3: Wrap into model
model = CaptioningModel(encoder, decoder).to(device)

# Step 4: Train the model
trained_model = train_model(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    word2idx=word2idx,
    device=device,
    batch_size=8,
    epochs=20,
    patience=3,
    lr=1e-4
)

Epoch 1 [Training]:  10%|█         | 377/3750 [03:33<30:48,  1.82it/s, loss=5.23]