# Neural Network Model

In [40]:
import numpy as np
import torch
import plotly.graph_objects as go

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel
from utils.get_data import get_data, get_test_train_data
from utils.evaluation import evaluate_recording
from utils.training import train_one_epoch

DATA_PATH = 'raw_data/full_data/'

TOTAL_RECORDINGS = 5
RECORDING_IDS = np.arange(TOTAL_RECORDINGS)
RECORDING_IDS = [3, 101, 72, 110, 59]
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 2001
RANK = 16
MEAN_SUBTRACTED = True
MODEL = "neural_network"
# MODEL = "reduced_rank"

In [41]:
X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)
X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT)

if MEAN_SUBTRACTED:
    Y_train = Y_train_mean_subtracted
    Y_test = Y_test_mean_subtracted
else:
    Y_train = Y_train_regular
    Y_test = Y_test_regular

In [42]:
print(X_train[0].shape)

(295, 211, 40)


In [43]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted[0][t], name='Mean Subtracted'))
    fig.update_layout(title=f'Original and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

In [45]:
if MODEL == "reduced_rank":
    model = ReducedRankModel(
        n_recordings=TOTAL_RECORDINGS,
        n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
        n_time_bins=40, 
        rank=RANK
    )
elif MODEL == "neural_network":
    model = NeuralNetworkModel(
        n_recordings=TOTAL_RECORDINGS,
        n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
        n_time_bins=40, 
        width=4000,
        hidden_layers=4,
        rank=RANK
    )

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(TRAINING_ITERATIONS):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [1.4316602945327759, 2.4385828971862793, 318970.4375, 1269.7012939453125, 357.39697265625], Test Loss [226.06541442871094, 2957.45556640625, 3820.262451171875, 1063.4532470703125, 165.81626892089844]
Iteration 100, Train Loss [2.420283794403076, 3.7620651721954346, 2.711750030517578, 2.9282426834106445, 3.2694172859191895], Test Loss [12.663140296936035, 6.484308242797852, 8.33734130859375, 46.172607421875, 10.527585983276367]


KeyboardInterrupt: 

In [None]:
train_losses = {}
test_losses = {}

# for width in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
for rank in [2, 4, 8, 16]:
    model = NeuralNetworkModel(
        n_recordings=TOTAL_RECORDINGS,
        n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
        n_time_bins=40, 
        width=1000,
        hidden_layers=4,
        rank=rank
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()

    train_losses[rank] = []
    test_losses[rank] = []

    for training_iteration in range(TRAINING_ITERATIONS):
        train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, model, optimizer, loss_fn)

        if training_iteration % 100 == 0:
            print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

        train_losses[rank].append(train_loss_iter)
        test_losses[rank].append(test_loss_iter)

Iteration 0, Train Loss [0.5648131370544434, 0.6291618943214417], Test Loss [0.8256673812866211, 0.791092038154602]
Iteration 100, Train Loss [0.322302907705307, 0.2544831335544586], Test Loss [0.713965654373169, 0.7019364237785339]
Iteration 200, Train Loss [0.07598769664764404, 0.04456452652812004], Test Loss [0.9142297506332397, 0.9619506001472473]
Iteration 300, Train Loss [0.014486381784081459, 0.016259051859378815], Test Loss [0.9953017830848694, 1.0209685564041138]
Iteration 400, Train Loss [0.012728437781333923, 0.0118115758523345], Test Loss [1.021017074584961, 1.0464943647384644]
Iteration 500, Train Loss [0.008318386040627956, 0.009678296744823456], Test Loss [1.0226885080337524, 1.042962670326233]
Iteration 0, Train Loss [0.5687515735626221, 0.7071180939674377], Test Loss [0.977382242679596, 0.9625833630561829]
Iteration 100, Train Loss [0.04103058949112892, 0.03797721117734909], Test Loss [0.7219447493553162, 0.8358548283576965]
Iteration 200, Train Loss [0.007002207916229

In [None]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'Test Recording {recording}'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'Train Recording {recording}'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

# Evaluation

In [None]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())

In [None]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], model, loss_fn, plot_num=2)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 0.04796317592263222


Evaluation, Recording 1, Loss 0.06785038858652115


Evaluation, Recording 2, Loss 0.017681101337075233


Evaluation, Recording 3, Loss 0.022932397201657295


Evaluation, Recording 4, Loss 0.04650332033634186


In [None]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], model, loss_fn, plot_num=3)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 3.797199249267578


Evaluation, Recording 1, Loss 1.9361345767974854


Evaluation, Recording 2, Loss 2.9157681465148926


Evaluation, Recording 3, Loss 2.2847325801849365


Evaluation, Recording 4, Loss 2.0194709300994873
