# Neural Network Model

In [10]:
import numpy as np
import torch
import plotly.graph_objects as go

from sklearn.linear_model import ARDRegression
from sklearn.multioutput import RegressorChain
from sklearn.metrics import mean_squared_error

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel
from models.zero_predictor import ZeroPredictor
from models.mean_predictor import MeanPredictor
from models.yizi_decoders import time_bin_wise_metrics, continuous_decoder, sliding_window_over_time, sliding_window_decoder

from utils.get_data import get_data, get_test_train_data, smooth_x, smooth_y
from utils.evaluation import evaluate_recording
from utils.training import train_one_epoch


DATA_PATH = 'raw_data/full_data/'

# RECORDING_IDS = [0, 1, 2, 4, 7, 8, 11, 12, 13, 21]
RECORDING_IDS = np.arange(112)
TOTAL_RECORDINGS = len(RECORDING_IDS) # In total there are 112 recordings
PLOT_RECORDINGS = 3
TIME_BINS = 40
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 20
RANK = 10
MEAN_SUBTRACTED = False
SMOOTHED = True
SMOOTHED_X = True
MODEL = "neural_network"
RANDOM_SEED = 39

In [11]:
# This code loads yizi's data
# neural_data = np.load(f"raw_data/yizi_data/neural_data.npy", allow_pickle=True).item()
# behavior_data = np.load(f"raw_data/yizi_data/behavior_data.npy", allow_pickle=True).item()

# # Probably recording: dab512bd-a02d-4c1f-8dbc-9155a163efc0
# X_list = [neural_data["po"]]
# Y_list = [behavior_data["wheel_speed"]]

# print(X_list[0].shape)

In [12]:
X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)

if SMOOTHED_X:
    X_list = smooth_x(X_list)

if SMOOTHED:
    Y_list = smooth_y(Y_list)

X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT, random_seed=RANDOM_SEED)

Y_train = Y_train_regular
Y_test = Y_test_regular

In [13]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    # fig.add_trace(go.Scatter(y=Y_train_smoothed[0][t], name='Smoothed'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted[0][t], name='Mean Subtracted'))
    # fig.add_trace(go.Scatter(y=Y_train_mean_subtracted_smoothed[0][t], name='Mean Subtracted Smoothed'))
    fig.update_layout(title=f'Original, Smoothed, and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

## Continuous Decoder model

In [14]:
Y_pred_continuous_decoder = []
test_loss_continuous_decoder = []

for recording_session in range(TOTAL_RECORDINGS):
    y_pred, evaluation = continuous_decoder(X_train[recording_session], Y_train[recording_session], X_test[recording_session], Y_test[recording_session], time_independent=False)
    Y_pred_continuous_decoder.append(y_pred)
    test_loss_continuous_decoder.append(evaluation[2])

test_loss_continuous_decoder = np.array(test_loss_continuous_decoder)

time-dependent:
Chosen alpha: 10
r2: 0.37469026154955454 corr: 0.613661821834778 mse: 0.004923691485981752
time-dependent:
Chosen alpha: 10
r2: 0.6874958854025703 corr: 0.8302528594835916 mse: 0.004831400075319194
time-dependent:
Chosen alpha: 1000
r2: 0.43352008602205916 corr: 0.6607186325498811 mse: 0.009061546373849094
time-dependent:
Chosen alpha: 100
r2: 0.5274701572073083 corr: 0.7357956725568784 mse: 0.009589157698551534
time-dependent:
Chosen alpha: 1000
r2: 0.3765411921762172 corr: 0.6192024331769077 mse: 0.011012472719409652
time-dependent:
Chosen alpha: 10
r2: 0.4051722061901888 corr: 0.6374605154803972 mse: 0.008254861966284893
time-dependent:
Chosen alpha: 100
r2: 0.4324168678382778 corr: 0.6620101226271273 mse: 0.0030833903182076673
time-dependent:
Chosen alpha: 10
r2: 0.518705939437358 corr: 0.7214000280949533 mse: 0.010229705684500713
time-dependent:
Chosen alpha: 10
r2: 0.5310814786777189 corr: 0.7292841015809571 mse: 0.007823511640387209
time-dependent:
Chosen alpha: 

## Mean Predictor

In [None]:
mean_predictor = MeanPredictor(Y_train)
test_loss_mean_predictor = []

for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], mean_predictor, loss_fn, plot_num=0)
    print(f'Performance of mean predictor on training data: {recording}, Loss {loss_item}')
    test_loss_mean_predictor.append(loss_item)

test_loss_mean_predictor = np.array(test_loss_mean_predictor)

Performance of mean predictor on training data: 0, Loss 0.004946375985719935
Performance of mean predictor on training data: 1, Loss 0.006960836492010732
Performance of mean predictor on training data: 2, Loss 0.008347617841664982
Performance of mean predictor on training data: 3, Loss 0.010550672449753287
Performance of mean predictor on training data: 4, Loss 0.014227278827092
Performance of mean predictor on training data: 5, Loss 0.007083147684356542
Performance of mean predictor on training data: 6, Loss 0.004660368844305392
Performance of mean predictor on training data: 7, Loss 0.012584671640805345
Performance of mean predictor on training data: 8, Loss 0.009245093873122747
Performance of mean predictor on training data: 9, Loss 0.012998579483699035
Performance of mean predictor on training data: 10, Loss 0.012998579483699035
Performance of mean predictor on training data: 11, Loss 0.00954133256053246
Performance of mean predictor on training data: 12, Loss 0.011747041777461362


In [None]:
interesting_recordings = np.argwhere(test_loss_continuous_decoder < test_loss_mean_predictor).squeeze()
print(interesting_recordings.shape)
print(interesting_recordings)

test_loss_continuous_decoder_interesting = test_loss_continuous_decoder[interesting_recordings]
test_loss_mean_predictor_interesting = test_loss_mean_predictor[interesting_recordings]

# 42: [  2   8  13  14  15  17  25  26  32  36  39  45  52  57  62  69  73  74
#   78  88  89  91  93  95 100 101 106 107 110 111]
# 41: [  4   9  10  11  12  21  23  26  30  32  36  37  40  44  49  50  51  53
#   54  55  56  57  58  65  67  70  71  75  76  77  83  88  89  93 101 106
#  111]
# 40: [  4  17  19  20  26  35  36  39  41  53  55  56  57  61  68  73  74  97
#  100 105 111]
# 39: [ 11  12  14  23  25  26  27  29  30  36  55  57  58  59  61  64  67  82
#   89  90  96 100 106 107 110]
# 38: [  3  14  17  20  23  26  42  45  49  53  58  60  66  68  71  76  83  88
#   90  93 100 103 108 111]

# 38: [  0   1   6   9  10  11  12  13  14  15  16  17  18  19  20  21  23  24
#   25  26  27  28  29  30  31  33  34  35  37  38  39  40  41  43  44  45
#   46  47  48  49  50  52  53  54  55  56  57  58  59  60  61  62  63  64
#   68  69  70  71  72  73  74  75  76  78  79  80  82  83  87  88  89  90
#   91  92  93  95  96 100 101 102 103 104 105 106 107 108 109 110 111]


(89,)
[  0   1   6   9  10  11  12  13  14  15  16  17  18  19  20  21  23  24
  25  26  27  28  29  30  31  33  34  35  37  38  39  40  41  43  44  45
  46  47  48  49  50  52  53  54  55  56  57  58  59  60  61  62  63  64
  68  69  70  71  72  73  74  75  76  78  79  80  82  83  87  88  89  90
  91  92  93  95  96 100 101 102 103 104 105 106 107 108 109 110 111]


In [None]:
fig = go.Figure()
# fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_zero_predictor, name='Zero Predictor'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_mean_predictor_interesting, name='Mean Predictor'))
fig.add_trace(go.Bar(x=np.arange(TOTAL_RECORDINGS), y=test_loss_continuous_decoder_interesting, name='Continuous Decoder'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Recording Session', yaxis_title='Test MSE')
fig.show()

# Analysis

In [None]:
interesting_decoding_sessions = []

interesting_decoding_sessions.append([2,8,13,14,15,17,25,26,32,36,39,45,52,57,62,69,73,74,78,88,89,91,93,95,100,101,106,107,110,111])
interesting_decoding_sessions.append([4,9,10,11,12,21,23,26,30,32,36,37,40,44,49,50,51,53,54,55,56,57,58,65,67,70,71,75,76,77,83,88,89,93,101,106,111])
interesting_decoding_sessions.append([4,17,19,20,26,35,36,39,41,53,55,56,57,61,68,73,74,97,100,105,111])
interesting_decoding_sessions.append([11,12,14,23,25,26,27,29,30,36,55,57,58,59,61,64,67,82,89,90,96,100,106,107,110])
interesting_decoding_sessions.append([3,14,17,20,23,26,42,45,49,53,58,60,66,68,71,76,83,88,90,93,100,103,108,111])

# flatten interesting_decoding_sessions
interesting_decoding_sessions = [item for sublist in interesting_decoding_sessions for item in sublist]
interesting_decoding_sessions = np.array(interesting_decoding_sessions)

# print elements of interesting_decoding_sessions that occur 5 times
values, occurances = np.unique(interesting_decoding_sessions, return_counts=True)

print(f"4 occurance sessions {values[np.argwhere(occurances==4).flatten()]}")
print(f"3 occurance sessions {values[np.argwhere(occurances==3).flatten()]}")
print(f"2 occurance sessions {values[np.argwhere(occurances==2).flatten()]}")
print(f"1 occurance sessions {values[np.argwhere(occurances==1).flatten()]}")

4 occurance sessions [ 36  57 100 111]
3 occurance sessions [ 14  17  23  53  55  58  88  89  93 106]
2 occurance sessions [  4  11  12  20  25  30  32  39  45  49  56  61  67  68  71  73  74  76
  83  90 101 107 110]
1 occurance sessions [  2   3   8   9  10  13  15  19  21  27  29  35  37  40  41  42  44  50
  51  52  54  59  60  62  64  65  66  69  70  75  77  78  82  91  95  96
  97 103 105 108]
