In [None]:
import arviz as az
import numpy as np
import seaborn as sns
from scipy import stats
import pandas as pd
import pickle
import pylab as pl
%matplotlib inline

from support import load_rates, get_metric, forest_plot

In [None]:
filename = 'fits/marmot-rl_sr.pkl'

with open(filename, 'rb') as fh:
    cells = pickle.load(fh)
    model = pickle.load(fh)
    fit = pickle.load(fh)

In [None]:
az.plot_trace(fit, ['sr_mean', 'sr_delta_mean', 'slope_mean', 'slope_delta_mean', 'threshold_mean', 'threshold_delta_mean', 'threshold_delta_sd']);

In [None]:
inference = az.from_pystan(fit=fit)
print('Max treedepth', inference.sample_stats['treedepth'].max())
print('Diverging', inference.sample_stats['diverging'].any())

In [None]:
ci = 90
summary = az.summary(fit, credible_interval=ci/100)

f, axes = pl.subplots(1, 3, figsize=(12, 4))

cell_metric = get_metric(summary, 'sr_delta_cell')
pop_metric = get_metric(summary, 'sr_delta_mean')
forest_plot(axes[0], cell_metric, pop_metric, 'sr', ci)

cell_metric = get_metric(summary, 'slope_delta_cell')
pop_metric = get_metric(summary, 'slope_delta_mean')
forest_plot(axes[1], cell_metric, pop_metric, 'slope', ci)

cell_metric = get_metric(summary, 'threshold_delta_cell')
pop_metric = get_metric(summary, 'threshold_delta_mean')
forest_plot(axes[2], cell_metric, pop_metric, 'threshold', ci)

f.savefig('reports/rate_level/summary.eps')
f.savefig('reports/rate_level/summary.pdf')
f.savefig('reports/rate_level/summary.png')

In [None]:
cols = [
    'sr_mean',
    'slope_mean',
    'threshold_mean',
    'sr_delta_mean',
    'slope_delta_mean',
    'threshold_delta_mean',
]
x = summary[cols].to_dataframe().T
x.to_csv('reports/rate_level/population_metrics.csv')

cols = [
    'sr_cell',
    'slope_cell',
    'threshold_cell',
    'sr_delta_cell',
    'slope_delta_cell',
    'threshold_delta_cell',
]

index = pd.Index(cells, name='cellid')
result = {}
for c in cols:
    r = summary[c].to_series().unstack('metric')
    r.index = index
    result[c] = r
result = pd.concat(result, names=['coefficient'])
result.to_csv('reports/rate_level/cell_metrics.csv')
x = result['mean'].unstack('coefficient')
x.to_csv('reports/rate_level/cell_metrics_mean_only.csv')

In [None]:
def plot_raw_data(e, s, ax):
    x = e['level'].tolist()
    y = e.eval('count/time').tolist()
    #x = [0, 0] + e['level'].tolist()
    #y = s.eval('count/time').tolist() + e.eval('count/time').tolist()
    #size = np.array(s['time'].tolist() + e['time'].tolist())
    #color = s['pupil'].tolist() + e['pupil'].tolist()
    #size = 100 * size/size.mean()
    color = e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    ax.scatter(x, y, 10, color, alpha=0.5)
    #ax.plot(x, y, 'o', color=color, alpha=0.5)
    

def plot_fit(er, sr, fit, i, cell, ax):
    c = fit.to_dataframe(diagnostics=False).mean()
    level = np.arange(0, 80)

    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    plot_raw_data(e, s, ax)
    s = s.set_index('pupil').eval('count/time')
    ax.axhline(s.loc[0], ls=':', color='seagreen')
    ax.axhline(s.loc[1], ls=':', color='orchid')
    
    i = i + 1
    sr = c[f'sr_cell[{i}]']
    slope = c[f'slope_cell[{i}]']
    threshold = c[f'threshold_cell[{i}]']
    sr_pupil_delta = c[f'sr_delta_cell[{i}]']
    slope_pupil_delta = c[f'slope_delta_cell[{i}]']
    threshold_pupil_delta = c[f'threshold_delta_cell[{i}]']

    pred = slope * (level - threshold) + sr
    pred[level <= threshold] = sr
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='seagreen')

    sr_pupil = sr + sr_pupil_delta
    slope_pupil = slope + slope_pupil_delta
    threshold_pupil = threshold + threshold_pupil_delta

    pred = slope_pupil * (level - threshold_pupil) + sr_pupil
    pred[level <= threshold_pupil] = sr_pupil
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
rates = load_rates()
f, ax = pl.subplots(1, 1, figsize=(5, 5))
for i, cell in enumerate(cells):
    ax.cla()
    plot_fit(rates['rlf'], rates['sr'], fit, i, cell, ax)
    t = f'{cell}'
    ax.set_title(t)
    ax.set_xlabel('Stim. level (dB SPL)')
    ax.set_ylabel('Rate (sp/sec)')
    f.savefig(f'reports/rate_level/cells/{cell}.png')
    f.savefig(f'reports/rate_level/cells/{cell}.pdf')
    