In [None]:
from lbhb.psychometric import CachedStanModel
import pandas as pd
import numpy as np
import pylab as pl
import arviz as az


In [None]:
#sig = pd.read_csv('psth_sig_cellids.csv')['cellid'].unique()
df = pd.read_csv('frequency_tuning_curves_for_bburan.csv')
df.columns = [s.replace(' ', '') for s in df.columns]
cols = ['pupil', 'frequency', 'ftc_count', 'ftc_time', 'spont_count', 'spont_time']
df = pd.wide_to_long(df, cols, 'cellid', 'idx', sep='_').dropna()
df['pupil'] -= 1
df['frequency'] = np.log(df['frequency'])

#mask = df.apply(lambda x: x.name[0] in sig, axis=1)
#df = df.loc[mask]

sr = df.groupby(['cellid', 'pupil'])[['spont_count', 'spont_time']].first().sort_index()
ftc = df.reset_index().set_index(['cellid', 'pupil', 'frequency'])[['ftc_count', 'ftc_time']].sort_index()
m = ftc['ftc_time'] > 0
ftc = ftc.loc[m]

cells = ftc.index.get_level_values('cellid').unique()

In [None]:
model = CachedStanModel('gaussian_FTC_single_cell.stan')

In [None]:
def get_cell_data(cell, ftc, sr):
    e = ftc.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()

    n = len(e)
    frequency = e['frequency'].values
    spike_count = e['ftc_count'].values.astype('i')
    sample_time = e['ftc_time'].values
    pupil = e['pupil'].values

    spont_count = s['spont_count'].values.astype('i')
    spont_time = s['spont_time'].values
    
    return {
        'n': n,
        'freq': frequency,
        'spike_count': spike_count,
        'sample_time': sample_time,
        'spont_count': spont_count,
        'spont_time': spont_time,
        'pupil': pupil,
    }


fits = {}
for cell in cells:
    data = get_cell_data(cell, ftc, sr)
    fits[cell] = model.sampling(data, iter=10000, control={'adapt_delta': 0.99})

In [None]:
with open('ftc_model.pkl', 'wb') as fh:
    pickle.dump(model, fh)

In [None]:
import pickle
with open('ftc_fits.pkl', 'wb') as fh:
    pickle.dump(fits, fh)

In [None]:
len(fits)

In [None]:
def plot_fit(ax, fit, data):
    bf = fit['bf'].mean()
    gain = fit['gain'].mean()
    bw = fit['bw'].mean()
    offset = fit['offset'].mean()

    frequency = np.arange(3, 11, 0.1)
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset + gain * l
    ax.plot(np.exp(frequency), l, ':', color='orchid', label='Sm. pupil')
    ax.axhline(offset, color='orchid')

    bf += fit['bf_pupil_delta'].mean()
    gain += fit['gain_pupil_delta'].mean()
    bw += fit['bw_pupil_delta'].mean()
    offset += fit['offset_pupil_delta'].mean()
    l = np.exp(-0.5*np.square((frequency-bf)/bw))
    l = offset + gain * l
    ax.plot(np.exp(frequency), l, ':', color='seagreen', label='Lg. pupil')
    ax.axhline(offset, color='seagreen')

    pupil = data['pupil']
    frequency = data['freq']
    evoked_rate = data['spike_count'] / data['sample_time']
    spont_rate = data['spont_count'] / data['spont_time']

    m_pupil = pupil == 0
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='orchid')
    m_pupil = pupil == 1
    ax.plot(np.exp(frequency[m_pupil]), evoked_rate[m_pupil], 'o-', color='seagreen')

    ax.axhline(spont_rate[0], color='orchid', ls=':', lw=2)
    ax.axhline(spont_rate[1], color='seagreen', ls=':', lw=2)
    ax.set_xscale('log')

f, axes = pl.subplots(10, 12, figsize=(20, 20))

for ax, (cell, fit) in zip(axes.ravel(), fits.items()):
    data = get_cell_data(cell, ftc, sr)
    plot_fit(ax, fit, data)
    
ax.legend()

In [None]:
fit['bandwidth'].mean(axis=0).mean()
fit['bandwidth_mean'].mean(), fit['bandwidth_sd'].mean()

In [None]:
av.plot_trace(fit, ['bandwidth_mean', 'bandwidth_sd', 'offset_alpha', 'offset_beta'])