In [None]:
import numpy as np
import jax
from matplotlib import pyplot as plt
from adaptive_latents.input_sources.autoregressor import AR_K
from tqdm.notebook import tqdm
import pandas as pd
from typing import Literal

jax.config.update('jax_enable_x64', True)
import adaptive_latents as al
rng = np.random.default_rng()

In [None]:
d = al.datasets.Naumann24uDataset(1)

In [None]:
a = d.get_rectangular_block(70)

targets = np.sort(np.unique(d.opto_stimulations.target_neuron))
opto_stims = np.zeros((a.shape[0], targets.size))
for idx, row in d.opto_stimulations.iterrows():
    assert (a.t == row.time).sum() == 1
    neuron_index = np.nonzero((targets == row.target_neuron))[0][0]
    opto_stims[(a.t == row.time), neuron_index] = 1


angles = np.sort(np.unique(d.visual_stimuli.l_angle))
visual_stimuli = np.zeros((a.shape[0], angles.size))
for idx, row in d.visual_stimuli.iterrows():
    assert (a.t == row.time).sum() == 1
    angle_index = np.nonzero((angles == row.l_angle))[0][0]
    visual_stimuli[(a.t == row.time), angle_index] = 1

stims = al.ArrayWithTime(np.hstack([opto_stims, visual_stimuli]), a.t)

ar = AR_K(k=20, rank_limit=None)
ar.fit(a, stims)

In [None]:
def evaluate(ar):
    window_time = 13 # seconds
    n_steps = int(window_time//a.dt) + 1
    errors = []

    for targets_index in range(12):
        for idx, start_t in enumerate(d.opto_stimulations[d.opto_stimulations.target_neuron == targets[targets_index]].time):
            trial = a.slice_by_time(start_t, start_t + window_time)

            pre_trial = a.slice_by_time(start_t - ar.k * a.dt - 1, start_t)
            new_stims = np.zeros((n_steps+ar.k, stims.shape[1]))
            new_stims[ar.k, targets_index] = 1
            starting_state = pre_trial[-ar.k:]
            prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)

            errors.append(trial - prediction)
    return ((np.array(errors)**2).mean())


In [None]:
ks = np.arange(1, 30)
k_results = []

for k in tqdm(ks):
    ar = AR_K(k=k, rank_limit=None)
    ar.fit(a, stims)
    k_results.append(evaluate(ar))


best_k = ks[np.argmin(k_results)]
ar = AR_K(k=best_k, rank_limit=None)
ar.fit(a, stims)
full_rank_baseline = evaluate(ar)



In [None]:
fig,ax = plt.subplots()


ax.plot(ks, k_results)
ax.axhline(y=full_rank_baseline, color='k')
ax.set_xlabel('number of autoregression steps')
ax.set_ylabel('mse')


In [None]:
rank_limits = np.arange(1, 30)
rl_results = []

for rl in tqdm(rank_limits):
    ar = AR_K(k=best_k, rank_limit=rl)
    ar.fit(a, stims)
    rl_results.append(evaluate(ar))



In [None]:
fig,ax = plt.subplots()
ax.plot(rank_limits, rl_results)

ax.axhline(y=full_rank_baseline, color='k')
ax.set_xlabel('rank constraint')
ax.set_ylabel('mse')



In [None]:
ar = AR_K(k=30, rank_limit=None)
ar.fit(a, stims)

fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(10, 10), tight_layout=True, sharey=True)
window_time = 13 # seconds
n_steps = int(window_time//a.dt) + 1
targets_index = 7
errors = []
truths = []

for idx, start_t in enumerate(d.opto_stimulations[d.opto_stimulations.target_neuron == targets[targets_index]].time):
    trial = a.slice_by_time(start_t, start_t + window_time) 

    pre_trial = a.slice_by_time(start_t - ar.k * a.dt - 1, start_t)
    new_stims = np.zeros((n_steps+ar.k, stims.shape[1]))
    new_stims[ar.k, targets_index] = 1
    starting_state = pre_trial[-ar.k:]
    prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)


    axs[idx,0].plot(trial)
    axs[idx,1].plot(prediction)

    truths.append(trial)
    errors.append(prediction - trial)

In [None]:
window_time = 13

_groups = Literal['opto', 'vis', 'rand']
def get_responses(n=None, group:_groups='opto', neuron=0, estimated=False):
    if group == 'opto':
        times = d.opto_stimulations[d.opto_stimulations.target_neuron == targets[n]].time
    elif group == 'vis':
        times = d.visual_stimuli[d.visual_stimuli.l_angle == angles[n]].time
    elif group == 'rand':
        times = rng.uniform(low=a.t[0] + ar.k * a.dt, high=a.t[-1] - window_time, size=5)
    else:
        raise ValueError(f'Unknown group {group}')

    n_steps = np.floor(window_time/a.dt).astype(int)
    if estimated:
        ret = []
        for t in times:
            try:
                idx = int(np.nonzero(a.t == t)[0][0])
            except IndexError:
                idx = np.argmin(np.abs(a.t - t))
            pre_trial = a.slice(idx-ar.k, idx)
            new_stims = stims.slice(idx-ar.k, int(idx + window_time//a.dt))
            starting_state = pre_trial[-ar.k:]
            prediction = ar.predict(starting_state, new_stims, n_steps=n_steps)
            ret.append(prediction[:,neuron])
        ret = np.column_stack(ret)
    else:
        ret = []
        for t in times:
            ret.append(a.slice_by_time(t, t + window_time)[:n_steps,neuron])
        ret = np.column_stack(ret)
        
    return ret

def modulation_statistic(x):
    differences = []
    for split in range(6,20):
        differences.append(x[split:].mean() - x[:split].mean())
    return max(differences)


def get_modulation(group:_groups='opto', neuron=0, estimated=False):
    tgts = {'opto':targets, 'vis':angles}.get(group)
    def f(g):
        return max([modulation_statistic(get_responses(n=i, group=g, neuron=neuron, estimated=estimated)) for i in range(len(tgts))])
    stat = f(group)
    null_samples = []
    for _ in range(200):
        null_samples.append(f('rand'))
    return (np.array(null_samples) < stat).mean()

get_modulation(group='opto', neuron=31, estimated=True)

In [None]:
modulations = []
for neuron in tqdm(range(a.shape[1])):
    modulations.append([])
    for group in ['opto', 'vis']:
        for estimated in [True, False]:
            m = get_modulation(group=group, neuron=neuron, estimated=estimated)
            modulations[-1].append(m)

modulations = np.array(modulations)


In [None]:
df = pd.DataFrame({
    'opto_est': modulations[:, 0],
    'opto_real': modulations[:, 1],
    'vis_est': modulations[:, 2],
    'vis_real': modulations[:, 3],
})

In [None]:
%matplotlib inline

plt.hist(df.opto_est, bins=20);

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, figsize=(10,5))
axs[0].scatter(df.opto_real, df.vis_real)
axs[0].set_xlabel('modulation to opto stimuli')
axs[0].set_ylabel('modulation to visual stimuli')
axs[0].set_title("'real' modulations")

axs[1].scatter(df.opto_est, df.vis_est)
axs[1].set_xlabel('modulation to opto stimuli')
axs[1].set_title('modulations from simulations')



In [None]:
%matplotlib inline
fig, ax = plt.subplots(figsize=(5,5))

ax.scatter(df.opto_real, df.vis_real)
ax.scatter(df.opto_est, df.vis_est)
ax.plot(np.vstack([df.opto_real,df.opto_est]), np.vstack([df.vis_real,df.vis_est]), color='k', alpha=.1);


In [None]:
%matplotlib inline

fig, ax = plt.subplots(figsize=(15,3))
ax.plot(a.t, a[:,np.argsort(df.opto_real)[:5]])
for t in d.opto_stimulations.time:
    ax.axvline(t, c='k')

for t in d.visual_stimuli.time:
    ax.axvline(t, c='gray')

ax.set_xlim([890, 1510])



In [None]:

def get_errors(group:_groups='opto'):
    l = []
    tgts = {'opto':targets, 'vis':angles, 'rand': range(3*max(len(targets), len(angles)))}.get(group)

    for neuron in range(a.shape[1]):
        l.append([])
        for n in range(len(tgts)):
            estimate = get_responses(n, group=group, neuron=neuron, estimated=True) 
            observed = get_responses(n, group=group, neuron=neuron, estimated=False)
            denominator = max(((observed.mean() - observed)**2).mean(), 0.01)
            l[-1].append(((estimate - observed)**2).mean()/denominator)
        
    l = np.array(l)
    l = np.log(l.mean(axis=1))
    return l

df['opto_errors'] = get_errors(group='opto')
df['vis_errors']= get_errors(group='vis')
df['rand_errors'] = get_errors(group='rand')


In [None]:
plt.scatter(df.rand_errors, df.opto_errors)
plt.xlabel('error over random times')
plt.ylabel('error for stim trials')


In [None]:
plt.scatter(df.opto_real, df.vis_real)
plt.xlabel('modulation for optogenetic stimulations')
plt.ylabel('prediction error for random times')
