In [None]:
import pandas as pd
import seaborn as sns
from scipy import stats
import pylab as pl
import numpy as np
from lbhb.psychometric import CachedStanModel
import arviz as az
%matplotlib inline

df = pd.read_csv('rate_level_functions_for_bburan.csv')
sig_cells = pd.read_csv('psth_sig_cellids.csv')['cellid'].tolist()
mask = df['cellid'].apply(lambda x: x in sig_cells)
df = df.loc[mask]
df.rename(columns=lambda x: x.replace(' ', ''), inplace=True)
df = pd.wide_to_long(df, ['pupil', 'level', 'spikes', 'spont'], 'cellid', 'idx', sep='_')
df.dropna(inplace=True)
spont = df.groupby(['cellid', 'pupil'])['spont'].first()
spont.sort_index(inplace=True)
df = df.reset_index().set_index(['cellid', 'pupil', 'level'], verify_integrity=True)['spikes']
df.sort_index(inplace=True)

# Code to assign spont as -20
s = spont.reset_index()
s['level'] = -10
s = s.set_index(['cellid', 'pupil', 'level'])['spont']
df = pd.concat((df, s))
df = df.sort_index()
df.name = 'spikes'

In [None]:
def sr_dist(sr):
    pl.hist(sr, density=True, bins=100);
    x = np.arange(0, 50, 0.1)
    a, loc, scale = stats.gamma.fit(sr, floc=0)
    y = stats.gamma.pdf(x, a, loc, scale)
    pl.plot(x, y, 'k-')
    y = stats.gamma.pdf(x, 0.5, 0, 1/0.1)
    pl.plot(x, y, 'r-')
    print(f'alpha = {a:.2f}, beta = {1/scale:.2f}')
    
sr_dist(spont)

In [None]:
df.groupby('level').mean().plot()

In [None]:
model = CachedStanModel('hierarchial_hockey_stick_V3f.stan')

#c = 'BOL006b-21-1'
#c = 'TAR010c-24-2'
#c = 'gus021d-b1'
#d = df_sig.loc[c].reset_index()
#d['cellid'] = c

#d = df.loc['BOL006b-00-1':'BOL006b-21-1'].reset_index()
d = df.reset_index()
    
cellid = pd.Categorical(d['cellid'])
cell = cellid.codes + 1
n_cells = cell.max()
d['cell_code'] = cellid.codes

data = {
    'n': len(d),
    'n_cells': n_cells,
    'x': d['level'].values,
    'y': d['spikes'].values,
    # Pupil should be 0 for small, 1 for large
    'pupil': d['pupil'].values.astype('i') - 1,
    'cell': cell,
}
fit = model.sampling(data, iter=50000)
fit

In [None]:
def get_trace(fit, variable):
    x = fit.extract(variable, permuted=True)[variable]
    if x.ndim == 1:
        x = x[..., np.newaxis]
    return x

sr_mean = get_trace(fit, 'sr_mean')
sr_pupil = get_trace(fit, 'sr_pupil')
sr_cell = get_trace(fit, 'sr_cell')
sr_cell_pupil = get_trace(fit, 'sr_cell_pupil')
sr_cell_sd = get_trace(fit, 'sr_cell_sd')
sr_cell_pupil_sd = get_trace(fit, 'sr_cell_pupil_sd')

slope_mean = get_trace(fit, 'slope_mean')
slope_pupil = get_trace(fit, 'slope_pupil')
slope_cell = get_trace(fit, 'slope_cell')
slope_cell_pupil = get_trace(fit, 'slope_cell_pupil')
slope_cell_sd = get_trace(fit, 'slope_cell_sd')
slope_cell_pupil_sd = get_trace(fit, 'slope_cell_pupil_sd')

threshold_mean = get_trace(fit, 'threshold_mean')
threshold_pupil = get_trace(fit, 'threshold_pupil')
threshold_cell = get_trace(fit, 'threshold_cell')
threshold_cell_pupil = get_trace(fit, 'threshold_cell_pupil')
threshold_cell_sd = get_trace(fit, 'threshold_cell_sd')
threshold_cell_pupil_sd = get_trace(fit, 'threshold_cell_pupil_sd')

sr = sr_mean + (sr_cell * sr_cell_sd) 
sr_pupil = sr + (sr_pupil + sr_cell_pupil * sr_cell_pupil_sd)

slope = slope_mean + (slope_cell * slope_cell_sd) 
slope_pupil = slope + (slope_pupil + slope_cell_pupil * slope_cell_pupil_sd)

threshold = threshold_mean + (threshold_cell * threshold_cell_sd) 
threshold_pupil = threshold + (threshold_pupil + threshold_cell_pupil * threshold_cell_pupil_sd)

level_start, level_stop = -20, 80
spike_start, spike_stop = 0, 110

x = np.arange(level_start, level_stop)[:, np.newaxis, np.newaxis]

y = (x - threshold) * slope
y = np.clip(y, sr, np.inf)

y_pupil = (x - threshold_pupil) * slope_pupil
y_pupil = np.clip(y_pupil, sr_pupil, np.inf)

In [None]:
def get_density(y, bins):
    fn = lambda x: np.histogram(x, bins, density=True)[0]
    return np.apply_along_axis(fn, 1, y)

In [None]:
import itertools

cell_indices = [0, 1, 2, 3, 4]
f, ax = pl.subplots(3, 3, figsize=(10, 10))
ax_iter = itertools.chain(*ax)
bins = np.arange(spike_start, spike_stop, 0.1)
pupil = 'large'

for ax, cell_index in zip(ax_iter, cell_indices):
    cell_name = cellid.categories[cell_index]
    if pupil == 'small':
        density = get_density(y[..., cell_index], bins)
        raw = df.loc[cell_name].loc[1]
    else:
        density = get_density(y_pupil[..., cell_index], bins)
        raw = df.loc[cell_name].loc[2]
        
    ax.plot(raw, 'o', color='white')
    ax.set_title(cell_name)
    ax.imshow(density.T, origin='lower', aspect='auto', 
              extent=(level_start, level_stop, spike_start, spike_stop))
    ax.axvline(threshold[:, cell_index].mean(), color='white')
    #ax.axhline(spont.loc[cell_name], color='white')


In [None]:
c = 'BOL006b-21-1'
#c = 'TAR010c-24-2'
#c = 'gus021d-b1'
#i = cellid.categories.tolist().index(c)
i = 56

for i_cell in range(i, i+1):
    pl.figure()

    for i_pupil in (0, 1):
        sr_i = sr[:, i_cell] + sr_pupil[:, i_cell] * i_pupil
        threshold_i = threshold[:, i_cell] + threshold_pupil[:, i_cell] * i_pupil
        slope_i = slope[:, i_cell] + slope_pupil[:, i_cell] * i_pupil
        intercept_i = sr_i - threshold_i * slope_i
        
        print(f'{sr_i.mean():.2f} {intercept_i.mean():.2f} {slope_i.mean():.2f}')
        print(f'{sr_i.min():.2f} {intercept_i.min():.2f} {slope_i.min():.2f}')

        x = np.arange(-20, 100)[..., np.newaxis]
        y = np.clip(slope_i * x + intercept_i, sr_i, np.inf)

        m = (d['cell_code'] == i_cell) & (d['pupil'] == (i_pupil + 1))
        subset = d.loc[m]
        p, = pl.plot(subset['level'], subset['spikes'], 'o')

        pl.plot(x, y[:, ::100], '-', alpha=0.005, color=p.get_color());

In [None]:
#az.plot_forest(fit, kind='ridgeplot', var_names=['sp_mean'], combined=True, figsize=(4, 4))
az.plot_trace(fit, var_names=('sr_mean', 'sr_pupil', 'threshold_mean', 'threshold_pupil', 'slope_mean', 'slope_pupil'))

In [None]:
x = fit.summary(['tp_mean', 'tp_sd', 'srp_mean', 'srp_sd', 'sp_mean', 'sp_sd'])
pd.DataFrame(x['summary'], index=x['summary_rownames'], columns=x['summary_colnames'])

In [None]:
threshold_pupil = get_trace(fit, 'threshold_pupil')
lb, median, ub = np.percentile(threshold_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Threshold (dB SPL)')
ax.xaxis.set_ticks([])
ax.grid()

i, j = median.argmin(), median.argmax()
print(i, j)
lb[i], ub[i]

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
slope_pupil = get_trace(fit, 'slope_pupil')
lb, median, ub = np.percentile(slope_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('Slope')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
sr_pupil = get_trace(fit, 'sr_pupil')
lb, median, ub = np.percentile(sr_pupil, [2.5, 50.0, 97.5], 0)
i = np.arange(len(median))
ax = pl.gca()
ax.errorbar(i, median, yerr=np.vstack((abs(lb-median), ub-median)))
ax.axhline(0, color='k')
ax.set_ylabel('SR (sp/sec)')
ax.xaxis.set_ticks([])
ax.grid()

In [None]:
fit