In [1]:
import pickle
import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint


with open("data/processed_monkey_dataset.pkl", "rb") as f:
    target, count_spikes, x_vel, y_vel = pickle.load(f)

print(len(target))

800


In [2]:
index = np.random.permutation(len(target))
test_index = index[:100]
valid_index = index[100:200]
train_index = index[200:]


def get_sampled_data(index, target, count_spikes, x_vel, y_vel):
    sampled_target = [target[i] for i in index]
    sampled_count_spikes = [count_spikes[i] for i in index]
    sampled_x_vel = [x_vel[i] for i in index]
    sampled_y_vel = [y_vel[i] for i in index]
    
    return sampled_target, sampled_count_spikes, sampled_x_vel, sampled_y_vel


train_target, train_count_spikes, train_x_vel, train_y_vel = get_sampled_data(train_index, target, count_spikes, x_vel, y_vel)
valid_target, valid_count_spikes, valid_x_vel, valid_y_vel = get_sampled_data(valid_index, target, count_spikes, x_vel, y_vel)
test_target, test_count_spikes, test_x_vel, test_y_vel = get_sampled_data(test_index, target, count_spikes, x_vel, y_vel)

In [3]:
class NeuralDataset(Dataset):
    def __init__(self, spike_counts, vel_x, vel_y, max_seq_len=100):
        def pad_sequence(X, max_seq_len):
            X_ = []
            
            for x in X:
                try:
                    x_proced = np.pad(x, ((0, 0), (max_seq_len - x.shape[1], 0)), 'constant', constant_values=0)
                except:
                    x_proced = np.pad(x, ((max_seq_len - x.shape[0], 0)), 'constant', constant_values=0)
                    
                X_.append(x_proced)
                
            return torch.tensor(X_, dtype=torch.float)
        
        self.spike_counts = pad_sequence(spike_counts, max_seq_len)
        self.vel_x = pad_sequence(vel_x, max_seq_len)
        self.vel_y = pad_sequence(vel_y, max_seq_len)

        
    def __len__(self):
        return self.spike_counts.shape[0]
    
    
    def __getitem__(self, idx):
        x = torch.t(self.spike_counts[idx])
        vel_x = self.vel_x[idx]
        vel_y = self.vel_y[idx]
        
        y = torch.nan_to_num(torch.stack((vel_x, vel_y), 1))
        
        return x, y
    

train_dataset = NeuralDataset(train_count_spikes, train_x_vel, train_y_vel, max_seq_len=100)
train_dataloader = DataLoader(train_dataset, batch_size=32, num_workers=0)

valid_dataset = NeuralDataset(valid_count_spikes, valid_x_vel, valid_y_vel, max_seq_len=100)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, num_workers=0)

test_dataset = NeuralDataset(test_count_spikes, test_x_vel, test_y_vel, max_seq_len=100)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=0)

In [4]:
class SimpleRNN(nn.Module):
    def __init__(self, channels=98):
        super().__init__()
        self.rnn = nn.GRU(channels, 256, num_layers=2, batch_first=True)
        self.out = nn.Linear(256, 2)
        self.activation = nn.ReLU(inplace=True)
        
        
    def forward(self, x):
        x, _ = self.rnn(x)
        x = self.activation(x)
        x = self.out(x)
        
        return x


In [5]:
class Regressor(pl.LightningModule):
    def __init__(self, model, learning_rate):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        
        self.train_mae = torchmetrics.MeanAbsoluteError()
        self.valid_mae = torchmetrics.MeanAbsoluteError()
        self.test_mae = torchmetrics.MeanAbsoluteError()

    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log("train_loss", loss)
        self.log("train_mae", self.train_mae(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log("valid_loss", loss)
        self.log("valid_mae", self.valid_mae(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)

    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log("test_loss", loss)
        self.log("test_mae", self.test_mae(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
          
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "valid_loss"}

callbacks = [
    EarlyStopping(monitor="valid_loss", min_delta=0.00, patience=3),
    ModelCheckpoint(dirpath='weights/RNN', filename='{epoch}-{valid_loss:.8f}-{valid_mae:.8f}')
]

In [6]:
model = SimpleRNN()
regressor = Regressor(model, learning_rate=0.0001)
trainer = pl.Trainer(accelerator="cpu", devices=1, max_epochs=20, callbacks=callbacks, enable_progress_bar=True)
trainer.fit(regressor, train_dataloader, valid_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type              | Params
------------------------------------------------
0 | model     | SimpleRNN         | 668 K 
1 | train_mae | MeanAbsoluteError | 0     
2 | valid_mae | MeanAbsoluteError | 0     
3 | test_mae  | MeanAbsoluteError | 0     
------------------------------------------------
668 K     Trainable params
0         Non-trainable params
668 K     Total params
2.675     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [13]:
trainer.test(regressor, test_dataloader, ckpt_path="weights/RNN/epoch=19-valid_loss=0.00002286-valid_mae=0.00282729.ckpt")

Restoring states from the checkpoint path at weights/RNN/epoch=19-valid_loss=0.00002286-valid_mae=0.00282729.ckpt
Loaded model weights from checkpoint at weights/RNN/epoch=19-valid_loss=0.00002286-valid_mae=0.00282729.ckpt


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 2.197964749939274e-05, 'test_mae': 0.0028090246487408876}
--------------------------------------------------------------------------------


[{'test_loss': 2.197964749939274e-05, 'test_mae': 0.0028090246487408876}]