In [None]:
import numpy as np
import adaptive_latents as al
import matplotlib.pyplot as plt
from sklearn.decomposition import NMF
from adaptive_latents.prediction_regression_run import pred_reg_run, defaults_per_dataset
import copy

rng = np.random.default_rng()

In [None]:
d = al.datasets.Odoherty21Dataset(bin_width=0.06, drop_third_coord=True)

In [None]:
%matplotlib inline
neural_data = d.neural_data.copy()


u,s,vh = np.linalg.svd(neural_data - neural_data.mean(axis=0), full_matrices=False)

nmf = NMF(n_components=2)
nmf.fit(neural_data)

def make_pc_ok(pc):
    return pc / np.abs(pc).max() * 5


response_time = np.arange(6)
response_decay = np.exp(-response_time/1.5) 

response_directions = {
    '99th quantile': np.quantile(neural_data, q=.99, axis=0) * .5,
    'mean': np.mean(neural_data, axis=0) * 5,
    'std': np.std(neural_data, axis=0) * 2,
    'pc1': make_pc_ok(vh[0]),
    'pc2': make_pc_ok(vh[1]),
    'pc15': make_pc_ok(vh[15]),
    'pc-1': make_pc_ok(vh[-1]),
    'nmf1': make_pc_ok(nmf.components_[0]),
    'nmf2': make_pc_ok(nmf.components_[1]),
}

stim_times = []
stim_batch_times = [80, 120, 160]
stim_batch_magnitudes = [2, 1, -1]
stim_batch_samples = [neural_data.time_to_sample(t) for t in stim_batch_times]

for batch_start, magnitude in zip(stim_batch_samples, stim_batch_magnitudes):
    stim_times.append([])
    for i, (k, response_direction) in enumerate(response_directions.items()):
        start = batch_start + i * len(response_time) * 2
        response = response_decay[:,None] @ response_direction[None,:]

        neural_data[start + response_time ,:] += response * magnitude
        stim_times[-1].append(neural_data.t[start])


fig, axs = plt.subplots()
stim_batch = 0
cutout = neural_data.slice_by_time(stim_batch_times[stim_batch]-1, stim_batch_times[stim_batch]+6)
axs.matshow(cutout, extent=[0,cutout.shape[1],cutout.t[-1],cutout.t[0]], vmin=0, vmax=d.neural_data.max()*1.4, aspect='auto', origin='upper')
axs.set_xlabel("neuron #")
axs.set_ylabel("time")
axs.set_title("stimuli visualized in firing rate matrix")
axs.set_yticks(stim_times[stim_batch])
axs.set_yticklabels(list(response_directions.keys()));


In [None]:
rarr = np.vstack(list(response_directions.values()))
plt.matshow(rarr)

In [None]:
run = pred_reg_run(neural_data, d.behavioral_data[:,:0], d.neural_data, dim_red_method='sjpca', **defaults_per_dataset['odoherty21'])
run2 = pred_reg_run(d.neural_data, d.behavioral_data[:,:0], d.neural_data, dim_red_method='sjpca', **defaults_per_dataset['odoherty21'])


In [None]:
%matplotlib qt
fig, axs = plt.subplots(nrows=3, layout='tight')

latents = run.dim_reduced_data 
latents: al.ArrayWithTime

flat_stim_times = [leaf for tree in stim_times for leaf in tree]

for i, ax in enumerate(axs):
    ax.plot(latents.t, latents[:,:])
    ax.set_xticks(flat_stim_times)
    for t in flat_stim_times:
        ax.axvline(t, color='k', linestyle='--', alpha=.5)
    ax.set_xticklabels(list(response_directions.keys()) *3 )
    ax.set_xlim(np.array([-1, 7]) + stim_batch_times[i])
    ax.set_title(f"magnitude={stim_batch_magnitudes[i]}")
e = {k:al.utils.column_space_distance(v[:,None]/np.linalg.norm(v), run.pipeline.steps[-3].Q) * 180/np.pi for k,v in response_directions.items()}


In [None]:
window_width = 7

ws = []

for i in range(3):
    sss = latents.time_to_sample(stim_times[i][1])
    derived_response = (latents[sss:sss+window_width] - latents[sss-1])
    w, _, _, _ = np.linalg.lstsq(response_decay[:,None],derived_response.T)
    ws.append(w)

plt.plot(np.squeeze(ws).T)



In [None]:
fig, ax = plt.subplots()
ax.bar(e.keys(), np.array(list(e.values())))
ax.set_xlabel('method')
ax.set_ylabel('angle from proSVD space')
