In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger


from DataLoader import VideoQADataLoader

In [2]:
!wandb login 3ed7a1bc59fad48beeadc999df34dbee428be831

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jupyter/.netrc


In [None]:
wandb.init(project="video-qa-hcrn-recvis")

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,0.0
val_acc,0.55178
train_acc,0.9866
_timestamp,1609377727.0
train_loss,0.00203
_step,27700.0
_runtime,48464.0
val_loss,0.0997
train_step_loss,0.07567
test_acc,0.5547


0,1
test_acc,▁
epoch,▁
_step,▁
_runtime,▁
_timestamp,▁


## Recreating papers results on MSVD-QA, MSRVTT-QA, TGIF-QA_FrameQA

In [4]:
import model.HCRN as HCRN

In [5]:
class HCRN_base(pl.LightningModule):
    def __init__(self,glove_matrix,lr,model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        glove_matrix = torch.FloatTensor(glove_matrix)
        self.model = HCRN.HCRNNetwork(**model_kwargs)
        with torch.no_grad():
            self.model.linguistic_input_unit.encoder_embed.weight.set_(glove_matrix)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('train_step_loss',loss,prog_bar = True,logger=True)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True,prog_bar = True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_loss',loss,logger=True)
        self.log('val_acc',val_acc,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        self.log('test_acc',test_acc,logger=True)
        
        
    

### MSVD-QA

In [5]:
### Dataloader params
dataset = 'MSVD-QA'
train_loader_kwargs = {
    'question_type': 'none',
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_train_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': True
}
val_loader_kwargs = {
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_val_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': False
}
test_loader_kwargs = {
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_test_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 1,
    'shuffle': False
}

In [6]:
train_loader = VideoQADataLoader(**train_loader_kwargs)
val_loader = VideoQADataLoader(**val_loader_kwargs)
test_loader = VideoQADataLoader(**test_loader_kwargs)

loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_train_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5
loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_val_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5
loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_test_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5


In [7]:
model_kwargs = {
        'question_type': 'none',
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': train_loader.vocab
    }
max_epochs =25

In [9]:
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')
model = HCRN_base(glove_matrix=train_loader.glove_matrix,lr=0.0001,model_kwargs=model_kwargs)

In [10]:
checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [11]:
trainer.fit(model,train_loader,val_loader)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetwork      | 43.7 M
-----------------------------------------------
43.7 M    Trainable params
0         Non-trainable params
43.7 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [17]:
trainer.test(model,test_dataloaders=test_loader)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…




--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.3596, device='cuda:0'),
 'val_acc': tensor(0.3359, device='cuda:0'),
 'val_loss': tensor(0.1351, device='cuda:0')}
--------------------------------------------------------------------------------


[{'val_loss': 0.1351281851530075,
  'val_acc': 0.3359314203262329,
  'test_acc': 0.35958045721054077}]

In [18]:
wandb.save('models_checkpoints/msvd/baseline/msvd-base-acc-epoch=12-val_acc=0.36.ckpt')

['/home/jupyter/video-qa-recvis/wandb/run-20201228_210833-dg5wtty9/files/models_checkpoints/msvd/baseline/msvd-base-acc-epoch=12-val_acc=0.36.ckpt']

### MSRVTT-QA

In [6]:
### Dataloader params
dataset = 'MSRVTT-QA'
train_loader_kwargs = {
    'question_type': 'none',
    'question_pt': 'data/msrvtt-qa/glove_question_embedding/msrvtt-qa_train_questions.pt',
    'vocab_json': 'data/msrvtt-qa/msrvtt-qa_vocab.json',
    'appearance_feat': 'data/msrvtt-qa/msrvtt-qa_appearance_feat.h5',
    'motion_feat': 'data/msrvtt-qa/msrvtt-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': True
}
val_loader_kwargs = {
    'question_pt': 'data/msrvtt-qa/glove_question_embedding/msrvtt-qa_val_questions.pt',
    'vocab_json': 'data/msrvtt-qa/msrvtt-qa_vocab.json',
    'appearance_feat': 'data/msrvtt-qa/msrvtt-qa_appearance_feat.h5',
    'motion_feat': 'data/msrvtt-qa/msrvtt-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': False
}
test_loader_kwargs = {
    'question_pt': 'data/msrvtt-qa/glove_question_embedding/msrvtt-qa_test_questions.pt',
    'vocab_json': 'data/msrvtt-qa/msrvtt-qa_vocab.json',
    'appearance_feat': 'data/msrvtt-qa/msrvtt-qa_appearance_feat.h5',
    'motion_feat': 'data/msrvtt-qa/msrvtt-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 1,
    'shuffle': False
}

In [7]:
train_loader = VideoQADataLoader(**train_loader_kwargs)
val_loader = VideoQADataLoader(**val_loader_kwargs)
test_loader = VideoQADataLoader(**test_loader_kwargs)

loading vocab from data/msrvtt-qa/msrvtt-qa_vocab.json
loading questions from data/msrvtt-qa/glove_question_embedding/msrvtt-qa_train_questions.pt
loading appearance feature from data/msrvtt-qa/msrvtt-qa_appearance_feat.h5
loading motion feature from data/msrvtt-qa/msrvtt-qa_motion_feat.h5
loading vocab from data/msrvtt-qa/msrvtt-qa_vocab.json
loading questions from data/msrvtt-qa/glove_question_embedding/msrvtt-qa_val_questions.pt
loading appearance feature from data/msrvtt-qa/msrvtt-qa_appearance_feat.h5
loading motion feature from data/msrvtt-qa/msrvtt-qa_motion_feat.h5
loading vocab from data/msrvtt-qa/msrvtt-qa_vocab.json
loading questions from data/msrvtt-qa/glove_question_embedding/msrvtt-qa_test_questions.pt
loading appearance feature from data/msrvtt-qa/msrvtt-qa_appearance_feat.h5
loading motion feature from data/msrvtt-qa/msrvtt-qa_motion_feat.h5


In [8]:
model_kwargs = {
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': train_loader.vocab
    }
max_epochs =25

In [12]:
wandb_logger = WandbLogger(name='HCRN-MSRVTT-base',project='video-qa-hcrn-recvis')
model = HCRN_base.load_from_checkpoint('models_checkpoints/msrvtt/baseline/msrvtt-base-epoch=09-val_acc=0.35.ckpt',glove_matrix=train_loader.glove_matrix,lr=0.0001,model_kwargs=model_kwargs)

In [13]:
checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msrvtt/baseline',
    filename='msrvtt-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model,train_loader,val_loader)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetwork      | 48.5 M
-----------------------------------------------
48.5 M    Trainable params
0         Non-trainable params
48.5 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

In [14]:
trainer.test(model,test_dataloaders=test_loader)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.3534, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_acc': 0.3534420132637024}]

In [15]:
wandb.log({'test_acc':0.3534})

In [17]:
wandb.save('models_checkpoints/msrvtt/baseline/msrvtt-base-epoch=09-val_acc=0.35.ckpt')

['/home/jupyter/video-qa-recvis/wandb/run-20201229_110247-37puyeg2/files/models_checkpoints/msrvtt/baseline/msrvtt-base-epoch=09-val_acc=0.35.ckpt']

### TGIF-QA FrameQA

In [6]:
### Dataloader params
dataset = 'TGIF-QA FrameQA'
train_loader_kwargs = {
    'question_type': 'frameqa',
    'question_pt': 'data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_train_questions.pt',
    'vocab_json': 'data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json',
    'appearance_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5',
    'motion_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': True
}
val_loader_kwargs = {
    'question_type': 'frameqa',
    'question_pt': 'data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_val_questions.pt',
    'vocab_json': 'data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json',
    'appearance_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5',
    'motion_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': False
}
test_loader_kwargs = {
    'question_type': 'frameqa',
    'question_pt': 'data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_test_questions.pt',
    'vocab_json': 'data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json',
    'appearance_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5',
    'motion_feat': 'data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 1,
    'shuffle': False
}

In [7]:
train_loader = VideoQADataLoader(**train_loader_kwargs)
val_loader = VideoQADataLoader(**val_loader_kwargs)
test_loader = VideoQADataLoader(**test_loader_kwargs)

loading vocab from data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json
loading questions from data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_train_questions.pt
loading appearance feature from data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5
loading motion feature from data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5
loading vocab from data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json
loading questions from data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_val_questions.pt
loading appearance feature from data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5
loading motion feature from data/tgif-qa_frameqa/tgif-qa_frameqa_motion_feat.h5
loading vocab from data/tgif-qa_frameqa/tgif-qa_frameqa_vocab.json
loading questions from data/tgif-qa_frameqa/glove_question_embedding/tgif-qa_frameqa_test_questions.pt
loading appearance feature from data/tgif-qa_frameqa/tgif-qa_frameqa_appearance_feat.h5
loading motion feature from data/tgif-qa_frameqa/tgif-qa_frameqa_m

In [8]:
model_kwargs = {
        'question_type': 'frameqa',
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': train_loader.vocab
    }
max_epochs =25

In [9]:
wandb_logger = WandbLogger(name='HCRN-TGIFQA-FrameQA-base',project='video-qa-hcrn-recvis')
model = HCRN_base.load_from_checkpoint("models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-base-epoch=06-val_acc=0.56.ckpt",glove_matrix=train_loader.glove_matrix,lr=0.0001,model_kwargs=model_kwargs)

In [10]:
checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/tgif-qa_frameqa/baseline',
    filename='tgif-qa_frameqa-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model,train_loader,val_loader)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetwork      | 44.0 M
-----------------------------------------------
44.0 M    Trainable params
0         Non-trainable params
44.0 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

In [11]:
trainer.test(model,test_dataloaders=test_loader)



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.5547, device='cuda:0')}
--------------------------------------------------------------------------------


[{'test_acc': 0.5546709299087524}]

In [12]:
wandb.log({'test_acc':0.5547})

In [17]:
wandb.save('models_checkpoints/tgif-qa_frameqa/baseline/tgif-qa_frameqa-base-epoch=06-val_acc=0.56.ckpt')

['/home/jupyter/video-qa-recvis/wandb/run-20201229_110247-37puyeg2/files/models_checkpoints/msrvtt/baseline/msrvtt-base-epoch=09-val_acc=0.35.ckpt']

## Tentative to reduce overfitting using dropout on CRN units

In [4]:
import model.HCRN as HCRN

In [5]:
class HCRN_dropout(pl.LightningModule):
    def __init__(self,glove_matrix,lr,model_kwargs):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        
        glove_matrix = torch.FloatTensor(glove_matrix)
        self.model = HCRN.HCRNNetworkDropout(**model_kwargs)
        with torch.no_grad():
            self.model.linguistic_input_unit.encoder_embed.weight.set_(glove_matrix)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.train_acc(logits,answers)
        self.log('train_step_loss',loss,prog_bar = True,logger=True)
        return {'loss': loss,'n_samples':len(answers)}
    
    def training_epoch_end(self, training_step_outputs):
        loss=0
        n_samples = 0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',self.train_acc.compute(),logger=True,prog_bar = True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = self.criterion(logits, answers)
        acc = self.valid_acc(logits,answers)
        return {'val_loss': loss,'n_samples':len(answers)}
    
    def validation_epoch_end(self, val_step_outputs):
        loss=0
        n_samples = 0
        for step_out in val_step_outputs:
            loss += step_out['val_loss']
            n_samples += step_out['n_samples']
        loss = loss/n_samples
        val_acc = self.valid_acc.compute()
        self.log('val_loss',loss,logger=True)
        self.log('val_acc',val_acc,logger=True)
        
    def test_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        acc = self.test_acc(logits,answers)
    
    def test_epoch_end(self,test_step_outputs):
        test_acc = self.test_acc.compute()
        self.log('test_acc',test_acc,logger=True)
        
        
    

### MSVD-QA

In [5]:
### Dataloader params
dataset = 'MSVD-QA'
train_loader_kwargs = {
    'question_type': 'none',
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_train_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': True
}
val_loader_kwargs = {
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_val_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': False
}
test_loader_kwargs = {
    'question_pt': 'data/msvd-qa/glove_question_embedding/msvd-qa_test_questions.pt',
    'vocab_json': 'data/msvd-qa/msvd-qa_vocab.json',
    'appearance_feat': 'data/msvd-qa/msvd-qa_appearance_feat.h5',
    'motion_feat': 'data/msvd-qa/msvd-qa_motion_feat.h5',
    'batch_size': 32,
    'num_workers': 1,
    'shuffle': False
}

In [6]:
train_loader = VideoQADataLoader(**train_loader_kwargs)
val_loader = VideoQADataLoader(**val_loader_kwargs)
test_loader = VideoQADataLoader(**test_loader_kwargs)

loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_train_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5
loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_val_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5
loading vocab from data/msvd-qa/msvd-qa_vocab.json
loading questions from data/msvd-qa/glove_question_embedding/msvd-qa_test_questions.pt
loading appearance feature from data/msvd-qa/msvd-qa_appearance_feat.h5
loading motion feature from data/msvd-qa/msvd-qa_motion_feat.h5


In [7]:
model_kwargs = {
        'question_type': 'none',
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': train_loader.vocab
    }
max_epochs =25

In [9]:
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')
model = HCRN_dropout(glove_matrix=train_loader.glove_matrix,lr=0.0001,model_kwargs=model_kwargs)

In [10]:
checkpoint_callback_val_acc = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='models_checkpoints/msvd/baseline',
    filename='msvd-base-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=max_epochs,
    logger = wandb_logger,
    callbacks =[checkpoint_callback_val_acc]
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [11]:
trainer.fit(model,train_loader,val_loader)


  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | train_acc | Accuracy         | 0     
2 | valid_acc | Accuracy         | 0     
3 | test_acc  | Accuracy         | 0     
4 | model     | HCRNNetwork      | 43.7 M
-----------------------------------------------
43.7 M    Trainable params
0         Non-trainable params
43.7 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [17]:
trainer.test(model,test_dataloaders=test_loader)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…




--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.3596, device='cuda:0'),
 'val_acc': tensor(0.3359, device='cuda:0'),
 'val_loss': tensor(0.1351, device='cuda:0')}
--------------------------------------------------------------------------------


[{'val_loss': 0.1351281851530075,
  'val_acc': 0.3359314203262329,
  'test_acc': 0.35958045721054077}]