## Figure on inference in HH model on Allen data

In [None]:
import delfi.utils.io as io
import delfi.utils.viz as viz
import matplotlib as mpl
import matplotlib.pyplot as plt
import model.utils as utils
import numpy as np
import seaborn as sns

from model import utils
from model.HodgkinHuxley import HodgkinHuxley
from model.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments
from scipy.special import erf

import sys; sys.path.append('../')
from common import col, svg, unit, plot_pdf, samples_nd


!mkdir -p svg/  # create svg subfolder if it does not exist

MPL_RC = '../.matplotlibrc'

PANEL_A = 'svg/fig_allen_a.svg'

%load_ext autoreload
%autoreload 2

## Load data

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
n_summary = 7
summary_stats = 1


list_cells_AllenDB = [[518290966,57,0.0234/126],[509881736,39,0.0153/184],[566517779,46,0.0195/198],
                      [567399060,38,0.0259/161],[569469018,44,0.033/403],[532571720,42,0.0139/127],
                      [555060623,34,0.0294/320],[534524026,29,0.027/209],[532355382,33,0.0199/230],
                      [526950199,37,0.0186/218]]


n_post = len(list_cells_AllenDB)

# define prior
p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)

# SNPE parameters
n_components = 1
n_sims = 125000
n_rounds = 2
svi = False
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

# IBEA parameters
algo = 'ibea'
offspring_size = 500
max_ngen = 100

obs_ls = []
I_ls = []
dt_ls = []
t_on_ls = []
t_off_ls = []
obs_stats_ls = []
m_ls = []
s_ls = []
posterior_ls = []
for cell_num in range(n_post):

    ephys_cell = list_cells_AllenDB[cell_num][0]
    sweep_number = list_cells_AllenDB[cell_num][1]
    A_soma = list_cells_AllenDB[cell_num][2]
    junction_potential = -14

    obs = utils.allen_obs_data(ephys_cell=ephys_cell,sweep_number=sweep_number,A_soma=A_soma)

    obs['data'] = obs['data'] + junction_potential
    I = obs['I']
    dt = obs['dt']
    t_on = obs['t_on']
    t_off = obs['t_off']
    
    obs_ls.append(obs)
    I_ls.append(I)
    dt_ls.append(dt)
    t_on_ls.append(t_on)
    t_off_ls.append(t_off)
    
    obs_stats = utils.allen_obs_stats(data=obs,ephys_cell=ephys_cell,sweep_number=sweep_number,
                                  n_xcorr=n_xcorr,n_mom=n_mom,
                                  summary_stats=summary_stats,n_summary=n_summary)
    
    obs_stats_ls.append(obs_stats[0])
    
    m = HodgkinHuxley(I=I, dt=dt, V0=obs['data'][0], seed=seed, cython=cython, prior_log=prior_log)
    m_ls.append(m)
    
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=n_summary)
    s_ls.append(s)
    
    ##############################################################################
    # SNPE results
    filename1 = './results/allen_'+str(ephys_cell)+'_'+str(sweep_number)+\
    '_run_1_round2_prior0013_param8'+svi_flag+'_ncomp'+str(n_components)+\
    '_nsims'+str(50000)+'_snpe.pkl'  #str(n_sims*n_rounds)

    try:
        _, _, posterior = io.load_pkl(filename1)
        posterior_ls.append(posterior[-1])
    except:
        pass

## Panel A

In [None]:
fig_inches = (8.503941600000001, 2.48474543625)

In [None]:
# Mode sample
import cma

prior_lims = np.vstack((p.lower, p.upper)).T

def mode_sample(posterior_, init=(prior_lims[:,1] - prior_lims[:,0])/2, rate=0.01):
    es = cma.CMAEvolutionStrategy(
        init,
        rate,
        {'scaling_of_variables':(p.upper-p.lower)})
    es.optimize(lambda x: -1. * posterior_.eval(x) )
    es.result_pretty()
    return es.best.x


In [None]:
cell_num = 0
fig_inches = (3.8, 3.8)
mn_post = mode_sample(posterior_ls[cell_num])

x_post = m_ls[cell_num].gen_single(mn_post)
plt.plot(x_post['data'], color = col['CONSISTENT1'], lw=1.5, label='mode', clip_on=False, zorder=100)

In [None]:
fig, axes = samples_nd(posterior_ls[cell_num].gen(5000),
                       limits=prior_lims,
                       ticks=prior_lims,
                       fig_size=fig_inches,
                       diag='kde',
                       upper='hist',
                       hist_diag={'bins': 50},
                       hist_offdiag={'bins': 50},
                       kde_diag={'bins': 50, 'color': col['SNPE']},
                       kde_offdiag={'bins': 50},
                       points=[mn_post],
                       points_offdiag={'markersize': 5},
                       points_colors=[col['CONSISTENT1']],
);

In [None]:
cell_ls = [0,1,2,3,4,5,6,7,8,9]
mn_post_ls = []
x_post_ls = []

for c, cell_num in enumerate(cell_ls):
    mn_post_ls.append(posterior_ls[cell_num].xs[0].m)
    x_post_ls.append(m_ls[cell_num].gen_single(mn_post_ls[-1]))

In [None]:
# number of cell recordings plotted
num_rec = 8

# matplotlib takes figsize specified as inches
fig_inches = (8.5, 2.0)

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)
    
    i = 0
    for c, cell_num in enumerate(cell_ls[:num_rec]):
        i += 1
        
        y_obs = obs_ls[cell_num]['data']
        t = obs_ls[cell_num]['time']
        duration = np.max(t)
        xlim_int = [-100, duration]
        ylim_int = [-25+y_obs[0],40]
        
        mn_post = mn_post_ls[c]
        x_post = x_post_ls[c]
        
        plt.subplot(2,num_rec,i)
        plt.title('ID ' + str(list_cells_AllenDB[c][0]))
        plt.plot(t, y_obs, color = col['GT'], lw=1.2, label='Allen Cell Types Database', clip_on=False, zorder=100)
        ax = plt.gca()       
        ax.spines['bottom'].set_color('w')
        ax.spines['left'].set_color('w')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(xlim_int)
        ax.set_ylim(ylim_int)
        
        plt.subplot(2, num_rec, i+num_rec)
        plt.plot(t, x_post['data'], color = col['CONSISTENT1'], lw=1.2, label='mode', clip_on=False, zorder=100)

        ax = plt.gca()       
        ax.spines['bottom'].set_color('w')
        ax.spines['left'].set_color('w')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(xlim_int)
        ax.set_ylim(ylim_int)

        if c == 0:
            delta_bar_x = 500
            bar_x = [xlim_int[0],xlim_int[0]+delta_bar_x]
            delta_bar_y = 80
            bar_y = [ylim_int[0],ylim_int[0]+delta_bar_y]

            plt.plot(bar_x, [ylim_int[0], ylim_int[0]],'k',linewidth=1.5, clip_on=False, zorder=100)
            plt.xlabel(str(delta_bar_x)+' ms',x=.17)
            plt.plot([xlim_int[0], xlim_int[0]], bar_y,'k',linewidth=1.5, clip_on=False, zorder=100)
            plt.ylabel(str(delta_bar_y)+' mV',y=.30)
     
        plt.tight_layout(h_pad=0.01, w_pad=0.01, pad=0.01)
        
    plt.savefig(PANEL_A, facecolor='None', transparent=True)  # the figure is saved as svg
    plt.show()
    plt.close()
    #svg(PANEL_A)

## Posteriors for supplement

In [None]:
import seaborn as sns

LABELS_HH =[r'$g_{Na}$', r'$g_{K}$', r'$g_{l}$', r'$g_{M}$', r'$\tau_{max}$',
            '-'+ r'$V_{T}$', r'$\sigma$', '-'+r'$E_{l}$']

num_post = 8

MPL_RC = '../.matplotlibrc'
with mpl.rc_context(fname=MPL_RC):
    fig, axes = samples_nd([posterior_ls[i].gen(10000) for i in range(num_post)], 
               samples_colors = sns.light_palette(col['SNPE'], n_colors=num_post, reverse=False).as_hex(),
               upper='contour', 
               diag='kde',
               kde_offdiag={'bins': 200},
               title='',
               limits=prior_lims,
               ticks=prior_lims,
               labels=LABELS_HH,
               contour_offdiag={'levels': [0.68]},
               fig_size=(12, 12));

    plt.savefig('svg/posterior_supp_allen.svg', facecolor='None', transparent=True)  # the figure is saved as svg
    plt.show()
    plt.close()

!cp svg/posterior_supp_allen.svg ../5_hh/fig/fig5_hh_supp_posteriors_allen.svg