In [None]:
%matplotlib inline

In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import pandas as pd
import numpy as np
import arviz as az
import pylab as pl

In [None]:
df = pd.read_csv('frequency_tuning_curves_for_bburan.csv')
df.columns = [s.replace(' ', '') for s in df.columns]
cols = ['pupil', 'frequency', 'ftc_count', 'ftc_time', 'spont_count', 'spont_time']
df = pd.wide_to_long(df, cols, 'cellid', 'idx', sep='_').dropna()

sr = df.groupby(['cellid', 'pupil'])[['spont_count', 'spont_time']].first().sort_index()
ftc = df.reset_index().set_index(['cellid', 'pupil', 'frequency'])[['ftc_count', 'ftc_time']].sort_index()

In [None]:
x = ftc.xs(1, level='pupil').eval('ftc_count/ftc_time')
f, axes = pl.subplots(10, 10, figsize=(15, 15))

for c, ax in zip(cells, axes.ravel()):
    x.loc[c].plot(ax=ax)
    ax.set_xscale("log")

In [None]:
import matplotlib as mp
mp.colors.is_color_like('ewoijgw')

In [None]:
#cell = 'BOL005c-04-1'
cells = ftc.index.get_level_values('cellid').unique()
pupil = 1

#e = ftc.loc[cells[0]:cells[10]].xs(1, level='pupil').reset_index()
#s = sr.loc[cells[0]:cells[10]].xs(1, level='pupil').reset_index()
e = ftc.xs(1, level='pupil').reset_index()
s = sr.xs(1, level='pupil').reset_index()
#e = ftc.reset_index()
#s = sr.reset_index()

cells = e['cellid'].unique()
cell_map = {c: i for i, c in enumerate(cells)}
cell_index = e['cellid'].apply(cell_map.get).values
s['cell_index'] = s['cellid'].map(cell_map.get)

n = len(e)
n_cells = len(cell_map)

#pupil = e['pupil'] - 1
frequency = np.log(e['frequency'].values)
spike_count = e['ftc_count'].values.astype('i')
sample_time = e['ftc_time'].values
spont_count = s['spont_count']
spont_time = s['spont_time']

In [None]:
import pymc3 as pm

with pm.Model() as model:
    BF = pm.Normal('BF_cell', mu=7.5, sd=2, shape=n_cells)
    BF_pupil = pm.Normal('BF_pupil', mu=7)
    
    bandwidth_alpha = pm.Bound(pm.Normal, lower=0)('bandwidth_alpha', mu=2, sd=1)
    bandwidth_beta = pm.Bound(pm.Normal, lower=0)('bandwidth_beta', mu=0.5, sd=0.25)
    bandwidth = pm.Gamma('bandwidth', alpha=bandwidth_alpha, beta=bandwidth_beta)
    bandwidth_sd = pm.HalfNormal('bandwidth_sd', sd=1)
    bandwidth = pm.HalfNormal('bandwidth_cell', sd=bandwidth_sd, shape=n_cells)
    
    gain_mean = pm.Normal('gain', mu=10, sd=20)
    gain_sd = pm.HalfNormal('gain_sd', sd=20)
    gain = pm.Normal('gain_cell', mu=gain_mean, sd=gain_sd, shape=n_cells)
    
    offset_sd = pm.HalfNormal('offset_sd', sd=2)
    offset = pm.HalfNormal('offset_cell', sd=offset_sd, shape=n_cells)
    sr_obs = pm.Poisson('sr_obs', mu=offset*spont_time, observed=spont_count)
    
    fc = BF[cell_index]
    o = offset[cell_index]
    bw = bandwidth[cell_index]
    g = gain[cell_index]
    rate = o + g * np.exp(-((frequency-fc)**2/(2*bw**2)))
    ftc_obs = pm.Poisson('ftc_obs', mu=rate*sample_time, observed=spike_count)
    
    fit = pm.sample()

In [None]:
az.plot_trace(fit, ['offset_cell'])

In [None]:
BF = fit['BF_cell'].mean(axis=0)
gain = fit['gain_cell'].mean(axis=0)
offset = fit['offset_cell'].mean(axis=0)
bandwidth = fit['bandwidth_cell'].mean(axis=0)

In [None]:
f, axes = pl.subplots(10, 10, figsize=(20, 20))

for cell, ax in zip(cells, axes.ravel()):
    i = cell_map[cell]

    g = gain[i]
    bf = BF[i]
    o = offset[i]
    bw = bandwidth[i]

    f = np.log(np.arange(100, 50000))
    rate = o + g * np.exp(-((f-bf)**2/(2*bw**2)))
    ax.plot(f, rate)

    x = e.query(f'cellid == "{cell}"')
    ax.plot(np.log(x['frequency']), x.eval('ftc_count/ftc_time'))
    #ax.axhline(s.query(f'cellid == "{cell}"').eval('spont_count/spont_time').iloc[0])
    ax.set_xscale('log')