In [1]:
import os
os.chdir('..')

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset

from sklearn.model_selection import train_test_split

import pytorch_lightning as pl

from nn_datasets.rnn import RnnDataset

In [3]:
dataset = torch.load('nn_datasets/precalculated_datasets/rnn_dataset.pt')

In [4]:
class LSTMClassifier(pl.LightningModule):
    def __init__(self, seq_len, input_size, 
                 hidden_layers = 1, hidden_size = 10, learning_rate = 1e-2):

        super(LSTMClassifier, self).__init__()

        self.learning_rate = learning_rate
        self.hidden_size = hidden_size
        self.hidden_layers = hidden_layers
        self.seq_len = seq_len
        self.input_size = input_size

        self.lstm = nn.LSTM(input_size = self.input_size,
                            hidden_size  = self.hidden_size,
                            num_layers = self.hidden_layers,
                            batch_first = True, dropout = .8)

        self.classifier = nn.Sequential(
            nn.BatchNorm1d(self.hidden_size),
            nn.Linear(self.hidden_size, 2),
            nn.BatchNorm1d(2)
        )                  


    def forward(self, inputs):

        out, (h_t, c_t) = self.lstm(inputs)
        out = self.classifier(out[:, -1, :])
        return out

    def training_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis = 1)).float().mean()
        self.log('train accuracy', accu, prog_bar = True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis = 1)).float().mean()
        self.log('validation loss', loss, prog_bar = True)
        self.log('validation accuracy', accu, prog_bar = True)
        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        out = self.forward(X)
        loss = F.cross_entropy(out, y.long())
        accu = (y.long() == torch.argmax(out, axis = 1)).float().mean()
        self.log('test loss', loss, prog_bar = True)
        self.log('test accuracy', accu, prog_bar = True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), self.learning_rate, weight_decay = 1e-2)


In [5]:
train_indices, test_indices = train_test_split(list(range(len(dataset))), test_size = .4)
test_indices, val_indices = train_test_split(test_indices, test_size = .5)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

train_dataloader = DataLoader(train_dataset, batch_size = 32)
val_dataloader = DataLoader(val_dataset, batch_size = 32)
test_dataloader = DataLoader(test_dataset, batch_size = 32)

In [None]:
trainer = pl.Trainer(
    max_epochs = 5
)

model = LSTMClassifier(seq_len = dataset[0][0].shape[0],
                      input_size = dataset[0][0].shape[1],
                      hidden_layers = 2,
                      hidden_size = 10,
                      learning_rate = 1e-2)

trainer.fit(model, train_dataloader, val_dataloader)

In [10]:
trainer.test(model, test_dataloader)

  rank_zero_warn(


Testing:  84%|████████▍ | 37/44 [00:00<00:00, 92.58it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test accuracy': 0.7783321738243103, 'test loss': 0.503904402256012}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 44/44 [00:00<00:00, 88.28it/s]


[{'test loss': 0.503904402256012, 'test accuracy': 0.7783321738243103}]