In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger


from DataLoader import VideoQADataLoader

ModuleNotFoundError: No module named 'torch'

In [6]:
!wandb login 3ed7a1bc59fad48beeadc999df34dbee428be831

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jupyter/.netrc


In [None]:
wandb.init(project="video-qa-hcrn-recvis")

## Recreating papers results on MSVD-QA, MSRVTT-QA, TGIF-QA_FrameQA

In [None]:
import model.HCRN as HCRN

### MSVD-QA

In [None]:
### Dataloader params
dataset = 'MSVD-QA'
train_loader_args = {
    'question_pt': '',
    'vocab_json': '',
    'appearance_feat': '',
    'motion_feat': '',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': True
}
val_loader_args = {
    'question_pt': '',
    'vocab_json': '',
    'appearance_feat': '',
    'motion_feat': '',
    'batch_size': 32,
    'num_workers': 4,
    'shuffle': False
}

In [None]:
train_loader = VideoQADataLoader(**train_loader_kwargs)
val_loader = VideoQADataLoader(**val_loader_kwargs)

In [None]:
model_kwargs = {
        'vision_dim': 2048,
        'module_dim': 512,
        'word_dim': 300,
        'k_max_frame_level': 16,
        'k_max_clip_level': 8,
        'spl_resolution': 1,
        'vocab': train_loader.vocab
    }
max_epochs =25

In [None]:
def HCRN_base(pl.LightningModule):
    def __init__(self,glove_matrix,lr):
        super().__init__()
        
        self.lr = lr
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = pl.metrics.Accuracy()
        
        glove_matrix = torch.FloatTensor(glove_matrix)
        self.model = HCRN.HCRNNetwork(**model_kwargs)
        with torch.no_grad():
            model.linguistic_input_unit.encoder_embed.weight.set_(glove_matrix)
    
    def forward(self,ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len):
        return self.model(ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question,
                question_len)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.StepLR(optimizer,10,gamma=0.5)
        return [optimizer],[scheduler]
   
    def training_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = criterion(logits, answers)
        acc = self.accuracy(logits,answers)
        self.log('train_step_loss',loss,prog_bar = True,logger=True)
        self.log('train_ste_acc',acc,prog_bar = True,logger=False)
        correct_preds = (logits == answers).sum()
        num_preds = len(answers)
        return {'loss': loss, 'correct_preds': correct_preds, 'num_preds': num_preds}
    
    def training_epoch_end(self, training_step_outputs):
        correct = 0
        num_preds =0
        loss=0
        for step_out in training_step_outputs:
            loss += step_out['loss']
            correct += step_out['correct_preds']
            num_preds += step_out['num_preds']
        train_acc = correct/num_preds
        self.log('train_loss',loss,logger=True)
        self.log('train_acc',train_acc,logger=True)
        
    def validation_step(self,batch,batch_idx):
        _, _, answers, *batch_input = batch
        logits = self(*batch_input)
        loss = criterion(logits, answers)
        acc = self.accuracy(logits,answers)
        correct_preds = (logits == answers).sum()
        num_preds = len(answers)
        return {'val_loss': loss, 'correct_preds': correct_preds, 'num_preds': num_preds}
    
    def validation_epoch_end(self, training_step_outputs):
        correct = 0
        num_preds = 0
        val_loss = 0
        for step_out in training_step_outputs:
            val_loss += step_out['val_loss']
            correct += step_out['correct_preds']
            num_preds += step_out['num_preds']
        train_acc = correct/num_preds
        self.log('val_loss',val_loss,logger=True)
        self.log('val_acc',train_acc,logger=True)
        
    

In [None]:
wandb_logger = WandbLogger(name='HCRN-MSVD-base',project='video-qa-hcrn-recvis')
model =HCRN_base()
trainer = pl.Trainer(gpus=1,max_epochs=max_epochs,logger = wandb_logger)
trainer.fit(model,train_loader,val_loader)