In [None]:
from support import CachedStanModel, load_rates, load_stan_data, get_metric, forest_plot
import pandas as pd
import numpy as np
import pylab as pl
import arviz as az

In [None]:
import impotlib

In [None]:
rates = load_rates()
cells, data = load_stan_data('ftc', significant_only=True, exclude_silent=True)

In [None]:
import pickle
with open('fits/marmot-ftc_sr_multiplicative_exclude_silent_significant_only.pkl', 'rb') as fh:
    cells = pickle.load(fh)
    model = pickle.load(fh)
    fit = pickle.load(fh)

In [None]:
az.plot_trace(fit, ['bw_mean', 'bw_sd'])

In [None]:
az.plot_trace(fit, ['bf_mean', 'bw_mean', 'gain_mean', 'sr_mean'])

In [None]:
az.plot_trace(fit, ['bf_delta_mean', 'bw_ratio_mean', 'gain_ratio_mean', 'sr_delta_mean'])

In [None]:
def plot_fit(ax, c, i, cells, ftc):
    cell = cells[i]
    frequency = np.linspace(2.5, 17.5, 100)

    sr = c[f'sr_cell[{i+1}]']
    bf = c[f'bf_cell[{i+1}]']
    bw = c[f'bw_cell[{i+1}]']
    gain = c[f'gain_cell[{i+1}]']
    
    gauss = sr * gain * np.exp(-0.5 * np.square((frequency-bf)/bw))
    ax.plot(frequency, gauss, 'r-')
    
    sr = c[f'sr_cell_pupil[{i+1}]']
    bf = c[f'bf_cell_pupil[{i+1}]']
    bw = c[f'bw_cell_pupil[{i+1}]']
    gain = c[f'gain_cell_pupil[{i+1}]']
    
    gauss = sr * gain * np.exp(-0.5 * np.square((frequency-bf)/bw))
    ax.plot(frequency, gauss, 'g-')

    x = ftc.loc[cell, 0]
    x = x.eval('count/time').rename('rate').reset_index()
    ax.plot(x['frequency'], x['rate'], 'ro')
    
    x = ftc.loc[cell, 1]
    x = x.eval('count/time').rename('rate').reset_index()
    ax.plot(x['frequency'], x['rate'], 'go')
    ax.set_title(f'{bw:0.2f}')
    
c = fit.to_dataframe(diagnostics=False).mean()

f, axes = pl.subplots(5, 5, figsize=(10, 10))
for i, ax in enumerate(axes.ravel()):
    plot_fit(ax, c, i, cells, rates['ftc'])

    
f.tight_layout()

#for i, ax in zip(range(15), axes.ravel()):
    #plot_fit(ax, c, i, cells, ftc, sr_cell, sr_cell_pupil)

In [None]:
inference = az.from_pystan(fit=fit)
print('Max treedepth', inference.sample_stats['treedepth'].max())
print('Diverging', inference.sample_stats['diverging'].any())

In [None]:
import support
import importlib
importlib.reload(support)

ci = 90
summary = az.summary(fit, credible_interval=ci/100)

f, axes = pl.subplots(2, 2, figsize=(8, 8))

cell_metric = support.get_metric(summary, 'sr_delta_cell')
pop_metric = support.get_metric(summary, 'sr_delta_mean')
support.forest_plot(axes[0, 0], cell_metric, pop_metric, 'sr', ci)

cell_metric = support.get_metric(summary, 'bf_delta_cell')
pop_metric = support.get_metric(summary, 'bf_delta_mean')
support.forest_plot(axes[0, 1], cell_metric, pop_metric, 'bf', ci)

cell_metric = support.get_metric(summary, 'gain_ratio_cell')
pop_metric = support.get_metric(summary, 'gain_ratio_mean')
support.forest_plot(axes[1, 0], cell_metric, pop_metric, 'gain', ci, 1)

cell_metric = support.get_metric(summary, 'bw_ratio_cell')
pop_metric = support.get_metric(summary, 'bw_ratio_mean')
support.forest_plot(axes[1, 1], cell_metric, pop_metric, 'bw', ci, 1)

f.tight_layout()
f.savefig('reports/tuning/summary.eps')
f.savefig('reports/tuning/summary.pdf')
f.savefig('reports/tuning/summary.png')