In [1]:
import os
os.chdir('..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader

from sklearn.model_selection import train_test_split

import pytorch_lightning as pl

from nn_datasets.rnn import RnnDataset

In [3]:
dataset = torch.load('nn_datasets/precalculated_datasets/rnn_dataset.pt')

In [4]:
class RNNClassifier(pl.LightningModule):
    """
    Binary classifier with RNN module.
    """
    def __init__(self, seq_len: int, input_size: int, 
                 hidden_layers: int = 1, hidden_size: int = 10,
                 learning_rate: float = 1e-3) -> None:
        """
        Args:
            seq_len (int): sequence len i.e. max length of sentence
            input_size (int): size of input vector
            hidden_layers (int, optional): number of hidden layers.
                Defaults to 1.
            hidden_size (int, optional): hidden layer size. 
                Defaults to 10.
            learning_rate (float, optional): Defaults to 1e-3.
        """
        super(RNNClassifier, self).__init__()
        
        self.hidden_size = hidden_size
        self.hidden_layers = hidden_layers
        self.seq_len = seq_len
        self.input_size = input_size

        self.learning_rate = learning_rate

        self.rnn = nn.RNN(input_size=self.input_size,
                          hidden_size=self.hidden_size,
                          num_layers=self.hidden_layers,
                          batch_first=True)
        self.linear = nn.Linear(self.hidden_size, 2)
        self.norm = nn.BatchNorm1d(2)

    def forward(self, inputs):

        out, h_n = self.rnn(inputs)
        output = self.linear(out[:, -1, :])

        return self.norm(output)

    def training_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis=1)).float().mean()
        self.log('train accuracy', accu, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis=1)).float().mean()
        self.log('validation loss', loss, prog_bar=True)
        self.log('validation accuracy', accu, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis=1)).float().mean()
        self.log('test loss', loss, prog_bar=True)
        self.log('test accuracy', accu, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.learning_rate, weight_decay=1e-3)

In [5]:
train_indices, test_indices = train_test_split(list(range(len(dataset))), 
                                               test_size=.4)
test_indices, val_indices = train_test_split(test_indices, 
                                             test_size=.5)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

train_dataloader = DataLoader(train_dataset, batch_size=32)
val_dataloader = DataLoader(val_dataset, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [None]:
trainer = pl.Trainer(
    max_epochs=10
)

model = RNNClassifier(seq_len=dataset[0][0].shape[0],
                      input_size=dataset[0][0].shape[1],
                      hidden_layers=3,
                      hidden_size=10,
                      learning_rate=1e-3)

trainer.fit(model, train_dataloader, val_dataloader)

In [7]:
trainer.test(model, test_dataloader)

  rank_zero_warn(


Testing:  64%|██████▎   | 28/44 [00:00<00:00, 139.99it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test accuracy': 0.7875980138778687, 'test loss': 0.49262693524360657}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 44/44 [00:00<00:00, 150.28it/s]


[{'test loss': 0.49262693524360657, 'test accuracy': 0.7875980138778687}]