In [None]:
import numpy as np
import jax
from matplotlib import pyplot as plt
from adaptive_latents.input_sources.autoregressor import AR_K

jax.config.update('jax_enable_x64', True)
import adaptive_latents as al
rng = np.random.default_rng()

In [None]:
d = al.datasets.Naumann24uDataset(1)

In [None]:
a = d.get_rectangular_block(60)

targets = np.sort(np.unique(d.opto_stimulations.target_neuron))
opto_stims = np.zeros((a.shape[0], targets.size))
for idx, row in d.opto_stimulations.iterrows():
    assert (a.t == row.time).sum() == 1
    neuron_index = np.nonzero((targets == row.target_neuron))[0][0]
    opto_stims[(a.t == row.time), neuron_index] = 1


angles = np.sort(np.unique(d.visual_stimuli.l_angle))
visual_stimuli = np.zeros((a.shape[0], angles.size))
for idx, row in d.visual_stimuli.iterrows():
    assert (a.t == row.time).sum() == 1
    angle_index = np.nonzero((angles == row.l_angle))[0][0]
    visual_stimuli[(a.t == row.time), angle_index] = 1

stims = np.hstack([opto_stims, visual_stimuli])

ar = AR_K(k=7, rank_limit=None)
ar.fit(a, stims)


In [None]:
n_steps = 50
new_stims = np.zeros((n_steps+ar.k, stims.shape[1]))
new_stims[1, 10] = 1
starting_state = a[-ar.k:] * 0 + ar.v
out = ar.predict(starting_state, new_stims, n_steps=n_steps)
plt.plot(out);
# plt.ylim([-5,5])


In [None]:
plt.matshow(np.vstack(ar.As))

In [None]:
fig, axs = plt.subplots(nrows=5, figsize=(5, 5), tight_layout=True)

for idx, start_t in enumerate(d.opto_stimulations[d.opto_stimulations.target_neuron == targets[0]].time):
    trial = a.slice_by_time(start_t, start_t + 20) 
    axs[idx].plot(trial.t, trial.as_array() - trial[0].as_array())
    axs[idx].set_xticks([])