In [1]:
import torch
import torch.nn as nn
torch.cuda.is_available()

True

In [2]:
from isaac.dataset import read_dataset, prepare_dataset
from isaac.constants import POSITION_COLS, MASS_CLASS_COLS, BASIC_TRAINING_COLS, FORCE_CLASS_COLS
from isaac.sanity import class_proportions
from isaac.models import initialise_model
from isaac.training import training_loop
import matplotlib.pyplot as plt
import numpy as np
import matplotlib2tikz
import pandas as pd

In [3]:
TR_COLS = BASIC_TRAINING_COLS
NORMALISE_DATA = True
BATCH_SIZE = 128
EPOCHS = 25
STEP_SIZE = 1
SEQ_END = 1800

In [4]:
train_trials = read_dataset("data/train_passive_trials.h5", n_trials=3500, cols=TR_COLS)
val_trials = read_dataset("data/val_passive_trials.h5", n_trials=500, cols=TR_COLS)

100%|██████████| 3500/3500 [00:30<00:00, 114.55it/s]
100%|██████████| 500/500 [00:04<00:00, 101.07it/s]


In [5]:
INPUT_DIM = len(TR_COLS)    # input dimension
HIDDEN_DIM = 25  # hidden layer dimension
N_LAYERS = 4     # number of hidden layers
OUTPUT_DIM = 3   # output dimension
DROPOUT = 0.5

network_params = (INPUT_DIM, HIDDEN_DIM, N_LAYERS, OUTPUT_DIM, DROPOUT)

# MASS TRAINING

In [6]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=MASS_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA)

100%|██████████| 3500/3500 [00:03<00:00, 1115.94it/s]
100%|██████████| 500/500 [00:00<00:00, 1054.88it/s]


In [7]:
train_loader, val_loader = loaders

In [8]:
class_proportions(train_loader)
class_proportions(val_loader)

[1188 1174 1138]
Majority class:  0.3394285714285714
[185 157 158]
Majority class:  0.37


In [9]:
labels = ["GRU", "RNN", "LSTM"]

stats_dfs = []

for cell_type, cell_label in zip([nn.GRU, nn.RNN, nn.LSTM], labels):
    print(cell_type)
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])

        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)
                
        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)
        
stats = pd.concat(stats_dfs)
# stats.to_hdf("cell_type_choice_plots/mass_stats.h5", key="stats")

<class 'torch.nn.modules.rnn.GRU'>


Train_loss (0.90)	 Train_acc (58.17)	 Val_acc (51.40): 100%|██████████| 25/25 [08:45<00:00, 21.07s/it]
Train_loss (0.67)	 Train_acc (71.77)	 Val_acc (54.60): 100%|██████████| 25/25 [08:56<00:00, 21.51s/it]
Train_loss (0.82)	 Train_acc (62.34)	 Val_acc (56.00): 100%|██████████| 25/25 [08:59<00:00, 21.43s/it]
  0%|          | 0/25 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.RNN'>


Train_loss (1.10)	 Train_acc (33.43)	 Val_acc (32.20): 100%|██████████| 25/25 [07:40<00:00, 18.42s/it]
Train_loss (1.11)	 Train_acc (31.17)	 Val_acc (31.00): 100%|██████████| 25/25 [07:42<00:00, 18.57s/it]
Train_loss (1.10)	 Train_acc (33.51)	 Val_acc (30.40): 100%|██████████| 25/25 [07:41<00:00, 18.53s/it]
  0%|          | 0/25 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.LSTM'>


Train_loss (1.09)	 Train_acc (38.09)	 Val_acc (35.60): 100%|██████████| 25/25 [08:26<00:00, 20.17s/it]
Train_loss (1.08)	 Train_acc (37.66)	 Val_acc (38.00): 100%|██████████| 25/25 [08:25<00:00, 20.27s/it]
Train_loss (0.93)	 Train_acc (57.49)	 Val_acc (35.20): 100%|██████████| 25/25 [08:25<00:00, 20.25s/it]


# FORCE TRAINING

In [10]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=FORCE_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA)

100%|██████████| 3500/3500 [00:03<00:00, 995.21it/s] 
100%|██████████| 500/500 [00:00<00:00, 1027.18it/s]


In [11]:
train_loader, val_loader = loaders

In [12]:
class_proportions(train_loader)
class_proportions(val_loader)

[1183 1169 1148]
Majority class:  0.338
[164 174 162]
Majority class:  0.348


In [13]:
stats_dfs = []

for cell_type, cell_label in zip([nn.RNN, nn.LSTM, nn.GRU], labels):
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])
        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)

        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)

stats = pd.concat(stats_dfs)
# stats.to_hdf("cell_type_choice_plots/force_stats.h5", key="stats")

Train_loss (1.10)	 Train_acc (34.83)	 Val_acc (32.20): 100%|██████████| 25/25 [07:40<00:00, 18.56s/it]
Train_loss (1.10)	 Train_acc (33.40)	 Val_acc (30.80): 100%|██████████| 25/25 [07:43<00:00, 18.49s/it]
Train_loss (1.11)	 Train_acc (34.06)	 Val_acc (30.60): 100%|██████████| 25/25 [07:40<00:00, 18.31s/it]
Train_loss (0.91)	 Train_acc (55.71)	 Val_acc (51.00): 100%|██████████| 25/25 [08:27<00:00, 20.42s/it]
Train_loss (0.97)	 Train_acc (49.60)	 Val_acc (44.80): 100%|██████████| 25/25 [08:32<00:00, 20.31s/it]
Train_loss (0.92)	 Train_acc (52.31)	 Val_acc (46.00): 100%|██████████| 25/25 [08:29<00:00, 20.44s/it]
Train_loss (0.65)	 Train_acc (57.26)	 Val_acc (49.40): 100%|██████████| 25/25 [08:47<00:00, 21.00s/it]
Train_loss (0.65)	 Train_acc (57.06)	 Val_acc (58.20): 100%|██████████| 25/25 [08:44<00:00, 21.10s/it]
Train_loss (0.67)	 Train_acc (60.09)	 Val_acc (48.60): 100%|██████████| 25/25 [08:46<00:00, 21.14s/it]
