# Neural Network Model

In [1]:
import numpy as np
import torch
import plotly.graph_objects as go

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel, NeuralNetworkClassificationModel
from models.zero_predictor import ZeroPredictor
from utils.get_data import get_data, get_test_train_data, smooth_y, threshold_y
from utils.evaluation import evaluate_recording, evaluate_recording_classification
from utils.training import train_one_epoch

from models.yizi_decoders import time_bin_wise_metrics, continuous_decoder, sliding_window_over_time, sliding_window_decoder

DATA_PATH = 'raw_data/full_data/'

RECORDING_IDS = [89]
# RECORDING_IDS = np.arange(10)
TOTAL_RECORDINGS = len(RECORDING_IDS) # 112
# RECORDING_IDS = np.arange(TOTAL_RECORDINGS)
TIME_BINS = 40
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 35
RANK = 10
SMOOTHED = True
MODEL = "neural_network"
RANDOM_SEED = 42
THRESHOLD = 0.5

In [2]:
# neural_data = np.load(f"raw_data/yizi_data/neural_data.npy", allow_pickle=True).item()
# behavior_data = np.load(f"raw_data/yizi_data/behavior_data.npy", allow_pickle=True).item()

# # Probably recording: dab512bd-a02d-4c1f-8dbc-9155a163efc0
# X_list = [neural_data["po"]]
# Y_list = [behavior_data["wheel_speed"]]

# print(X_list[0].shape)

X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)

X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT, random_seed=RANDOM_SEED)

Y_train_thresholded = threshold_y(Y_train_regular, THRESHOLD)
Y_test_thresholded = threshold_y(Y_test_regular, THRESHOLD)

Y_train_smoothed = smooth_y(Y_train_regular, normalize=False)
Y_test_smoothed = smooth_y(Y_test_regular, normalize=False)

Y_train_smoothed_thresholded = threshold_y(Y_train_smoothed, THRESHOLD)
Y_test_smoothed_thresholded = threshold_y(Y_test_smoothed, THRESHOLD)
print(Y_test_smoothed_thresholded[0])

if THRESHOLD is not None:
    if SMOOTHED:
        Y_train = Y_train_smoothed_thresholded
        Y_test = Y_test_smoothed_thresholded
    else:
        Y_train = Y_train_thresholded
        Y_test = Y_test_thresholded
else:
    if SMOOTHED:
        Y_train = Y_train_smoothed
        Y_test = Y_test_smoothed
    else:
        Y_train = Y_train_regular
        Y_test = Y_test_regular

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 1 1]
 ...
 [0 0 0 ... 1 1 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]


In [3]:
# loop through Y_train_smoothed_thresholded and Y_train_thresholded and print the proportion of 1s
print("Smoothed")
for y in Y_train_smoothed_thresholded:
    print(np.sum(y) / (y.shape[0] * y.shape[1]))

print("Regular")
for y in Y_train_thresholded:
    print(np.sum(y) / (y.shape[0] * y.shape[1]))

Smoothed
0.6130546075085325
Regular
0.48916382252559726


In [4]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    fig.add_trace(go.Scatter(y=Y_train_smoothed[0][t], name='Smoothed'))
    fig.add_trace(go.Scatter(y=Y_train_thresholded[0][t], name='Thresholded'))
    fig.add_trace(go.Scatter(y=Y_train_smoothed_thresholded[0][t], name='Smoothed and Thresholded'))
    fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

## Reduced Rank

In [5]:
reduced_rank_model = ReducedRankModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    rank=RANK
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
optimizer = torch.optim.Adam(reduced_rank_model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(200):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, reduced_rank_model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [38.77676773071289], Test Loss [34.53838348388672]
Iteration 100, Train Loss [10.08514404296875], Test Loss [9.265486717224121]


In [6]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

test_loss_reduced_rank = test_loss[-1]

## Neural Network

In [5]:
neural_network_model = NeuralNetworkClassificationModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=TIME_BINS, 
    width=100,
    hidden_layers=3,
    rank=RANK
)

if RANDOM_SEED != -1:
    torch.manual_seed(RANDOM_SEED)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) 0.0455 best
optimizer = torch.optim.Adam(neural_network_model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = torch.nn.BCELoss()

train_loss = []
test_loss = []

for training_iteration in range(TRAINING_ITERATIONS):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, neural_network_model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [0.7864475250244141], Test Loss [0.657846212387085]


In [6]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

test_loss_neural_network = test_loss[-1]

## Continuous Decoder model

In [36]:
Y_pred_continuous_decoder = []
test_loss_continuous_decoder = []

for recording_session in range(TOTAL_RECORDINGS):
    y_pred, evaluation = continuous_decoder(X_train[recording_session], Y_train[recording_session], X_test[recording_session], Y_test[recording_session], time_independent=False)
    Y_pred_continuous_decoder.append(y_pred)
    test_loss_continuous_decoder.append(evaluation[2])

test_loss_continuous_decoder = np.array(test_loss_continuous_decoder)

time-dependent:
Chosen alpha: 1000
r2: 0.029060718347812764 corr: 0.18993592957096633 mse: 0.021751478977193657
time-dependent:
Chosen alpha: 1000
r2: 0.010576321595836657 corr: 0.10835228649920768 mse: 0.02162823151521842
time-dependent:
Chosen alpha: 1000
r2: -0.015670528993327437 corr: -0.07048391656033248 mse: 0.019447292867527603
time-dependent:
Chosen alpha: 100
r2: 0.2284144781621259 corr: 0.47959510023913576 mse: 0.012800555195702794
time-dependent:
Chosen alpha: 1000
r2: -0.006680139320656142 corr: 0.016882176397427485 mse: 0.01941848925887107
time-dependent:
Chosen alpha: 1000
r2: 0.022673661283033386 corr: 0.1717430588518267 mse: 0.01269161457334218
time-dependent:
Chosen alpha: 1000
r2: 0.02317415180428417 corr: 0.22492233853986565 mse: 0.007908085381200353
time-dependent:
Chosen alpha: 1000
r2: 0.03897716478051372 corr: 0.2030891043030955 mse: 0.031017187086813783
time-dependent:
Chosen alpha: 1000
r2: 0.039830397985079036 corr: 0.20376376095353654 mse: 0.02258653227185705

## Zero Predictor

In [7]:
zero_predictor = ZeroPredictor()
test_loss_zero_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording_classification(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=0)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')
    test_loss_zero_predictor.append(loss_item)

test_loss_zero_predictor = np.array(test_loss_zero_predictor)

Performance of zero predictor on training data: 0, Loss 61.30546188354492


In [8]:
fig = go.Figure()
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_zero_predictor, name='Zero Predictor'))
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_reduced_rank, name='Reduced Rank'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_neural_network, name='Neural Network'))
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_continuous_decoder, name='Continuous Decoder'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Recording Session', yaxis_title='Test MSE')
fig.show()

# Evaluation

In [14]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())

In [40]:
for recording_session in range(TOTAL_RECORDINGS):
    for t in range(5):
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=Y_test[recording_session][t], name='Original'))
        fig.add_trace(go.Scatter(y=Y_pred_continuous_decoder[recording_session][t], name='Prediction'))
        fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
        fig.show()


In [41]:
zero_predictor = ZeroPredictor()

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording_classification(recording, X_train[recording], Y_train[recording], zero_predictor, loss_fn, plot_num=1)
    print(f'Performance of zero predictor on training data: {recording}, Loss {loss_item}')

    loss_item = evaluate_recording_classification(recording, X_test[recording], Y_test[recording], zero_predictor, loss_fn, plot_num=1)
    print(f'Performance of zero predictor on test data: {recording}, Loss {loss_item}')

Performance of zero predictor on training data: 0, Loss 0.005544403567910194


Performance of zero predictor on test data: 0, Loss 0.02240804396569729


Performance of zero predictor on training data: 1, Loss 0.014304372482001781


Performance of zero predictor on test data: 1, Loss 0.021873731166124344


Performance of zero predictor on training data: 2, Loss 0.015348194167017937


Performance of zero predictor on test data: 2, Loss 0.01931684836745262


Performance of zero predictor on training data: 3, Loss 0.024416137486696243


Performance of zero predictor on test data: 3, Loss 0.016675101593136787


Performance of zero predictor on training data: 4, Loss 0.022469857707619667


Performance of zero predictor on test data: 4, Loss 0.019291868433356285


Performance of zero predictor on training data: 5, Loss 0.009893023408949375


Performance of zero predictor on test data: 5, Loss 0.012986313551664352


Performance of zero predictor on training data: 6, Loss 0.011806861497461796


Performance of zero predictor on test data: 6, Loss 0.008097667247056961


Performance of zero predictor on training data: 7, Loss 0.026756394654512405


Performance of zero predictor on test data: 7, Loss 0.03227612003684044


Performance of zero predictor on training data: 8, Loss 0.017567452043294907


Performance of zero predictor on test data: 8, Loss 0.02353309467434883


Performance of zero predictor on training data: 9, Loss 0.029393309727311134


Performance of zero predictor on test data: 9, Loss 0.025857416912913322


In [9]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording_classification(recording, X_train[recording], Y_train[recording], neural_network_model, loss_fn, plot_num=5)
    print(f'Training Evaluation, Recording {recording}, Loss {loss_item}')

Training Evaluation, Recording 0, Loss 0.1776430904865265


In [11]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording_classification(recording, X_test[recording], Y_test[recording], neural_network_model, loss_fn, plot_num=50)
    print(f'Testing Evaluation, Recording {recording}, Loss {loss_item}')

Testing Evaluation, Recording 0, Loss 0.40338587760925293
