In [None]:
pip install --quiet av sentence-transformers

In [None]:
import os
import av
import numpy as np
import pandas as pd
import pickle
import torch
import requests
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoImageProcessor, TimesformerModel
from transformers import GPT2Tokenizer

class VQAGIFDataset(Dataset):
    def __init__(self, pickle_dataset): 
        """
        Args:
            pickle dataset (string): Path to the pickle file with urls, questions, answers and question text embeddings.
        
        Returns a dict with the pixel values, question embeddings and answer tokens
        """
        self.tgif_frame = pd.read_pickle(pickle_dataset)
        
        self.preprocess = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")
        new_pad_token = "[PAD]"
        if new_pad_token not in self.gpt2_tokenizer.get_vocab():
            self.gpt2_tokenizer.add_special_tokens({'pad_token': new_pad_token})

    def __len__(self):
        return len(self.tgif_frame)
    
    def download_gif(self, gif_url, output_path):
        response = requests.get(gif_url, stream=True)
        if response.status_code == 200:
            with open(output_path, 'wb') as f:
                f.write(response.content)
            return True
        return False

    def read_video_pyav(self, container, indices):
        '''
        Decode the video with PyAV decoder.
        '''
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])

    def sample_frame_indices(self, clip_len, frame_sample_rate, seg_len):
        '''
        Sample a given number of frame indices from the video.
        '''
        converted_len = int(clip_len * frame_sample_rate)
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
        indices = np.linspace(start_idx, end_idx, num=clip_len)
        indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        return indices
    
    def process_gifs(self, gif_url):
        '''
        Gets the gif file from the url provided, reads the video file and returns a list of stacked frames
        Args: gif_url: url to the gif
        '''
        embeddings_dict = {}
        gifs_checked = 0
        true_gifs = []
        download_folder = "/kaggle/working/gifs"

        # Create the download folder if it doesn't exist
        if not os.path.exists(download_folder):
            os.makedirs(download_folder)

        gif_name = gif_url.split('/')[-1]
        gif_path = os.path.join(download_folder, gif_name)

        # Download the GIF
        if self.download_gif(gif_url, gif_path):
            gifs_checked += 1

            container = av.open(gif_path)
            total_frames = container.streams.video[0].frames

            # Sample 8 frames from the GIF and process the video
            indices = self.sample_frame_indices(clip_len=8, frame_sample_rate=0.5, seg_len=total_frames)
            video = self.read_video_pyav(container, indices)

            # Close the container to release any held resources
            container.close()

            # Delete the GIF file after processing
            if os.path.exists(gif_path):
                os.remove(gif_path)  # Deletes the downloaded GIF file

        return list(video)

    def __getitem__(self, idx):
        # Gets the GIF features
        gif_url = self.tgif_frame.iloc[idx, 0]
        video_list = self.process_gifs(gif_url)
        inputs = self.preprocess(video_list, return_tensors="pt")
        pixel_values = inputs.get('pixel_values')
        pixel_values = pixel_values.squeeze(0)
        
        # Gets the question features
        question_features = self.tgif_frame.iloc[idx, 4]
        
        # Tokenizes the answers
        answers = self.tgif_frame.iloc[idx, 3] + " <END>"
        answer_id = self.gpt2_tokenizer(answers, return_tensors="pt", truncation=True, padding="max_length", max_length=37).input_ids.squeeze(0)  # One extra for the end token.
        
        sample = {'processed_gifs': pixel_values, 'question_embeddings': question_features, 'answer': answer_id}
        
        return sample


def get_loaders(pickle_dataset='/kaggle/input/embedding-q-all/updated_final_df_with_q_embeddings.pkl', batch_size=16, split_ratio=(0.9, 0.09, 0.01)):
    """
    Returns training, validation, and test data loaders.
    Args:
        pickle_dataset (string): Path to the pickle file.
        batch_size (int): Batch size for DataLoader.
        transform (callable, optional): Optional transform to be applied on image.
        split_ratio (tuple): Ratios for train, val, and test split. They should sum to 1.
    """

    assert sum(split_ratio) == 1, "Split ratios should sum to 1."

    dataset = VQAGIFDataset(pickle_dataset)
    
    total_size = len(dataset)
    train_size = int(split_ratio[0] * total_size)
    val_size = int(split_ratio[1] * total_size)
    test_size = total_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    torch.save(test_dataset, '/kaggle/working/test_dataset.pth')
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
    
    def forward(self, query, key, value):
        # Apply attention where the query is the question features and key, value are the video features
        attn_output, _ = self.attn(query, key, value)
        return attn_output

class VQAModel(nn.Module):
    def __init__(self):
        super(VQAModel, self).__init__()
        
        # For GIF Features
        self.gif_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        self.Timesformer_model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
        
        # Cross attention and resizing functions
        self.cross_attention = CrossAttention(embed_dim=768, num_heads=8)
        self.question_projection = nn.Linear(1024, 768)
        self.project_down = nn.Linear(768 * 2, 1024)
        
        # GPT2-Medium
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")
        new_pad_token = "[PAD]"
        if new_pad_token not in self.gpt2_tokenizer.get_vocab():
            self.gpt2_tokenizer.add_special_tokens({'pad_token': new_pad_token})
        self.gpt2_model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-medium")
        self.gpt2_model.resize_token_embeddings(len(self.gpt2_tokenizer))  # Resize embeddings to accommodate new pad token
        self.gpt2_model.config.pad_token_id = self.gpt2_tokenizer.get_vocab()[new_pad_token]
        
        
    def forward(self, processed_gif, question_features):
        
        # Gets the video embeddings
        outputs = self.Timesformer_model(processed_gif)
        image_features = outputs.last_hidden_state.mean(dim=1).unsqueeze(1)
        
        # Resizes the question features
        question_features = self.question_projection(question_features)
        question_features = question_features.unsqueeze(1)

        # Applying cross-attention (text features as query, video features as key/value)
        question_query = question_features.permute(1, 0, 2)  # (question_seq_len, batch_size, embed_dim)
        image_key_value = image_features.permute(1, 0, 2)    # (image_seq_len, batch_size, embed_dim)
        
        cross_attended_features = self.cross_attention(query=question_query, key=image_key_value, value=image_key_value)
        cross_attended_features = cross_attended_features.permute(1, 0, 2)  # back to (batch_size, question_seq_len, embed_dim)  
        
        # Concatenating Projecting the combined features down to match GPT-2's expected embedding size
        combined_features = torch.cat((cross_attended_features, question_features), dim=-1)
        combined_features = self.project_down(combined_features)
        
        # Pass the combined features to GPT-2
        outputs = self.gpt2_model(inputs_embeds=combined_features)
        logits = outputs.logits
        
        # Generation of the output sequence using GPT-2
        eos_token_id = self.gpt2_tokenizer.encode("<END>", add_prefix_space=True)[0]
        generated_sequence = self.gpt2_model.generate(
                                                        inputs_embeds=combined_features, 
                                                        max_length=16, 
                                                        pad_token_id=self.gpt2_model.config.pad_token_id, 
                                                        eos_token_id=eos_token_id,
                                                        repetition_penalty = 1.2,
                                                        top_k = 50,
                                                        top_p = 0.9,
                                                        num_beams=5,  # Number of beams
                                                        early_stopping=True,
                                                        no_repeat_ngram_size=2,
                                                        temperature=0.7
                                                    )
        
        return logits, generated_sequence


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm 
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2-medium")
BATCH_SIZE = 16
EPOCHS = 20
LEARNING_RATE = 1e-4
MODEL_PATH = '/kaggle/working/best_model.pth'
MODEL_PATH_2 = '/kaggle/working/10_model.pth'
MODEL_PATH_3 = '/kaggle/working/15_model.pth'
test_data_path = '/kaggle/working/test_data.csv'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader = get_loaders(batch_size=BATCH_SIZE)
model = VQAModel().to(device)
model = torch.nn.DataParallel(model)
val_model = SentenceTransformer('all-MiniLM-L6-v2')
criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_loader, desc="Training"): 
        
        images, question_features, answers = batch['processed_gifs'], batch['question_embeddings'], batch['answer']
        images = images.to(device)
        question_features = question_features.to(device)
        answers = answers.to(device)
        
        optimizer.zero_grad()
        logits, _ = model(images, question_features)
        if logits.size(1) < answers.size(1):
            answers = answers[:, :logits.size(1)]
        elif logits.size(1) > answers.size(1):
            logits = logits[:, :answers.size(1)]
        logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
        answers_reshaped = answers.contiguous().view(-1)
        
        loss = criterion(logits_reshaped, answers_reshaped)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(dataloader)
    return avg_loss

def validate(model, dataloader, criterion, tokenizer, val_model, device):
    model.eval()
    running_loss = 0.0
    total = 0
    similarity_scores = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            images, question_features, answer_tokens = batch['processed_gifs'], batch['question_embeddings'], batch['answer']
            images = images.to(device)
            question_features = question_features.to(device)
            answer_tokens = answer_tokens.to(device)

            logits, gen_seq = model(images, question_features)
            
            reference_answers = [tokenizer.decode(a, skip_special_tokens=True) for a in answer_tokens]
            generated_answers = [tokenizer.decode(g, skip_special_tokens=True) for g in gen_seq]
            
            reference_embeddings = val_model.encode(reference_answers, convert_to_tensor=False)
            model_embeddings = val_model.encode(generated_answers, convert_to_tensor=False)

            for ref_emb, mod_emb in zip(reference_embeddings, model_embeddings):
                similarity = cosine_similarity([ref_emb], [mod_emb])[0][0]
                similarity_scores.append(similarity)
            
            if logits.size(1) < answer_tokens.size(1):
                answer_tokens = answer_tokens[:, :logits.size(1)]
            elif logits.size(1) > answer_tokens.size(1):
                logits = logits[:, :answer_tokens.size(1)]
            logits_reshaped = logits.contiguous().view(-1, logits.size(-1))
            answers_reshaped = answer_tokens.contiguous().view(-1)
            total += answers_reshaped.size(0)

            loss = criterion(logits_reshaped, answers_reshaped)
            running_loss += loss.item()
    
    avg_loss = running_loss / len(dataloader)
    average_similarity = sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0

    return avg_loss, average_similarity


best_val_loss = float('inf')
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    
    print(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Val accuracy: {val_accuracy:.2f}%")
  
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Model {epoch} saved to {MODEL_PATH}")
    elif epoch > 15:
        epochs_no_improve += 1
        if epochs_no_improve >= 5:
            print('Early stopping!')
            early_stop = True
            break
    
    if epoch == 10:
        torch.save(model.state_dict(), MODEL_PATH_2)
        print(f"Model saved to {MODEL_PATH_2}")
    
    if epoch == 15:
        torch.save(model.state_dict(), MODEL_PATH_3)
        print(f"Model saved to {MODEL_PATH_3}")
        
print("Training complete.")