In [None]:
import arviz as az
import numpy as np
import seaborn as sns
from scipy import stats
import pandas as pd
import pickle
import pylab as pl
%matplotlib inline

from support import get_metric, forest_plot, load_rates

In [None]:
filename = 'fits/rl_sr_exclude_silent_significant_only.pkl'

with open(filename, 'rb') as fh:
    cells = pickle.load(fh)
    model = pickle.load(fh)
    fit = pickle.load(fh)

In [None]:
az.plot_trace(fit, ['sr_mean', 'sr_delta_mean', 'slope_mean', 'slope_delta_mean', 'threshold_mean', 'threshold_delta_mean', 'threshold_delta_sd']);

In [None]:
summary = az.summary(fit, credible_interval=0.9)

In [None]:
f, axes = pl.subplots(1, 3, figsize=(12, 4))

cell_metric = get_metric(summary, 'sr_delta_cell')
pop_metric = get_metric(summary, 'sr_delta_mean')
forest_plot(axes[0], cell_metric, pop_metric, 'sr')

cell_metric = get_metric(summary, 'slope_delta_cell')
pop_metric = get_metric(summary, 'slope_delta_mean')
forest_plot(axes[1], cell_metric, pop_metric, 'slope')

cell_metric = get_metric(summary, 'threshold_delta_cell')
pop_metric = get_metric(summary, 'threshold_delta_mean')
forest_plot(axes[2], cell_metric, pop_metric, 'threshold')

#f.savefig('rl_sr.eps')

In [None]:
cols = [
    'sr_mean',
    'slope_mean',
    'threshold_mean',
    'sr_delta_mean',
    'slope_delta_mean',
    'threshold_delta_mean',
]
summary[cols].to_dataframe().T.to_csv('rl_sr_population_metrics.csv')

cols = [
    'sr_cell',
    'slope_cell',
    'threshold_cell',
    'sr_delta_cell',
    'slope_delta_cell',
    'threshold_delta_cell',
]

index = pd.Index(cells, name='cellid')
result = {}
for c in cols:
    r = summary[c].to_series().unstack('metric')
    r.index = index
    result[c] = r
result = pd.concat(result, names=['coefficient'])
result.to_csv('rl_sr_cell_metrics.csv')
result['mean'].unstack('coefficient').to_csv('rl_sr_cell_metrics_mean_only.csv')

In [None]:
def plot_raw_data(e, s, ax):
    x = e['level'].tolist()
    y = e.eval('count/time').tolist()
    #x = [0, 0] + e['level'].tolist()
    #y = s.eval('count/time').tolist() + e.eval('count/time').tolist()
    #size = np.array(s['time'].tolist() + e['time'].tolist())
    #color = s['pupil'].tolist() + e['pupil'].tolist()
    #size = 100 * size/size.mean()
    color = e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    ax.scatter(x, y, 10, color, alpha=0.5)
    #ax.plot(x, y, 'o', color=color, alpha=0.5)
    

def plot_fit(er, sr, fit, i, cell, ax):
    c = fit.to_dataframe(diagnostics=False).mean()
    level = np.arange(-80, 80)

    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    plot_raw_data(e, s, ax)
    s = s.set_index('pupil').eval('count/time')
    ax.axhline(s.loc[0], ls=':', color='seagreen')
    ax.axhline(s.loc[1], ls=':', color='orchid')
    
    i = i + 1
    sr = c[f'sr_cell[{i}]']
    slope = c[f'slope_cell[{i}]']
    threshold = c[f'threshold_cell[{i}]']
    sr_pupil_delta = c[f'sr_delta_cell[{i}]']
    slope_pupil_delta = c[f'slope_delta_cell[{i}]']
    threshold_pupil_delta = c[f'threshold_delta_cell[{i}]']

    pred = slope * (level - threshold) + sr
    pred[level <= threshold] = sr
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='seagreen')

    sr_pupil = sr + sr_pupil_delta
    slope_pupil = slope + slope_pupil_delta
    threshold_pupil = threshold + threshold_pupil_delta

    pred = slope_pupil * (level - threshold_pupil) + sr_pupil
    pred[level <= threshold_pupil] = sr_pupil
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
    
f, axes = pl.subplots(5, 5, figsize=(10, 10))
for i, (cell, ax) in enumerate(zip(cells, axes.ravel())):
    plot_fit(er, sr, fit, i, cell, ax)
    #i = cell_map[cell] - 1
    t = f'{cell}'
    #lb, m, ub = np.percentile(fit['threshold_cell'][:, i], [2.5, 50.0, 97.5])
    #t = f'{t}\nTh ({lb:.0f} | {m:.0f} | {ub:.0f})'
    #lb, m, ub = np.percentile(fit['threshold_delta_cell'][:, i], [2.5, 50.0, 97.5])
    #t = f'{t}\n$\Delta$th ({lb:.0f} | {m:.0f} | {ub:.0f})'
    ax.set_title(t)
    
f.tight_layout()

In [None]:
summary['threshold_sd'].to_series()

In [None]:
threshold_cell = summary['threshold_cell'].to_series().loc['mean']
pl.hist(threshold_cell, bins=50);

In [None]:
rates = load_rates()
er = rates['rlf']
sr = rates['sr']

#if exclude_silent:
#    spike_counts = er['count'].groupby(['cellid', 'pupil']).sum()
#    m = spike_counts == 0
#    exclude = spike_counts.loc[m].unstack().index.values.tolist()
#    sr = sr.drop(exclude)
#    er = er.drop(exclude)
#
#if significant_only:
#    er = er.query('significant')
#    sr = sr.query('significant')

e = er.reset_index()
s = sr.reset_index()

cells = e['cellid'].unique()
cell_map = {c: i+1 for i, c in enumerate(cells)}
e['cell_index'] = e['cellid'].apply(cell_map.get).values
s['cell_index'] = s['cellid'].map(cell_map.get)
s = s.set_index(['cell_index', 'pupil']) \
    .sort_index()[['count', 'time']].unstack()

_, indices = np.unique(e[['cell_index', 'pupil']].values.tolist(), \
                       axis=0, return_index=True)
indices = np.r_[indices, [len(e), -1]]
data_cell_index = np.array(indices).reshape((-1, 2)) + 1


In [None]:
_, i = np.unique(e[['cell_index', 'pupil']].values.tolist(), axis=0, return_index=True)