In [1]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

# data prep
from dataset import (
    build_vocabulary,
    get_splits,
    get_loaders,
    get_image_transformations,
    convert_captions_to_sequences,
)
from cfg import flickr_image_path, flickr_text_path, device
from torchvision.models import resnet50, ResNet50_Weights
import torch
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import torch.optim as optim

from evaluation import (
    calculate_bleu_score,
    calculate_meteor_score,
    prepare_image2captions,
    evaluate
)

word2idx, idx2word, image_captions = build_vocabulary(flickr_text_path)
captions_seqs, max_seq_len = convert_captions_to_sequences(image_captions, word2idx)
train_images, val_images, test_images = get_splits(list(image_captions.keys()))

# Hyperparameters
num_epochs = 1  # Number of epochs for training
embed_size = 256  # Embedding size
hidden_size = 512  # Hidden size of the LSTM
vocab_size = len(word2idx)  # Vocabulary size
num_layers = 2  # Number of layers in LSTM


train_loader, val_loader, test_loader = get_loaders(
    img_dir=flickr_image_path,
    caption_path=flickr_text_path,
    transform=get_image_transformations(),
)

Using device: cpu


  Referenced from: <CAF361F5-1CAC-3EBE-9FC4-4B823D275CAA> /opt/miniconda3/envs/image-captioning-project/lib/python3.8/site-packages/torchvision/image.so
  warn(


In [2]:
# Defining the Encoder using Pre-trained ResNet-50
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        # Load the pre-trained ResNet-50 model
        weights = ResNet50_Weights.DEFAULT
        self.resnet = resnet50(weights=weights)
        # Get the number of features in the last layer
        self.in_features = self.resnet.fc.in_features
        # Replace the last fully connected layer with an identity layer (effectively removing it)
        self.resnet.fc = nn.Identity()
        # Add a fully connected layer to transform features to the desired embedding size
        self.embed = nn.Linear(self.in_features, embed_size)
        # Add batch normalization
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

        # Unfreeze the last convolutional block (layer4) for fine-tuning
        for name, param in self.resnet.named_parameters():
            if "layer4" in name:
                param.requires_grad = True  # Fine-tune layer4 parameters
            else:
                param.requires_grad = False  # Freeze other layers

    def forward(self, images):
        # Extract feature vectors from images using ResNet-50
        features = self.resnet(images)
        # Flatten the features
        features = features.view(features.size(0), -1)
        # Transform features to the embedding size
        features = self.embed(features)
        # Apply batch normalization
        features = self.bn(features)
        return features

# Defining the Decoder (RNN)
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2, dropout=0.5):
        super(DecoderRNN, self).__init__()
        # Embedding layer to convert word indices to embeddings
        self.embed = nn.Embedding(vocab_size, embed_size)
        # LSTM layer
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        # Linear layer to map hidden states to vocabulary scores
        self.linear = nn.Linear(hidden_size, vocab_size)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)
        # Initialize weights
        self.init_weights()

    def init_weights(self):
        # Initialize embedding and linear layers with uniform distribution
        nn.init.uniform_(self.embed.weight, -0.1, 0.1)
        nn.init.uniform_(self.linear.weight, -0.1, 0.1)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, features, captions):
        # Embed the captions (excluding the last token)
        embeddings = self.embed(captions[:, :-1])
        # Concatenate the image features and embeddings
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        # Apply dropout to embeddings
        embeddings = self.dropout(embeddings)
        # Pass through the LSTM
        hiddens, _ = self.lstm(embeddings)
        # Generate output scores (logits) for each time step
        outputs = self.linear(hiddens)
        return outputs

In [3]:
# Initialize the encoder and decoder models
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers, dropout=0.6).to(
    device
)

# Define the loss function (ignore padding index)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<pad>"])

# Define the optimizer (train both encoder and decoder)
# Use different learning rates for different parameter groups
optimizer = optim.Adam(
    [
        {"params": decoder.parameters()},  # Decoder parameters
        {"params": encoder.embed.parameters()},  # Encoder embedding layer
        {"params": encoder.bn.parameters()},  # Encoder batch normalization
        {
            "params": encoder.resnet.layer4.parameters(),
            "lr": 1e-5,
        },  # Fine-tuned encoder parameters with smaller LR
    ],
    lr=1e-4,
    weight_decay=1e-4,
)

# Learning rate scheduler to reduce LR after certain epochs
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)


# Function to generate captions for an image
def generate_caption_ids(decoder, features, word2idx, max_length=100):
    """
    Generate a caption for an image given the extracted features.
    """
    sampled_ids = []
    inputs = features.unsqueeze(1)  # Start with the image features as input
    states = None  # Initial hidden states

    for _ in range(max_length):
        # Pass through the LSTM
        hiddens, states = decoder.lstm(inputs, states)
        # Pass through the linear layer to get scores over the vocabulary
        outputs = decoder.linear(hiddens.squeeze(1))
        # Get the most probable word index
        predicted = outputs.argmax(1)
        sampled_ids.append(predicted.item())
        # If the <end> token is generated, stop
        if predicted.item() == word2idx["<end>"]:
            break
        # Prepare input for the next time step
        inputs = decoder.embed(predicted)
        inputs = inputs.unsqueeze(1)

    return sampled_ids


In [4]:
# Prepare validation and test image-to-captions mappings
val_image2captions = prepare_image2captions(val_images, captions_seqs, idx2word)
test_image2captions = prepare_image2captions(test_images, captions_seqs, idx2word)

In [5]:
# Lists to store training and validation metrics
train_losses = []
val_losses = []
val_bleu_scores = []
val_meteor_scores = []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for i, (images, captions, lengths) in enumerate(train_loader):
        images = images.to(device)
        captions = captions.to(device)
        lengths = torch.tensor(lengths)
        adjusted_lengths = lengths - 1  # Adjust lengths for excluding the last token

        # Prepare targets by excluding the <start> token
        targets = nn.utils.rnn.pack_padded_sequence(
            captions[:, 1:], adjusted_lengths, batch_first=True, enforce_sorted=False
        )[0]

        # Forward pass
        features = encoder(images)
        outputs = decoder(features, captions)
        # Pack the outputs
        outputs = nn.utils.rnn.pack_padded_sequence(
            outputs, adjusted_lengths, batch_first=True, enforce_sorted=False
        )[0]

        # Compute loss
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()

        # Clip gradients to prevent exploding gradients
        nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=5)
        nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5)

        optimizer.step()

        total_loss += loss.item()

        # Print training info every 100 steps
        if i % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}"
            )

    # Compute average training loss
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

    # Evaluate on validation set
    val_loss = evaluate(
        encoder,
        decoder,
        val_loader,
        criterion=nn.CrossEntropyLoss(ignore_index=word2idx["<pad>"]),
    )
    val_losses.append(val_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}")

    # Calculate BLEU and METEOR scores on validation set
    bleu_score = calculate_bleu_score(
        encoder,
        decoder,
        flickr_image_path,
        val_images,
        val_image2captions,
        get_image_transformations(),
        word2idx,
        idx2word,
    )
    val_bleu_scores.append(bleu_score)

    meteor = calculate_meteor_score(
        encoder,
        decoder,
        flickr_image_path,
        val_images,
        val_image2captions,
        get_image_transformations(),
        word2idx,
        idx2word,
    )
    val_meteor_scores.append(meteor)

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Validation BLEU Score: {bleu_score:.4f}, METEOR Score: {meteor:.4f}"
    )

    # Update learning rate scheduler
    scheduler.step()

Using device: cpu


  Referenced from: <CAF361F5-1CAC-3EBE-9FC4-4B823D275CAA> /opt/miniconda3/envs/image-captioning-project/lib/python3.8/site-packages/torchvision/image.so
  warn(


Using device: cpu


  Referenced from: <CAF361F5-1CAC-3EBE-9FC4-4B823D275CAA> /opt/miniconda3/envs/image-captioning-project/lib/python3.8/site-packages/torchvision/image.so
  warn(


Epoch [1/1], Step [0/885], Loss: 9.0393
Epoch [1/1], Step [100/885], Loss: 6.2295
