TODO:
- build out the validation and logging with Tensorboard support

In [18]:
%load_ext autoreload
%autoreload 2

In [42]:
import torch 
import torch.nn as nn
import pytorch_lightning as pl
import constants
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
from sklearn.model_selection import train_test_split

In [43]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.x = np.random.randn(100, constants.SEQ_LEN, constants.EMBED_LEN) # 100 datapoints of 32 consecutive 512-d embeddings
        self.y = np.random.randint(2, size=(100, constants.SEQ_LEN)) # 100 datapoints of 32 consecutive labels
        self.batch_size = batch_size
        self.pos_weight = self.get_pos_weight()

    def get_pos_weight(self):
        # To balance our model, since our dataset has a very low prevalence
        labels = self.y.ravel()
        num_positive = labels.sum()
        num_negative = len(labels) - num_positive
        pos_weight = num_positive / num_negative
        return torch.tensor([pos_weight])

    def setup(self, stage):
        X_train, X_test, y_train, y_test = train_test_split(self.x, self.y, test_size=0.30, random_state=42)
        X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.50, random_state=42)
        self.train = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
        self.val = TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val).float())
        self.test = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float())
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [58]:
class LipReader(pl.LightningModule):
    def __init__(self, hidden_size, num_layers, dropout_rate, bidirectional, pos_weight):
        super().__init__()
        num_directions = 2 if bidirectional else 1
        self.gru = nn.GRU(input_size=512, hidden_size=hidden_size, num_layers=num_layers, 
            batch_first=True, bidirectional=bidirectional)
        self.linear = nn.Linear(in_features=num_directions*hidden_size, out_features=1)
        self.pos_weight = pos_weight

        # For debugging
        self.num_directions = num_directions
        self.hidden_size = hidden_size
        self.num_layers = num_layers

    def forward(self, x):
        """
        args:
            - x: shape (batch_size, seq_length, 512)
        
        returns:
            - probabilities: shape (batch_size, seq_length, 1)
        """
        batch_size = x.shape[0]
        assert x.shape == (batch_size, constants.SEQ_LEN, constants.EMBED_LEN), x.shape

        output, _ = self.gru(x)
        assert output.shape == (batch_size, constants.SEQ_LEN, self.num_directions * self.hidden_size), output.shape
    
        score = self.linear(output)
        assert score.shape == (batch_size, constants.SEQ_LEN, 1)

        score = torch.squeeze(score)
        assert score.shape == (batch_size, constants.SEQ_LEN)

        return score

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
        loss = loss_fn(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [61]:
dm = DataModule(batch_size=BATCH_SIZE)
dm.setup('fit') #try without fit?

model = LipReader(hidden_size=32, num_layers=1, dropout_rate=0, bidirectional=True, pos_weight=dm.pos_weight)
trainer = pl.Trainer()
trainer.fit(model, dm)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name   | Type   | Params
----------------------------------
0 | gru    | GRU    | 104 K 
1 | linear | Linear | 65    
----------------------------------
104 K     Trainable params
0         Non-trainable params
104 K     Total params
0.420     Total estimated model params size (MB)
Epoch 13:   0%|          | 0/2 [03:35<?, ?it/s, loss=0.543, v_num=6]
