# Neural Network Model

In [15]:
import numpy as np
import torch
import plotly.graph_objects as go

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel
from models.zero_predictor import ZeroPredictor
from models.mean_predictor import MeanPredictor
from models.yizi_decoders import time_bin_wise_metrics, continuous_decoder, sliding_window_over_time, sliding_window_decoder

from utils.get_data import get_data, get_test_train_data, smooth_y
from utils.evaluation import evaluate_recording
from utils.training import train_one_epoch


DATA_PATH = 'raw_data/full_data/'

# RECORDING_IDS = [89, 90]
RECORDING_IDS = np.arange(10)
TOTAL_RECORDINGS = len(RECORDING_IDS) # In total there are 112 recordings
PLOT_RECORDINGS = 3
TIME_BINS = 40
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 20
RANK = 10
MEAN_SUBTRACTED = False
SMOOTHED = True
MODEL = "neural_network"
RANDOM_SEED = 42

In [16]:
# This code loads yizi's data
# neural_data = np.load(f"raw_data/yizi_data/neural_data.npy", allow_pickle=True).item()
# behavior_data = np.load(f"raw_data/yizi_data/behavior_data.npy", allow_pickle=True).item()

# # Probably recording: dab512bd-a02d-4c1f-8dbc-9155a163efc0
# X_list = [neural_data["po"]]
# Y_list = [behavior_data["wheel_speed"]]

# print(X_list[0].shape)

In [17]:
X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)

X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT, random_seed=RANDOM_SEED)

Y_train_smoothed = smooth_y(Y_train_regular)
Y_test_smoothed = smooth_y(Y_test_regular)

Y_train_mean_subtracted_smoothed = smooth_y(Y_train_mean_subtracted)
Y_test_mean_subtracted_smoothed = smooth_y(Y_test_mean_subtracted)

if MEAN_SUBTRACTED:
    if SMOOTHED:
        Y_train = Y_train_mean_subtracted_smoothed
        Y_test = Y_test_mean_subtracted_smoothed
    else:
        Y_train = Y_train_mean_subtracted
        Y_test = Y_test_mean_subtracted
else:
    if SMOOTHED:
        Y_train = Y_train_smoothed
        Y_test = Y_test_smoothed
    else:
        Y_train = Y_train_regular
        Y_test = Y_test_regular

In [18]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    fig.add_trace(go.Scatter(y=Y_train_smoothed[0][t], name='Smoothed'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted[0][t], name='Mean Subtracted'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted_smoothed[0][t], name='Mean Subtracted Smoothed'))
    fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

In [19]:
loss_fn = torch.nn.MSELoss()

## Reduced Rank

In [20]:
reduced_rank_model = ReducedRankModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    rank=RANK
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
optimizer = torch.optim.Adam(reduced_rank_model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(200):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, reduced_rank_model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [2.182859182357788, 2.3864684104919434, 0.7247431874275208, 29.242944717407227, 1.5197280645370483, 11.07209300994873, 14.782331466674805, 9.951565742492676, 3.563952684402466, 16.304889678955078], Test Loss [2.089543581008911, 2.461747169494629, 0.7366320490837097, 30.109355926513672, 1.4513477087020874, 10.822796821594238, 15.424304962158203, 9.456409454345703, 3.6558024883270264, 16.34398078918457]
Iteration 100, Train Loss [0.10954174399375916, 0.14291438460350037, 0.1513132005929947, 0.16874992847442627, 0.1665649116039276, 0.14084206521511078, 0.14235085248947144, 0.14369772374629974, 0.1636534035205841, 0.13291007280349731], Test Loss [0.1420314908027649, 0.15821321308612823, 0.14892609417438507, 0.16206631064414978, 0.1691424548625946, 0.141210675239563, 0.13259878754615784, 0.1725798100233078, 0.19688595831394196, 0.1379622220993042]


In [21]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

test_loss_reduced_rank = test_loss[-1]

## Neural Network

In [22]:
neural_network_model = NeuralNetworkModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    width=1000,
    hidden_layers=3,
    rank=RANK
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.001, weight_decay=0.01)

train_loss = []
test_loss = []

for training_iteration in range(TRAINING_ITERATIONS):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, neural_network_model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [0.04368145763874054, 0.2794710397720337, 0.04422873258590698, 5.290378093719482, 0.0560423918068409, 1.6672255992889404, 0.9114725589752197, 0.95157790184021, 0.18406349420547485, 6.854626655578613], Test Loss [0.5450387597084045, 0.2752496302127838, 0.04928147792816162, 1.5378141403198242, 0.10560349375009537, 0.6041120290756226, 0.7634207606315613, 0.7432767748832703, 0.3207519054412842, 1.8562495708465576]


In [23]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

test_loss_neural_network = test_loss[-1]

## Continuous Decoder model

In [24]:
Y_pred_continuous_decoder = []
test_loss_continuous_decoder = []

for recording_session in range(TOTAL_RECORDINGS):
    y_pred, evaluation = continuous_decoder(X_train[recording_session], Y_train[recording_session], X_test[recording_session], Y_test[recording_session], time_independent=False)
    Y_pred_continuous_decoder.append(y_pred)
    test_loss_continuous_decoder.append(evaluation[2])

test_loss_continuous_decoder = np.array(test_loss_continuous_decoder)

time-dependent:
Chosen alpha: 1000
r2: 0.023703685907214056 corr: 0.602334713989529 mse: 0.02609984434030187
time-dependent:
Chosen alpha: 1000
r2: 0.5288564492666994 corr: 0.7459212248750491 mse: 0.010744662976213168
time-dependent:
Chosen alpha: 1000
r2: 0.40063648594147305 corr: 0.6336715965011659 mse: 0.009499207473756724
time-dependent:
Chosen alpha: 100
r2: 0.4416702067950975 corr: 0.7234400084282193 mse: 0.011385778440282907
time-dependent:
Chosen alpha: 1000
r2: 0.36401567702226034 corr: 0.605232782251363 mse: 0.013783027509796413
time-dependent:
Chosen alpha: 1000
r2: 0.40168433012938387 corr: 0.6432827369865283 mse: 0.008177470147140237
time-dependent:
Chosen alpha: 1000
r2: 0.11144423350470001 corr: 0.6011051910835543 mse: 0.008109921274342824
time-dependent:
Chosen alpha: 1000
r2: 0.39435689914926975 corr: 0.6750948398293547 mse: 0.02306805087832937
time-dependent:
Chosen alpha: 1000
r2: 0.37662918606336293 corr: 0.6877414311746699 mse: 0.017574195474948984
time-dependent:


## Zero Predictor

In [25]:
zero_predictor = ZeroPredictor()
test_loss_zero_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')
    test_loss_zero_predictor.append(loss_item)

test_loss_zero_predictor = np.array(test_loss_zero_predictor)

Performance of zero predictor on training data: 0, Loss 0.013082008808851242
Performance of zero predictor on training data: 1, Loss 0.03412018343806267
Performance of zero predictor on training data: 2, Loss 0.04327685385942459
Performance of zero predictor on training data: 3, Loss 0.06633462011814117
Performance of zero predictor on training data: 4, Loss 0.053665388375520706
Performance of zero predictor on training data: 5, Loss 0.03388050198554993
Performance of zero predictor on training data: 6, Loss 0.03342223912477493
Performance of zero predictor on training data: 7, Loss 0.05492505058646202
Performance of zero predictor on training data: 8, Loss 0.048733361065387726
Performance of zero predictor on training data: 9, Loss 0.07775436341762543


## Mean Predictor

In [26]:
mean_predictor = MeanPredictor(Y_train)
test_loss_mean_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')
    test_loss_mean_predictor.append(loss_item)

test_loss_mean_predictor = np.array(test_loss_mean_predictor)

Performance of zero predictor on training data: 0, Loss 0.004947812346117907
Performance of zero predictor on training data: 1, Loss 0.006720211784004756
Performance of zero predictor on training data: 2, Loss 0.008976126785998384
Performance of zero predictor on training data: 3, Loss 0.016920646140708224
Performance of zero predictor on training data: 4, Loss 0.014816321109973253
Performance of zero predictor on training data: 5, Loss 0.008239859915209913
Performance of zero predictor on training data: 6, Loss 0.008040015532512797
Performance of zero predictor on training data: 7, Loss 0.012619350564626373
Performance of zero predictor on training data: 8, Loss 0.009247272415757964
Performance of zero predictor on training data: 9, Loss 0.012588139701469358



Using a target size (torch.Size([620, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([446, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([372, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([295, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([363, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure t

In [27]:
fig = go.Figure()
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_zero_predictor, name='Zero Predictor'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_mean_predictor, name='Mean Predictor'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_reduced_rank, name='Reduced Rank'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_neural_network, name='Neural Network'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_continuous_decoder, name='Continuous Decoder'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Recording Session', yaxis_title='Test MSE')
fig.show()

# Evaluation

In [28]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())

In [29]:

for recording_session in range(PLOT_RECORDINGS):
    for t in range(1):
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=Y_test[recording_session][t], name='Original'))
        fig.add_trace(go.Scatter(y=Y_pred_continuous_decoder[recording_session][t], name='Prediction'))
        fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
        fig.show()


In [30]:
zero_predictor = ZeroPredictor()

for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')

    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on test data: {recording}, Loss {loss_item}')

Performance of zero predictor on training data: 0, Loss 0.013082008808851242
Performance of zero predictor on test data: 0, Loss 0.05173294246196747
Performance of zero predictor on training data: 1, Loss 0.03412018343806267
Performance of zero predictor on test data: 1, Loss 0.04684337601065636
Performance of zero predictor on training data: 2, Loss 0.04327685385942459
Performance of zero predictor on test data: 2, Loss 0.04240398108959198


In [31]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of mean predictor on training data: {recording}, Loss {loss_item}')

    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of mean predictor on test data: {recording}, Loss {loss_item}')

Performance of mean predictor on training data: 0, Loss 0.004947812346117907
Performance of mean predictor on test data: 0, Loss 0.026562155120096288
Performance of mean predictor on training data: 1, Loss 0.006720211784004756
Performance of mean predictor on test data: 1, Loss 0.010936572292060834
Performance of mean predictor on training data: 2, Loss 0.008976126785998384
Performance of mean predictor on test data: 2, Loss 0.009500123974852177



Using a target size (torch.Size([155, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([112, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([93, 40])) that is different to the input size (torch.Size([40])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.



In [32]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], neural_network_model, loss_fn, plot_num=1)
    print(f'Training Evaluation, Recording {recording}, Loss {loss_item}')

Training Evaluation, Recording 0, Loss 0.005760076455771923


Training Evaluation, Recording 1, Loss 0.009752118028700352


Training Evaluation, Recording 2, Loss 0.026937242597341537


In [33]:
for recording in range(PLOT_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], neural_network_model, loss_fn, plot_num=1)
    print(f'Testing Evaluation, Recording {recording}, Loss {loss_item}')

Testing Evaluation, Recording 0, Loss 0.02735517919063568


Testing Evaluation, Recording 1, Loss 0.015846578404307365


Testing Evaluation, Recording 2, Loss 0.026000484824180603
