### Creating the Dataset class and data loaders 

In [None]:
import os
import pandas as pd
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPImageProcessor, GPT2Tokenizer
from transformers import TimesformerConfig, TimesformerModel

class VQARADDataset(Dataset):
    def __init__(self, pickle_dataset, gif_feat_dic):
        """
        Initializes the VQARADDataset object by loading a preprocessed dataset and corresponding GIF features.

        Args:
            pickle_dataset (str): Path to the pickled dataset file that contains question features, answers, and other metadata.
            gif_feat_dic (str): Path to the pickled dictionary containing pre-extracted GIF embeddings.

        Attributes:
            tgif_frame (pd.DataFrame): DataFrame containing the dataset information (questions, answers, and features).
            gif_feat_dict (dict): Dictionary mapping GIF names to their feature embeddings.
            gpt2_tokenizer (GPT2Tokenizer): Tokenizer for tokenizing the answers using the GPT-2 tokenizer.
        """

        with open(gif_feat_dic, 'rb') as f:
            gif_feat_dict = pickle.load(f)
        self.tgif_frame = pd.read_pickle(pickle_dataset)
        self.gif_feat_dict = gif_feat_dict
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilbert/distilgpt2")
        self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token

    def __len__(self):
        return len(self.tgif_frame)

    def __getitem__(self, idx):
        gif_name = self.tgif_frame.iloc[idx, 1]+'.gif'
        question_features = self.tgif_frame.iloc[idx, 4]
        answers = self.tgif_frame.iloc[idx, 3] + " <END>"
        answer_id = self.gpt2_tokenizer(answers, return_tensors="pt", truncation=True, padding="max_length", max_length=37).input_ids.squeeze(0)  # One extra for the end token.
        gif_features = self.gif_feat_dict[gif_name]
        
        sample = {'gif_embeddings': gif_features, 'question_embeddings': question_features, 'answer': answer_id}
        
        return sample


def get_loaders(pickle_dataset='/kaggle/input/cleaned-gif-embeddings/questions_with_gpt2_embeddings_using_gpu_cleaned.pkl', gif_feat_dic='/kaggle/input/cleaned-gif-embeddings/combined_file.pkl', batch_size=32, split_ratio=(0.9, 0.09, 0.01)):
    """
    Returns DataLoaders for training, validation, and test splits.

    Args:
        pickle_dataset (str): Path to the pickled dataset file.
        gif_feat_dic (str): Path to the pickled dictionary containing GIF embeddings.
        batch_size (int): The batch size for the DataLoader.
        split_ratio (tuple): A tuple of three values representing the ratios for training, validation, and test splits. Must sum to 1.

    Returns:
        tuple: A tuple containing:
            - train_loader (DataLoader): DataLoader for the training split.
            - val_loader (DataLoader): DataLoader for the validation split.
            - test_loader (DataLoader): DataLoader for the test split.
    """


    assert sum(split_ratio) == 1, "Split ratios should sum to 1."

    dataset = VQARADDataset(pickle_dataset, gif_feat_dic) 
    
    total_size = len(dataset)
    train_size = int(split_ratio[0] * total_size)
    val_size = int(split_ratio[1] * total_size)
    test_size = total_size - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

### Creating the Model class

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel, GPT2LMHeadModel
from typing import Tuple
from PIL import Image

class VQAModel(nn.Module):
    def __init__(self):
        """
        Initializes the VQAModel by loading a pretrained GPT-2 model and its corresponding tokenizer.
        
        Components:
        - `gpt2_model`: A pretrained DistilGPT2 model to generate textual answers.
        - `project_down`: A linear layer to project concatenated image and question features into the appropriate size for GPT-2.
        - `gpt2_tokenizer`: The tokenizer associated with the DistilGPT2 model, used for decoding the generated tokens.
        """
        super(VQAModel, self).__init__()
        self.gpt2_model = GPT2LMHeadModel.from_pretrained("distilbert/distilgpt2")  # Load GPT-2 model for answer generation
        self.project_down = nn.Linear(768*2, 768)  # Linear layer to combine image and question embeddings into 768 dimensions
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilbert/distilgpt2")  # Tokenizer for GPT-2 model

    def forward(self, image_features: torch.Tensor, question_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the VQA model.
        
        Args:
        - `image_features`: Tensor of features extracted from an image.
        - `question_features`: Tensor of features extracted from a text question.
        
        Returns:
        - `logits`: Output logits from the GPT-2 model representing the next token predictions.
        - `generated_sequence`: The sequence of tokens generated by the GPT-2 model.
        """
        # Concatenate image and question features along the last dimension
        combined_features = torch.cat((image_features, question_features), dim=-1)
        
        # Project down the concatenated features to match GPT-2's input size (768 dimensions)
        combined_features = self.project_down(combined_features)
        combined_features = combined_features.unsqueeze(1)  # Add a sequence dimension for GPT-2 input

        # Pass the features through the GPT-2 model to get logits
        outputs = self.gpt2_model(inputs_embeds=combined_features)
        logits = outputs.logits  # Logits for the next token predictions
        
        # Tokenize and encode the end-of-sequence token
        eos_token_id = self.gpt2_tokenizer.encode("<END>", add_prefix_space=True)[0]


### Defining the training loop

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# Initialize GPT-2 tokenizer for later decoding of model outputs
tokenizer = GPT2Tokenizer.from_pretrained("distilbert/distilgpt2")

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-4

# File paths for saving model checkpoints at different stages
MODEL_PATH = '/kaggle/working/best_model.pth'
MODEL_PATH_2 = '/kaggle/working/25_model.pth'
MODEL_PATH_3 = '/kaggle/working/50_model.pth'

# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Get data loaders for training and validation
train_loader, val_loader, _ = get_loaders(batch_size=BATCH_SIZE)

model = VQAModel().to(device)

# Use DataParallel to parallelize the model across multiple GPUs if available
model = torch.nn.DataParallel(model)

# Loss function (cross-entropy) and optimizer (Adam)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train(model, dataloader, criterion, optimizer, device):
    """
    Train the model for one epoch.

    Args:
    - model (nn.Module): The VQA model.
    - dataloader (DataLoader): DataLoader for training data.
    - criterion (nn.Module): Loss function (cross-entropy).
    - optimizer (optim.Optimizer): Optimizer for training (Adam).
    - device (torch.device): Device to perform training (CPU or GPU).

    Returns:
    - avg_loss (float): Average training loss over the epoch.
    """
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader, desc="Training"):
        # Unpack input data
        image_features, question_features, answers = batch['gif_embeddings'], batch['question_embeddings'], batch['answer']
        image_features, question_features, answers = image_features.to(device), question_features.to(device), answers.to(device)

        optimizer.zero_grad()

        # Forward pass
        logits, _ = model(image_features, question_features)

        # Adjust sequence lengths between predictions and ground truth answers
        if logits.size(1) < answers.size(1):
            answers = answers[:, :logits.size(1)]
        elif logits.size(1) > answers.size(1):
            logits = logits[:, :answers.size(1)]

        # Reshape for the loss function
        logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
        answers_reshaped = answers.contiguous().view(-1)

        # Compute loss and perform backpropagation
        loss = criterion(logits_reshaped, answers_reshaped)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)
    return avg_loss

def compute_bleu(reference, hypothesis):
    """
    Compute BLEU score between reference and hypothesis.

    Args:
    - reference (str): The ground truth reference sentence.
    - hypothesis (str): The predicted hypothesis sentence.

    Returns:
    - bleu (float): BLEU score between 0 and 1.
    """
    reference = reference.split()
    hypothesis = hypothesis.split()

    references = [reference] 

    smoothing = SmoothingFunction().method1
    bleu = sentence_bleu(references, hypothesis, smoothing_function=smoothing)

    return bleu

def validate(model, dataloader, criterion, device):
    """
    Validate the model and compute loss and BLEU score.

    Args:
    - model (nn.Module): The VQA model.
    - dataloader (DataLoader): DataLoader for validation data.
    - criterion (nn.Module): Loss function (cross-entropy).
    - device (torch.device): Device to perform validation (CPU or GPU).

    Returns:
    - avg_loss (float): Average validation loss.
    - average_bleu (float): Average BLEU score for the validation set.
    """
    model.eval()
    running_loss = 0.0
    total = 0
    total_bleu = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            image_features, question_features, answers = batch['gif_embeddings'], batch['question_embeddings'], batch['answer']
            image_features, question_features, answers = image_features.to(device), question_features.to(device), answers.to(device)

            # Forward pass
            logits, gen_seq = model(image_features, question_features)

            # Decode generated token sequences to strings
            generated_answers = [tokenizer.decode(g, skip_special_tokens=True) for g in gen_seq]

            # Decode reference answers
            reference_answers = [tokenizer.decode(a, skip_special_tokens=True) for a in answers]

            # Compute BLEU score for each reference/predicted pair
            for true, generated in zip(reference_answers, generated_answers):
                total_bleu += compute_bleu(true.lower(), generated.lower())

            # Adjust sequence lengths and compute loss
            if logits.size(1) < answers.size(1):
                answers = answers[:, :logits.size(1)]
            elif logits.size(1) > answers.size(1):
                logits = logits[:, :answers.size(1)]

            logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
            answers_reshaped = answers.contiguous().view(-1)

            loss = criterion(logits_reshaped, answers_reshaped)
            running_loss += loss.item()

            total += answers_reshaped.size(0)

    avg_loss = running_loss / len(dataloader)
    average_bleu = total_bleu / total if total > 0 else 0

    return avg_loss, average_bleu

# Training loop
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Train the model and compute loss
    train_loss = train(model, train_loader, criterion, optimizer, device)
    
    # Validate the model and compute loss and BLEU score
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)

    print(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Val accuracy: {val_accuracy:.2f}%")

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Model saved to {MODEL_PATH}")
    
    # Save the model at epoch 25
    if epoch == 24:
        torch.save(model.state_dict(), MODEL_PATH_2)
        print(f"Model saved to {MODEL_PATH_2}")
    
    # Save the model at epoch 50
    if epoch == 49:
        torch.save(model.state_dict(), MODEL_PATH_3)
        print(f"Model saved to {MODEL_PATH_3}")

print("Training complete.")
