In [42]:
import numpy as np
import torch
import plotly.graph_objs as go
from sklearn.metrics import mean_squared_error, r2_score
from sklearn import linear_model
from models.reduced_rank_model import ReducedRankModel

## Processing Data

In [43]:
with open('./raw_data/full_data/names_and_shapes.txt') as file:
    lines = []
    for line in file:
        sep = line.find(',')
        name = line[:sep]
        lines.append(name)

In [44]:
X_list = []
Y_list = []

for i in range(len(lines) // 2):
    X_file = './raw_data/full_data/' + lines[i]
    X_i = np.load(X_file)
    X_list.append(X_i)
    Y_file = './raw_data/full_data/' + lines[i + len(lines) // 2]
    Y_i = np.load(Y_file)
    Y_list.append(Y_i)

In [86]:
avg_spikes_lst = []

for i in range(len(X_list)):
    avg_spikes_lst.append(np.mean(X_list[i], axis=0))

In [87]:
avg_spikes_lst[0].shape

(25, 40)

In [88]:
session1 = avg_spikes_lst[0][:,0]

In [89]:
# Bar plot of session1, ranked in descending order
session1 = np.sort(session1)[::-1]
fig = go.Figure(data=[go.Bar(x=np.arange(len(session1)), y=session1)])
fig.update_layout(title='Session 1')
fig.show()

In [46]:
# num_sessions = len(X_list)
num_sessions = 3
time_bins = X_list[0].shape[2]
train_split = 0.8
train_iters = 1001

In [47]:
X_train = [X[:int(train_split * X.shape[0])] for X in X_list]
X_test = [X[int(train_split * X.shape[0]):] for X in X_list]
Y_train = [Y[:int(train_split * Y.shape[0])] for Y in Y_list]
Y_test = [Y[int(train_split * Y.shape[0]):] for Y in Y_list]

## Decoding with Reduced Rank Model

In [48]:
def train_step(session, X_train_session, Y_train_session, model, optimizer, loss_fn):
    optimizer.zero_grad()
    
    X_train_session = torch.from_numpy(X_train_session).float()
    Y_train_session = torch.from_numpy(Y_train_session).float()

    Y_pred = model(session, X_train_session)
    
    loss = loss_fn(Y_pred, Y_train_session)
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [59]:
reduced_rank_model = ReducedRankModel(
    num_sessions,
    [X_train[i].shape[1] for i in range(num_sessions)], 
    time_bins, 
    rank=3
)

optimizer = torch.optim.Adam(reduced_rank_model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

loss = []

for training_iteration in range(train_iters):
    loss_iter = [session for session in range(num_sessions)]

    for session in range(num_sessions):
        loss_item = train_step(session, X_train[session], Y_train[session], reduced_rank_model, optimizer, loss_fn)
        loss_iter[session] = loss_item
        if training_iteration % 100 == 0:
            print(f'Iteration {training_iteration}, Recording {session}, Loss {loss_item}')
    
    loss.append(loss_iter)

Iteration 0, Recording 0, Loss 1.6467673778533936
Iteration 0, Recording 1, Loss 2.9828641414642334
Iteration 0, Recording 2, Loss 1.560168743133545
Iteration 100, Recording 0, Loss 1.2008883953094482
Iteration 100, Recording 1, Loss 1.8504019975662231
Iteration 100, Recording 2, Loss 1.5476949214935303
Iteration 200, Recording 0, Loss 1.0082918405532837
Iteration 200, Recording 1, Loss 1.5960330963134766
Iteration 200, Recording 2, Loss 1.545701503753662
Iteration 300, Recording 0, Loss 0.9401167035102844
Iteration 300, Recording 1, Loss 1.5635188817977905
Iteration 300, Recording 2, Loss 1.5457500219345093
Iteration 400, Recording 0, Loss 0.9208054542541504
Iteration 400, Recording 1, Loss 1.5661983489990234
Iteration 400, Recording 2, Loss 1.5458753108978271
Iteration 500, Recording 0, Loss 0.9157947301864624
Iteration 500, Recording 1, Loss 1.569416880607605
Iteration 500, Recording 2, Loss 1.5459182262420654
Iteration 600, Recording 0, Loss 0.9145590662956238
Iteration 600, Record

In [60]:
# Plot the loss curve using plotly graphical objects
loss = np.array(loss)

fig = go.Figure()
for session in range(num_sessions):
    fig.add_trace(go.Scatter(y=loss[:, session], name=f'Recording {session}'))
fig.update_layout(title=f'Loss Curve Rank {reduced_rank_model.rank}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

In [51]:
def evaluate_recording(recording, X_test_recording, Y_test_recording, model, loss_fn, plot=False):
    model.eval()
    
    X_test_recording = torch.from_numpy(X_test_recording).float()
    Y_test_recording = torch.from_numpy(Y_test_recording).float()
    Y_pred = model(recording, X_test_recording)
    loss = loss_fn(Y_pred, Y_test_recording)

    if plot:
        for t in range(3):
            fig = go.Figure()
            fig.add_trace(go.Scatter(y=Y_test_recording[t], name='Actual'))
            fig.add_trace(go.Scatter(y=Y_pred[t].detach().numpy(), name='Predicted'))
            fig.update_layout(title=f'Actual and Predicted Wheel Speeds, Recording {recording}, Trial {t}', xaxis_title='Time', yaxis_title='Value')
            fig.show()

    return loss.item()

In [52]:
for recording in range(num_sessions):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], reduced_rank_model, loss_fn, plot=True)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 1.6864758729934692


Evaluation, Recording 1, Loss 1.378225326538086


Evaluation, Recording 2, Loss 1.6727323532104492


## Decoding with Autoregressive Model of Order 1 

## Decoding with ARD Prior

# Appendix

## Encoding with Poisson regression

In [None]:
def create_avg_and_max_matrices(X):
    avg_matrix = np.empty((0, time_bins))
    max_matrix = np.empty((0, time_bins))
    
    for session in X:
        avg_spikes = np.mean(session, axis=1)
        max_spikes = np.max(session, axis=1)
        avg_matrix = np.vstack((avg_matrix, avg_spikes))
        max_matrix = np.vstack((max_matrix, max_spikes))
        
    return avg_matrix, max_matrix

In [None]:
train_avg_spikes, train_max_spikes = create_avg_and_max_matrices(X_train)
test_avg_spikes, test_max_spikes = create_avg_and_max_matrices(X_test)

In [None]:
def create_wheel_speed_matrix(Y):
    wheel_speed_matrix = np.empty((0, time_bins))

    for session in Y:
        wheel_speed_matrix = np.vstack((wheel_speed_matrix, session))
    
    return wheel_speed_matrix

In [None]:
train_wheel_speeds = create_wheel_speed_matrix(Y_train)
test_wheel_speeds = create_wheel_speed_matrix(Y_test)

In [None]:
def train_and_predict(train_spikes, test_spikes, train_speeds, test_speeds):
    encoders = [linear_model.PoissonRegressor() for _ in range(time_bins)]
    spikes_pred = np.zeros(test_spikes.shape)

    for t in range(time_bins):
        train_speeds_dim = np.expand_dims(train_speeds[:,t], axis=1)
        test_speeds_dim = np.expand_dims(test_speeds[:,t], axis=1)
        encoders[t].fit(train_speeds_dim, train_spikes[:,t])
        spikes_pred[:,t] = encoders[t].predict(test_speeds_dim)
        
    return encoders, spikes_pred

In [None]:
avg_encoders, avg_spikes_pred = train_and_predict(train_avg_spikes, test_avg_spikes, train_wheel_speeds, test_wheel_speeds) 
max_encoders, max_spikes_pred = train_and_predict(train_max_spikes, test_max_spikes, train_wheel_speeds, test_wheel_speeds)

In [None]:
print(f'MSE for average spikes: {mean_squared_error(test_avg_spikes, avg_spikes_pred)}')
print(f'MSE for maximum spikes: {mean_squared_error(test_max_spikes, max_spikes_pred)}')
print(f'R-squared for average spikes: {r2_score(test_avg_spikes, avg_spikes_pred)}')
print(f'R-squared for maximum spikes: {r2_score(test_max_spikes, max_spikes_pred)}')

MSE for average spikes: 0.03327826055121139
MSE for maximum spikes: 4.692961438400262
R-squared for average spikes: -0.0008471324065031149
R-squared for maximum spikes: 0.0001365736782785204
