In [1]:
import torch
import torch.nn as nn

In [2]:
from isaac.dataset import read_dataset, prepare_dataset
from isaac.constants import POSITION_COLS, MASS_CLASS_COLS, BASIC_TRAINING_COLS, FORCE_CLASS_COLS
from isaac.sanity import class_proportions
from isaac.models import initialise_model
from isaac.training import training_loop
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
from isaac.utils import get_cuda_device_if_available, create_directory
device = get_cuda_device_if_available()
print(device)

directory = "cell_type_choice_plots/"
create_directory(directory)

cuda:0


In [5]:
TR_COLS = BASIC_TRAINING_COLS
NORMALISE_DATA = True
BATCH_SIZE = 128
EPOCHS = 100
STEP_SIZE = 1
SEQ_END = 1800

In [6]:
train_trials = read_dataset("data/train_passive_trials.h5", n_trials=3500, cols=TR_COLS)
val_trials = read_dataset("data/val_passive_trials.h5", n_trials=1000, cols=TR_COLS)

100%|██████████| 10/10 [00:00<00:00, 110.57it/s]
100%|██████████| 10/10 [00:00<00:00, 114.52it/s]


In [7]:
INPUT_DIM = len(TR_COLS)    # input dimension
HIDDEN_DIM = 25  # hidden layer dimension
N_LAYERS = 4     # number of hidden layers
OUTPUT_DIM = 3   # output dimension
DROPOUT = 0.5

network_params = (INPUT_DIM, HIDDEN_DIM, N_LAYERS, OUTPUT_DIM, DROPOUT)

# MASS TRAINING

In [8]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=MASS_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA, device=device)

100%|██████████| 10/10 [00:00<00:00, 1016.04it/s]
100%|██████████| 10/10 [00:00<00:00, 1011.55it/s]


In [9]:
train_loader, val_loader = loaders

In [10]:
class_proportions(train_loader)
class_proportions(val_loader)

[1 5 4]
Majority class:  0.5
[3 3 4]
Majority class:  0.4


In [11]:
labels = ["GRU", "RNN", "LSTM"]

stats_dfs = []

for cell_type, cell_label in zip([nn.GRU, nn.RNN, nn.LSTM], labels):
    print(cell_type)
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])

        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed, device=device)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)
                
        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)
        
stats = pd.concat(stats_dfs)
# stats.to_hdf(directory+"mass_stats.h5", key="stats")

  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.GRU'>


Train_loss (1.11)	 Train_acc (70.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.60s/it] 
Train_loss (1.16)	 Train_acc (60.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.66s/it] 
Train_loss (1.15)	 Train_acc (50.00)	 Val_acc (40.00): 100%|██████████| 2/2 [00:05<00:00,  2.68s/it] 
  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.RNN'>


Train_loss (0.97)	 Train_acc (50.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:02<00:00,  1.02s/it] 
Train_loss (0.98)	 Train_acc (40.00)	 Val_acc (40.00): 100%|██████████| 2/2 [00:02<00:00,  1.00s/it] 
Train_loss (1.01)	 Train_acc (50.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:02<00:00,  1.01s/it] 
  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.LSTM'>


Train_loss (1.16)	 Train_acc (50.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.82s/it] 
Train_loss (1.13)	 Train_acc (40.00)	 Val_acc (40.00): 100%|██████████| 2/2 [00:05<00:00,  2.84s/it] 
Train_loss (1.09)	 Train_acc (50.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.87s/it] 


# FORCE TRAINING

In [12]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=FORCE_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA, device=device)

100%|██████████| 10/10 [00:00<00:00, 466.68it/s]
100%|██████████| 10/10 [00:00<00:00, 685.83it/s]


In [13]:
train_loader, val_loader = loaders

In [14]:
class_proportions(train_loader)
class_proportions(val_loader)

[3 3 4]
Majority class:  0.4
[3 4 3]
Majority class:  0.4


In [15]:
stats_dfs = []

for cell_type, cell_label in zip([nn.RNN, nn.LSTM, nn.GRU], labels):
    print(cell_type)
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])
        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed, device=device)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)

        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)

stats = pd.concat(stats_dfs)
# stats.to_hdf(directory+"force_stats.h5", key="stats")

  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.RNN'>


Train_loss (1.18)	 Train_acc (50.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:02<00:00,  1.08s/it] 
Train_loss (1.05)	 Train_acc (40.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:02<00:00,  1.02s/it] 
Train_loss (1.14)	 Train_acc (60.00)	 Val_acc (40.00): 100%|██████████| 2/2 [00:02<00:00,  1.00it/s] 
  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.LSTM'>


Train_loss (1.12)	 Train_acc (30.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.88s/it] 
Train_loss (1.10)	 Train_acc (40.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.93s/it] 
Train_loss (1.10)	 Train_acc (40.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.93s/it] 
  0%|          | 0/2 [00:00<?, ?it/s]

<class 'torch.nn.modules.rnn.GRU'>


Train_loss (1.09)	 Train_acc (60.00)	 Val_acc (30.00): 100%|██████████| 2/2 [00:05<00:00,  2.77s/it] 
Train_loss (1.14)	 Train_acc (60.00)	 Val_acc (20.00): 100%|██████████| 2/2 [00:05<00:00,  2.72s/it] 
Train_loss (1.12)	 Train_acc (70.00)	 Val_acc (20.00): 100%|██████████| 2/2 [00:06<00:00,  3.02s/it] 
