# Prepare

In [1]:
#!pip install -e ..
from altair.vegalite.v4.api import FacetChart, Chart, LayerChart
from sklearn.preprocessing import LabelEncoder
import bulwark.checks as ck
from bayes_window.generative_models import generate_fake_spikes
from bayes_window.visualization import fake_spikes_explore, plot_data_and_posterior
from bayes_window import models
from bayes_window.fitting import fit_numpyro
from bayes_window.utils import add_data_to_posterior

trans = LabelEncoder().fit_transform



# Make some data


In [2]:
df, df_monster, index_cols, firing_rates = generate_fake_spikes(n_trials=2,
                                                                n_neurons=8,
                                                                n_mice=4,
                                                                dur=7, )

In [3]:
import numpy as np
df['log_isi']=np.log10(df['isi'])

In [14]:
import altair as alt
from bayes_window import visualization,utils
from importlib import reload
reload(visualization)
reload(utils)
y='log_isi'
df['neuron']=df['neuron'].astype(int)
ddf, dy = utils.make_fold_change(df,
                         y=y,
                         index_cols=('stim', 'mouse_code', 'neuron'),
                         condition_name='stim',
                         do_take_mean=True)

visualization.plot_data(ddf, x='neuron',y=dy, color='mouse_code',add_box=True, )

# Estimate model

In [15]:
#y = list(set(df.columns) - set(index_cols))[0]
trace = fit_numpyro(y=df[y].values,
                    stim_on=(df['stim']).astype(int).values,
                    treat=trans(df['neuron']),
                    subject=trans(df['mouse']),
                    progress_bar=True,
                    model=models.model_hier_normal_stim,
                    n_draws=100, num_chains=1, )

sample: 100%|██████████| 1100/1100 [00:06<00:00, 173.39it/s, 31 steps of size 1.87e-01. acc. prob=0.90]


n(Divergences) = 0


# Add data back

In [6]:
reload(utils)
df_both = utils.add_data_to_posterior(df,
                                trace=trace,
                                y=y,
                                index_cols=['neuron', 'stim', 'mouse_code', ],
                                condition_name='stim',
                                b_name='b_stim_per_condition',  # for posterior
                                group_name='neuron'  # for posterior
                                )

# Plot data and posterior

In [7]:
reload(visualization)
visualization.plot_posterior(df_both,alt.Chart(df_both),y=f'{y} diff', x='neuron',color='mouse_code',)

In [8]:
visualization.plot_posterior(df_both,alt.Chart(df_both),y=f'{y} diff', x='neuron',color='mouse_code',)+\
visualization.plot_data(df_both, alt.Chart(df_both), x='neuron',y=f'{y} diff', color='mouse_code',add_box=True)

In [9]:
reload(visualization)
visualization.plot_data_and_posterior(df_both,
                                      y=f'{y} diff', x='neuron',color='mouse_code',
                                      title=y,hold_for_facet=False,add_box=True)