In [39]:
import numpy as np
import torch
import plotly.graph_objs as go
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.multioutput import RegressorChain
from sklearn.linear_model import Ridge, ARDRegression
from models.reduced_rank_model import ReducedRankModel

In [40]:
NEURON_CLUSTERS = 3
RANK = 3
RANDOM_SEED = 42

## Processing Data

In [41]:
with open('./raw_data/full_data/names_and_shapes.txt') as file:
    lines = []
    for line in file:
        sep = line.find(',')
        name = line[:sep]
        lines.append(name)

In [42]:
X_list = []
Y_list = []

for i in range(len(lines) // 2):
    X_file = './raw_data/full_data/' + lines[i]
    X_i = np.load(X_file)
    X_list.append(X_i)
    Y_file = './raw_data/full_data/' + lines[i + len(lines) // 2]
    Y_i = np.load(Y_file)
    Y_list.append(Y_i)

In [43]:
# num_sessions = len(X_list)
num_sessions = 5
time_bins = X_list[0].shape[2]
train_split = 0.8
train_iters = 1001

In [44]:
X_train = [X[:int(train_split * X.shape[0])] for X in X_list]
X_test = [X[int(train_split * X.shape[0]):] for X in X_list]
Y_train = [Y[:int(train_split * Y.shape[0])] for Y in Y_list]
Y_test = [Y[int(train_split * Y.shape[0]):] for Y in Y_list]

### Set neuron cluster thresholds based on average spikes per session

In [45]:
# Average spikes per neuron per session
avg_spikes_lst = []
for i in range(len(X_train)):
    avg_spikes_lst.append(np.mean(np.sum(X_train[i], axis=2), axis=0))

avg_spikes = np.concatenate(avg_spikes_lst, axis=0)

In [46]:
# Bar plot of avg_spikes, ranked in descending order
avg_spikes = np.sort(avg_spikes)[::-1]
fig = go.Figure(data=[go.Bar(x=np.arange(len(avg_spikes)), y=avg_spikes)])
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5,
                  opacity=0.6)
fig.update_layout(title='Average Spikes per Neuron per Session', xaxis_title='Neuron', yaxis_title='Average Spikes')
fig.show()

In [47]:
# Create list of avg spikes thresholds for neuron clusters
def get_thresholds(avg_spikes, num_clusters, splits=None):
    thresholds = []
    if splits is not None and len(splits) == num_clusters - 1:
        thresholds.append(avg_spikes[0])
        for i in range(num_clusters - 1):
            thresholds.append(avg_spikes[int(splits[i] * len(avg_spikes))])
    else:
        for i in range(num_clusters):
            thresholds.append(avg_spikes[int(i * len(avg_spikes) / num_clusters)])
    thresholds.append(avg_spikes[-1])
    thresholds[0] += 1
    return thresholds

### Cluster neurons in training data per thresholds

In [48]:
# Cluster neurons based on avg spikes per neuron
def cluster_neurons(X_old, thresholds, num_clusters):
    X_new = []

    for X in X_old:
        avg_spikes_per_neuron = np.mean(np.sum(X, axis=2), axis=0)
        neuron_clusters = []
        for j in range(num_clusters):
            neuron_clusters.append(np.where((avg_spikes_per_neuron < thresholds[j]) &
                                            (avg_spikes_per_neuron >= thresholds[j+1]))[0])
        
        X_ = np.zeros((X.shape[0], num_clusters, X.shape[2]))
        for k in range(num_clusters):
            if len(neuron_clusters[k]) > 0:
                X_[:, k, :] = np.mean(X[:, neuron_clusters[k], :], axis=1)

        X_new.append(X_)
    
    return X_new

In [49]:
avg_spikes_thresholds = get_thresholds(avg_spikes, NEURON_CLUSTERS)
X_train_new = cluster_neurons(X_train, avg_spikes_thresholds, NEURON_CLUSTERS)
X_train_new = np.concatenate(X_train_new, axis=0)
Y_train_new = np.concatenate(Y_train, axis=0)

## Decoding with Reduced Rank Model

In [50]:
def train_step(session, X_train_session, Y_train_session, model, optimizer, loss_fn):
    optimizer.zero_grad()
    
    X_train_session = torch.from_numpy(X_train_session).float()
    Y_train_session = torch.from_numpy(Y_train_session).float()

    Y_pred = model(session, X_train_session)
    
    loss = loss_fn(Y_pred, Y_train_session)
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [51]:
reduced_rank_model = ReducedRankModel(
    num_sessions,
    [X_train[i].shape[1] for i in range(num_sessions)], 
    time_bins, 
    rank=RANK
)

optimizer = torch.optim.Adam(reduced_rank_model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

loss = []

for training_iteration in range(train_iters):
    loss_iter = [session for session in range(num_sessions)]

    for session in range(num_sessions):
        loss_item = train_step(session, X_train[session], Y_train[session], reduced_rank_model, optimizer, loss_fn)
        loss_iter[session] = loss_item
        if training_iteration % 100 == 0:
            print(f'Iteration {training_iteration}, Recording {session}, Loss {loss_item}')
    
    loss.append(loss_iter)

Iteration 0, Recording 0, Loss 1.4582931995391846
Iteration 0, Recording 1, Loss 3.116133689880371
Iteration 0, Recording 2, Loss 1.5478532314300537
Iteration 0, Recording 3, Loss 14.562522888183594
Iteration 0, Recording 4, Loss 1.293709635734558
Iteration 100, Recording 0, Loss 1.0496759414672852
Iteration 100, Recording 1, Loss 2.023165225982666
Iteration 100, Recording 2, Loss 1.5464152097702026
Iteration 100, Recording 3, Loss 2.7189645767211914
Iteration 100, Recording 4, Loss 1.0631664991378784
Iteration 200, Recording 0, Loss 1.032011866569519
Iteration 200, Recording 1, Loss 1.795853853225708
Iteration 200, Recording 2, Loss 1.5460048913955688
Iteration 200, Recording 3, Loss 1.9497973918914795
Iteration 200, Recording 4, Loss 1.0168259143829346
Iteration 300, Recording 0, Loss 1.0344609022140503
Iteration 300, Recording 1, Loss 1.7766733169555664
Iteration 300, Recording 2, Loss 1.5459978580474854
Iteration 300, Recording 3, Loss 1.9318690299987793
Iteration 300, Recording 4,

In [52]:
# Plot the loss curve using plotly graphical objects
loss = np.array(loss)

fig = go.Figure()
for session in range(num_sessions):
    fig.add_trace(go.Scatter(y=loss[:, session], name=f'Recording {session}'))
fig.update_layout(title=f'Loss Curve Rank {reduced_rank_model.rank}', xaxis_title='Iteration', yaxis_title='Loss')
fig.show()

In [53]:
def evaluate_recording(recording, X_test_recording, Y_test_recording, model, loss_fn, plot=False):
    model.eval()
    
    X_test_recording = torch.from_numpy(X_test_recording).float()
    Y_test_recording = torch.from_numpy(Y_test_recording).float()
    Y_pred = model(recording, X_test_recording)
    loss = loss_fn(Y_pred, Y_test_recording)

    if plot:
        for t in range(3):
            fig = go.Figure()
            fig.add_trace(go.Scatter(y=Y_test_recording[t], name='Actual'))
            fig.add_trace(go.Scatter(y=Y_pred[t].detach().numpy(), name='Predicted'))
            fig.update_layout(title=f'Actual and Predicted Wheel Speeds, Recording {recording}, Trial {t}', xaxis_title='Time', yaxis_title='Value')
            fig.show()

    return loss.item()

In [54]:
for recording in range(num_sessions):
    loss_item = evaluate_recording(recording, X_test[recording], Y_test[recording], reduced_rank_model, loss_fn, plot=True)
    print(f'Evaluation, Recording {recording}, Loss {loss_item}')

Evaluation, Recording 0, Loss 1.66038179397583


Evaluation, Recording 1, Loss 1.399110198020935


Evaluation, Recording 2, Loss 1.6727325916290283


Evaluation, Recording 3, Loss 3.0916919708251953


Evaluation, Recording 4, Loss 0.8285379409790039


## Decoding with Clustered Neurons by Ridge Regression

In [13]:
# Perform Ridge regression on X_train_new and Y_train_new
X_train_new = X_train_new.reshape(X_train_new.shape[0], -1)
reg = Ridge(alpha=0.5)
reg.fit(X_train_new, Y_train_new)

In [14]:
X_test_new = cluster_neurons(X_test, avg_spikes_thresholds, NEURON_CLUSTERS)
X_test_new = np.concatenate(X_test_new, axis=0)
X_test_new = X_test_new.reshape(X_test_new.shape[0], -1)
Y_test_new = np.concatenate(Y_test, axis=0)

In [15]:
# Evaluate Ridge regression model
Y_pred = reg.predict(X_test_new)
loss = mean_squared_error(Y_test_new, Y_pred)
print(f'Evaluation, Ridge Regression, Loss {loss}')

Evaluation, Ridge Regression, Loss 0.6480733012442628


In [16]:
# Plot the predicted wheel speeds against the actual wheel speeds
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_test_new[t], name='Actual'))
    fig.add_trace(go.Scatter(y=Y_pred[t], name='Predicted'))
    fig.update_layout(title=f'Actual and Predicted Wheel Speeds, Ridge Regression, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

In [17]:
# Compute MSE for different number of neuron clusters
mse_lst = []
for i in range(10):
    avg_spikes_thresholds = get_thresholds(avg_spikes, i+1)
    X_train_new = cluster_neurons(X_train, avg_spikes_thresholds, i+1)
    X_train_new = np.concatenate(X_train_new, axis=0)
    Y_train_new = np.concatenate(Y_train, axis=0)

    X_test_new = cluster_neurons(X_test, avg_spikes_thresholds, i+1)
    X_test_new = np.concatenate(X_test_new, axis=0)
    Y_test_new = np.concatenate(Y_test, axis=0)

    X_train_new = X_train_new.reshape(X_train_new.shape[0], -1)
    X_test_new = X_test_new.reshape(X_test_new.shape[0], -1)

    reg = Ridge(alpha=0.5)
    reg.fit(X_train_new, Y_train_new)

    Y_pred = reg.predict(X_test_new)
    mse = mean_squared_error(Y_test_new, Y_pred)
    mse_lst.append(mse)

In [18]:
# Plot mse against number of neuron clusters, highlighting bar height differences
fig = go.Figure(data=[go.Bar(x=np.arange(len(mse_lst)), y=mse_lst)])
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5,
                  opacity=0.6)

# set y-axis range
low = min(mse_lst)
high = max(mse_lst)
fig.update_yaxes(range=[low - 0.0005, high + 0.0005])

# relabel x-axis ticks
fig.update_layout(xaxis_ticktext=np.arange(1, 11), xaxis_tickvals=np.arange(10))

fig.update_layout(title='MSE vs Number of Neuron Clusters', xaxis_title='Number of Neuron Clusters', yaxis_title='MSE')
fig.show()

## Decoding with ARD Prior

In [19]:
avg_spikes_thresholds = get_thresholds(avg_spikes, NEURON_CLUSTERS)
X_train_new = cluster_neurons(X_train, avg_spikes_thresholds, NEURON_CLUSTERS)
X_train_new = np.concatenate(X_train_new, axis=0)
Y_train_new = np.concatenate(Y_train, axis=0)

X_test_new = cluster_neurons(X_test, avg_spikes_thresholds, NEURON_CLUSTERS)
X_test_new = np.concatenate(X_test_new, axis=0)
Y_test_new = np.concatenate(Y_test, axis=0)

X_train_new = X_train_new.reshape(X_train_new.shape[0], -1)
X_test_new = X_test_new.reshape(X_test_new.shape[0], -1)

In [20]:
# Perform ARD regression with RegressorChain on X_train_new and Y_train_new
reg = RegressorChain(ARDRegression(alpha_1=1e-4, alpha_2=1e-4, lambda_1=1, lambda_2=0.5))
reg.fit(X_train_new, Y_train_new)

In [21]:
# Evaluate ARD regression model
Y_pred = reg.predict(X_test_new)
loss = mean_squared_error(Y_test_new, Y_pred)
print(f'Evaluation, ARD Regression, Loss {loss}')

Evaluation, ARD Regression, Loss 0.6480559616654439


In [22]:
# Plot the predicted wheel speeds against the actual wheel speeds
for t in range(3):
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=Y_test_new[t], name='Actual'))
    fig.add_trace(go.Scatter(y=Y_pred[t], name='Predicted'))
    fig.update_layout(title=f'Actual and Predicted Wheel Speeds, ARD Regression, Trial {t}', xaxis_title='Time', yaxis_title='Value')
    fig.show()

In [23]:
# Compute MSE for different number of neuron clusters
mse_lst = []

for i in range(10):
    avg_spikes_thresholds = get_thresholds(avg_spikes, i+1)
    X_train_new = cluster_neurons(X_train, avg_spikes_thresholds, i+1)
    X_train_new = np.concatenate(X_train_new, axis=0)
    Y_train_new = np.concatenate(Y_train, axis=0)

    X_test_new = cluster_neurons(X_test, avg_spikes_thresholds, i+1)
    X_test_new = np.concatenate(X_test_new, axis=0)
    Y_test_new = np.concatenate(Y_test, axis=0)

    X_train_new = X_train_new.reshape(X_train_new.shape[0], -1)
    X_test_new = X_test_new.reshape(X_test_new.shape[0], -1)

    reg = RegressorChain(ARDRegression())
    reg.fit(X_train_new, Y_train_new)

    Y_pred = reg.predict(X_test_new)
    mse = mean_squared_error(Y_test_new, Y_pred)
    mse_lst.append(mse)

In [24]:
# Plot mse against number of neuron clusters, highlighting bar height differences
fig = go.Figure(data=[go.Bar(x=np.arange(len(mse_lst)), y=mse_lst)])
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5,
                  opacity=0.6)

# set y-axis range
low = min(mse_lst)
high = max(mse_lst)
fig.update_yaxes(range=[low - 0.0005, high + 0.0005])

# relabel x-axis ticks
fig.update_layout(xaxis_ticktext=np.arange(1, 11), xaxis_tickvals=np.arange(10))

fig.update_layout(title='MSE of ARD Regression Model', xaxis_title='Number of Neuron Clusters', yaxis_title='MSE')
fig.show()

# Appendix

## Decoding with Autoregressive Model of Order 1 

In [None]:
import pymc as pm

In [25]:
avg_spikes_thresholds = get_thresholds(avg_spikes, NEURON_CLUSTERS)
X_train_new = cluster_neurons(X_train, avg_spikes_thresholds, NEURON_CLUSTERS)
X_train_new = np.concatenate(X_train_new, axis=0)
Y_train_new = np.concatenate(Y_train, axis=0)

X_test_new = cluster_neurons(X_test, avg_spikes_thresholds, NEURON_CLUSTERS)
X_test_new = np.concatenate(X_test_new, axis=0)
Y_test_new = np.concatenate(Y_test, axis=0)

In [26]:
# Perform autoregressive regression on X_train_new and Y_train_new with pymc ar1 model
with pm.Model() as ar1:
    X = pm.MutableData('X', X_train_new)
    y = pm.MutableData('y', Y_train_new)
    rho = pm.Normal('rho', 0.0, 1.0)
    sigma = pm.HalfNormal('sigma', 1.0)
    pm.AR('obs', rho=rho, sigma=sigma, init_dist=pm.Normal.dist(0, 10), observed=y)
    trace = pm.sample()

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [rho, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 124 seconds.


In [27]:
# Generate predictions from the posterior
with ar1:
    pm.set_data({'X': X_test_new, 'y': Y_test_new})
    trace.extend(pm.sample_posterior_predictive(trace))

Sampling: [obs]


In [36]:
y_pred = np.mean(np.mean(trace.posterior_predictive['obs'], axis=1), axis=0)

In [38]:
# Evaluate AR1 model
loss = mean_squared_error(Y_test_new, y_pred)
print(f'Evaluation, AR1, Loss {loss}')

Evaluation, AR1, Loss 1.3254480734895129


## Encoding with Poisson regression

In [None]:
def create_avg_and_max_matrices(X):
    avg_matrix = np.empty((0, time_bins))
    max_matrix = np.empty((0, time_bins))
    
    for session in X:
        avg_spikes = np.mean(session, axis=1)
        max_spikes = np.max(session, axis=1)
        avg_matrix = np.vstack((avg_matrix, avg_spikes))
        max_matrix = np.vstack((max_matrix, max_spikes))
        
    return avg_matrix, max_matrix

In [None]:
train_avg_spikes, train_max_spikes = create_avg_and_max_matrices(X_train)
test_avg_spikes, test_max_spikes = create_avg_and_max_matrices(X_test)

In [None]:
def create_wheel_speed_matrix(Y):
    wheel_speed_matrix = np.empty((0, time_bins))

    for session in Y:
        wheel_speed_matrix = np.vstack((wheel_speed_matrix, session))
    
    return wheel_speed_matrix

In [None]:
train_wheel_speeds = create_wheel_speed_matrix(Y_train)
test_wheel_speeds = create_wheel_speed_matrix(Y_test)

In [None]:
def train_and_predict(train_spikes, test_spikes, train_speeds, test_speeds):
    encoders = [linear_model.PoissonRegressor() for _ in range(time_bins)]
    spikes_pred = np.zeros(test_spikes.shape)

    for t in range(time_bins):
        train_speeds_dim = np.expand_dims(train_speeds[:,t], axis=1)
        test_speeds_dim = np.expand_dims(test_speeds[:,t], axis=1)
        encoders[t].fit(train_speeds_dim, train_spikes[:,t])
        spikes_pred[:,t] = encoders[t].predict(test_speeds_dim)
        
    return encoders, spikes_pred

In [None]:
avg_encoders, avg_spikes_pred = train_and_predict(train_avg_spikes, test_avg_spikes, train_wheel_speeds, test_wheel_speeds) 
max_encoders, max_spikes_pred = train_and_predict(train_max_spikes, test_max_spikes, train_wheel_speeds, test_wheel_speeds)

In [None]:
print(f'MSE for average spikes: {mean_squared_error(test_avg_spikes, avg_spikes_pred)}')
print(f'MSE for maximum spikes: {mean_squared_error(test_max_spikes, max_spikes_pred)}')
print(f'R-squared for average spikes: {r2_score(test_avg_spikes, avg_spikes_pred)}')
print(f'R-squared for maximum spikes: {r2_score(test_max_spikes, max_spikes_pred)}')