In [None]:
import pandas as pd
import seaborn as sns
from scipy import stats
import pylab as pl
import numpy as np
from lbhb.psychometric import CachedStanModel
import arviz as az
%matplotlib inline

In [None]:
rlf = pd.read_csv('rate_level_functions_for_bburan.csv')
rlf.columns = [s.replace(' ', '') for s in rlf.columns]
cols = ['pupil', 'level', 'rlf_count', 'rlf_time', 'spont_count', 'spont_time']
rlf = pd.wide_to_long(rlf, cols, 'cellid', 'idx', sep='_').dropna()
rlf['pupil'] -= 1

# evoked rate
er = rlf.reset_index().set_index(['cellid', 'pupil', 'level'], verify_integrity=True)[['rlf_count', 'rlf_time']]
# spont rate
sr = rlf.groupby(['cellid', 'pupil'])[['spont_count', 'spont_time']].first()

pl.figure()
sr.eval('spont_count/spont_time').groupby('pupil').hist()
pl.figure()
er.eval('rlf_count/rlf_time').groupby('level').mean().plot()

In [None]:
cells = er.index.get_level_values('cellid').unique()

In [None]:
def plot_raw_data(e, s, ax):
    x = [0, 0] + e['level'].tolist()
    y = s.eval('spont_count/spont_time').tolist() + e.eval('rlf_count/rlf_time').tolist()
    size = np.array(s['spont_time'].tolist() + e['rlf_time'].tolist())
    color = s['pupil'].tolist() + e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    size = 100 * size/size.mean()
    ax.scatter(x, y, size, color, alpha=0.5)

def plot_fit(fit, ax):
    c = fit.to_dataframe(permuted=True).mean()
    level = np.arange(-10, 80)
    pred = c['slope'] * (level - c['threshold']) + c['sr']
    pred[level <= c['threshold']] = c['sr']
    ax.plot(level, pred, color='seagreen')

    pred = (c['slope'] + c['slope_pupil_delta']) * (level - (c['threshold'] + c['threshold_pupil_delta'])) + (c['sr'] + c['sr_pupil_delta'])
    pred[level <= c['threshold'] + c['threshold_pupil_delta']] = (c['sr'] + c['sr_pupil_delta'])
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
def fit_data(e, s):
    data = {
        'n': len(e),
        'evoked_time': e['rlf_time'].values,
        'evoked_count': e['rlf_count'].astype('i').values,
        'evoked_level': e['level'].values,
        'pupil': e['pupil'].astype('i').values,
        'sr_count': s['spont_count'].astype('i').values,
        'sr_time': s['spont_time'].values,
    }
    return model.sampling(data, iter=20000, control={'adapt_delta': 0.999, 'max_treedepth': 25})


In [None]:
model = CachedStanModel('hockey_stick.stan')

In [None]:
#cell = cells[0]
cell = 'TAR010c-15-4'
e = er.loc[cell].reset_index()
s = sr.loc[cell].reset_index()
fit = fit_data(e, s)
display(fit)

ax = pl.gca()
plot_raw_data(e, s, ax)
plot_fit(fit, ax)

In [None]:
import arviz
arviz.plot_trace(fit)

In [None]:
fits = {}
for cell in cells:
    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    fits[cell] = fit_data(e, s)

In [None]:
fit = fits['BOL006b-18-1']
print(fit)
arviz.plot_trace(fit);

In [None]:
#ax = pl.gca()
#cell = cells[0]
#e = er.loc[cell].reset_index()
#s = sr.loc[cell].reset_index()
#fit = fits[cell]

f, axes = pl.subplots(9, 9, figsize=(20, 20))

for cell, ax in zip(cells, axes.ravel()):
    try:
        e = er.loc[cell].reset_index()
        s = sr.loc[cell].reset_index()
        fit = fits[cell]
        plot_raw_data(e, s, ax)
        plot_fit(fit, ax)
        ax.set_title(cell)
    except:
        pass
    
f.tight_layout()

In [None]:
    
f, axes = pl.subplots(6, 6, figsize=(15, 15))
#f, axes = pl.subplots(5, 5, figsize=(10, 10), squeeze=False)

fits = {}
for cell, ax in zip(cells, axes.ravel()):
    fits[cell] = fit_cell(cell, 1, ax)
    
f.tight_layout()

In [None]:
e

In [None]:
np.percentile(fits[0]['threshold'], [2, 97])

In [None]:
def get_trace(fit, variable):
    x = fit.extract(variable, permuted=True)[variable]
    if x.ndim == 1:
        x = x[..., np.newaxis]
    return x

#sr_mean = get_trace(fit, 'sr_mean')
sr_pupil = get_trace(fit, 'sr_pupil')
sr_cell = get_trace(fit, 'sr_cell')
sr_cell_pupil = get_trace(fit, 'sr_cell_pupil')
#sr_cell_sd = get_trace(fit, 'sr_cell_sd')
sr_cell_pupil_sd = get_trace(fit, 'sr_cell_pupil_sd')

slope_mean = get_trace(fit, 'slope_mean')
slope_pupil = get_trace(fit, 'slope_pupil')
slope_cell = get_trace(fit, 'slope_cell')
slope_cell_pupil = get_trace(fit, 'slope_cell_pupil')
slope_cell_sd = get_trace(fit, 'slope_cell_sd')
slope_cell_pupil_sd = get_trace(fit, 'slope_cell_pupil_sd')

threshold_mean = get_trace(fit, 'threshold_mean')
threshold_pupil = get_trace(fit, 'threshold_pupil')
threshold_cell = get_trace(fit, 'threshold_cell')
threshold_cell_pupil = get_trace(fit, 'threshold_cell_pupil')
threshold_cell_sd = get_trace(fit, 'threshold_cell_sd')
threshold_cell_pupil_sd = get_trace(fit, 'threshold_cell_pupil_sd')

sr = sr_cell
sr_pupil = sr + (sr_pupil + sr_cell_pupil * sr_cell_pupil_sd)

slope = slope_mean + (slope_cell * slope_cell_sd) 
slope_pupil = slope + (slope_pupil + slope_cell_pupil * slope_cell_pupil_sd)

threshold = threshold_mean + (threshold_cell * threshold_cell_sd) 
threshold_pupil = threshold + (threshold_pupil + threshold_cell_pupil * threshold_cell_pupil_sd)

level_start, level_stop = -20, 80
spike_start, spike_stop = 0, 110

x = np.arange(level_start, level_stop)[:, np.newaxis, np.newaxis]

y = (x - threshold) * slope
y = np.clip(y, sr, np.inf)

y_pupil = (x - threshold_pupil) * slope_pupil
y_pupil = np.clip(y_pupil, sr_pupil, np.inf)

In [None]:
def get_density(y, bins):
    fn = lambda x: np.histogram(x, bins, density=True)[0]
    return np.apply_along_axis(fn, 1, y)

In [None]:
sr

In [None]:
import itertools

cell_indices = [0, 1, 2, 3, 4]
f, ax = pl.subplots(3, 3, figsize=(10, 10))
ax_iter = itertools.chain(*ax)
bins = np.arange(spike_start, spike_stop, 0.1)
pupil = 'small'

for ax, cell_index in zip(ax_iter, cell_indices):
    cell_name = cellid.categories[cell_index]
    if pupil == 'small':
        density = get_density(y[..., cell_index], bins)
        raw = df.loc[cell_name].loc[1]
        s_raw = spont.loc[cell_name].loc[1]
        s_fit = sr[:, cell_index].mean()
    else:
        density = get_density(y_pupil[..., cell_index], bins)
        raw = df.loc[cell_name].loc[2]
        s_raw = spont.loc[cell_name].loc[2]
        s_fit = sr_pupil[:, cell_index].mean()
        
    ax.plot(raw, 'o', color='white')
    ax.set_title(cell_name)
    ax.imshow(density.T, origin='lower', aspect='auto', 
              extent=(level_start, level_stop, spike_start, spike_stop))
    ax.axvline(threshold[:, cell_index].mean(), color='black')
    ax.axhline(s_raw, color='white')
    ax.axhline(s_fit, color='black')
    #ax.axhline(spont.loc[cell_name], color='white')


In [None]:
c = 'BOL006b-21-1'
#c = 'TAR010c-24-2'
#c = 'gus021d-b1'
#i = cellid.categories.tolist().index(c)
i = 56

for i_cell in range(i, i+1):
    pl.figure()

    for i_pupil in (0, 1):
        sr_i = sr[:, i_cell] + sr_pupil[:, i_cell] * i_pupil
        threshold_i = threshold[:, i_cell] + threshold_pupil[:, i_cell] * i_pupil
        slope_i = slope[:, i_cell] + slope_pupil[:, i_cell] * i_pupil
        intercept_i = sr_i - threshold_i * slope_i
        
        print(f'{sr_i.mean():.2f} {intercept_i.mean():.2f} {slope_i.mean():.2f}')
        print(f'{sr_i.min():.2f} {intercept_i.min():.2f} {slope_i.min():.2f}')

        x = np.arange(-20, 100)[..., np.newaxis]
        y = np.clip(slope_i * x + intercept_i, sr_i, np.inf)

        m = (d['cell_code'] == i_cell) & (d['pupil'] == (i_pupil + 1))
        subset = d.loc[m]
        p, = pl.plot(subset['level'], subset['spikes'], 'o')

        pl.plot(x, y[:, ::100], '-', alpha=0.005, color=p.get_color());

In [None]:
#az.plot_forest(fit, kind='ridgeplot', var_names=['sp_mean'], combined=True, figsize=(4, 4))
az.plot_trace(fit, var_names=('sr_mean', 'sr_pupil', 'threshold_mean', 'threshold_pupil', 'slope_mean', 'slope_pupil'))

In [None]:
x = fit.summary(['tp_mean', 'tp_sd', 'srp_mean', 'srp_sd', 'sp_mean', 'sp_sd'])
pd.DataFrame(x['summary'], index=x['summary_rownames'], columns=x['summary_colnames'])

In [None]:
threshold_pupil = get_trace(fit, 'threshold_pupil')
lb, median, ub = np.percentile(threshold_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Threshold (dB SPL)')
ax.xaxis.set_ticks([])
ax.grid()

i, j = median.argmin(), median.argmax()
print(i, j)
lb[i], ub[i]

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
slope_pupil = get_trace(fit, 'slope_pupil')
lb, median, ub = np.percentile(slope_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Slope')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
fit