In [1]:
# https://arxiv.org/pdf/2404.00308

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!pip install evaluate
!pip install faiss-gpu
!python -m pip install scikit-image
!pip install pillow
!pip install wandb
!pip install git+https://github.com/openai/CLIP.git
!pip install peft
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git

In [4]:
import sys
sys.path.append('/kaggle/input/yesr/pytorch/default/1')

In [None]:
sys.path

In [6]:
from retriever import GIFFrameRetriever
clip_cut = GIFFrameRetriever()

In [None]:
import clip
clipm, clipp = clip.load("ViT-B/32", device="cuda")

In [8]:
import math
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPImageProcessor, GPT2Tokenizer, CLIPProcessor, CLIPModel, GPT2LMHeadModel, Blip2QFormerConfig, Blip2QFormerModel
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score
from nltk.translate.bleu_score import sentence_bleu
from evaluate import load
import collections
from torch.cuda.amp import autocast
import numpy as np
from nltk.tokenize import TreebankWordTokenizer
from peft import LoraConfig, get_peft_model, TaskType

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
class GQADataset(torch.utils.data.Dataset):
    def __init__(self, final_data, clip_cut, nf, train_setting = True):
        self.questions = final_data['question']
        self.image_urls = final_data['url']
        self.answers = final_data['answer']
        self.clip_cut= clip_cut
        self.clipm = clipm
        self.clipp = clipp
        self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.train_setting = train_setting
        self.nf = nf

    def __len__(self):
        return len(self.image_urls)

    def __getitem__(self, idx):

        url = self.image_urls[idx]
        keyframes = self.clip_cut.retrieve_images_from_gif(url, self.clipm, self.clipp, self.questions[idx], nf)

        questions = self.questions[idx]
        answers = self.answers[idx]

        processed_keyframes = [self.processor(frame, return_tensors='pt')['pixel_values'].squeeze(0) for frame in keyframes]
        processed_keyframes = torch.stack(processed_keyframes)

        tokens, mask, c_len, t_len = self.pad_sequences(self.questions[idx], self.answers[idx], nf)

        tokens = tokens.long()
        mask = mask.long()

        sample = {'images' : processed_keyframes,
                  'tokens': tokens,
                  'mask': mask,
                  'c_len' : c_len,
                  't_len' : t_len,
                  'answers' : answers,
                  'questions' : questions}

        return sample

    def pad_sequences(self, questions, answers, nf):
        m = [
          torch.tensor(self.tokenizer.encode('question: ')),
          torch.tensor(self.tokenizer.encode(' context:')),
          torch.tensor(self.tokenizer.encode('answer ')),
          torch.tensor(self.tokenizer.encode('<|endoftext|>')),
      ]
        m_mask = [
          torch.ones(len(m[0])),
          torch.ones(len(m[1])),
          torch.ones(len(m[2])),
          torch.zeros(len(m[3]))
      ]

        if self.train_setting:
            q = torch.tensor(self.tokenizer.encode(str(questions)))
            a = torch.tensor(self.tokenizer.encode(str(answers)))

            q, q_mask, leftover_tokens = self.make_padding(32, q, question=True)

            c_len =  m[1].size(0)
            t_len = m[0].size(0) + nf + q.size(0) + m[1].size(0)

            a, a_mask, _ = self.make_padding(32, a, leftover_tokens=leftover_tokens)

            if len((a == 0).nonzero()) != 0:
                pad_start = (a == 0).nonzero()[0]
            else:
                pad_start = []

            a = torch.cat((a, m[3])) if len(pad_start) == 0 else torch.cat((a[:pad_start], m[3], a[pad_start:]))
            q = torch.cat((m[1], torch.ones(nf), m[0],  q, m[2], a))

            q_mask = torch.cat((m_mask[1], torch.ones(nf), m_mask[0],  q_mask, m_mask[2], a_mask, m_mask[3]))

            return q, q_mask, c_len, t_len
        else:
            q = torch.tensor(self.tokenizer.encode(str(questions)))
            q, q_mask, _ = self.make_padding_test_setting(32, q)

            c_len =  m[1].size(0)
            t_len = m[0].size(0) + nf + q.size(0) + m[1].size(0)

            q = torch.cat((m[1], torch.ones(nf), m[0],  q, m[2]))

            q_mask = torch.cat((m_mask[1], torch.ones(nf), m_mask[0],  q_mask, m_mask[2]))
            return q, q_mask, c_len, t_len


    def make_padding(self, max_len, tokens, question=False, leftover_tokens=0):
        padding = max_len - tokens.size(0)
        if padding > 0:
            if question:
                leftover_tokens = padding
                mask = torch.ones(tokens.size(0))
            else:
                tokens = torch.cat((tokens, torch.zeros(padding + leftover_tokens)))
                mask = torch.zeros(max_len + leftover_tokens)

        elif padding == 0:
            if question:
                mask = torch.ones(tokens.size(0))
            else:
                mask = torch.zeros(tokens.size(0) + leftover_tokens)
                tokens = torch.cat((tokens, torch.zeros(leftover_tokens)))

        elif padding < 0:
            if question:
                tokens = tokens[:max_len]
                mask = torch.ones(max_len)
            else:
                tokens = torch.cat((tokens[:max_len], torch.zeros(leftover_tokens)))
                mask = torch.zeros(max_len + leftover_tokens)

        return tokens, mask, leftover_tokens


    def make_padding_test_setting(self, max_len, tokens, do_padding=False):
        padding = max_len - tokens.size(0)
        padding_len = 0
        if padding > 0:
            if do_padding:
                mask = torch.cat((torch.ones(tokens.size(0)), torch.zeros(padding)))
                tokens = torch.cat((tokens, torch.zeros(padding)))
                padding_len = padding
            else:
                mask = torch.ones(tokens.size(0))
        elif padding == 0:
            mask = torch.ones(max_len)
        elif padding < 0:
            tokens = tokens[:max_len]
            mask = torch.ones(max_len)
        return tokens, mask, padding_len

In [11]:
class MLP(nn.Module):
    def __init__(self, sizes):
        super().__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1]))
            if i < len(sizes) - 2:
                layers.append(nn.ReLU())
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [12]:
peft_config = LoraConfig(
  task_type=TaskType.CAUSAL_LM, inference_mode=False,
  r=8,
  lora_alpha=32, lora_dropout=0.1
)

In [13]:
class GQAModel(nn.Module):
    def __init__(self, peft_config, nf, num_query_token=3, cross_attention_freq=2):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        for param in self.clip.parameters():
            param.requires_grad = False

        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.llm = GPT2LMHeadModel.from_pretrained("gpt2")
        self.peft_config = peft_config
        self.llm = get_peft_model(self.llm, self.peft_config)
        self.nf=  nf

        self.clip_vision_output_dim = self.clip.config.vision_config.hidden_size

        self.mapper = MLP(sizes=(self.clip_vision_output_dim, 768, 768))
        self.dropout = nn.Dropout(0.1)
        self.Qformer, self.query_tokens = self.init_Qformer(nf)
        for param in self.Qformer.parameters():
            param.requires_grad = False

        self.llm_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm.config.hidden_size)
        self.image_proj = nn.Linear(512,self.Qformer.config.encoder_hidden_size)

    def init_Qformer(self, num_query_token=3):
        qformer_config = Blip2QFormerConfig.from_pretrained("Salesforce/blip2-opt-2.7b")
        qformer_config.query_length = num_query_token

        Qformer = Blip2QFormerModel.from_pretrained("Salesforce/blip2-opt-2.7b", config=qformer_config)
        query_tokens = nn.Parameter(torch.zeros(1, num_query_token, qformer_config.hidden_size))
        query_tokens.data.normal_(mean=0.0, std=qformer_config.initializer_range)

        return Qformer, query_tokens

    def forward(self, images, tokens, mask, c_len):
        batch_size, num_frames, c, h, w = images.size()

        images_reshaped = images.view(-1, c, h, w)
        image_features = self.clip.get_image_features(images_reshaped)

        image_features = self.image_proj(image_features)
        image_features = image_features.view(batch_size, num_frames, -1)

        image_atts = torch.ones(image_features.size()[:-1], dtype=torch.long).to(images.device)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1) 
        query_output = self.Qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        V = self.llm_proj(query_output.last_hidden_state)

        sigma = 0.1
        rho = 0.5 + sigma * torch.randn(1).item()
        rho = torch.clamp(torch.tensor(rho), 0.3, 0.7).item()
        dynamic_mask = torch.rand((V.shape[0], V.shape[1], V.shape[2]), device=V.device) > rho

        V_masked = V * dynamic_mask.float()
        caption_features = self.llm.transformer.wte(tokens)

        caption_features_masked = caption_features
        caption_features_unmasked = caption_features

        for b in range(caption_features.shape[0]):
            caption_features_masked[b,c_len[b]:c_len[b]+self.nf,:] = V_masked[b]
        for b in range(caption_features.shape[0]):
            caption_features_unmasked[b,c_len[b]:c_len[b]+self.nf,:] = V[b]

        llm_output_unmasked = self.llm(inputs_embeds=caption_features_unmasked, attention_mask=mask,output_hidden_states=True)
        llm_output_masked = self.llm(inputs_embeds=caption_features_masked, attention_mask=mask, output_hidden_states=True)

        masked_out = llm_output_masked.hidden_states[-1]
        unmasked_out = llm_output_unmasked.hidden_states[-1]

        return llm_output_masked, masked_out, unmasked_out, dynamic_mask, rho

    def generate(self, images, tokens, mask, c_len):


        batch_size, num_frames, c, h, w = images.size()

        images_reshaped = images.view(-1, c, h, w)
        image_features = self.clip.get_image_features(images_reshaped)
        print(image_features.shape)

        image_features = self.image_proj(image_features)
        image_features = image_features.view(batch_size, num_frames, -1)

        image_atts = torch.ones(image_features.size()[:-1], dtype=torch.long).to(images.device)
        query_tokens = self.query_tokens.expand(batch_size, -1, -1)

        query_output = self.Qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        V = self.llm_proj(query_output.last_hidden_state)

        embedding = self.llm.transformer.wte(tokens)

        for b in range(embedding.shape[0]):
                 embedding[b,c_len[b]:c_len[b]+self.nf,:] = V[b]

        return embedding

In [14]:
df = pd.read_csv('/kaggle/input/dataset/data.csv')

In [None]:
df.head()

In [16]:
train_data = df.sample(frac = 0.0068).reset_index()

In [None]:
batch_size = 16
epochs = 10
learning_rate = 1e-3
nf=5

train_dataset = GQADataset(train_data, clip_cut, nf, train_setting = True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = GQAModel(peft_config, nf).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
MODEL_PATH = '/kaggle/working//best_model.pth'

In [113]:
criterion = nn.CrossEntropyLoss()

In [114]:
def calculate_mvm_loss(masked_out , unmasked_out, dynamic_mask, c_len, rho, num_frames):
        batch_size = masked_out.shape[0]
        mse_loss = 0
        total_tokens = 0

        for b in range(batch_size):
#             print(rho)
            start_idx = c_len[b]
            end_idx = start_idx + num_frames

            masked_visual = masked_out[b, start_idx:end_idx]
            unmasked_visual = unmasked_out[b, start_idx:end_idx]

            mask = dynamic_mask[b]

            unmasked_indices = ~mask
            v_masked = masked_visual[unmasked_indices]
            v_unmasked = unmasked_visual[unmasked_indices]
            
            K = masked_visual.size(0)
            T = masked_visual.size(1)

            mse_loss += F.mse_loss(v_masked, v_unmasked, reduction='sum')/((1-rho)*K*T)
            total_tokens += unmasked_indices.sum().item()

        mvm_loss = mse_loss / batch_size

        return mvm_loss

In [None]:
import wandb

wandb.init(project="first run", config={
    "epochs": epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate
})

In [116]:
# https://wandb.ai/authorize

In [117]:
# import shutil

def train(model, dataloader, optimizer, scheduler, device):    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_tokens = 0
        for batch in tqdm(train_loader, desc="Training"):
            images, tokens, mask, c_len , t_len = batch['images'], batch['tokens'], batch['mask'], batch['c_len'], batch['t_len'],
            images = images.to(device)
            tokens = tokens.to(device)
            mask = mask.to(device)
            t_len = t_len.to(device)
            c_len = c_len.to(device)

            optimizer.zero_grad()
            logits, masked_out, unmasked_out, dynamic_mask, rho = model(images, tokens, mask, c_len)
            logits = logits.logits

            l1 = calculate_mvm_loss(masked_out, unmasked_out, dynamic_mask, c_len, rho, nf)
            shift = 0
            loss=0
            for b in range(logits.size(0)):
                            condensed_tokens = tokens[b,t_len[b] + 1:]
                            condensed_logits = logits[b,shift + t_len[b]:-1]
                        
                            loss += F.cross_entropy(condensed_logits.reshape(-1,logits.shape[-1]), condensed_tokens.flatten(), ignore_index=0)
                            total_tokens += (condensed_tokens != 0).sum().item()

            loss=loss/logits.size(0)
            loss = loss + l1
            print("batch sum loss : ", loss)
            loss.backward()
            optimizer.step()
            scheduler.step(loss)
            running_loss += loss.item()

        print(f"Train loss: {running_loss:.4f}")
        avg_loss = running_loss / len(dataloader)
        perplexity = math.exp(running_loss/total_tokens)
        checkpoint_path = f"checkpoint_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        wandb.log({"avg_loss": avg_loss, 'avg_perp' : perplexity, "epoch": epoch+1})

In [None]:
train(model, train_loader, optimizer, scheduler, device)

In [None]:

torch.save(model.state_dict(), MODEL_PATH)
print(f"Model saved to {MODEL_PATH}")
print("Training complete.")