In [2]:
import random 
import pytorch_lightning as pl
import albumentations as alb

%run VideoEncoder.ipynb
%run TextDecoder.ipynb

class VideoCaptioner(pl.LightningModule):
    def __init__(self, textTokenizer, embedding_size: int, state_size: int, encoding_size: int, val_data = None):
        super(VideoCaptioner, self).__init__()
        self.vocabulary_size = len(textTokenizer.vocab)
        self.padding_token_id = textTokenizer.vocab.stoi["<pad>"]
        
        self.val_data = val_data
        
        self.video_encoder = VideoEncoder(encoding_size = encoding_size)
        self.text_decoder = TextDecoder(embedding_size, 
                                        state_size, 
                                        self.vocabulary_size)
        
        self.criterion = nn.CrossEntropyLoss(
            ignore_index = self.padding_token_id)
        
        self.text_tokenizer = textTokenizer

        self.video_encoder_learning_rate = 1e-3
        self.text_decoder_learning_rate = 1e-3

    def forward(self, video, text, lengths, teacher_forcing = 1.0):
        predicted_scores = list()
        encoded_video = self.video_encoder(video)

        start_token = text[:, 0]  # This should be the <start> symbol.
        
        token_scores, state = self.text_decoder((encoded_video,encoded_video), start_token)
        predicted_scores.append(token_scores)

        for i in range(0, max(lengths) - 2):
            if random.random() < teacher_forcing:
                current_token = text[:, i + 1]
            else:
                _, max_token = token_scores.max(dim = 1)
                current_token = max_token.detach() # No backprop.
            token_scores, state = self.text_decoder(state, current_token)
            predicted_scores.append(token_scores)
            
        # torch.stack(,1) forces batch_first = True on this output.
        return torch.stack(predicted_scores, 1), lengths

    def training_step(self, batch, batch_idx, optimizer_idx):
        videos, texts, lengths = batch

        # Compute the predicted texts.
        predicted_texts, _ = self(videos, texts, lengths, 
                                  teacher_forcing = 1.0)
        
        # Define the target texts. 
        # We have to predict everything except the <start> token.
        target_texts =  texts[:, 1:].contiguous()

        # Use cross entropy loss.
        loss = self.criterion(predicted_texts.view(-1, self.vocabulary_size),
                              target_texts.view(-1))
        self.log('train_loss', loss, on_epoch = True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        videos, texts, lengths = batch

        predicted_texts, _ = self(videos, texts, lengths,
                                  teacher_forcing = 0.0)
        
        target_texts = texts[:, 1:].contiguous()

        loss = self.criterion(predicted_texts.view(-1, self.vocabulary_size),
                              target_texts.view(-1))
        self.log('val_loss', loss, on_epoch = True)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
        print('Validation loss %.2f' %  loss_mean)
        
        return {'val_loss': loss_mean}
    
    def training_epoch_end(self, outputs):
        
        loss_mean = torch.stack([x['loss'] for x in outputs[0]]).mean()
        print('Training loss %.2f' %  loss_mean)

#     def configure_optimizers(self):
#         return [torch.optim.SGD(list(self.video_encoder.fc1.parameters())+\
#                                 list(self.video_encoder.bn1.parameters())+\
#                                 list(self.video_encoder.fc2.parameters())+\
#                                 list(self.video_encoder.bn2.parameters())+\
#                                 list(self.video_encoder.fc3.parameters()),
#                                 lr = self.video_encoder_learning_rate), \
#                 torch.optim.Adam(self.text_decoder.parameters(), 
#                                  lr = self.text_decoder_learning_rate)], []

    def configure_optimizers(self):
        return [torch.optim.SGD(list(self.video_encoder.base_network.fc.parameters())+\
                                list(self.video_encoder.bn.parameters()),
                                lr = self.video_encoder_learning_rate), \
                torch.optim.Adam(self.text_decoder.parameters(),
                                 lr = self.text_decoder_learning_rate)], []