# Reduced Rank Model

In [1]:
import numpy as np
import torch
import plotly.graph_objects as go

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel
from utils.get_data import get_data, get_test_train_data
from utils.evaluation import evaluate_recording
from utils.training import train_one_epoch

DATA_PATH = 'raw_data/full_data/'

TOTAL_RECORDINGS = 10 # 112
RECORDING_IDS = np.arange(TOTAL_RECORDINGS)
# RECORDING_IDS = [5, 6, 8]
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 501
RANK = 3
MEAN_SUBTRACTED = False
# MODEL = "neural_network"
MODEL = "reduced_rank"

In [2]:
X_list, Y_list = get_data(DATA_PATH, RECORDING_IDS)
X_train, X_test, Y_train_regular, Y_test_regular, Y_train_mean_subtracted, Y_test_mean_subtracted = get_test_train_data(X_list, Y_list, TRAIN_SPLIT)

if MEAN_SUBTRACTED:
    Y_train = Y_train_mean_subtracted
    Y_test = Y_test_mean_subtracted
else:
    Y_train = Y_train_regular
    Y_test = Y_test_regular

In [4]:
# from load_data.load_data import normalize_data
# X_train = [normalize_data(X_train_session) for X_train_session in X_train] 
# X_test = [normalize_data(X_test_session) for X_test_session in X_test]

  data_norm[:,t*n_units:(t+1)*n_units] = (data_norm[:,t*n_units:(t+1)*n_units] - mean_per_trial) / std_per_trial


In [5]:
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_train_regular[0][t], name='Original'))
    fig.add_trace(go.Scatter(y=Y_train_mean_subtracted[0][t], name='Mean Subtracted'))
    fig.update_layout(title=f'Original and Mean Subtracted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

# Training

In [6]:
model = ReducedRankModel(
    n_recordings=TOTAL_RECORDINGS,
    n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
    n_time_bins=40, 
    rank=RANK
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(TRAINING_ITERATIONS):
    train_loss_iter, test_loss_iter = train_one_epoch(X_train, Y_train, X_test, Y_test, model, optimizer, loss_fn)

    if training_iteration % 100 == 0:
        print(f'Iteration {training_iteration}, Train Loss {train_loss_iter}, Test Loss {test_loss_iter}')

    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Train Loss [3.87371563911438, 5.383478164672852, 2.7021484375, 13.774391174316406, 4.302082061767578, 10.224696159362793, 9.216532707214355, 3.199312448501587, 2.925170660018921, 4.580094337463379], Test Loss [3.846010684967041, 5.2899675369262695, 2.6242547035217285, 13.743958473205566, 4.232574939727783, 9.939427375793457, 9.182621955871582, 3.2080883979797363, 2.9873807430267334, 4.331789970397949]
Iteration 100, Train Loss [1.030358910560608, 1.2995128631591797, 1.1990480422973633, 2.3226685523986816, 0.8613584041595459, 0.8942620158195496, 0.6071637868881226, 0.9789215922355652, 0.7475394010543823, 1.579469919204712], Test Loss [0.9962438344955444, 1.3978055715560913, 1.1741418838500977, 2.1905243396759033, 0.8722450733184814, 0.8423254489898682, 0.5886803269386292, 1.0544482469558716, 0.7969605326652527, 1.3966110944747925]
Iteration 200, Train Loss [0.6944399476051331, 0.8449401259422302, 0.7697157263755798, 1.7877352237701416, 0.5503630042076111, 0.5403658747673035

In [7]:
print(np.argmax(train_loss[-1]))
print(np.argpartition(train_loss[-1], -5)[-5:])
print(np.partition(train_loss[-1], -5)[-5:])

3
[0 1 3 2 9]
[0.62233597 0.69022948 1.52071357 0.62463069 0.90230823]


In [8]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'{recording} - Test Recording'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'{recording} - Train Recording'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

# Evaluation

In [9]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())

In [10]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_train[recording], Y_train[recording], model, loss_fn, plot_num=1)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 0.6223629117012024


Evaluation, Recording 1, Loss 0.6902154088020325


Evaluation, Recording 2, Loss 0.6245977878570557


Evaluation, Recording 3, Loss 1.5203088521957397


Evaluation, Recording 4, Loss 0.4898335933685303


Evaluation, Recording 5, Loss 0.44461116194725037


Evaluation, Recording 6, Loss 0.343426913022995


Evaluation, Recording 7, Loss 0.49229657649993896


Evaluation, Recording 8, Loss 0.3985899090766907


Evaluation, Recording 9, Loss 0.902259886264801


In [11]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], model, loss_fn, plot_num=3)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 0.5917466878890991


Evaluation, Recording 1, Loss 0.7562533020973206


Evaluation, Recording 2, Loss 0.6096684336662292


Evaluation, Recording 3, Loss 1.4721012115478516


Evaluation, Recording 4, Loss 0.4957953095436096


Evaluation, Recording 5, Loss 0.4179922938346863


Evaluation, Recording 6, Loss 0.315343976020813


Evaluation, Recording 7, Loss 0.5192296504974365


Evaluation, Recording 8, Loss 0.43697553873062134


Evaluation, Recording 9, Loss 0.7625113725662231
