In [None]:
from lbhb.psychometric import CachedStanModel
import pandas as pd
import numpy as np
import pylab as pl
import arviz as az


In [None]:
#sig = pd.read_csv('psth_sig_cellids.csv')['cellid'].unique()
df = pd.read_csv('frequency_tuning_curves_for_bburan.csv')
df.columns = [s.replace(' ', '') for s in df.columns]
cols = ['pupil', 'frequency', 'ftc_count', 'ftc_time', 'spont_count', 'spont_time']
df = pd.wide_to_long(df, cols, 'cellid', 'idx', sep='_').dropna()
df['pupil'] -= 1
df['frequency'] = np.log(df['frequency'])

#mask = df.apply(lambda x: x.name[0] in sig, axis=1)
#df = df.loc[mask]

sr = df.groupby(['cellid', 'pupil'])[['spont_count', 'spont_time']].first().sort_index()
ftc = df.reset_index().set_index(['cellid', 'pupil', 'frequency'])[['ftc_count', 'ftc_time']].sort_index()
m = ftc['ftc_time'] > 0
ftc = ftc.loc[m]

cells = ftc.index.get_level_values('cellid').unique()

In [None]:
def get_cell_data(cell, ftc, sr):
    e = ftc.loc[cell].reset_index()
    s = sr.loc[cell].eval('spont_count/spont_time')

    n = len(e)
    frequency = e['frequency'].values
    spike_count = e['ftc_count'].values.astype('i')
    sample_time = e['ftc_time'].values
    pupil = e['pupil'].astype('i').values

    return {
        'n': n,
        'freq': frequency,
        'spike_count': spike_count,
        'sample_time': sample_time,
        'sr': s.loc[0],
        'sr_lg': s.loc[1],
        'pupil': pupil,
    }


In [None]:
model = CachedStanModel('gaussian_FTC_single_cell.stan')
fits = {}
for cell in cells:
    try:
        data = get_cell_data(cell, ftc, sr)
        fits[cell] = model.sampling(data, iter=20000, n_jobs=8, control={'adapt_delta': 0.99})
    except:
        fits[cell] = None

In [None]:
with open('ftc_model.pkl', 'wb') as fh:
    pickle.dump(model, fh)

In [None]:
import pickle
with open('ftc_fits.pkl', 'wb') as fh:
    pickle.dump(model, fh)
    pickle.dump(fits, fh)

In [None]:
def plot_fit(ax, fit, data):
    bf = fit['bf'].mean()
    gain = fit['gain'].mean()
    bw = fit['bw'].mean()
    offset = data['sr']

    frequency = np.arange(3, 11, 0.1)
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset + gain * l
    ax.plot(np.exp(frequency), l, ':', color='orchid', label='Sm. pupil')
    ax.axhline(offset, color='orchid')

    bf += fit['bf_pupil_delta'].mean()
    gain += fit['gain_pupil_delta'].mean()
    bw += fit['bw_pupil_delta'].mean()
    offset = data['sr_lg']
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset + gain * l
    ax.plot(np.exp(frequency), l, ':', color='seagreen', label='Lg. pupil')
    ax.axhline(offset, color='seagreen')

    pupil = data['pupil']
    frequency = data['freq']
    evoked_rate = data['spike_count'] / data['sample_time']
    #spont_rate = data['spont_count'] / data['spont_time']

    m_pupil = pupil == 0
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='orchid')
    m_pupil = pupil == 1
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='seagreen')

    #ax.axhline(spont_rate[0], color='orchid', ls=':', lw=2)
    #ax.axhline(spont_rate[1], color='seagreen', ls=':', lw=2)
    ax.set_xscale('log')

f, axes = pl.subplots(10, 12, figsize=(20, 20))

for ax, (cell, fit) in zip(axes.ravel(), fits.items()):
    data = get_cell_data(cell, ftc, sr)
    plot_fit(ax, fit, data)
    
ax.legend()

In [None]:
def plot_fit(ax, fit, data):
    bf = fit['bf'].mean()
    gain = fit['gain'].mean()
    bw = fit['bw'].mean()
    offset = data['sr']

    frequency = np.arange(3, 11, 0.1)
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset * gain * l
    ax.plot(np.exp(frequency), l, ':', color='orchid', label='Sm. pupil')
    ax.axhline(offset, color='orchid')

    bf += fit['bf_delta'].mean()
    gain *= fit['gain_ratio'].mean()
    bw *= fit['bw_ratio'].mean()
    offset = data['sr_lg']
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset * gain * l
    ax.plot(np.exp(frequency), l, ':', color='seagreen', label='Lg. pupil')
    ax.axhline(offset, color='seagreen')

    pupil = data['pupil']
    frequency = data['freq']
    evoked_rate = data['spike_count'] / data['sample_time']
    #spont_rate = data['spont_count'] / data['spont_time']

    m_pupil = pupil == 0
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='orchid')
    m_pupil = pupil == 1
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='seagreen')

    #ax.axhline(spont_rate[0], color='orchid', ls=':', lw=2)
    #ax.axhline(spont_rate[1], color='seagreen', ls=':', lw=2)
    ax.set_xscale('log')

f, axes = pl.subplots(10, 12, figsize=(20, 20))

for ax, (cell, fit) in zip(axes.ravel(), fits.items()):
    data = get_cell_data(cell, ftc, sr)
    plot_fit(ax, fit, data)
    
ax.legend()

In [None]:
results = [az.summary(f, credible_interval=0.9).to_dataframe() for f in fits.values()]
results = pd.concat(results, keys=fits.keys(), names=['cell'])

In [None]:
results.head()

In [None]:
def plot(ax, df, measure):
    m = df['mean']
    lb = df['hpd 5.00%']
    ub = df['hpd 95.00%']
    gr = df['gelman-rubin statistic']
    stat = stats.wilcoxon(m)

    #ax.axvline(0)
    #t = ax.get_xaxis_transform()
    #fc = 'lightgreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'gray'
    #c = 'seagreen' if (ensemble_lb > 0) or (ensemble_ub < 0) else 'k'
    #ax.axvline(ensemble_m, color=c)
    #r = mp.patches.Rectangle((ensemble_lb, 0), ensemble_ub-ensemble_lb, 1,
    #                         transform=t, ec='none', fc=fc)
    #ax.add_patch(r)
    n_sig =0 
    for i, (a, d, b, g) in enumerate(zip(lb, m, ub, gr)):
        if g > 1.1:
            c = 'r'
        else:
            c = 'g' if (b < 0) or (a > 0) else 'k'
            n_sig += 1 if (b < 0) or (a > 0) else 0
        ax.plot([a, b], [i, i], '-', color=c, lw=0.5)
        ax.plot([d], [i], 'o', color=c)
        
    ax.set_xlabel(f'Change in {measure} (lg. re sm. pupil)\n{n_sig} sig. out of {len(lb)}\nWilcoxon p={stat.pvalue:0.4f}')
    sns.despine(ax=ax, top=True, left=True, right=True, bottom=False)
    ax.yaxis.set_ticks_position('none')
    ax.yaxis.set_ticks([])
    ax.grid()
    return ax


import seaborn as sns
from scipy import stats
bf_pupil_delta = results['bf_pupil_delta'].unstack().sort_values('mean')
bw_pupil_delta = results['bw_pupil_delta'].unstack().sort_values('mean')
gain_pupil_delta = results['gain_pupil_delta'].unstack().sort_values('mean')


f, axes = pl.subplots(1, 3, figsize=(9, 3))

plot(axes[0], bf_pupil_delta, 'BF')
plot(axes[1], bw_pupil_delta, 'BW')
plot(axes[2], gain_pupil_delta, 'gain')

In [None]:
m = bw_pupil_delta['gelman-rubin statistic'] < 1.1
np.median(bw_pupil_delta.loc[m]['mean'])