In [173]:
%load_ext autoreload
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import h5py
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
import pickle
import json

from DataLoader import VideoQADataModule
from preprocess.msvd_text_prep import create_vocab, process_questions
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
wandb.init(project="video-qa-hcrn-recvis")

[34m[1mwandb[0m: Currently logged in as: [33mnicolas-dufour[0m (use `wandb login --relogin` to force relogin)


## Load Data

In [208]:
msvd_glove_data_module = VideoQADataModule('data','msvd-qa',batch_size=32,text_embedding_method='glove',num_workers=8)
msrvtt_glove_data_module = VideoQADataModule('data','msrvtt-qa',batch_size=32,text_embedding_method='glove',num_workers=8)
tgif_glove_data_module = VideoQADataModule('data','tgif-qa_frameqa',batch_size=32,text_embedding_method='glove',num_workers=8)

In [1]:
%%time
msvd_glove_data_module.setup()
loader = msvd_glove_data_module.train_dataloader()
next(iter(loader))[-2].size()

NameError: name 'msvd_glove_data_module' is not defined

In [210]:
%time next(iter(loader))[-2].size()

CPU times: user 33.7 ms, sys: 282 ms, total: 315 ms
Wall time: 671 ms


torch.Size([32, 21])

## Defining the base model

In [211]:
import model.HCRN as HCRN

In [215]:
class HCRN_glove(pl.LightningModule):
    def __init__(self,glove_matrix,lr,model_kwargs,optimizer = 'AdamW'):
        super().__init__()
        
        self.lr = lr
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        glove_matrix = torch.FloatTensor(glove_matrix)
        self.model = HCRN.HCRNNetworkGlove(**model_kwargs)
        with torch.no_grad():
            self.model.linguistic_input_unit.encoder_embed.weight.set_(glove_matrix)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len)
    
    def configure_optimizers(self):
        if(self.optimizer == 'Adam'):
            optimizer = optim.Adam(self.parameters(), lr=self.lr)
        elif(self.optimizer == 'AdamW'):
            optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        else:
            raise "Optimizer not supported"
        scheduler = optim.lr_scheduler.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input).detach()
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)
        
        
    

## Recreating papers results on MSVD-QA, MSRVTT-QA, TGIF-QA_FrameQA

### MSVD-QA

In [216]:
max_epochs = 1

model_kwargs = {
        'question_type': msvd_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msvd_glove_data_module.glove_matrix,
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [217]:
trainer.fit(model,msvd_glove_data_module)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetworkGlove | 43.7 M
-----------------------------------------------
43.7 M    Trainable params
0         Non-trainable params
43.7 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [None]:
test_res = trainer.test(verbose=False)

In [None]:
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### MSRVTT-QA

In [None]:
max_epochs = 25

model_kwargs = {
        'question_type': msrvtt_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msrvtt_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/baseline',
    filename='msrvtt-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msrvtt_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### TGIF-QA FrameQA

In [19]:
max_epochs = 25

model_kwargs = {
        'question_type': tgif_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgif_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=tgif_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='Adam'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model,tgif_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Improving using AdamW

### MSVD-QA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msvd_glove_data_module.glove_matrix,
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-adamw-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msvd_glove_data_module)

In [None]:
test_res = trainer.test(verbose=False)
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### MSRVTT-QA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': msrvtt_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msrvtt_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=msrvtt_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/baseline',
    filename='msrvtt-adamw-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,msrvtt_glove_data_module)

In [None]:
test_res = trainer.test(model,msrvtt_glove_data_module.test_dataloader())
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save('models_checkpoints/msrvtt/baseline/msrvtt-adamw-epoch=09-val_acc=0.35.ckpt')

### TGIF-QA FrameQA

In [None]:
max_epochs = 15

model_kwargs = {
        'question_type': tgif_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgif_glove_data_module.vocab
    }
model = HCRN_glove.load_from_checkpoint(
    'models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-adamW-epoch=07-val_acc=0.56.ckpt',
    glove_matrix=tgif_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

In [None]:
trainer.fit(model,tgif_glove_data_module)

In [None]:
test_res = trainer.test(model,test_dataloaders = tgif_glove_data_module.test_dataloader())
wandb.log({'test_acc':test_res[0]['test_acc']})

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## 16 bit precision

### TGIF-QA FrameQA

In [7]:
max_epochs = 15

model_kwargs = {
        'question_type': tgif_glove_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': tgif_glove_data_module.vocab
    }
model = HCRN_glove(
    glove_matrix=tgif_glove_data_module.glove_matrix,
    lr=0.0001, 
    model_kwargs=model_kwargs, 
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    resume_from_checkpoint = 'models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-base-epoch=08-val_acc=0.56.ckpt',
    precision = 16,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [None]:
trainer.fit(model,tgif_glove_data_module)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetworkGlove | 44.0 M
-----------------------------------------------
44.0 M    Trainable params
0         Non-trainable params
44.0 M    Total params


loading questions from data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_val_questions.pt
loading appearance feature from data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5
loading motion feature from data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

loading questions from data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_train_questions.pt
loading appearance feature from data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5
loading motion feature from data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [15]:
test_res = trainer.test()
wandb.log({'test_acc':test_res[0]['test_acc']})

FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/video-qa-recvis/models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-base-epoch=06-val_acc=0.56.ckpt'

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

## Create Bert Questions datasets

### MSVD-QA

In [12]:
import pandas as pd
pd.read_json('data/msvd-qa/raw_questions/train_qa.json').to_csv('data/msvd-qa/raw_questions/train_qa.csv',sep= '\t')
pd.read_json('data/msvd-qa/raw_questions/val_qa.json').to_csv('data/msvd-qa/raw_questions/val_qa.csv',sep= '\t')
pd.read_json('data/msvd-qa/raw_questions/test_qa.json').to_csv('data/msvd-qa/raw_questions/test_qa.csv',sep= '\t')

In [None]:
create_vocab('data/msvd-qa/raw_questions/train_qa.json',vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json')

In [17]:
process_questions(
    train_csv ='data/msvd-qa/raw_questions/train_qa.csv', 
    val_csv = 'data/msvd-qa/raw_questions/val_qa.csv',
    test_csv = 'data/msvd-qa/raw_questions/test_qa.csv',
    fine_tune_out_path ='data/msvd-qa/bert_question_embedding/question_finetuned_model',
    train_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_train_questions.pt',
    val_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_val_questions.pt',
    test_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_test_questions.pt',
    vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json'
)

Loading tokenizer
Load data


Using custom data configuration default
Reusing dataset csv (/home/jupyter/.cache/huggingface/datasets/csv/default-3addb679453b5dc1/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2)


Tokenizing questions


Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/csv/default-3addb679453b5dc1/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2/cache-5aa4907cf407eca9.arrow
Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/csv/default-3addb679453b5dc1/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2/cache-c5d2ea6d0b52d4fa.arrow
Loading cached processed dataset at /home/jupyter/.cache/huggingface/datasets/csv/default-3addb679453b5dc1/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2/cache-61cbdec8e8d6fd3d.arrow


Load Vocab
Tokenizing answers


HBox(children=(FloatProgress(value=0.0, max=30933.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=6415.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13157.0), HTML(value='')))


Renaming fields


HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))


Saving datasets
Finetuning Masked LM Bert model with train questions


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).




Epoch,Training Loss,Validation Loss
1,No log,1.084632
2,1.094516,1.05775
3,0.640716,1.046123


Model finetuned with validation perpexity of 1.034027338027954




Saving Model


### MSRVTT-QA

In [3]:
create_vocab('data/msvd-qa/raw_questions/train_qa.json',vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json')

NameError: name 'create_vocab' is not defined

In [None]:
process_questions(
    train_csv ='data/msvd-qa/raw_questions/train_qa.csv', 
    val_csv = 'data/msvd-qa/raw_questions/val_qa.csv',
    test_csv = 'data/msvd-qa/raw_questions/test_qa.csv',
    fine_tune_out_path ='data/msvd-qa/bert_question_embedding/bert_model',
    train_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_train_questions.pt',
    val_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_val_questions.pt',
    test_output = 'data/msvd-qa/bert_question_embedding/msvd-qa_test_questions.pt',
    vocab_path='data/msvd-qa/bert_question_embedding/msvd-qa_vocab_bert.json'
)

In [6]:
with open('data/msvd-qa/bert_question_embedding/msvd-qa_train_questions.pt', 'rb') as f:
    obj = pickle.load(f)

In [7]:
obj

Dataset({
    features: ['Unnamed: 0', 'answer_token', 'question_attention_mask', 'question_id', 'question_token_type_ids', 'question_tokens', 'video_id'],
    num_rows: 30933
})

In [5]:
msvd_bert_data_module = VideoQADataModule('data','msvd-qa',batch_size=32,text_embedding_method='bert')

In [6]:
import model.HCRN as HCRN

In [7]:
class HCRNBert(pl.LightningModule):
    def __init__(self,lr,model_kwargs,optimizer = 'AdamW'):
        super().__init__()
        
        self.lr = lr
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        self.model = HCRN.HCRNNetworkBert(**model_kwargs)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question_tokens,question_attention_masks,question_token_type_ids)
    
    def configure_optimizers(self):
        if(self.optimizer == 'Adam'):
            optimizer = optim.Adam(self.parameters(), lr=self.lr)
        elif(self.optimizer == 'AdamW'):
            optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        else:
            raise "Optimizer not supported"
        scheduler = optim.lr_scheduler.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('step_loss',loss,prog_bar = True,logger=False)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_acc',val_acc,prog_bar = True,logger=True)
        self.log('val_loss',loss,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        print(f"The test accuracy is {test_acc}")
        self.log('test_acc',test_acc,logger=True)

In [8]:
max_epochs = 15

model_kwargs = {
        'question_type': msvd_bert_data_module.question_type,
        'vision_dim': 2048,
        'module_dim': 512,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': msvd_bert_data_module.vocab,
        'bert_path': 'bert-base-uncased',
        'mult_embedding': True
    }
model = HCRNBert(
    lr=0.0001,
    model_kwargs=model_kwargs,
    optimizer='AdamW'
)

checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/bert',
    filename='msvd-adamw-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(name='HCRN-MSVD-bert',project='video-qa-hcrn-recvis')

trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [9]:
trainer.fit(model,msvd_bert_data_module)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetworkBert  | 151 M 
-----------------------------------------------
42.7 M    Trainable params
108 M     Non-trainable params
151 M     Total params


loading questions from data/msvd-qa/bert_question_embedding/msvd-qa_val_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

loading questions from data/msvd-qa/bert_question_embedding/msvd-qa_train_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1