In [None]:
%load_ext autoreload
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import h5py
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
import transformers
import pickle
import json

from DataLoader import VideoQADataModule, TVQADataModule
import preprocess.msvd_text_prep as msvd_text_prep
import preprocess.tgif_frameqa_text_prep as tgif_frameqa_text_prep

%autoreload 2

## Load Data

In [None]:
# msvd_glove_data_module = VideoQADataModule('data','msvd-qa',batch_size=32,text_embedding_model='glove',num_workers=8)
# msrvtt_glove_data_module = VideoQADataModule('data','msrvtt-qa',batch_size=32,text_embedding_model='glove',num_workers=8)
# tgif_glove_data_module = VideoQADataModule('data','tgif-qa_frameqa',batch_size=32,text_embedding_model='glove',num_workers=8)

In [None]:
# msvd_bert_data_module = VideoQADataModule('data','msvd-qa',batch_size=32,text_embedding_model='bert',num_workers=8)
msrvtt_bert_data_module = VideoQADataModule('data','msrvtt-qa',batch_size=32,text_embedding_model='bert',num_workers=8)
# tgifqa_frameqa_bert_data_module = VideoQADataModule('data','tgif-qa_frameqa',batch_size=32,text_embedding_model='bert')

## Defining the base model

In [None]:
import model.HCRN as HCRN

In [None]:
class HCRN_glove(pl.LightningModule):
    def __init__(self,glove_matrix,lr,model_kwargs,optimizer = 'AdamW'):
        super().__init__()
        
        self.lr = lr
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        glove_matrix = torch.FloatTensor(glove_matrix)
        self.model = HCRN.HCRNNetworkGlove(**model_kwargs)
        with torch.no_grad():
            self.model.linguistic_input_unit.encoder_embed.weight.set_(glove_matrix)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len)
    
    def configure_optimizers(self):
        if(self.optimizer == 'Adam'):
            optimizer = optim.Adam(self.parameters(), lr=self.lr)
        elif(self.optimizer == 'AdamW'):
            optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        else:
            raise "Optimizer not supported"
        scheduler = optim.lr_scheduler.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input).detach()
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)
        
        
    

## Recreating papers results on MSVD-QA, MSRVTT-QA, TGIF-QA_FrameQA

### MSVD-QA

In [None]:
max_epochs = 1

model_kwargs = {
        'question_type': msvd_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msvd_glove_data_module.glove_matrix,
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)

In [None]:
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### MSRVTT-QA

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msrvtt_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msrvtt_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/baseline',
    filename='msrvtt-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msrvtt_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### TGIF-QA FrameQA

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': tgif_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgif_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=tgif_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,tgif_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Improving using AdamW

### MSVD-QA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msvd_glove_data_module.glove_matrix,
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-adamw-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### MSRVTT-QA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msrvtt_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/baseline',
    filename='msrvtt-adamw-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msrvtt_glove_data_module)

In [None]:
test_res = trainer.test(model,msrvtt_glove_data_module.test_dataloader())
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save('models_checkpoints/msrvtt/baseline/msrvtt-adamw-epoch=09-val_acc=0.35.ckpt')

### TGIF-QA FrameQA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgif_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgif_glove_data_module.vocab
    }
model = HCRN_glove.load_from_checkpoint(
    'models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-adamW-epoch=07-val_acc=0.56.ckpt',
    glove_matrix=tgif_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,tgif_glove_data_module)

In [None]:
test_res = trainer.test(model,test_dataloaders = tgif_glove_data_module.test_dataloader())
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Create Bert Questions datasets

### MSVD-QA

In [None]:
import pandas as pd
pd.read_json('data/msvd-qa/raw_questions/train_qa.json').to_csv('data/msvd-qa/raw_questions/train_qa.csv',sep= '\t')
pd.read_json('data/msvd-qa/raw_questions/val_qa.json').to_csv('data/msvd-qa/raw_questions/val_qa.csv',sep= '\t')
pd.read_json('data/msvd-qa/raw_questions/test_qa.json').to_csv('data/msvd-qa/raw_questions/test_qa.csv',sep= '\t')

In [None]:
msvd_text_prep.create_vocab('data/msvd-qa/raw_questions/train_qa.json',vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json')

In [None]:
msvd_text_prep.process_questions(
    train_csv ='data/msvd-qa/raw_questions/train_qa.csv', 
    val_csv = 'data/msvd-qa/raw_questions/val_qa.csv',
    test_csv = 'data/msvd-qa/raw_questions/test_qa.csv',
    fine_tune_out_path ='data/msvd-qa/bert_question_embedding/question_finetuned_model',
    train_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_train_questions.pt',
    val_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_val_questions.pt',
    test_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_test_questions.pt',
    vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json'
)

### TGIFQA FrameQA

In [None]:
tgif_frameqa_text_prep.create_vocab('data/tgif-qa_frameqa/raw_questions/train_qa.csv',vocab_path='data/tgif-qa_frameqa/bert_question_embedding/tgif-qa_frameqa_vocab_bert.json')

In [None]:
tgif_frameqa_text_prep.process_questions(
    train_csv ='data/tgif-qa_frameqa/raw_questions/train_qa.csv', 
    val_csv = 'data/tgif-qa_frameqa/raw_questions/val_qa.csv',
    test_csv = 'data/tgif-qa_frameqa/raw_questions/test_qa.csv',
    fine_tune_out_path ='data/tgif-qa_frameqa/bert_question_embedding/question_finetuned_model',
    train_output = 'data/tgif-qa_frameqa/bert_question_embedding/tgif-qa_frameqa_train_questions.pt',
    val_output = 'data/tgif-qa_frameqa/bert_question_embedding/tgif-qa_frameqa_val_questions.pt',
    test_output = 'data/tgif-qa_frameqa/bert_question_embedding/tgif-qa_frameqa_test_questions.pt',
    vocab_path='data/tgif-qa_frameqa/bert_question_embedding/tgif-qa_frameqa_vocab_bert.json'
)

## Bert text embeddings

In [None]:
import model.HCRN as HCRN

In [None]:
class HCRNBert(pl.LightningModule):
    def __init__(self, lr, model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        self.model = HCRN.HCRNNetworkBert(**model_kwargs)
        
        self.bert_params = []
        self.rest_params = []
        for name, param in self.named_parameters():
            if(name.startswith('model.linguistic_input_unit.bert')):
                self.bert_params.append(param)
            else:
                self.rest_params.append(param)
    
    def forward(self,ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids):
        return self.model(ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids)
    
    def configure_optimizers(self):
        optimizer_model = optim.AdamW(self.rest_params, lr=self.lr)
        optimizer_bert = optim.AdamW(self.bert_params, lr=1e-5)

        return {'optimizer': optimizer_model},{'optimizer': optimizer_bert}
   
    def training_step(self,batch,batch_idx,optimizer_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs[0]:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)

### No training

#### Bert uncased single layer embedding

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='bert fextract hf-model single layer embed')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'mult_embedding': False
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-nograd-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### 4 Last layers bert base uncased

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='bert fextract hf-model 4 layer embed')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'mult_embedding': True
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-4layer-nograd-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Finetuned on questions 1 layer

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='bert fextract question-tuned-model single layer embed')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'data/tgif-qa_frameqa/bert_question_embedding/question_finetuned_model',
        'mult_embedding': False
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-finetuned-1layer-nograd-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Finetuned on questions 4layer

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='bert fextract question-tuned-model single layer embed')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'data/tgif-qa_frameqa/bert_question_embedding/question_finetuned_model',
        'mult_embedding': True
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-finetuned-4layer-nograd-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### Tune all MSVD

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All'
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test()
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### Tune all MSRVTT

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All'
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test()
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### Tune TGIF QA

#### Train All

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All'
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(ckpt_path='models_checkpoints/tgif-qa_frameqa/bert/tgif-bert-pretrained-1layer-train-all-epoch=08-val_acc=0.55.ckpt',verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Train 4 last layers

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train 4 last hf-model pooler')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'last-4'
    }
model = HCRNBert(
    nb_train_steps = tgifqa_frameqa_bert_data_module.number_training_steps(),
    lr=0.0001,
    max_epochs=max_epochs,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Train 2 last layers

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train 4 last hf-model pooler')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'last-4'
    }
model = HCRNBert(
    nb_train_steps = tgifqa_frameqa_bert_data_module.number_training_steps(),
    lr=0.0001,
    max_epochs=max_epochs,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Roberta Text Embedding

#### Train All

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa distillbert train all hf-model pooler')

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'distilbert-base-uncased',
        'train_bert': 'All'
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-distillbert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-distillbert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
from transformers import DistilBertModel
DistilBertModel.from_pretrained('distilbert-base-uncased')

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Bert Ablation

In [None]:
import model.HCRN as HCRN

In [None]:
class HCRNBertAblation(pl.LightningModule):
    def __init__(self, lr, model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        self.model = HCRN.HCRNNetworkBertAblation(**model_kwargs)
        
        self.bert_params = []
        self.rest_params = []
        for name, param in self.named_parameters():
            if(name.startswith('model.linguistic_input_unit.bert')):
                self.bert_params.append(param)
            else:
                self.rest_params.append(param)
    
    def forward(self,ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids):
        return self.model(ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids)
    
    def configure_optimizers(self):
        optimizer_model = optim.AdamW(self.rest_params, lr=self.lr)
        optimizer_bert = optim.AdamW(self.bert_params, lr=1e-5)
        return {'optimizer': optimizer_model},{'optimizer': optimizer_bert}
   
    def training_step(self,batch,batch_idx,optimizer_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs[0]:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)

### MSVD

#### Motion Ablation

##### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd-qa bert train all hf-model pooler no-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-video-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-clip-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Question Ablation

###### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question','video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd train all hf-model pooler no-video-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-clip-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### No visual features

##### No Visual

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-visual-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion','video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### No appearance

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msvd bert train all hf-model pooler no-appearance-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msvd_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### MSRVTT

#### Motion Ablation

##### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-video-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-clip-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Question Ablation

###### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question','video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt train all hf-model pooler no-video-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-clip-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### No visual features

##### No Visual

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-visual-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion','video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### No appearance

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='msrvtt bert train all hf-model pooler no-appearance-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,msrvtt_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### TGIF-QA

#### Motion Ablation

##### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-video-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-clip-motion')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### Question Ablation

###### Full Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question','video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Video level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-video-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### Clip level Ablation

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-clip-question-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_question']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

#### No visual features

##### No Visual

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-visual-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['clip_motion','video_motion','video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

##### No appearance

In [None]:
wandb.init(project="video-qa-hcrn-recvis",name='tgif-qa bert train all hf-model pooler no-appearance-features')

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgifqa_frameqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgifqa_frameqa_bert_data_module.vocab,
        'transformer_path': 'bert-base-uncased',
        'train_bert': 'All',
        'ablated_features':['video_appearance_ablation']
    }
model = HCRNBertAblation(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
lr_logger_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc,lr_logger_callback]
)

In [None]:
trainer.fit(model,tgifqa_frameqa_bert_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## TVQA

In [None]:
import model.HCRN as HCRN

In [None]:
class HCRNSubtitles(pl.LightningModule):
    def __init__(self, lr, model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        self.model = HCRN.HCRNNetworkTVQA(**model_kwargs)
        
        self.bert_params = []
        self.rest_params = []
        for name, param in self.named_parameters():
            if(name.startswith('model.linguistic_input_unit.bert')):
                self.bert_params.append(param)
            else:
                self.rest_params.append(param)
    
    def forward(self,ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids,
                video_appearance_feat, question_tokens,question_attention_masks,question_token_type_ids,
               subtitles_tokens, subtitles_attention_mask, subtitles_token_type_ids):
        return self.model(ans_candidates_tokens, ans_candidates_attention_mask,
                          ans_candidates_token_type_ids, video_appearance_feat,
                          question_tokens,question_attention_masks,question_token_type_ids,
                         subtitles_tokens,subtitles_attention_mask, subtitles_token_type_ids)
    
    def configure_optimizers(self):
        optimizer_model = optim.AdamW(self.rest_params, lr=self.lr)
        optimizer_bert = optim.AdamW(self.bert_params, lr=1e-5)

        return {'optimizer': optimizer_model},{'optimizer': optimizer_bert}
   
    def training_step(self,batch,batch_idx,optimizer_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs[0]:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)

In [None]:
tvqa_bert_data_module = TVQADataModule('data','tvqa',batch_size=24,text_embedding_model='distilbert',num_workers=8)

In [5]:
wandb.init(project="video-qa-hcrn-recvis",name='tvqa subtitles pre conditioning bert ')

[34m[1mwandb[0m: Currently logged in as: [33mnicolas-dufour[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
max_epochs = 10

model_kwargs = {
        'question_type': tvqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'transformer_path': 'distilbert-base-uncased',
        'train_bert': 'all'
    }

model = HCRNSubtitles(
    lr=0.0001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tgif-bert-pretrained-1layer-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc],
    precision = 16
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [None]:
trainer.fit(model,tvqa_bert_data_module)

loading appearance feature from data/tvqa/tvqa_appearance_feat.h5
Loading subtitles from data/tvqa/tvqa_subtitles_splitted.pt



  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetworkTVQA  | 105 M 
-----------------------------------------------
105 M     Trainable params
0         Non-trainable params
105 M     Total params


loading questions from data/tvqa/distilbert_question_embedding/tvqa_val_questions.pt


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

loading questions from data/tvqa/distilbert_question_embedding/tvqa_train_questions.pt


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## 2 Stream HCRN

In [2]:
import model.HCRN as HCRN

In [3]:
class HCRNSubtitles2Streams(pl.LightningModule):
    def __init__(self, lr, model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        self.model = HCRN.HCRNNetworkTVQA2Stream(**model_kwargs)
        
        self.bert_params = []
        self.rest_params = []
        for name, param in self.named_parameters():
            if(name.startswith('model.linguistic_input_unit.bert')):
                self.bert_params.append(param)
            else:
                self.rest_params.append(param)
    
    def forward(self,ans_candidates_tokens, ans_candidates_attention_mask, ans_candidates_token_type_ids,
                video_appearance_feat, question_tokens,question_attention_masks,question_token_type_ids,
               subtitles_tokens, subtitles_attention_mask, subtitles_token_type_ids):
        return self.model(ans_candidates_tokens, ans_candidates_attention_mask,
                          ans_candidates_token_type_ids, video_appearance_feat,
                          question_tokens,question_attention_masks,question_token_type_ids,
                         subtitles_tokens,subtitles_attention_mask, subtitles_token_type_ids)
    
    def configure_optimizers(self):
        optimizer_model = optim.AdamW(self.rest_params, lr=self.lr)
        optimizer_bert = optim.AdamW(self.bert_params, lr=1e-5)

        return {'optimizer': optimizer_model},{'optimizer': optimizer_bert}
   
    def training_step(self,batch,batch_idx,optimizer_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs[0]:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)

In [4]:
tvqa_bert_data_module = TVQADataModule('data','tvqa',batch_size=25,text_embedding_model='distilbert',num_workers=8)

In [5]:
wandb.init(project="video-qa-hcrn-recvis",name='tvqa subtitles 2 streams ')

[34m[1mwandb[0m: Currently logged in as: [33mnicolas-dufour[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
max_epochs = 10

model_kwargs = {
        'question_type': tvqa_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'transformer_path': 'distilbert-base-uncased',
        'train_bert': 'all'
    }

model = HCRNSubtitles2Streams.load_from_checkpoint(
    'models_checkpoints/tgif-qa_frameqa/bert/tvqa-2streams-train-all-epoch=04-val_acc=0.41.ckpt',
    lr=0.00001,
    model_kwargs=model_kwargs,
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/bert',
    filename='tvqa-2streams-train-all-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    resume_from_checkpoint = 'models_checkpoints/tgif-qa_frameqa/bert/tvqa-2streams-train-all-epoch=04-val_acc=0.41.ckpt',
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc],
    precision = 16
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [7]:
trainer.fit(model,tvqa_bert_data_module)

loading appearance feature from data/tvqa/tvqa_appearance_feat.h5
Loading subtitles from data/tvqa/tvqa_subtitles_splitted.pt



  | Name      | Type                   | Params
-----------------------------------------------------
0 | criterion | CrossEntropyLoss       | 0     
1 | train_acc | Accuracy               | 0     
2 | valid_acc | Accuracy               | 0     
3 | test_acc  | Accuracy               | 0     
4 | model     | HCRNNetworkTVQA2Stream | 88.0 M
-----------------------------------------------------
88.0 M    Trainable params
0         Non-trainable params
88.0 M    Total params


loading questions from data/tvqa/distilbert_question_embedding/tvqa_val_questions.pt


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

loading questions from data/tvqa/distilbert_question_embedding/tvqa_train_questions.pt


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1

In [8]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

loading questions from data/tvqa/distilbert_question_embedding/tvqa_test_questions.pt


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

OSError: [Errno 12] Cannot allocate memory

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)