In [1]:
import torch
import torch.nn as nn

In [2]:
from isaac.dataset import read_dataset, prepare_dataset
from isaac.constants import MASS_CLASS_COLS, FORCE_CLASS_COLS, BASIC_TRAINING_COLS, XY_RTHETA_COLS, RTHETA_COLS, XY_VXVY_RTHETA_COLS
from isaac.sanity import class_proportions
from isaac.models import initialise_model
from isaac.training import training_loop
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [3]:
from isaac.utils import get_cuda_device_if_available, create_directory
device = get_cuda_device_if_available()
print(device)

cuda:0


In [4]:
NORMALISE_DATA = True
BATCH_SIZE = 128
EPOCHS = 60
STEP_SIZE = 3
SEQ_END = 2700

In [5]:
HIDDEN_DIM = 25  # hidden layer dimension
N_LAYERS = 4     # number of hidden layers
OUTPUT_DIM = 3   # output dimension
DROPOUT = 0.5

# TRAINING

In [6]:
train_trials = read_dataset("data/r_train_trials.h5")
val_trials = read_dataset("data/r_val_trials.h5")

train_classes = [trial.combined_solution.iloc[0] for trial in train_trials]
val_classes = [trial.combined_solution.iloc[0] for trial in val_trials]

np.random.seed(37)
train_trials, _ = train_test_split(train_trials, train_size=4000, stratify=train_classes)
val_trials, _ = train_test_split(val_trials, train_size=1000, stratify=val_classes)

100%|██████████| 100/100 [00:00<00:00, 122.79it/s]
100%|██████████| 100/100 [00:00<00:00, 120.93it/s]


In [7]:
N_MODELS = 25

model_base_directory = "models/GRU_singlebranch/"
data_base_directory = "GRU_singlebranch/"


feature_sets = [BASIC_TRAINING_COLS, XY_RTHETA_COLS, RTHETA_COLS, XY_VXVY_RTHETA_COLS]
feature_sets_names = ["xy_vxvy", "xy_rtheta", "rtheta", "xy_vxvy_rtheta"]
class_columns = [MASS_CLASS_COLS, FORCE_CLASS_COLS]
class_names = ["mass", "force"]

for cl_cols, cl_name in zip(class_columns, class_names):
    for features, name in zip(feature_sets, feature_sets_names):

        loaders, scaler = prepare_dataset([train_trials, val_trials], class_columns=cl_cols, 
                                      training_columns=features, batch_size=BATCH_SIZE, 
                                      normalise_data=NORMALISE_DATA, device=device)

        network_params = (len(features), HIDDEN_DIM, N_LAYERS, OUTPUT_DIM, DROPOUT)
        stats_dfs = []
        full_model_directory = model_base_directory + name + "/"
        full_data_directory = data_base_directory + name + "/"
        
        create_directory(full_model_directory)
        create_directory(full_data_directory)

        for seed in range(N_MODELS):
            df = pd.DataFrame(columns=["seed", "Epoch", "Loss"])

            model, error, optimizer = initialise_model(network_params, lr=0.01, seed=seed, device=device)
            epoch_losses, epoch_accuracies, best_model = training_loop(model, optimizer, error, loaders[0], 
                                                                       loaders[1], EPOCHS, seq_end=SEQ_END,
                                                                       step_size=STEP_SIZE)

            torch.save(best_model.state_dict(), full_model_directory + "best_" + cl_name + "_model_seed_%d.pt" % seed)

            df["Epoch"] = np.arange(EPOCHS)
            df["Loss"] = epoch_losses
            df["Train Accuracy"] = epoch_accuracies[0]
            df["Val Accuracy"] = epoch_accuracies[1]
            df["seed"] = seed
            stats_dfs.append(df)

        stats = pd.concat(stats_dfs)
        stats.to_hdf(full_data_directory + cl_name + "_stats.h5", key="stats")

100%|██████████| 80/80 [00:00<00:00, 1036.17it/s]
100%|██████████| 80/80 [00:00<00:00, 981.20it/s]
Train_loss: ([0.99237096]) Train_acc: (61.25) Val_acc: ([33.75]): 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]
Train_loss: ([1.00497615]) Train_acc: (55.0) Val_acc: ([42.5]): 100%|██████████| 3/3 [00:01<00:00,  2.99it/s]  
Train_loss: ([0.99458659]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:01<00:00,  2.95it/s] 
Train_loss: ([0.97988415]) Train_acc: (51.25) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:01<00:00,  2.98it/s]
Train_loss: ([0.98236549]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:01<00:00,  2.95it/s]
Train_loss: ([0.99404019]) Train_acc: (48.75) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:00<00:00,  3.04it/s]
Train_loss: ([1.00499225]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:00<00:00,  3.02it/s]
Train_loss: ([0.96785414]) Train_acc: (51.25) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:01<00:00,  2.86it/s]
Train_loss: (

Train_loss: ([1.00330997]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:00<00:00,  3.04it/s]
Train_loss: ([1.01604199]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:00<00:00,  3.11it/s]
Train_loss: ([0.99915016]) Train_acc: (53.75) Val_acc: ([46.25]): 100%|██████████| 3/3 [00:01<00:00,  2.99it/s]
Train_loss: ([0.98182374]) Train_acc: (50.0) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:01<00:00,  2.90it/s] 
100%|██████████| 80/80 [00:00<00:00, 757.48it/s]
100%|██████████| 80/80 [00:00<00:00, 727.99it/s]
Train_loss: ([1.0109539]) Train_acc: (53.75) Val_acc: ([45.0]): 100%|██████████| 3/3 [00:00<00:00,  3.02it/s] 
Train_loss: ([0.99969578]) Train_acc: (55.0) Val_acc: ([41.25]): 100%|██████████| 3/3 [00:01<00:00,  2.97it/s]
Train_loss: ([0.99079692]) Train_acc: (61.25) Val_acc: ([41.25]): 100%|██████████| 3/3 [00:01<00:00,  2.95it/s]
Train_loss: ([1.03209734]) Train_acc: (55.0) Val_acc: ([42.5]): 100%|██████████| 3/3 [00:01<00:00,  2.61it/s] 
Train_loss: ([

Train_loss: ([1.08227277]) Train_acc: (46.25) Val_acc: ([35.0]): 100%|██████████| 3/3 [00:01<00:00,  2.78it/s]
Train_loss: ([1.06024408]) Train_acc: (48.75) Val_acc: ([23.75]): 100%|██████████| 3/3 [00:01<00:00,  2.74it/s]
Train_loss: ([1.05133891]) Train_acc: (50.0) Val_acc: ([28.75]): 100%|██████████| 3/3 [00:01<00:00,  2.74it/s] 
Train_loss: ([1.0662744]) Train_acc: (46.25) Val_acc: ([28.75]): 100%|██████████| 3/3 [00:01<00:00,  2.93it/s]
Train_loss: ([1.08423996]) Train_acc: (50.0) Val_acc: ([30.0]): 100%|██████████| 3/3 [00:00<00:00,  3.03it/s] 
Train_loss: ([1.07708132]) Train_acc: (48.75) Val_acc: ([27.5]): 100%|██████████| 3/3 [00:01<00:00,  2.96it/s]
Train_loss: ([1.05660892]) Train_acc: (52.5) Val_acc: ([30.0]): 100%|██████████| 3/3 [00:01<00:00,  2.99it/s] 
Train_loss: ([1.06978226]) Train_acc: (43.75) Val_acc: ([28.75]): 100%|██████████| 3/3 [00:01<00:00,  2.97it/s]
Train_loss: ([1.05338597]) Train_acc: (46.25) Val_acc: ([27.5]): 100%|██████████| 3/3 [00:01<00:00,  2.97it/s