# Reduced Rank Model

In [2]:
import numpy as np
import os
import torch
import plotly.graph_objects as go

from models.reduced_rank_model import ReducedRankModel
from models.neural_network import NeuralNetworkModel

DATA_PATH = 'raw_data/full_data/'

TOTAL_RECORDINGS = 10
RECORDING_IDS = np.arange(TOTAL_RECORDINGS)
# RECORDING_IDS = [5, 6, 8]
TRAIN_SPLIT = 0.8
TRAINING_ITERATIONS = 1001
RANK = 3
MEAN_SUBTRACTED = False
MODEL = "reduced_rank"

In [3]:
files = os.listdir(DATA_PATH)
X_files = [f for f in files if f.endswith('.npy') and f.startswith('X')]
X_files.sort()
Y_files = [f for f in files if f.endswith('.npy') and f.startswith('Y')]
Y_files.sort()

X_files = np.array(X_files)
Y_files = np.array(Y_files)

X_list = [np.load(DATA_PATH + f) for f in X_files[RECORDING_IDS]]
Y_list = [np.load(DATA_PATH + f) for f in Y_files[RECORDING_IDS]]

In [4]:
print(len(X_list))
print(X_list[0].shape)

10
(775, 25, 40)


In [5]:
Y_list_mean_subtracted = [Y - np.mean(Y, axis=0, keepdims=True) for Y in Y_list]
X_list_shuffled = []
Y_list_shuffled = []
Y_list_mean_subtracted_shuffled = []

for i in range(TOTAL_RECORDINGS):
    indices = np.arange(X_list[i].shape[0])
    np.random.shuffle(indices)
    X_list_shuffled.append(X_list[i][indices])
    Y_list_mean_subtracted_shuffled.append(Y_list_mean_subtracted[i][indices])
    Y_list_shuffled.append(Y_list[i][indices])

X_train = [X[:int(TRAIN_SPLIT * X.shape[0])] for X in X_list_shuffled]
X_test = [X[int(TRAIN_SPLIT * X.shape[0]):] for X in X_list_shuffled]
if MEAN_SUBTRACTED:
    Y_train = [Y[:int(TRAIN_SPLIT * Y.shape[0])] for Y in Y_list_mean_subtracted_shuffled]
    Y_test = [Y[int(TRAIN_SPLIT * Y.shape[0]):] for Y in Y_list_mean_subtracted_shuffled]
else:
    Y_train = [Y[:int(TRAIN_SPLIT * Y.shape[0])] for Y in Y_list_shuffled]
    Y_test = [Y[int(TRAIN_SPLIT * Y.shape[0]):] for Y in Y_list_shuffled]

In [6]:
for t in range(6):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_list[0][t], name='Actual'))
    fig.add_trace(go.Scatter(y=Y_list_mean_subtracted[0][t], name='Predicted'))
    fig.update_layout(title=f'Actual and Predicted Wheel Speeds, Recording 0, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

In [7]:
print(len(X_train))
print(X_train[0].shape)

10
(620, 25, 40)


# Training

In [8]:
def evaluate_recording(recording, X_test_recording, Y_test_recording, model, optimizer, loss_fn, plot=False):
    # Test the model
    model.eval()
    
    X_test_recording = torch.from_numpy(X_test_recording).float()
    Y_test_recording = torch.from_numpy(Y_test_recording).float()
    Y_pred = model(recording, X_test_recording)
    loss = loss_fn(Y_pred, Y_test_recording)

    if plot:
        for t in range(6):
            fig = go.Figure()
            fig.add_trace(go.Scatter(y=Y_test_recording[t], name='Actual'))
            fig.add_trace(go.Scatter(y=Y_pred[t].detach().numpy(), name='Predicted'))
            fig.update_layout(title=f'Actual and Predicted Wheel Speeds, Recording {recording}, Trial {t}', xaxis_title='Time', yaxis_title='Value')
            fig.show()

    return loss.item()

In [9]:
def train_step(recording, X_train_recording, Y_train_recording, model, optimizer, loss_fn):
    optimizer.zero_grad()
    
    X_train_recording = torch.from_numpy(X_train_recording).float()
    Y_train_recording = torch.from_numpy(Y_train_recording).float()

    Y_pred = model(recording, X_train_recording)
    
    loss = loss_fn(Y_pred, Y_train_recording)
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [10]:
if MODEL == "reduced_rank":
    model = ReducedRankModel(
        n_recordings=TOTAL_RECORDINGS,
        n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
        n_time_bins=40, 
        rank=RANK
    )
elif MODEL == "neural_network":
    model = NeuralNetworkModel(
        n_recordings=TOTAL_RECORDINGS,
        n_neurons_per_recording=[X_train[i].shape[1] for i in range(TOTAL_RECORDINGS)], 
        n_time_bins=40, 
        width=1000,
        hidden_layers=3,
        rank=RANK
    )

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

train_loss = []
test_loss = []

for training_iteration in range(TRAINING_ITERATIONS):
    # Get a random permutation of numbers from 0 to TOTAL_RECORDINGS
    # session_order = np.arange(TOTAL_RECORDINGS)
    # np.random.shuffle(session_order)

    train_loss_iter = [recordings for recordings in range(TOTAL_RECORDINGS)]
    test_loss_iter = [recordings for recordings in range(TOTAL_RECORDINGS)]

    for recording in range(TOTAL_RECORDINGS):
        train_loss_item = train_step(recording, X_train[recording], Y_train[recording], model, optimizer, loss_fn)
        test_loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], model, optimizer, loss_fn, plot=False)
        
        train_loss_iter[recording] = train_loss_item
        test_loss_iter[recording] = test_loss_item
        
        if training_iteration % 100 == 0:
            print(f'Iteration {training_iteration}, Recording {recording}, Train Loss {train_loss_item}, Test Loss {test_loss_item}')
    
    train_loss.append(train_loss_iter)
    test_loss.append(test_loss_iter)

Iteration 0, Recording 0, Train Loss 2.1033756732940674, Test Loss 2.0401949882507324
Iteration 0, Recording 1, Train Loss 2.363130569458008, Test Loss 2.4854278564453125
Iteration 0, Recording 2, Train Loss 1.952628493309021, Test Loss 2.008145570755005
Iteration 0, Recording 3, Train Loss 7.470195770263672, Test Loss 7.285292148590088
Iteration 0, Recording 4, Train Loss 1.5973690748214722, Test Loss 1.63424813747406
Iteration 0, Recording 5, Train Loss 3.0236616134643555, Test Loss 2.9634509086608887
Iteration 0, Recording 6, Train Loss 2.2433576583862305, Test Loss 2.206555128097534
Iteration 0, Recording 7, Train Loss 2.4192299842834473, Test Loss 2.4096193313598633
Iteration 0, Recording 8, Train Loss 1.3870102167129517, Test Loss 1.358349084854126
Iteration 0, Recording 9, Train Loss 8.921563148498535, Test Loss 9.044267654418945


Iteration 100, Recording 0, Train Loss 0.9792956113815308, Test Loss 0.8998295664787292
Iteration 100, Recording 1, Train Loss 1.1609747409820557, Test Loss 1.2882964611053467
Iteration 100, Recording 2, Train Loss 1.0303266048431396, Test Loss 1.0596227645874023
Iteration 100, Recording 3, Train Loss 1.9413015842437744, Test Loss 1.6681915521621704
Iteration 100, Recording 4, Train Loss 0.6754927039146423, Test Loss 0.704535961151123
Iteration 100, Recording 5, Train Loss 0.7833033800125122, Test Loss 0.7398509383201599
Iteration 100, Recording 6, Train Loss 0.47209829092025757, Test Loss 0.4367552399635315
Iteration 100, Recording 7, Train Loss 0.7978923320770264, Test Loss 0.787980318069458
Iteration 100, Recording 8, Train Loss 0.6128144264221191, Test Loss 0.6079303026199341
Iteration 100, Recording 9, Train Loss 1.1255220174789429, Test Loss 1.1733214855194092
Iteration 200, Recording 0, Train Loss 0.7274807691574097, Test Loss 0.6508477330207825
Iteration 200, Recording 1, Train

In [11]:
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)

fig = go.Figure()
for recording in range(TOTAL_RECORDINGS):
    fig.add_trace(go.Scatter(y=test_loss[:, recording], name=f'Test Recording {recording}'))
    fig.add_trace(go.Scatter(y=train_loss[:, recording], name=f'Train Recording {recording}'))
fig.update_layout(title=f'Loss Curve Rank {RANK}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

# Evaluation

In [13]:
# Print all of the model parameters
# for param in reduced_rank_model.parameters():
#     print(param.data.shape)
#     print(param.data)

# print(f"=== U matrices, length: {len(reduced_rank_model.Us)} ===")
# print(reduced_rank_model.Us)
# print(f"=== V matrices, shape: {reduced_rank_model.V.shape} ===")
# print(reduced_rank_model.V)
# print(f"=== bias, shape: {reduced_rank_model.bias.shape} ===")
# print(reduced_rank_model.bias)

# print(reduced_rank_model.parameters())


In [15]:
for recording in range(TOTAL_RECORDINGS):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], model, optimizer, loss_fn, plot=True)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 0.5705170631408691


Evaluation, Recording 1, Loss 0.809992790222168


Evaluation, Recording 2, Loss 0.6201502084732056


Evaluation, Recording 3, Loss 1.2699768543243408


Evaluation, Recording 4, Loss 0.5200156569480896


Evaluation, Recording 5, Loss 0.41322213411331177


Evaluation, Recording 6, Loss 0.339739590883255


Evaluation, Recording 7, Loss 0.5004708170890808


Evaluation, Recording 8, Loss 0.4013534486293793


Evaluation, Recording 9, Loss 0.8772998452186584
