# Neural Network Model

In [1]:
import numpy as np
import torch
import plotly.graph_objects as go

from sklearn.linear_model import ARDRegression
from sklearn.multioutput import RegressorChain
from sklearn.metrics import mean_squared_error

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel
from models.zero_predictor import ZeroPredictor
from models.mean_predictor import MeanPredictor
from models.yizi_decoders import time_bin_wise_metrics, continuous_decoder, sliding_window_over_time, sliding_window_decoder

from utils.get_data import get_data, get_test_train_data, smooth_x, smooth_y
from utils.evaluation import evaluate_recording
from utils.training import train_one_epoch


DATA_PATH = 'raw_data/full_data/'

RECORDING_IDS = [36, 57, 100, 111] #+ [14, 17, 23, 53, 55, 58, 88, 89, 93, 106]
# RECORDING_IDS = np.arange(10)
TOTAL_RECORDINGS = len(RECORDING_IDS) # In total there are 112 recordings
PLOT_RECORDINGS = 3
TIME_BINS = 40
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 20
RANK = 10
MEAN_SUBTRACTED = False
SMOOTHED = True
SMOOTHED_X = True
MODEL = "neural_network"
RANDOM_SEED = 41

In [2]:
# This code loads yizi's data
# neural_data = np.load(f"raw_data/yizi_data/neural_data.npy", allow_pickle=True).item()
# behavior_data = np.load(f"raw_data/yizi_data/behavior_data.npy", allow_pickle=True).item()

# # Probably recording: dab512bd-a02d-4c1f-8dbc-9155a163efc0
# X_list = [neural_data["po"]]
# Y_list = [behavior_data["wheel_speed"]]

# print(X_list[0].shape)

In [3]:
# X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)



# X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT, random_seed=RANDOM_SEED)

# if SMOOTHED_X:
#     X_train = smooth_x(X_train)
#     X_test = smooth_x(X_test)

# Y_train_smoothed = smooth_y(Y_train_regular)
# Y_test_smoothed = smooth_y(Y_test_regular)

# Y_train_mean_subtracted_smoothed = smooth_y(Y_train_mean_subtracted)
# Y_test_mean_subtracted_smoothed = smooth_y(Y_test_mean_subtracted)

# if MEAN_SUBTRACTED:
#     if SMOOTHED:
#         Y_train = Y_train_mean_subtracted_smoothed
#         Y_test = Y_test_mean_subtracted_smoothed
#     else:
#         Y_train = Y_train_mean_subtracted
#         Y_test = Y_test_mean_subtracted
# else:
#     if SMOOTHED:
#         Y_train = Y_train_smoothed
#         Y_test = Y_test_smoothed
#     else:
#         Y_train = Y_train_regular
#         Y_test = Y_test_regular

In [4]:
X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)

if SMOOTHED_X:
    X_list = smooth_x(X_list)

if SMOOTHED:
    Y_list = smooth_y(Y_list)

X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT, random_seed=RANDOM_SEED)

Y_train = Y_train_regular
Y_test = Y_test_regular
# if SMOOTHED_X:
#     X_train = smooth_x(X_train)
#     X_test = smooth_x(X_test)

# Y_train_smoothed = smooth_y(Y_train_regular)
# Y_test_smoothed = smooth_y(Y_test_regular)

# Y_train_mean_subtracted_smoothed = smooth_y(Y_train_mean_subtracted)
# Y_test_mean_subtracted_smoothed = smooth_y(Y_test_mean_subtracted)

# if MEAN_SUBTRACTED:
#     if SMOOTHED:
#         Y_train = Y_train_mean_subtracted_smoothed
#         Y_test = Y_test_mean_subtracted_smoothed
#     else:
#         Y_train = Y_train_mean_subtracted
#         Y_test = Y_test_mean_subtracted
# else:
#     if SMOOTHED:
#         Y_train = Y_train_smoothed
#         Y_test = Y_test_smoothed
#     else:
#         Y_train = Y_train_regular
#         Y_test = Y_test_regular

In [5]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    # fig.add_trace(go.Scatter(y=Y_train_smoothed[0][t], name='Smoothed'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted[0][t], name='Mean Subtracted'))
    # fig.add_trace(go.Scatter(y=Y_train_mean_subtracted_smoothed[0][t], name='Mean Subtracted Smoothed'))
    fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

In [6]:
loss_fn = torch.nn.MSELoss()

## Reduced Rank

In [31]:
reduced_rank_model = ReducedRankModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    rank=3
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
optimizer = torch.optim.Adam(reduced_rank_model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(1001):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, reduced_rank_model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [1.0105645656585693, 1.0185211896896362, 1.0103423595428467, 1.0331549644470215], Test Loss [1.0034688711166382, 1.0143344402313232, 1.011672019958496, 1.043847680091858]
Iteration 100, Train Loss [0.5057018399238586, 0.5139608383178711, 0.5050218105316162, 0.5299863815307617], Test Loss [0.5014842748641968, 0.5122342109680176, 0.5062246322631836, 0.5392560362815857]
Iteration 200, Train Loss [0.2570652961730957, 0.25589582324028015, 0.2544521391391754, 0.2687816321849823], Test Loss [0.25398266315460205, 0.2556680738925934, 0.2554248571395874, 0.275619775056839]
Iteration 300, Train Loss [0.1296914517879486, 0.12568168342113495, 0.12687796354293823, 0.13515403866767883], Test Loss [0.12734614312648773, 0.12635040283203125, 0.1275702565908432, 0.14018060266971588]
Iteration 400, Train Loss [0.06627660989761353, 0.06245025619864464, 0.06378761678934097, 0.06892738491296768], Test Loss [0.0643608421087265, 0.06365904211997986, 0.06419983506202698, 0.07276225835084

In [32]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

test_loss_reduced_rank = test_loss[-1]

## Neural Network

In [7]:
neural_network_model = NeuralNetworkModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    width=2000,
    hidden_layers=3,
    rank=RANK
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
# optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.001, weight_decay=0.01)
# optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.00004, weight_decay=0.01) # Really good for rank 10, 4 data
optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.0001, weight_decay=0.001)
# optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.001)
# nn.BatchNorm1d(100)
train_loss = []
test_loss = []

for training_iteration in range(151):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, neural_network_model, optimizer, loss_fn)

    if training_iteration % 10 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [0.061459124088287354, 0.04750856012105942, 0.03511972725391388, 0.028620097786188126], Test Loss [0.030916480347514153, 0.029504599049687386, 0.023675836622714996, 0.021991580724716187]
Iteration 10, Train Loss [0.014829352498054504, 0.013867842964828014, 0.014402356930077076, 0.013751332648098469], Test Loss [0.013329054228961468, 0.015683412551879883, 0.014175760559737682, 0.014935337007045746]
Iteration 20, Train Loss [0.0142504358664155, 0.013208480551838875, 0.01403845939785242, 0.013187987729907036], Test Loss [0.013041217811405659, 0.014827979728579521, 0.013916726224124432, 0.014403438195586205]
Iteration 30, Train Loss [0.013840504921972752, 0.012821287848055363, 0.013790631666779518, 0.012788757681846619], Test Loss [0.012933078221976757, 0.014350904151797295, 0.013785751536488533, 0.01396091841161251]
Iteration 40, Train Loss [0.013513792306184769, 0.012511787936091423, 0.013544772751629353, 0.01239020749926567], Test Loss [0.012905958108603954, 0.01

In [8]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
# Show the y axis from 0 to 0.1
fig.update_yaxes(range=[0, 0.04])
fig.show()

test_loss_neural_network = test_loss[-1]

## Continuous Decoder model

In [9]:
Y_pred_continuous_decoder = []
test_loss_continuous_decoder = []

for recording_session in range(TOTAL_RECORDINGS):
    y_pred, evaluation = continuous_decoder(X_train[recording_session], Y_train[recording_session], X_test[recording_session], Y_test[recording_session], time_independent=False)
    Y_pred_continuous_decoder.append(y_pred)
    test_loss_continuous_decoder.append(evaluation[2])

test_loss_continuous_decoder = np.array(test_loss_continuous_decoder)

time-dependent:
Chosen alpha: 10
r2: 0.5514621486197606 corr: 0.7460602962810368 mse: 0.011435355949357679
time-dependent:
Chosen alpha: 10
r2: 0.5347374102375595 corr: 0.7316881544919058 mse: 0.011373395842066729
time-dependent:
Chosen alpha: 10
r2: 0.5719607446430899 corr: 0.7579690478073614 mse: 0.01052688932962336
time-dependent:
Chosen alpha: 10
r2: 0.6376397243344992 corr: 0.8051757804612079 mse: 0.011428955669119755


## ARD Regression Model

In [15]:
Y_pred_ARD = []
test_loss_ARD = []
ARD_model = RegressorChain(ARDRegression(alpha_1=1e-4, alpha_2=1e-4, lambda_1=1, lambda_2=0.5))

for recording in range(TOTAL_RECORDINGS):
    X_train_new = X_train[recording].reshape(X_train[recording].shape[0], -1)
    X_test_new = X_test[recording].reshape(X_test[recording].shape[0], -1)

    ARD_model.fit(X_train_new, Y_train[recording])
    Y_pred = ARD_model.predict(X_test_new)
    Y_pred_ARD.append(Y_pred)
    
    test_loss = mean_squared_error(Y_test[recording], Y_pred)
    print(f'Performance of ARD model on test data: {recording}, Loss {test_loss}')
    test_loss_ARD.append(test_loss)

test_loss_ARD = np.array(test_loss_ARD)

Performance of ARD model on test data: 0, Loss 0.030607482200984658
Performance of ARD model on test data: 1, Loss 0.03861754108817529


: 

In [118]:
# ARD Performance:
test_loss_ARD = np.array([
    0.047269990033470675, 
    0.018843322495421387, 
    0.022582635558847583,
    0.0468420420424728,
    0.057981851816811925,
    0.03037443153550907,
    0.023371255594955443,
    0.022045891482480172,
    0.05732039933265805,
    0.035073908627002826
])

## Zero Predictor

In [119]:
zero_predictor = ZeroPredictor()
test_loss_zero_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')
    test_loss_zero_predictor.append(loss_item)

test_loss_zero_predictor = np.array(test_loss_zero_predictor)

Performance of zero predictor on training data: 0, Loss 0.06247894465923309
Performance of zero predictor on training data: 1, Loss 0.05807190015912056
Performance of zero predictor on training data: 2, Loss 0.054529912769794464
Performance of zero predictor on training data: 3, Loss 0.06047986075282097


## Mean Predictor

In [13]:
mean_predictor = MeanPredictor(Y_train)
test_loss_mean_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')
    test_loss_mean_predictor.append(loss_item)

test_loss_mean_predictor = np.array(test_loss_mean_predictor)

Performance of zero predictor on training data: 0, Loss 0.014219517880003755
Performance of zero predictor on training data: 1, Loss 0.017322901560449005
Performance of zero predictor on training data: 2, Loss 0.014173219471063727
Performance of zero predictor on training data: 3, Loss 0.015911425953747955


In [14]:
fig = go.Figure()
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_zero_predictor, name='Zero Predictor'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_mean_predictor, name='Mean Predictor'))
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_reduced_rank, name='Reduced Rank'))
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_ARD, name='ARD Model'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_neural_network, name='Neural Network'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_continuous_decoder, name='Continuous Decoder'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Recording Session', yaxis_title='Test MSE')
fig.show()

# Evaluation

In [None]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())

In [40]:
for recording_session in range(PLOT_RECORDINGS):
    for t in range(2):
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=Y_test[recording_session][t], name='Original'))
        fig.add_trace(go.Scatter(y=Y_pred_continuous_decoder[recording_session][t], name='Prediction'))
        fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
        fig.show()


In [None]:
zero_predictor = ZeroPredictor()

for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')

    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on test data: {recording}, Loss {loss_item}')

Performance of zero predictor on training data: 0, Loss 0.06247894465923309
Performance of zero predictor on test data: 0, Loss 0.07374325394630432
Performance of zero predictor on training data: 1, Loss 0.05807190015912056
Performance of zero predictor on test data: 1, Loss 0.0756937637925148
Performance of zero predictor on training data: 2, Loss 0.054529912769794464
Performance of zero predictor on test data: 2, Loss 0.06270706653594971


In [88]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of mean predictor on training data: {recording}, Loss {loss_item}')

    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], mean_predictor, loss_fn, plot_num=2)
    print(f'Performance of mean predictor on test data: {recording}, Loss {loss_item}')

Performance of mean predictor on training data: 0, Loss 0.01631386089600843


Performance of mean predictor on test data: 0, Loss 0.015349475338455812
Performance of mean predictor on training data: 1, Loss 0.016039364283480385


Performance of mean predictor on test data: 1, Loss 0.020524314189927523
Performance of mean predictor on training data: 2, Loss 0.014791066972857309


Performance of mean predictor on test data: 2, Loss 0.016578440945140993


In [95]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], neural_network_model, loss_fn, plot_num=1)
    print(f'Training Evaluation, Recording {recording}, Loss {loss_item}')

Training Evaluation, Recording 0, Loss 0.000462967756902799


Training Evaluation, Recording 1, Loss 0.0006469715153798461


Training Evaluation, Recording 2, Loss 0.00063310656696558


In [96]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], neural_network_model, loss_fn, plot_num=2)
    print(f'Testing Evaluation, Recording {recording}, Loss {loss_item}')

Testing Evaluation, Recording 0, Loss 0.02064559981226921


Testing Evaluation, Recording 1, Loss 0.023233434185385704


Testing Evaluation, Recording 2, Loss 0.01827158033847809
