In [None]:
import arviz as az
import numpy as np
import os
import pandas as pd
import pickle
import pylab as pl
import seaborn as sns
from scipy import stats
%matplotlib inline

from support import load_rates, get_metric, forest_plot

In [None]:
print(model.model_code)

In [None]:
#filename = 'fits/hyena-rl_sr_th_bound_delta_band-2000.pkl'
#filename = 'fits/hyena-rl_sr_th_delta_band-2000.pkl'
#filename = 'fits/hyena-rl_sr_th_bound_prestim_band-2000.pkl'
#filename = 'fits/marmot-rl_sr_prestim_band-2000.pkl'
#filename = 'fits/marmot-rl_sr_th_delta_unconstrained_prestim_band-2000.pkl'
filename = 'fits/hyena-rl_sr_th_delta_prestim_band-2000.pkl'
folder = 'rlf_' + filename.split('-')[1][6:]
print(folder)

#folder = 'rate_level_band' if 'band' in filename else 'rate_level'
which = 'rlf_band' if 'band' in filename else 'rlf'

with open(filename, 'rb') as fh:
    cells = pickle.load(fh)
    model = pickle.load(fh)
    fit = pickle.load(fh)
    
os.makedirs(f'reports/{folder}/cells', exist_ok=True)

In [None]:
az.plot_trace(fit, ['slope_mean', 'slope_sd', 'threshold_beta', 'threshold_alpha'])
az.plot_trace(fit, ['threshold_mean', 'threshold_delta_mean', 'threshold_delta_sd'])

In [None]:
ci = 90
summary = az.summary(fit, fmt='xarray', credible_interval=ci/100)

In [None]:
f, axes = pl.subplots(1, 3, figsize=(12, 4))

cell_metric = get_metric(summary, 'sr_delta_cell')
pop_metric = get_metric(summary, 'sr_delta_mean')
forest_plot(axes[0], cell_metric, pop_metric, '$\Delta$ SR (lg. re sm.)', ci)

cell_metric = get_metric(summary, 'slope_delta_cell')
pop_metric = get_metric(summary, 'slope_delta_mean')
forest_plot(axes[1], cell_metric, pop_metric, '$\Delta$ slope (lg. re sm.)', ci)

cell_metric = get_metric(summary, 'threshold_delta_cell')
pop_metric = get_metric(summary, 'threshold_delta_mean')
forest_plot(axes[2], cell_metric, pop_metric, '$\Delta$ threshold (lg. re sm.)', ci)

f.savefig(f'reports/{folder}/coef_delta_summary.eps')
f.savefig(f'reports/{folder}/coef_delta_summary.pdf')
f.savefig(f'reports/{folder}/coef_delta_summary.png')

In [None]:
ci = 90

f, axes = pl.subplots(1, 3, figsize=(12, 4))

cell_metric = get_metric(summary, 'sr_cell')
pop_metric = get_metric(summary, 'sr_mean')
forest_plot(axes[0], cell_metric, pop_metric, 'SR (sm. pupil)', ci, ref=None)

cell_metric = get_metric(summary, 'slope_cell')
pop_metric = get_metric(summary, 'slope_mean')
forest_plot(axes[1], cell_metric, pop_metric, 'slope (sm. pupil)', ci, ref=None)

cell_metric = get_metric(summary, 'threshold_cell')
pop_metric = get_metric(summary, 'threshold_mean')
forest_plot(axes[2], cell_metric, pop_metric, 'threshold (sm. pupil)', ci, ref=None)

f.savefig(f'reports/{folder}/coef_summary.eps', bbox_inches='tight')
f.savefig(f'reports/{folder}/coef_summary.pdf', bbox_inches='tight')
f.savefig(f'reports/{folder}/coef_summary.png', bbox_inches='tight')

In [None]:
cols = [
    'sr_mean',
    'slope_mean',
    'threshold_mean',
    'sr_delta_mean',
    'slope_delta_mean',
    'threshold_delta_mean',
]
x = summary[cols].to_dataframe().T
x.to_csv(f'reports/{folder}/population_metrics.csv')

cols = [
    'sr_cell',
    'slope_cell',
    'threshold_cell',
    'sr_delta_cell',
    'slope_delta_cell',
    'threshold_delta_cell',
]

index = pd.Index(cells, name='cellid')
result = {}
for c in cols:
    r = summary[c].to_series().unstack('metric')
    r.index = index
    result[c] = r
result = pd.concat(result, names=['coefficient'])
result.to_csv(f'reports/{folder}/cell_metrics.csv')
x = result['mean'].unstack('coefficient')
x.to_csv(f'reports/{folder}/cell_metrics_mean_only.csv')

In [None]:
f, axes = pl.subplots(1, 3, figsize=(12, 4))

def plot_corr(ax, t, td, errors=True):

    s = (td['hpd 5%'] > 0) | (td['hpd 95%'] < 0)
    x = t.loc[~s]
    y = td.loc[~s]
    x_err = x[['hpd 5%', 'hpd 95%']].values - x[['mean']].values
    y_err = y[['hpd 5%', 'hpd 95%']].values - y[['mean']].values
    
    if errors:
        yerr = np.abs(y_err.T) 
        xerr = np.abs(x_err.T) 
    else:
        yerr = xerr = None
    ax.errorbar(x['mean'], y['mean'], xerr=xerr, yerr=yerr, fmt='ko', alpha=0.25)
    x = t.loc[s]
    y = td.loc[s]
    x_err = x[['hpd 5%', 'hpd 95%']].values - x[['mean']].values
    y_err = y[['hpd 5%', 'hpd 95%']].values - y[['mean']].values
    if errors:
        yerr = np.abs(y_err.T) 
        xerr = np.abs(x_err.T) 
    else:
        yerr = xerr = None
    ax.errorbar(x['mean'], y['mean'], xerr=xerr, yerr=yerr, fmt='go', alpha=0.25)

    #pl.errorbar(t.loc[~s, 'mean'], td.loc[~s, 'mean'], 'ko')
    #pl.plot(t.loc[s, 'mean'], td.loc[s, 'mean'], 'go')
    
t = result.loc['threshold_cell']
td = result.loc['threshold_delta_cell']
plot_corr(axes[0], t, td)
axes[0].set_xlabel('Threshold (dB SPL)')
axes[0].set_ylabel('$\Delta$ threshold (dB)')

t = result.loc['sr_cell']
td = result.loc['sr_delta_cell']
plot_corr(axes[1], t, td)
axes[1].set_xscale('log')
axes[1].set_xlabel('SR')
axes[1].set_ylabel('$\Delta$ SR')

t = result.loc['slope_cell']
td = result.loc['slope_delta_cell']
plot_corr(axes[2], t, td)
axes[2].set_xlabel('Slope')
axes[2].set_ylabel('$\Delta$ slope')

f.tight_layout()

In [None]:
td = result.loc['threshold_delta_cell']
sd = result.loc['slope_delta_cell']

f, ax = pl.subplots(1, 1, figsize=(5, 5))
plot_corr(ax, td, sd, False)
ax.set_xlabel('$\Delta$ threshold')
ax.set_ylabel('$\Delta$ slope')

In [None]:
def plot_raw_data(e, s, ax):
    x = e['level'].tolist()
    y = e.eval('count/time').tolist()
    #x = [0, 0] + e['level'].tolist()
    #y = s.eval('count/time').tolist() + e.eval('count/time').tolist()
    #size = np.array(s['time'].tolist() + e['time'].tolist())
    #color = s['pupil'].tolist() + e['pupil'].tolist()
    #size = 100 * size/size.mean()
    color = e['pupil'].tolist()
    colors = {0: 'seagreen', 1: 'orchid'}
    color = [colors[e] for e in color]
    ax.scatter(x, y, 10, color, alpha=0.5)
    #ax.plot(x, y, 'o', color=color, alpha=0.5)
    

def plot_fit(er, sr, summary, i, cells, ax):
    cell = cells[i]
    level = np.arange(0, 80)

    e = er.loc[cell].reset_index()
    s = sr.loc[cell].reset_index()
    plot_raw_data(e, s, ax)
    s = s.set_index('pupil').eval('count/time')
    ax.axhline(s.loc[0], ls=':', color='seagreen')
    ax.axhline(s.loc[1], ls=':', color='orchid')
    
    sr = get_metric(summary, 'sr_cell').loc[i, 'mean']
    slope = get_metric(summary, 'slope_cell').loc[i, 'mean']
    threshold = get_metric(summary, 'threshold_cell').loc[i, 'mean']
    
    sr_pupil = get_metric(summary, 'sr_cell_pupil').loc[i, 'mean']
    slope_pupil = get_metric(summary, 'slope_cell_pupil').loc[i, 'mean']
    threshold_pupil = get_metric(summary, 'threshold_cell_pupil').loc[i, 'mean']
    
    pred = slope * (level - threshold) + sr
    pred[level <= threshold] = sr
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='seagreen')

    pred = slope_pupil * (level - threshold_pupil) + sr_pupil
    pred[level <= threshold_pupil] = sr_pupil
    pred = np.clip(pred, 0, np.inf)
    ax.plot(level, pred, color='orchid')
    
    o = 1
    color_map = {'-': 'r', '=': 'k', '+': 'g'}
    for metric in ('sr_delta', 'slope_delta', 'threshold_delta'):
        ref = 0 if 'delta' in metric else 1
        m = get_metric(summary, f'{metric}_cell', sig_ref=ref).loc[i]
        o -= 0.05
        c = color_map[m['change']]
        ax.text(1.1, o, f'{metric}: {m["mean"]:.2f}', transform=ax.transAxes, color=c)

In [None]:
rates = load_rates()
er = rates[which]
                
if 'prestim' in folder:
    sr = rates['rlf_band_prestim'] 
    sr = sr.groupby(['cellid', 'pupil']).sum()
    sr['significant'] = sr['significant'].clip(0, 1)
else:
    sr = rates['sr']

In [None]:
f, ax = pl.subplots(1, 1, figsize=(5, 5))
for i, cell in enumerate(cells):
    ax.cla()
    plot_fit(er, sr, summary, i, cells, ax)
    t = f'{cell}'
    ax.set_title(t)
    ax.set_xlabel('Stim. level (dB SPL)')
    ax.set_ylabel('Rate (sp/sec)')
    f.savefig(f'reports/{folder}/cells/{cell}.png', bbox_inches='tight')
    f.savefig(f'reports/{folder}/cells/{cell}.pdf', bbox_inches='tight')
    
print(folder) 
print(which)