In [None]:
import pandas as pd
import seaborn as sns
from scipy import stats
import pylab as pl
import numpy as np
from lbhb.psychometric import CachedStanModel
import arviz as az
import pickle
%matplotlib inline

refit = False

In [None]:
rlf = pd.read_csv('rate_level_functions_for_bburan.csv')
rlf.columns = [s.replace(' ', '') for s in rlf.columns]
cols = ['pupil', 'level', 'rlf_count', 'rlf_time', 'spont_count', 'spont_time']
rlf = pd.wide_to_long(rlf, cols, 'cellid', 'idx', sep='_').dropna()
rlf['pupil'] -= 1

# evoked rate
er = rlf.reset_index().set_index(['cellid', 'pupil', 'level'], verify_integrity=True)[['rlf_count', 'rlf_time']]
# spont rate
sr = rlf.groupby(['cellid', 'pupil'])[['spont_count', 'spont_time']].first()

pl.figure()
sr.eval('spont_count/spont_time').groupby('pupil').hist()
pl.figure()
er.eval('rlf_count/rlf_time').groupby('level').mean().plot()

er.sort_index(inplace=True)
sr.sort_index(inplace=True)

cells = er.index.get_level_values('cellid').unique()
cell_map = {c: i+1 for i, c in enumerate(cells)}

In [None]:
sr_model = CachedStanModel('sr_fit.stan')
sr_unstacked =sr.unstack('pupil')
data = {
    'n_cells': len(sr_unstacked),
    'spike_count': sr_unstacked['spont_count'][0].astype('i').values,
    #'spike_count_lg': sr_unstacked['spont_count'][1].astype('i'),
    'sample_time': sr_unstacked['spont_time'][0].values,
    #'sample_time_lg': sr_unstacked['spont_time'][1],
}
print(data)
#sr_fit = sr_model.sampling(data)

In [None]:
sr_model.sampling(data)

In [None]:
def plot_raw_data(e, s, ax):
    x = [0, 0] + e['level'].tolist()
    y = s.eval('spont_count/spont_time').tolist() + e.eval('rlf_count/rlf_time').tolist()
    size = np.array(s['spont_time'].tolist() + e['rlf_time'].tolist())
    color = s['pupil'].tolist() + e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    size = 100 * size/size.mean()
    ax.scatter(x, y, size, color, alpha=0.5)

In [None]:
model = CachedStanModel('hockey_stick.stan')

In [None]:
if refit:
    e = er.reset_index()
    s = sr.reset_index()
    cell_index = e['cellid'].apply(cell_map.get).values

    s['cell_index'] = s['cellid'].map(cell_map.get)
    s = s.set_index(['pupil', 'cell_index'])[['spont_count', 'spont_time']] \
        .sort_index().reset_index()

    data = {
        'n': len(e),
        'n_cells': len(cells),
        'cell_index': cell_index.astype('i'),
        'evoked_level': e['level'].values,
        'pupil': e['pupil'].values.astype('i'),
        'evoked_time': e['rlf_time'].values,
        'evoked_count': e['rlf_count'].values.astype('i'),
        'sr_count': s['spont_count'].values.astype('i'),
        'sr_time': s['spont_time'].values,
    }

    fit = model.sampling(data, iter=10000)

    with open('fit.pkl', 'wb') as f:
        pickle.dump(model, f)
        pickle.dump(fit, f)
        
else:
    with open('fit.pkl', 'rb') as fh:
        model = pickle.load(fh)
        fit = pickle.load(fh)

In [None]:
def plot_fit(er, sr, fit, cell_map, cell, ax):
    c = fit.to_dataframe(diagnostics=False).mean()
    level = np.arange(-10, 80)
    i = cell_map[cell]

    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    plot_raw_data(e, s, ax)

    sr_regular = c[f'sr[{i}]']
    sr_pupil = c[f'sr_pupil[{i}]']
    slope = c[f'slope[{i}]']
    threshold = c[f'threshold[{i}]']
    slope_pupil_delta = c[f'slope_pupil_delta[{i}]']
    threshold_pupil_delta = c[f'threshold_pupil_delta[{i}]']

    pred = slope * (level - threshold) + sr_regular
    pred[level <= threshold] = sr_regular
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='seagreen')

    slope_pupil = slope + slope_pupil_delta
    threshold_pupil = threshold + threshold_pupil_delta

    pred = slope_pupil * (level - threshold_pupil) + sr_pupil
    pred[level <= threshold_pupil] = sr_pupil
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
    
f, axes = pl.subplots(5, 5, figsize=(10, 10))
for cell, ax in zip(cells[mask], axes.ravel()):
    plot_fit(er, sr, fit, cell_map, cell, ax)
    i = cell_map[cell] - 1
    lb, m, ub = np.percentile(fit['threshold_pupil_delta'][:, i], [2.5, 50.0, 97.5])
    
    ax.set_title(f'{cell}\n({lb:.0f} | {m:.0f} | {ub:.0f})')
    
f.tight_layout()

In [None]:
sr_mean = fit['sr_alpha']/fit['sr_beta']
sr_pupil_mean = fit['sr_pupil_alpha']/fit['sr_pupil_beta']

print(fit['sr_alpha'].mean(axis=0))
print(fit['sr_pupil_alpha'].mean(axis=0))
print(fit['sr_beta'].mean(axis=0))
print(fit['sr_pupil_beta'].mean(axis=0))

print('SR', np.percentile(sr_mean, [2.5, 50.0, 97.5], axis=0))
print('SR pupil', np.percentile(sr_pupil_mean, [2.5, 50.0, 97.5], axis=0))

In [None]:
fit['threshold_pupil_delta']

In [None]:
import matplotlib as mp
from scipy import stats

def plot(ax, lb, m, ub, measure, ensemble_lb, ensemble_m, ensemble_ub):
    stat = stats.wilcoxon(m)
    mask = (ub < 0) | (lb > 0)
    i = np.argsort(m)
    lb = lb[i]
    m = m[i]
    ub = ub[i]

    #ax.axvline(0)
    t = ax.get_xaxis_transform()
    fc = 'lightgreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'gray'
    c = 'seagreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'k'
    ax.axvline(ensemble_m, color=c)
    r = mp.patches.Rectangle((ensemble_lb, 0), ensemble_ub-ensemble_lb, 1,
                             transform=t, ec='none', fc=fc)
    ax.add_patch(r)
    n_sig =0 
    for i, (a, d, b) in enumerate(zip(lb, m, ub)):
        c = 'g' if (b < 0) or (a > 0) else 'k'
        n_sig += 1 if (b < 0) or (a > 0) else 0
        ax.plot([a, b], [i, i], '-', color=c, lw=0.5)
        ax.plot([d], [i], 'o', color=c)
        
    ax.set_xlabel(f'Change in {measure} (lg. re sm. pupil)\n{n_sig} sig. out of {len(lb)}\nMedian change {ensemble_m:.2f}\nWilcoxon p={stat.pvalue:0.4f}')
    sns.despine(ax=ax, top=True, left=True, right=True, bottom=False)
    ax.yaxis.set_ticks_position('none')
    ax.yaxis.set_ticks([])
    ax.grid()
    return ax

f, axes = pl.subplots(1, 3, figsize=(15, 5))

thresholds = pd.Series(fit['threshold'].mean(axis=0), index=cell_map.keys())
mask = (thresholds < -20) | (thresholds > 65)
mask = ~mask
# hack
mask[:] = True

pct = [5, 50, 95]
x = fit['slope_pupil_delta'][:, mask]
e = fit['slope_pupil_delta_mean']
lb, m, ub = np.percentile(x, pct, axis=0)
elb, em, eub = np.percentile(e, pct, axis=0)
plot(axes[0], lb, m, ub, 'slope', elb, em, eub)

x = fit['threshold_pupil_delta'][:, mask]
e = fit['threshold_pupil_delta_mean']
lb, m, ub = np.percentile(x, pct, axis=0)
elb, em, eub = np.percentile(e, pct, axis=0)
plot(axes[1], lb, m, ub, 'threshold', elb, em, eub)

x = fit['sr_pupil'] - fit['sr']
x = x[:, mask]
e = (fit['sr_pupil_alpha']/fit['sr_pupil_beta'])/(fit['sr_alpha']/fit['sr_beta'])
lb, m, ub = np.percentile(x, pct, axis=0)
elb, em, eub = np.percentile(e, pct, axis=0)

plot(axes[2], lb, m, ub, 'SR', elb, em, eub)

In [None]:
from scipy import stats
stats.wilcoxon(m)

In [None]:
x = fit['threshold_pupil_delta'].mean(axis=0)
y = fit['slope_pupil_delta'].mean(axis=0)
z = (fit['sr_pupil'] - fit['sr']).mean(axis=0)

f, ax = pl.subplots(1, 3, figsize=(15, 5))
ax[0].plot(x, y, 'k.')
ax[1].plot(x, z, 'k.')
ax[2].plot(y, z, 'k.')



In [None]:
x = fit['threshold_pupil_delta'].mean(axis=0)
x.argmin()
cells[x.argmin()]

print(cells[x.argmin()])
plot_fit(er, sr, fit, cell_map, cells[x.argmin()], pl.gca())

In [None]:

f, axes = pl.subplots(5, 5, figsize=(10, 10))
for cell, ax in zip(cells[mask], axes.ravel()):
    plot_fit(er, sr, fit, cell_map, cell, ax)
    i = cell_map[cell] - 1
    lb, m, ub = np.percentile(fit['threshold_pupil_delta'][:, i], [2.5, 50.0, 97.5])
    
    ax.set_title(f'{cell}\n({lb:.0f} | {m:.0f} | {ub:.0f})')
    
f.tight_layout()

In [None]:
fit['threshold'].mean(axis=0).shape

In [None]:
import seaborn as sns

    


In [None]:
pred = c['slope'] * (level - c['threshold']) + c['sr']
pred[level <= c['threshold']] = c['sr']
pred = np.clip(pred, 0, np.inf)
ax.plot(level, pred, color='seagreen')

pred = (c['slope'] + c['slope_pupil_delta']) * (level - (c['threshold'] + c['threshold_pupil_delta'])) + (c['sr'] + c['sr_pupil_delta'])
pred[level <= c['threshold'] + c['threshold_pupil_delta']] = (c['sr'] + c['sr_pupil_delta'])
pred = np.clip(pred, 0, np.inf)
ax.plot(level, pred, color='orchid')

In [None]:
arviz.plot_trace(fit);

In [None]:
fit['slope'].shape

In [None]:
#cell = cells[0]
cell = 'TAR010c-15-4'
e = er.loc[cell].reset_index()
s = sr.loc[cell].reset_index()
#fit = fit_data(e, s)
display(fit)

ax = pl.gca()
plot_raw_data(e, s, ax)
plot_fit(fit, ax)

In [None]:
import arviz
arviz.plot_trace(fit)

In [None]:
fits = {}
for cell in cells:
    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    fits[cell] = fit_data(e, s)

In [None]:
fit = fits['BOL006b-18-1']
print(fit)
arviz.plot_trace(fit);

In [None]:
#ax = pl.gca()
#cell = cells[0]
#e = er.loc[cell].reset_index()
#s = sr.loc[cell].reset_index()
#fit = fits[cell]

f, axes = pl.subplots(9, 9, figsize=(20, 20))

for cell, ax in zip(cells, axes.ravel()):
    try:
        e = er.loc[cell].reset_index()
        s = sr.loc[cell].reset_index()
        fit = fits[cell]
        plot_raw_data(e, s, ax)
        plot_fit(fit, ax)
        ax.set_title(cell)
    except:
        pass
    
f.tight_layout()

In [None]:
    
f, axes = pl.subplots(6, 6, figsize=(15, 15))
#f, axes = pl.subplots(5, 5, figsize=(10, 10), squeeze=False)

fits = {}
for cell, ax in zip(cells, axes.ravel()):
    fits[cell] = fit_cell(cell, 1, ax)
    
f.tight_layout()

In [None]:
e

In [None]:
np.percentile(fits[0]['threshold'], [2, 97])

In [None]:
def get_trace(fit, variable):
    x = fit.extract(variable, permuted=True)[variable]
    if x.ndim == 1:
        x = x[..., np.newaxis]
    return x

#sr_mean = get_trace(fit, 'sr_mean')
sr_pupil = get_trace(fit, 'sr_pupil')
sr_cell = get_trace(fit, 'sr_cell')
sr_cell_pupil = get_trace(fit, 'sr_cell_pupil')
#sr_cell_sd = get_trace(fit, 'sr_cell_sd')
sr_cell_pupil_sd = get_trace(fit, 'sr_cell_pupil_sd')

slope_mean = get_trace(fit, 'slope_mean')
slope_pupil = get_trace(fit, 'slope_pupil')
slope_cell = get_trace(fit, 'slope_cell')
slope_cell_pupil = get_trace(fit, 'slope_cell_pupil')
slope_cell_sd = get_trace(fit, 'slope_cell_sd')
slope_cell_pupil_sd = get_trace(fit, 'slope_cell_pupil_sd')

threshold_mean = get_trace(fit, 'threshold_mean')
threshold_pupil = get_trace(fit, 'threshold_pupil')
threshold_cell = get_trace(fit, 'threshold_cell')
threshold_cell_pupil = get_trace(fit, 'threshold_cell_pupil')
threshold_cell_sd = get_trace(fit, 'threshold_cell_sd')
threshold_cell_pupil_sd = get_trace(fit, 'threshold_cell_pupil_sd')

sr = sr_cell
sr_pupil = sr + (sr_pupil + sr_cell_pupil * sr_cell_pupil_sd)

slope = slope_mean + (slope_cell * slope_cell_sd) 
slope_pupil = slope + (slope_pupil + slope_cell_pupil * slope_cell_pupil_sd)

threshold = threshold_mean + (threshold_cell * threshold_cell_sd) 
threshold_pupil = threshold + (threshold_pupil + threshold_cell_pupil * threshold_cell_pupil_sd)

level_start, level_stop = -20, 80
spike_start, spike_stop = 0, 110

x = np.arange(level_start, level_stop)[:, np.newaxis, np.newaxis]

y = (x - threshold) * slope
y = np.clip(y, sr, np.inf)

y_pupil = (x - threshold_pupil) * slope_pupil
y_pupil = np.clip(y_pupil, sr_pupil, np.inf)

In [None]:
def get_density(y, bins):
    fn = lambda x: np.histogram(x, bins, density=True)[0]
    return np.apply_along_axis(fn, 1, y)

In [None]:
sr

In [None]:
import itertools

cell_indices = [0, 1, 2, 3, 4]
f, ax = pl.subplots(3, 3, figsize=(10, 10))
ax_iter = itertools.chain(*ax)
bins = np.arange(spike_start, spike_stop, 0.1)
pupil = 'small'

for ax, cell_index in zip(ax_iter, cell_indices):
    cell_name = cellid.categories[cell_index]
    if pupil == 'small':
        density = get_density(y[..., cell_index], bins)
        raw = df.loc[cell_name].loc[1]
        s_raw = spont.loc[cell_name].loc[1]
        s_fit = sr[:, cell_index].mean()
    else:
        density = get_density(y_pupil[..., cell_index], bins)
        raw = df.loc[cell_name].loc[2]
        s_raw = spont.loc[cell_name].loc[2]
        s_fit = sr_pupil[:, cell_index].mean()
        
    ax.plot(raw, 'o', color='white')
    ax.set_title(cell_name)
    ax.imshow(density.T, origin='lower', aspect='auto', 
              extent=(level_start, level_stop, spike_start, spike_stop))
    ax.axvline(threshold[:, cell_index].mean(), color='black')
    ax.axhline(s_raw, color='white')
    ax.axhline(s_fit, color='black')
    #ax.axhline(spont.loc[cell_name], color='white')


In [None]:
c = 'BOL006b-21-1'
#c = 'TAR010c-24-2'
#c = 'gus021d-b1'
#i = cellid.categories.tolist().index(c)
i = 56

for i_cell in range(i, i+1):
    pl.figure()

    for i_pupil in (0, 1):
        sr_i = sr[:, i_cell] + sr_pupil[:, i_cell] * i_pupil
        threshold_i = threshold[:, i_cell] + threshold_pupil[:, i_cell] * i_pupil
        slope_i = slope[:, i_cell] + slope_pupil[:, i_cell] * i_pupil
        intercept_i = sr_i - threshold_i * slope_i
        
        print(f'{sr_i.mean():.2f} {intercept_i.mean():.2f} {slope_i.mean():.2f}')
        print(f'{sr_i.min():.2f} {intercept_i.min():.2f} {slope_i.min():.2f}')

        x = np.arange(-20, 100)[..., np.newaxis]
        y = np.clip(slope_i * x + intercept_i, sr_i, np.inf)

        m = (d['cell_code'] == i_cell) & (d['pupil'] == (i_pupil + 1))
        subset = d.loc[m]
        p, = pl.plot(subset['level'], subset['spikes'], 'o')

        pl.plot(x, y[:, ::100], '-', alpha=0.005, color=p.get_color());

In [None]:
#az.plot_forest(fit, kind='ridgeplot', var_names=['sp_mean'], combined=True, figsize=(4, 4))
az.plot_trace(fit, var_names=('sr_mean', 'sr_pupil', 'threshold_mean', 'threshold_pupil', 'slope_mean', 'slope_pupil'))

In [None]:
x = fit.summary(['tp_mean', 'tp_sd', 'srp_mean', 'srp_sd', 'sp_mean', 'sp_sd'])
pd.DataFrame(x['summary'], index=x['summary_rownames'], columns=x['summary_colnames'])

In [None]:
threshold_pupil = get_trace(fit, 'threshold_pupil')
lb, median, ub = np.percentile(threshold_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Threshold (dB SPL)')
ax.xaxis.set_ticks([])
ax.grid()

i, j = median.argmin(), median.argmax()
print(i, j)
lb[i], ub[i]

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
slope_pupil = get_trace(fit, 'slope_pupil')
lb, median, ub = np.percentile(slope_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Slope')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
fit