In [1]:
import torch
import torch.nn as nn

In [2]:
from isaac.dataset import read_dataset, prepare_dataset
from isaac.constants import POSITION_COLS, MASS_CLASS_COLS, BASIC_TRAINING_COLS, FORCE_CLASS_COLS
from isaac.sanity import class_proportions
from isaac.models import initialise_model
from isaac.training import training_loop
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
from isaac.utils import get_cuda_device_if_available, create_directory
device = get_cuda_device_if_available()
print(device)

directory = "cell_type_choice_plots/"
create_directory(directory)

cuda:0


In [4]:
TR_COLS = BASIC_TRAINING_COLS
NORMALISE_DATA = True
BATCH_SIZE = 128
EPOCHS = 25
STEP_SIZE = 1
SEQ_END = 2700

In [5]:
train_trials = read_dataset("data/train_passive_trials.h5", n_trials=3500, cols=TR_COLS)
val_trials = read_dataset("data/val_passive_trials.h5", n_trials=1000, cols=TR_COLS)

100%|██████████| 3500/3500 [00:33<00:00, 104.75it/s]
100%|██████████| 1000/1000 [00:09<00:00, 105.01it/s]


In [6]:
INPUT_DIM = len(TR_COLS)    # input dimension
HIDDEN_DIM = 25  # hidden layer dimension
N_LAYERS = 4     # number of hidden layers
OUTPUT_DIM = 3   # output dimension
DROPOUT = 0.5

network_params = (INPUT_DIM, HIDDEN_DIM, N_LAYERS, OUTPUT_DIM, DROPOUT)

# MASS TRAINING

In [7]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=MASS_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA, device=device)

100%|██████████| 3500/3500 [00:03<00:00, 996.91it/s] 
100%|██████████| 1000/1000 [00:00<00:00, 1092.02it/s]


In [8]:
train_loader, val_loader = loaders

In [9]:
class_proportions(train_loader)
class_proportions(val_loader)

[1188 1174 1138]
Majority class:  0.3394285714285714
[340 349 311]
Majority class:  0.349


(array([340, 349, 311]), 0.349)

In [10]:
labels = ["GRU", "RNN", "LSTM"]

stats_dfs = []

for cell_type, cell_label in zip([nn.GRU, nn.RNN, nn.LSTM], labels):
    
    cell_type_directory = "models/cell_type_%s/" % cell_label
    create_directory(cell_type_directory)
    
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])

        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed, device=device)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)
        
        torch.save(best_model.state_dict(), cell_type_directory + "best_mass_model_seed_%d.pt" % seed)

        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)
        
stats = pd.concat(stats_dfs)
stats.to_hdf(directory+"mass_stats.h5", key="stats")

Train_loss: ([0.65387614]) Train_acc: (72.77142857142857) Val_acc: ([55.9]): 100%|██████████| 25/25 [13:32<00:00, 32.76s/it] 
Train_loss: ([0.54862248]) Train_acc: (77.42857142857143) Val_acc: ([58.6]): 100%|██████████| 25/25 [13:32<00:00, 32.57s/it]
Train_loss: ([0.57437534]) Train_acc: (74.02857142857142) Val_acc: ([59.9]): 100%|██████████| 25/25 [13:13<00:00, 32.00s/it] 
Train_loss: ([1.10445257]) Train_acc: (32.51428571428571) Val_acc: ([31.1]): 100%|██████████| 25/25 [11:31<00:00, 27.39s/it] 
Train_loss: ([1.1071615]) Train_acc: (32.68571428571428) Val_acc: ([36.0]): 100%|██████████| 25/25 [11:13<00:00, 26.83s/it]  
Train_loss: ([1.10429655]) Train_acc: (34.371428571428574) Val_acc: ([35.4]): 100%|██████████| 25/25 [11:21<00:00, 27.25s/it]
Train_loss: ([1.07911892]) Train_acc: (40.285714285714285) Val_acc: ([36.6]): 100%|██████████| 25/25 [13:13<00:00, 31.82s/it]
Train_loss: ([1.09853391]) Train_acc: (33.94285714285714) Val_acc: ([34.0]): 100%|██████████| 25/25 [13:13<00:00, 31.78

# FORCE TRAINING

In [11]:
loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=FORCE_CLASS_COLS, 
                                  training_columns=TR_COLS, batch_size=BATCH_SIZE, 
                                  normalise_data=NORMALISE_DATA, device=device)

100%|██████████| 3500/3500 [00:03<00:00, 946.91it/s]
100%|██████████| 1000/1000 [00:01<00:00, 995.72it/s]


In [12]:
train_loader, val_loader = loaders

In [13]:
class_proportions(train_loader)
class_proportions(val_loader)

[1183 1169 1148]
Majority class:  0.338
[337 350 313]
Majority class:  0.35


(array([337, 350, 313]), 0.35)

In [14]:
stats_dfs = []

for cell_type, cell_label in zip([nn.GRU, nn.RNN, nn.LSTM], labels):

    cell_type_directory = "models/cell_type_%s/" % cell_label
    create_directory(cell_type_directory)
    
    for seed in [0, 42, 72]:
        df = pd.DataFrame(columns=["cell_type", "Epoch", "Loss"])
        model, error, optimizer = initialise_model(network_params, lr=0.01, cell_type=cell_type, seed=seed, device=device)
        epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, train_loader, 
                                                                   val_loader, EPOCHS, seq_end=SEQ_END,
                                                                   step_size=STEP_SIZE)
        
        torch.save(best_model.state_dict(), cell_type_directory + "best_force_model_seed_%d.pt" % seed)

        df["Epoch"] = np.arange(EPOCHS)
        df["Loss"] = epoch_losses
        df["Train Accuracy"] = epoch_accuracies[0]
        df["Val Accuracy"] = epoch_accuracies[1]
        df["cell_type"] = cell_label
        stats_dfs.append(df)

stats = pd.concat(stats_dfs)
stats.to_hdf(directory+"force_stats.h5", key="stats")

Train_loss: ([0.67409354]) Train_acc: (60.17142857142857) Val_acc: ([53.2]): 100%|██████████| 25/25 [13:25<00:00, 32.86s/it] 
Train_loss: ([0.76193574]) Train_acc: (49.371428571428574) Val_acc: ([50.1]): 100%|██████████| 25/25 [13:39<00:00, 32.37s/it]
Train_loss: ([0.98409577]) Train_acc: (53.97142857142857) Val_acc: ([44.4]): 100%|██████████| 25/25 [13:41<00:00, 32.98s/it]
Train_loss: ([1.103018]) Train_acc: (33.48571428571429) Val_acc: ([32.0]): 100%|██████████| 25/25 [11:41<00:00, 28.06s/it]   
Train_loss: ([1.10386604]) Train_acc: (33.22857142857143) Val_acc: ([33.3]): 100%|██████████| 25/25 [11:38<00:00, 27.90s/it] 
Train_loss: ([1.10598423]) Train_acc: (34.17142857142857) Val_acc: ([33.7]): 100%|██████████| 25/25 [11:29<00:00, 27.61s/it] 
Train_loss: ([0.96951999]) Train_acc: (53.68571428571428) Val_acc: ([42.7]): 100%|██████████| 25/25 [13:01<00:00, 31.39s/it] 
Train_loss: ([0.99109282]) Train_acc: (51.2) Val_acc: ([41.4]): 100%|██████████| 25/25 [12:59<00:00, 32.13s/it]        