In [None]:
pip install --quiet av sentence-transformers

# CUSTOM DATASET AND DATALOADER FOR OUR DATA

In [None]:
import os
import av  # PyAV for handling video files
import numpy as np
import pandas as pd
import pickle
import torch
import requests  # For downloading GIFs
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoImageProcessor, TimesformerModel
from transformers import GPT2Tokenizer

class VQAGIFDataset(Dataset):
    """
    A custom Dataset class for handling Video QA (VQA) GIF data.

    This dataset handles downloading GIFs, processing video frames, and preparing
    question-answer pairs for training or evaluation in a VQA system.

    Args:
        pickle_dataset (str): Path to the pickled DataFrame containing the dataset.
    """
    def __init__(self, pickle_dataset):  # transform=None
        # Load the dataset from a pickle file
        self.tgif_frame = pd.read_pickle(pickle_dataset)
        
        # Initialize the image processor from the pre-trained VideoMAE model
        self.preprocess = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        
        # Initialize the GPT-2 tokenizer
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")
        
        # Define a new padding token
        new_pad_token = "[PAD]"
        
        # Add the new padding token to the tokenizer's vocabulary if it's not already present
        if new_pad_token not in self.gpt2_tokenizer.get_vocab():
            self.gpt2_tokenizer.add_special_tokens({'pad_token': new_pad_token})

    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
        return len(self.tgif_frame)
    
    def download_gif(self, gif_url, output_path):
        """
        Downloads a GIF from a given URL and saves it to the specified path.

        Args:
            gif_url (str): URL of the GIF to download.
            output_path (str): Local path where the GIF will be saved.

        Returns:
            bool: True if download was successful, False otherwise.
        """
        response = requests.get(gif_url, stream=True)
        if response.status_code == 200:
            with open(output_path, 'wb') as f:
                f.write(response.content)
            return True
        return False

    def read_video_pyav(self, container, indices):
        """
        Decodes video frames from a PyAV container based on specified indices.

        Args:
            container (av.container.input.InputContainer): PyAV container for the video.
            indices (list of int): Frame indices to extract.

        Returns:
            numpy.ndarray: Array of extracted frames in RGB format.
        """
        frames = []
        container.seek(0)  # Seek to the beginning of the video
        start_index = indices[0]
        end_index = indices[-1]
        
        # Iterate through decoded frames and collect those at the specified indices
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break  # Stop if we've passed the last desired frame
            if i >= start_index and i in indices:
                frames.append(frame)
        
        # Convert frames to numpy arrays in RGB format and stack them
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])

    def sample_frame_indices(self, clip_len, frame_sample_rate, seg_len):
        """
        Samples a set of frame indices from the video for processing.

        Args:
            clip_len (int): Number of frames to sample.
            frame_sample_rate (float): Sampling rate (e.g., 0.5 means every other frame).
            seg_len (int): Total number of frames in the video.

        Returns:
            numpy.ndarray: Array of sampled frame indices.
        """
        converted_len = int(clip_len * frame_sample_rate)
        
        # Randomly choose an end index ensuring the clip fits within the segment length
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
        
        # Generate evenly spaced indices between start and end
        indices = np.linspace(start_idx, end_idx, num=clip_len)
        
        # Ensure indices are within valid range and convert to integers
        indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        
        return indices
    
    def process_gifs(self, gif_url):
        """
        Downloads and processes a GIF to extract video frames.

        Args:
            gif_url (str): URL of the GIF to process.

        Returns:
            list of numpy.ndarray: List of processed video frames.
        """
        embeddings_dict = {}
        gifs_checked = 0
        true_gifs = []
        download_folder = "/kaggle/working/gifs"

        # Create the download folder if it doesn't exist
        if not os.path.exists(download_folder):
            os.makedirs(download_folder)

        # Extract the GIF filename from the URL
        gif_name = gif_url.split('/')[-1]
        gif_path = os.path.join(download_folder, gif_name)

        # Download the GIF
        if self.download_gif(gif_url, gif_path):
            gifs_checked += 1

            # Open the downloaded GIF using PyAV
            container = av.open(gif_path)
            total_frames = container.streams.video[0].frames

            # Sample 8 frames from the GIF
            indices = self.sample_frame_indices(clip_len=8, frame_sample_rate=0.5, seg_len=total_frames)
            video = self.read_video_pyav(container, indices)

            # Close the container to release resources
            container.close()

            # Delete the GIF file after processing to save space
            if os.path.exists(gif_path):
                os.remove(gif_path)  # Deletes the downloaded GIF file

            return list(video)
        else:
            # Handle download failure (optional: log or raise an exception)
            return []

    def __getitem__(self, idx):
        """
        Retrieves a single sample from the dataset.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            dict: A dictionary containing processed GIF frames, question embeddings, and answer tokens.
        """
        # Get the GIF URL from the dataset
        gif_url = self.tgif_frame.iloc[idx, 0]
        
        # Process the GIF to extract video frames
        video_list = self.process_gifs(gif_url)
        
        # Preprocess the video frames using the image processor
        inputs = self.preprocess(video_list, return_tensors="pt")
        pixel_values = inputs.get('pixel_values')
        pixel_values = pixel_values.squeeze(0)  # Remove batch dimension
        
        # Extract question features and answer from the dataset
        question_features = self.tgif_frame.iloc[idx, 4]
        answers = self.tgif_frame.iloc[idx, 3] + " <END>"  # Append end token to the answer
        
        # Tokenize the answer using GPT-2 tokenizer
        answer_id = self.gpt2_tokenizer(
            answers,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=37
        ).input_ids.squeeze(0)  # Remove batch dimension
        
        # Create a sample dictionary
        sample = {
            'processed_gifs': pixel_values,          # Tensor of processed GIF frames
            'question_embeddings': question_features, # Precomputed question embeddings
            'answer': answer_id                      # Tokenized answer
        }
        
        return sample


def get_loaders(
    pickle_dataset='/kaggle/input/embedding-q-all/updated_final_df_with_q_embeddings.pkl',
    batch_size=32,
    split_ratio=(0.9, 0.09, 0.01)
):
    """
    Creates and returns DataLoader objects for training, validation, and testing.

    Args:
        pickle_dataset (str, optional): Path to the pickled dataset. Defaults to
            '/kaggle/input/embedding-q-all/updated_final_df_with_q_embeddings.pkl'.
        batch_size (int, optional): Number of samples per batch. Defaults to 32.
        split_ratio (tuple, optional): Ratios for train, validation, and test splits.
            Should sum to 1. Defaults to (0.9, 0.09, 0.01).

    Returns:
        tuple: DataLoader objects for training, validation, and testing.
    """
    # Ensure that the split ratios sum to 1
    assert sum(split_ratio) == 1, "Split ratios should sum to 1."

    # Initialize the custom dataset
    dataset = VQAGIFDataset(pickle_dataset)  # , transform=transform
    
    # Calculate sizes for each split
    total_size = len(dataset)
    train_size = int(split_ratio[0] * total_size)
    val_size = int(split_ratio[1] * total_size)
    test_size = total_size - train_size - val_size
    
    # Split the dataset into training, validation, and testing
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    # Save the test dataset for later use (optional)
    torch.save(test_dataset, '/kaggle/working/test_dataset.pth')
    
    # Create DataLoader objects for each split
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader


# PREPARING THE MODEL

In [None]:
import os
import av  # PyAV for handling video files
import numpy as np
import pandas as pd
import pickle
import torch
import requests  # For downloading GIFs
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoImageProcessor, TimesformerModel
from transformers import GPT2Tokenizer

class VQAGIFDataset(Dataset):
    """
    A custom Dataset class for handling Video QA (VQA) GIF data.

    This dataset handles downloading GIFs, processing video frames, and preparing
    question-answer pairs for training or evaluation in a VQA system.

    Args:
        pickle_dataset (str): Path to the pickled DataFrame containing the dataset.
    """
    def __init__(self, pickle_dataset):  # transform=None
        # Load the dataset from a pickle file
        self.tgif_frame = pd.read_pickle(pickle_dataset)
        
        # Initialize the image processor from the pre-trained VideoMAE model
        self.preprocess = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        
        # Initialize the GPT-2 tokenizer
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")
        
        # Define a new padding token
        new_pad_token = "[PAD]"
        
        # Add the new padding token to the tokenizer's vocabulary if it's not already present
        if new_pad_token not in self.gpt2_tokenizer.get_vocab():
            self.gpt2_tokenizer.add_special_tokens({'pad_token': new_pad_token})

    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
        return len(self.tgif_frame)
    
    def download_gif(self, gif_url, output_path):
        """
        Downloads a GIF from a given URL and saves it to the specified path.

        Args:
            gif_url (str): URL of the GIF to download.
            output_path (str): Local path where the GIF will be saved.

        Returns:
            bool: True if download was successful, False otherwise.
        """
        response = requests.get(gif_url, stream=True)
        if response.status_code == 200:
            with open(output_path, 'wb') as f:
                f.write(response.content)
            return True
        return False

    def read_video_pyav(self, container, indices):
        """
        Decodes video frames from a PyAV container based on specified indices.

        Args:
            container (av.container.input.InputContainer): PyAV container for the video.
            indices (list of int): Frame indices to extract.

        Returns:
            numpy.ndarray: Array of extracted frames in RGB format.
        """
        frames = []
        container.seek(0)  # Seek to the beginning of the video
        start_index = indices[0]
        end_index = indices[-1]
        
        # Iterate through decoded frames and collect those at the specified indices
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break  # Stop if we've passed the last desired frame
            if i >= start_index and i in indices:
                frames.append(frame)
        
        # Convert frames to numpy arrays in RGB format and stack them
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])

    def sample_frame_indices(self, clip_len, frame_sample_rate, seg_len):
        """
        Samples a set of frame indices from the video for processing.

        Args:
            clip_len (int): Number of frames to sample.
            frame_sample_rate (float): Sampling rate (e.g., 0.5 means every other frame).
            seg_len (int): Total number of frames in the video.

        Returns:
            numpy.ndarray: Array of sampled frame indices.
        """
        converted_len = int(clip_len * frame_sample_rate)
        
        # Randomly choose an end index ensuring the clip fits within the segment length
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
        
        # Generate evenly spaced indices between start and end
        indices = np.linspace(start_idx, end_idx, num=clip_len)
        
        # Ensure indices are within valid range and convert to integers
        indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        
        return indices
    
    def process_gifs(self, gif_url):
        """
        Downloads and processes a GIF to extract video frames.

        Args:
            gif_url (str): URL of the GIF to process.

        Returns:
            list of numpy.ndarray: List of processed video frames.
        """
        embeddings_dict = {}
        gifs_checked = 0
        true_gifs = []
        download_folder = "/kaggle/working/gifs"

        # Create the download folder if it doesn't exist
        if not os.path.exists(download_folder):
            os.makedirs(download_folder)

        # Extract the GIF filename from the URL
        gif_name = gif_url.split('/')[-1]
        gif_path = os.path.join(download_folder, gif_name)

        # Download the GIF
        if self.download_gif(gif_url, gif_path):
            gifs_checked += 1

            # Open the downloaded GIF using PyAV
            container = av.open(gif_path)
            total_frames = container.streams.video[0].frames

            # Sample 8 frames from the GIF
            indices = self.sample_frame_indices(clip_len=8, frame_sample_rate=0.5, seg_len=total_frames)
            video = self.read_video_pyav(container, indices)

            # Close the container to release resources
            container.close()

            # Delete the GIF file after processing to save space
            if os.path.exists(gif_path):
                os.remove(gif_path)  # Deletes the downloaded GIF file

            return list(video)
        else:
            # Handle download failure (optional: log or raise an exception)
            return []

    def __getitem__(self, idx):
        """
        Retrieves a single sample from the dataset.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            dict: A dictionary containing processed GIF frames, question embeddings, and answer tokens.
        """
        # Get the GIF URL from the dataset
        gif_url = self.tgif_frame.iloc[idx, 0]
        
        # Process the GIF to extract video frames
        video_list = self.process_gifs(gif_url)
        
        # Preprocess the video frames using the image processor
        inputs = self.preprocess(video_list, return_tensors="pt")
        pixel_values = inputs.get('pixel_values')
        pixel_values = pixel_values.squeeze(0)  # Remove batch dimension
        
        # Extract question features and answer from the dataset
        question_features = self.tgif_frame.iloc[idx, 4]
        answers = self.tgif_frame.iloc[idx, 3] + " <END>"  # Append end token to the answer
        
        # Tokenize the answer using GPT-2 tokenizer
        answer_id = self.gpt2_tokenizer(
            answers,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=37
        ).input_ids.squeeze(0)  # Remove batch dimension
        
        # Create a sample dictionary
        sample = {
            'processed_gifs': pixel_values,          # Tensor of processed GIF frames
            'question_embeddings': question_features, # Precomputed question embeddings
            'answer': answer_id                      # Tokenized answer
        }
        
        return sample


def get_loaders(
    pickle_dataset='/kaggle/input/embedding-q-all/updated_final_df_with_q_embeddings.pkl',
    batch_size=32,
    split_ratio=(0.9, 0.09, 0.01)
):
    """
    Creates and returns DataLoader objects for training, validation, and testing.

    Args:
        pickle_dataset (str, optional): Path to the pickled dataset. Defaults to
            '/kaggle/input/embedding-q-all/updated_final_df_with_q_embeddings.pkl'.
        batch_size (int, optional): Number of samples per batch. Defaults to 32.
        split_ratio (tuple, optional): Ratios for train, validation, and test splits.
            Should sum to 1. Defaults to (0.9, 0.09, 0.01).

    Returns:
        tuple: DataLoader objects for training, validation, and testing.
    """
    # Ensure that the split ratios sum to 1
    assert sum(split_ratio) == 1, "Split ratios should sum to 1."

    # Initialize the custom dataset
    dataset = VQAGIFDataset(pickle_dataset)  # , transform=transform
    
    # Calculate sizes for each split
    total_size = len(dataset)
    train_size = int(split_ratio[0] * total_size)
    val_size = int(split_ratio[1] * total_size)
    test_size = total_size - train_size - val_size
    
    # Split the dataset into training, validation, and testing
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    # Save the test dataset for later use (optional)
    torch.save(test_dataset, '/kaggle/working/test_dataset.pth')
    
    # Create DataLoader objects for each split
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader


# TRAINING FUNCTION FOR OUT MODEL

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  # For displaying progress bars during iteration
from sentence_transformers import SentenceTransformer  # For embedding sentences
from sklearn.metrics.pairwise import cosine_similarity  # For computing similarity between embeddings

from transformers import GPT2Tokenizer  # GPT-2 tokenizer

# -----------------------------------------------------------------------------------
# Configuration and Setup
# -----------------------------------------------------------------------------------

# Initialize the GPT-2 tokenizer from the specified pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")

# Define training hyperparameters
BATCH_SIZE = 16
EPOCHS = 2
LEARNING_RATE = 1e-4

# Define paths for saving models and test data
MODEL_PATH = '/kaggle/working/best_model.pth'
MODEL_PATH_2 = '/kaggle/working/10_model.pth'
MODEL_PATH_3 = '/kaggle/working/15_model.pth'
test_data_path = '/kaggle/working/test_data.csv'

# Determine the computing device: use GPU if available, else fallback to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize data loaders for training, validation, and testing
train_loader, val_loader, test_loader = get_loaders(batch_size=BATCH_SIZE)

# Initialize the VQA model and move it to the designated device
model = VQAModel().to(device)

# If multiple GPUs are available, wrap the model with DataParallel for parallel processing
model = torch.nn.DataParallel(model)

# Initialize the SentenceTransformer model for validation embeddings
val_model = SentenceTransformer('all-MiniLM-L6-v2')

# Define the loss function as Cross Entropy Loss
criterion = nn.CrossEntropyLoss()

# Initialize the optimizer as Adam with the specified learning rate
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# -----------------------------------------------------------------------------------
# Training and Validation Functions
# -----------------------------------------------------------------------------------

def train(model, dataloader, criterion, optimizer, device):
    """
    Trains the model for one epoch.

    Args:
        model (torch.nn.Module): The VQA model to train.
        dataloader (DataLoader): DataLoader for the training dataset.
        criterion (torch.nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        device (torch.device): The device to perform computations on.

    Returns:
        float: The average training loss over the epoch.
    """
    model.train()  # Set the model to training mode
    running_loss = 0.0  # Initialize running loss

    # Iterate over batches in the training DataLoader with a progress bar
    for batch in tqdm(train_loader, desc="Training"): 
        # Extract data from the batch
        images = batch['processed_gifs'].to(device)
        question_features = batch['question_embeddings'].to(device)
        answers = batch['answer'].to(device)
        
        optimizer.zero_grad()  # Reset gradients

        # Forward pass: compute logits from the model
        logits, _ = model(images, question_features)
        
        # Align logits and answers dimensions if necessary
        if logits.size(1) < answers.size(1):
            answers = answers[:, :logits.size(1)]
        elif logits.size(1) > answers.size(1):
            logits = logits[:, :answers.size(1)]
        
        # Reshape logits and answers for loss computation
        logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
        answers_reshaped = answers.contiguous().view(-1)
        
        # Compute loss
        loss = criterion(logits_reshaped, answers_reshaped)
        loss.backward()  # Backpropagate loss
        optimizer.step()  # Update model parameters

        running_loss += loss.item()  # Accumulate loss

    # Calculate average loss over the epoch
    avg_loss = running_loss / len(dataloader)
    return avg_loss


def validate(model, dataloader, criterion, tokenizer, device):
    """
    Validates the model on the validation dataset.

    Args:
        model (torch.nn.Module): The VQA model to validate.
        dataloader (DataLoader): DataLoader for the validation dataset.
        criterion (torch.nn.Module): Loss function.
        tokenizer (GPT2Tokenizer): Tokenizer for decoding answers.
        device (torch.device): The device to perform computations on.

    Returns:
        tuple:
            float: The average validation loss.
            float: The average cosine similarity between reference and generated answers.
    """
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0  # Initialize running loss
    similarity_scores = []  # List to store cosine similarity scores
    
    with torch.no_grad():  # Disable gradient computation
        # Iterate over batches in the validation DataLoader with a progress bar
        for batch in tqdm(dataloader, desc="Validating"):
            # Extract data from the batch
            images = batch['processed_gifs'].to(device)
            question_features = batch['question_embeddings'].to(device)
            answer_tokens = batch['answer'].to(device)

            # Forward pass: compute logits and generated sequences from the model
            logits, gen_seq = model(images, question_features)
            
            # Decode reference answers and generated sequences to strings
            reference_answers = [tokenizer.decode(a, skip_special_tokens=True) for a in answer_tokens]
            generated_answers = [tokenizer.decode(g, skip_special_tokens=True) for g in gen_seq]
            
            # Encode reference and generated answers using the validation SentenceTransformer model
            reference_embeddings = val_model.encode(reference_answers, convert_to_tensor=False)
            model_embeddings = val_model.encode(generated_answers, convert_to_tensor=False)

            # Compute cosine similarity for each pair of reference and generated embeddings
            for ref_emb, mod_emb in zip(reference_embeddings, model_embeddings):
                similarity = cosine_similarity([ref_emb], [mod_emb])[0][0]
                similarity_scores.append(similarity)
            
            # Align logits and answer_tokens dimensions if necessary for loss computation
            if logits.size(1) < answer_tokens.size(1):
                answer_tokens = answer_tokens[:, :logits.size(1)]
            elif logits.size(1) > answer_tokens.size(1):
                logits = logits[:, :answer_tokens.size(1)]

            # Reshape logits and answers for loss computation
            logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
            answers_reshaped = answer_tokens.contiguous().view(-1)

            # Compute loss
            loss = criterion(logits_reshaped, answers_reshaped)
            running_loss += loss.item()  # Accumulate loss
    
    # Calculate average loss over the validation dataset
    avg_loss = running_loss / len(dataloader)
    
    # Calculate average cosine similarity, handling the case of no similarity scores
    average_similarity = sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0

    return avg_loss, average_similarity

# -----------------------------------------------------------------------------------
# Training Loop
# -----------------------------------------------------------------------------------

# Initialize the best validation loss to infinity for tracking improvements
best_val_loss = float('inf')

# Iterate over the number of epochs
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")

    # Train the model for one epoch and retrieve the average training loss
    train_loss = train(model, train_loader, criterion, optimizer, device)

    # Validate the model and retrieve the average validation loss and similarity
    val_loss, val_accuracy = validate(model, val_loader, criterion, tokenizer, device)
    
    # Display the training and validation metrics
    print(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Val accuracy: {val_accuracy:.2f}%")
  
    # Save the model if the validation loss has improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Model {epoch} saved to {MODEL_PATH}")
    
    # Save intermediate models at specific epochs
    if epoch == 1:
        torch.save(model.state_dict(), MODEL_PATH_2)
        print(f"Model saved to {MODEL_PATH_2}")
    
    if epoch == 2:
        torch.save(model.state_dict(), MODEL_PATH_3)
        print(f"Model saved to {MODEL_PATH_3}")

# Indicate that training has completed
print("Training complete.")
