In [None]:
import pandas as pd
import seaborn as sns
from scipy import stats
import pylab as pl
import numpy as np
from lbhb.psychometric import CachedStanModel
import arviz as az
import pickle
%matplotlib inline

from support import load_rates
rates = load_rates()


In [None]:
def plot_raw_data(e, s, ax):
    x = [0, 0] + e['level'].tolist()
    y = s.eval('spont_count/spont_time').tolist() + e.eval('rlf_count/rlf_time').tolist()
    size = np.array(s['spont_time'].tolist() + e['rlf_time'].tolist())
    color = s['pupil'].tolist() + e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    size = 100 * size/size.mean()
    ax.scatter(x, y, size, color, alpha=0.5)

In [None]:
model = CachedStanModel('rl_nosr.stan')

In [None]:
refit = False
sig_only = False
pkl_name = 'sig_fit_nosr.pkl' if sig_only else 'fit_nosr.pkl'

In [None]:
if refit:
    if sig_only:
        e = er.loc[er['significant'] == True].reset_index()
        s = sr.loc[sr['significant'] == True].reset_index()
    else:
        e = er.reset_index()
        s = sr.reset_index()
        
    cells = e['cellid'].unique()
    cell_map = {c: i+1 for i, c in enumerate(cells)}
    cell_index = e['cellid'].apply(cell_map.get).values
    
    s['cell_index'] = s['cellid'].map(cell_map.get)
    s = s.set_index(['pupil', 'cell_index'])[['spont_count', 'spont_time']] \
        .sort_index().eval('spont_count/spont_time').unstack('pupil').reset_index()
    sr_cell = s[0].values
    sr_cell_pupil = s[1].values
    
    data = {
        'n': len(e),
        'n_cells': len(cells),
        'cell_index': cell_index.astype('i'),
        'level': e['level'].values,
        'pupil': e['pupil'].values.astype('i'),
        'time': e['rlf_time'].values,
        'count': e['rlf_count'].values.astype('i'),
        'sr': sr_cell,
        'sr_pupil': sr_cell_pupil,
    }

    fit = model.sampling(data, n_jobs=8, iter=20000, control={'adapt_delta': 0.9, 'max_treedepth': 20})

    
    with open(pkl_name, 'wb') as f:
        pickle.dump(model, f)
        pickle.dump(fit, f)
        
else:
    with open(pkl_name, 'rb') as fh:
        model = pickle.load(fh)
        fit = pickle.load(fh)

In [None]:
az.plot_trace(fit, ['slope_mean', 'slope_sd', 'slope_delta_meana', 'slope_delta_sd', 'threshold_mean', 'threshold_sd', 'threshold_delta_mean', 'threshold_delta_sd']);

In [None]:
summary = az.summary(fit, credible_interval=0.9)

In [None]:
az.plot_kde(fit['slope_mean'], fit['threshold_mean'])

In [None]:
slope = get_metric(summary, 'slope_delta_cell')['mean']
threshold = get_metric(summary, 'threshold_delta_cell')['mean']
az.plot_kde(slope, threshold)
pl.plot(slope, threshold, 'wo')
stats.pearsonr(slope, threshold)

In [None]:
def get_color(row, lb_label, ub_label):
    if row['gelman-rubin statistic'] > 1.1:
        return 'red'
    if (row[lb_label] > 0) or (row[ub_label] < 0):
        return 'green'
    return 'gray'


def plot(ax, cell_metric, pop_metric, measure):
    cell_metric = cell_metric.sort_values('mean')
    ci_label = ['hpd 5.00%', 'hpd 95.00%']
    
    color = get_color(pop_metric, *ci_label)
    ax.axvspan(*pop_metric[ci_label], facecolor=color, alpha=0.5)
    ax.axvline(pop_metric['mean'], color=color)
    n_sig = 0
    for i, (_, row) in enumerate(cell_metric.iterrows()):
        color = get_color(row, *ci_label)
        if color == 'green':
            n_sig += 1
        ax.plot(row[ci_label], [i, i], '-', color=color, lw=0.5)
        ax.plot(row[['mean']], [i], 'o', color=color)
        
        #color = 'k' if row['good_cell'] else 'r'
        #ax.plot(row[['mean']], [-5], '+', color=color)
        
    title = f'Change in {measure} (lg. re sm. pupil)'
    n_sig = f'{n_sig} sig. out of {len(cell_metric)}'
    pop_stat = f'Mean change {pop_metric["mean"]:.2f} (90% CI {pop_metric[ci_label[0]]:.2f} to {pop_metric[ci_label[1]]:.2f})'
    ax.set_xlabel(f'{title}\n{n_sig}\n{pop_stat}')
    sns.despine(ax=ax, top=True, left=True, right=True, bottom=False)
    ax.yaxis.set_ticks_position('none')
    ax.yaxis.set_ticks([])
    ax.grid()
    return ax


def get_metric(summary, metric, index=None):
    x = summary[metric].to_series()
    if x.index.nlevels == 2:
        x = x.unstack('metric')
    if index is not None:
        x.index = index
    return x

In [None]:
import importlib
importlib 

In [None]:
from support import load_rates
data = load_rates()

In [None]:
f, axes = pl.subplots(1, 2, figsize=(8, 3))

cell_metric = get_metric(summary, 'slope_delta_cell')
pop_metric = get_metric(summary, 'slope_delta_mean')
plot(axes[0], cell_metric, pop_metric, 'slope')

cell_metric = get_metric(summary, 'threshold_delta_cell')
pop_metric = get_metric(summary, 'threshold_delta_mean')
plot(axes[1], cell_metric, pop_metric, 'threshold')

In [None]:
def plot_fit(er, sr, fit, cell_map, cell, ax):
    c = fit.to_dataframe(diagnostics=False).mean()
    level = np.arange(-10, 80)
    i = cell_map[cell]

    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    plot_raw_data(e, s, ax)
    
    s = s.set_index('pupil').eval('spont_count/spont_time')
    sr_regular = s.loc[0]
    sr_pupil = s.loc[1]

    slope = c[f'slope_cell[{i}]']
    threshold = c[f'threshold_cell[{i}]']
    slope_pupil_delta = c[f'slope_delta_cell[{i}]']
    threshold_pupil_delta = c[f'threshold_delta_cell[{i}]']

    pred = slope * (level - threshold) + sr_regular
    pred[level <= threshold] = sr_regular
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='seagreen')

    slope_pupil = slope + slope_pupil_delta
    threshold_pupil = threshold + threshold_pupil_delta

    pred = slope_pupil * (level - threshold_pupil) + sr_pupil
    pred[level <= threshold_pupil] = sr_pupil
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
    
f, axes = pl.subplots(5, 5, figsize=(10, 10))
for cell, ax in zip(cells, axes.ravel()):
    plot_fit(er, sr, fit, cell_map, cell, ax)
    i = cell_map[cell] - 1
    lb, m, ub = np.percentile(fit['threshold_delta_cell'][:, i], [2.5, 50.0, 97.5])
    
    ax.set_title(f'{cell}\n({lb:.0f} | {m:.0f} | {ub:.0f})')
    
f.tight_layout()

In [None]:
import matplotlib as mp
from scipy import stats

def plot(ax, lb, m, ub, measure, ensemble_lb, ensemble_m, ensemble_ub):
    stat = stats.wilcoxon(m)
    mask = (ub < 0) | (lb > 0)
    i = np.argsort(m)
    lb = lb[i]
    m = m[i]
    ub = ub[i]

    #ax.axvline(0)
    t = ax.get_xaxis_transform()
    fc = 'lightgreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'gray'
    c = 'seagreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'k'
    ax.axvline(ensemble_m, color=c)
    r = mp.patches.Rectangle((ensemble_lb, 0), ensemble_ub-ensemble_lb, 1,
                             transform=t, ec='none', fc=fc)
    ax.add_patch(r)
    n_sig =0 
    for i, (a, d, b) in enumerate(zip(lb, m, ub)):
        c = 'g' if (b < 0) or (a > 0) else 'k'
        n_sig += 1 if (b < 0) or (a > 0) else 0
        ax.plot([a, b], [i, i], '-', color=c, lw=0.5)
        ax.plot([d], [i], 'o', color=c)
        
    ax.set_xlabel(f'Change in {measure} (lg. re sm. pupil)\n{n_sig} sig. out of {len(lb)}\nMedian change {ensemble_m:.2f}\nWilcoxon p={stat.pvalue:0.4f}')
    sns.despine(ax=ax, top=True, left=True, right=True, bottom=False)
    ax.yaxis.set_ticks_position('none')
    ax.yaxis.set_ticks([])
    ax.grid()
    return ax

f, axes = pl.subplots(1, 2, figsize=(10, 5))

thresholds = pd.Series(fit['threshold_cell'].mean(axis=0), index=cell_map.keys())
mask = (thresholds < -20) | (thresholds > 65)
mask = ~mask
# hack
mask[:] = True

pct = [5, 50, 95]
x = fit['slope_delta_cell'][:, mask]
e = fit['slope_delta_mean']
lb, m, ub = np.percentile(x, pct, axis=0)
elb, em, eub = np.percentile(e, pct, axis=0)
plot(axes[0], lb, m, ub, 'slope', elb, em, eub)

x = fit['threshold_delta_cell'][:, mask]
e = fit['threshold_delta_mean']
lb, m, ub = np.percentile(x, pct, axis=0)
elb, em, eub = np.percentile(e, pct, axis=0)
plot(axes[1], lb, m, ub, 'threshold', elb, em, eub)

In [None]:
az.plot_trace(fit, ['slope_pupil_delta_mean', 'threshold_pupil_delta_mean']);