In [None]:
from support import CachedStanModel, load_rates, load_stan_data, get_metric, forest_plot
import numpy as np
import pandas as pd
import pickle
import pylab as pl
import arviz as az

In [None]:
mode, filename = 'additive', 'fits/hyena-ftc_sr_additive_exclude_silent_significant_only.pkl'
with open(filename, 'rb') as fh:
    cells = pickle.load(fh)
    model = pickle.load(fh)
    fit = pickle.load(fh)

In [None]:
#az.plot_trace(fit, ['bf_mean', 'bw_mean', 'gain_mean', 'sr_mean'])
#az.plot_trace(fit, ['bf_sd', 'bw_sd', 'gain_sd', 'sr_sd'])
#az.plot_trace(fit, ['bf_delta_mean', 'bw_ratio_mean', 'gain_ratio_mean', 'sr_delta_mean'])

In [None]:
ci = 90
summary = az.summary(fit, credible_interval=ci/100)

f, axes = pl.subplots(2, 2, figsize=(8, 8))

cell_metric = get_metric(summary, 'sr_delta_cell')
pop_metric = get_metric(summary, 'sr_delta_mean')
forest_plot(axes[0, 0], cell_metric, pop_metric, 'sr', ci)

cell_metric = get_metric(summary, 'bf_delta_cell')
pop_metric = get_metric(summary, 'bf_delta_mean')
forest_plot(axes[0, 1], cell_metric, pop_metric, 'bf', ci)

cell_metric = get_metric(summary, 'gain_ratio_cell')
pop_metric = get_metric(summary, 'gain_ratio_mean')
forest_plot(axes[1, 0], cell_metric, pop_metric, 'gain', ci, 1)

cell_metric = get_metric(summary, 'bw_ratio_cell')
pop_metric = get_metric(summary, 'bw_ratio_mean')
forest_plot(axes[1, 1], cell_metric, pop_metric, 'bw', ci, 1)

f.tight_layout()
f.savefig('reports/tuning/summary.eps')
f.savefig('reports/tuning/summary.pdf')
f.savefig('reports/tuning/summary.png')

In [None]:
cols = [
    'sr_mean',
    'bf_mean',
    'bw_mean',
    'gain_mean',
    'sr_delta_mean',
    'bf_delta_mean',
    'bw_delta_mean',
    'gain_delta_mean',
    'bw_ratio_mean',
    'gain_ratio_mean',
]
x = summary[cols].to_dataframe().T
x.to_csv('reports/tuning/population_metrics.csv')

cols = [
    'sr_cell',
    'bf_cell',
    'bw_cell',
    'gain_cell',
    'sr_delta_cell',
    'bf_delta_cell',
    'bw_delta_cell',
    'gain_delta_cell',
    'bw_ratio_cell',
    'gain_ratio_cell',
]
index = pd.Index(cells, name='cellid')
result = {}
for c in cols:
    r = summary[c].to_series().unstack('metric')
    r.index = index
    result[c] = r
    
result = pd.concat(result, names=['coefficient'])
result.to_csv('reports/tuning/cell_metrics.csv')
x = result['mean'].unstack('coefficient')
x.to_csv('reports/tuning/cell_metrics_mean_only.csv')

In [None]:
get_metric(summary, 'bw_ratio_cell').loc[5, 'mean']

In [None]:
def plot_fit(ax, summary, i, cells, ftc):
    cell = cells[i]
    frequency = np.linspace(2.5, 17.5, 100)
    
    sr = get_metric(summary, 'sr_cell').loc[i, 'mean']
    bw = get_metric(summary, 'bw_cell').loc[i, 'mean']
    bf = get_metric(summary, 'bf_cell').loc[i, 'mean']
    gain = get_metric(summary, 'gain_cell').loc[i, 'mean']
    
    gauss = gain * np.exp(-0.5 * np.square((frequency-bf)/bw))
    if mode == 'additive':
        gauss = sr + gauss
    elif mode == 'multiplicative':
        gauss = sr * gauss
    ax.plot(frequency, gauss, '-', color='seagreen')
    
    sr = get_metric(summary, 'sr_cell_pupil').loc[i, 'mean']
    bw = get_metric(summary, 'bw_cell_pupil').loc[i, 'mean']
    bf = get_metric(summary, 'bf_cell_pupil').loc[i, 'mean']
    gain = get_metric(summary, 'gain_cell_pupil').loc[i, 'mean']
    
    gauss = gain * np.exp(-0.5 * np.square((frequency-bf)/bw))
    if mode == 'additive':
        gauss = sr + gauss
    elif mode == 'multiplicative':
        gauss = sr * gauss
    ax.plot(frequency, gauss, '-', color='orchid')
    
    o = 1
    color_map = {'-': 'r', '=': 'k', '+': 'g'}
    for metric in ('sr_delta', 'bf_delta', 'bw_ratio', 'gain_ratio'):
        ref = 0 if 'delta' in metric else 1
        m = get_metric(summary, f'{metric}_cell', sig_ref=ref).loc[i]
        o -= 0.05
        c = color_map[m['change']]
        ax.text(1.1, o, f'{metric}: {m["mean"]:.2f}', transform=ax.transAxes, color=c)
            
    sr = get_metric(summary, 'sr_cell_pupil').loc[i, 'mean']
    bw = get_metric(summary, 'bw_cell_pupil').loc[i, 'mean']
    bf = get_metric(summary, 'bf_cell_pupil').loc[i, 'mean']
    gain = get_metric(summary, 'gain_cell_pupil').loc[i, 'mean']

    x = ftc.loc[cell, 0]
    x = x.eval('count/time').rename('rate').reset_index()
    ax.plot(x['frequency'], x['rate'], 'o', color='seagreen')
    
    x = ftc.loc[cell, 1]
    x = x.eval('count/time').rename('rate').reset_index()
    ax.plot(x['frequency'], x['rate'], 'o', color='orchid')
    
    
rates = load_rates()
f, ax = pl.subplots(1, 1, figsize=(5, 5))
for i, cell in enumerate(cells):
    ax.cla()
    plot_fit(ax, summary, i, cells, rates['ftc'])
    t = f'{cell}'
    ax.set_title(t)
    ax.set_xlabel('Freq. (Hz)')
    ax.set_ylabel('Rate (sp/sec)')
    f.savefig(f'reports/tuning/cells/{cell}.png', bbox_inches='tight')
    f.savefig(f'reports/tuning/cells/{cell}.pdf', bbox_inches='tight')
    
f.tight_layout()

In [None]:
gain_ratio = get_metric(summary, 'gain_ratio_cell')['mean']
bw_ratio = get_metric(summary, 'bw_ratio_cell')['mean']
pl.plot(gain_ratio, bw_ratio, 'ko')