In [None]:
import delfi.distribution as dd
import delfi.utils.io as io
import matplotlib as mpl
import matplotlib.pyplot as plt
import model.utils as utils
import numpy as np
import seaborn as sns

from model.HodgkinHuxley import HodgkinHuxley
from model.HodgkinHuxleyStatsMoments import HodgkinHuxleyStatsMoments

import sys; sys.path.append('../')
from common import col, svg, plot_pdf, samples_nd

%load_ext autoreload
%autoreload 2

!mkdir -p svg/


PANEL_A = 'illustration/panel_a.svg'
PANEL_B = 'svg/panel_b.svg'
PANEL_C = 'svg/panel_c.svg'
PANEL_D1 = 'svg/panel_d1.svg'
PANEL_D2 = 'svg/panel_d2.svg'
PANEL_D3 = 'svg/panel_d3.svg'
PANEL_E = 'svg/panel_e.svg'

SUPP_POST_1 = 'svg/'

LABELS_HH =[r'$g_{Na}$', r'$g_{K}$', r'$g_{l}$', r'$g_{M}$', r'$\tau_{max}$',
            '-'+ r'$V_{T}$', r'$\sigma$', '-'+r'$E_{l}$']
LABELS_HH_SPIKE_SUMSTATS = [r'$f\_rat$',r'$ap\_l$',r'$ap\_o$',r'$r\_pot$',r'$\sigma_{r\_pot}$',r'$ahd$',
                      r'$a\_ind$',r'$sp\_w$']
LABELS_HH_SUMSTATS = [r'$sp$',r'$rpot$',r'$\sigma_{rpot}$','$m_1$','$m_2$','$m_3$','$m_4$']

### Panel A

In [None]:
svg(PANEL_A)

### Panel B

In [None]:
true_params, labels_params = utils.obs_params()

seed = 1
prior_uniform = True
prior_log = False
prior_extent = True
n_xcorr = 0
n_mom = 4
cython=True
summary_stats = 1

I, t_on, t_off, dt = utils.syn_current()
A_soma = np.pi*((70.*1e-4)**2)  # cm2

obs = utils.syn_obs_data(I, dt, true_params, seed=seed, cython=cython)
obs_stats = utils.syn_obs_stats(data=obs,I=I, t_on=t_on, t_off=t_off, dt=dt, params=true_params,
                                seed=seed, n_xcorr=n_xcorr, n_mom=n_mom, cython=cython,
                                summary_stats=summary_stats,n_summary=7)
y_obs = obs['data']
t = obs['time']
duration = np.max(t)

m = HodgkinHuxley(I, dt, V0=obs['data'][0],seed=seed, cython=cython,prior_log=prior_log)

p = utils.prior(true_params=true_params,prior_uniform=prior_uniform,
                prior_extent=prior_extent,prior_log=prior_log, seed=seed)

n_summary_ls = [1,4,7]
n_post = len(n_summary_ls)
s_ls = []
posterior_ls = []
for nsum in n_summary_ls:
    s = HodgkinHuxleyStatsMoments(t_on=t_on, t_off=t_off,n_xcorr=n_xcorr,n_mom=n_mom,n_summary=nsum)
    s_ls.append(s)
    filename1 = './results/posterior_{nsum}_single_round_maf_lfs.pkl'.format(nsum=nsum)  # apt_mdn_nsum_{nsum}_mogp_lfs
    _, _, posterior = io.load_pkl(filename1)
    posterior_ls.append(posterior)

In [None]:
true_params

In [None]:
# Mode sample
import cma

def mode_sample(posterior_):
    init = true_params
    es = cma.CMAEvolutionStrategy(
        init,
        0.001,
        {'scaling_of_variables':(p.upper-p.lower)})
    es.optimize(lambda x: -1. * posterior_.eval(x) )
    es.result_pretty()
    return es.best.x
    #return posterior_.xs[0].m



In [None]:
# FIGURE and GRI
fig_inches = (3.8, 3.8)
prior_lims = np.vstack((p.lower, p.upper)).T
mn_post = mode_sample(posterior_ls[2][-1])
post_low = 0.6*mn_post

In [None]:
MPL_RC = '../.matplotlibrc'
with mpl.rc_context(fname=MPL_RC):
    fig, axes = samples_nd(posterior_ls[2][-1].gen(10000),
                           limits=prior_lims,
                           ticks=prior_lims,
                           labels=LABELS_HH,
                           fig_size=fig_inches,
                           diag='kde',
                           upper='kde',
                           hist_diag={'bins': 50},
                           hist_offdiag={'bins': 50},
                           kde_diag={'bins': 50, 'color': col['SNPE']},
                           kde_offdiag={'bins': 50},
                           points=[true_params, mn_post, post_low],
                           points_offdiag={'markersize': 5},
                           points_colors=[col['GT'], col['CONSISTENT1'], col['INCONSISTENT']],
    );

    plt.savefig(PANEL_B, facecolor='None', transparent=True)
    plt.show()

In [None]:
# For supplement
fig_inches = (5.8, 5.8)

prior_lims = np.vstack((p.lower, p.upper)).T

for i, posterior in enumerate(posterior_ls):
    MPL_RC = '../.matplotlibrc'
    with mpl.rc_context(fname=MPL_RC):
        fig, axes = samples_nd(posterior_ls[i][-1].gen(10000),
                               limits=prior_lims,
                               ticks=prior_lims,
                               labels=LABELS_HH,
                               fig_size=fig_inches,
                               diag='kde',
                               upper='kde',
                               hist_diag={'bins': 50},
                               hist_offdiag={'bins': 50},
                               kde_diag={'bins': 50, 'color': col['SNPE']},
                               kde_offdiag={'bins': 50},
                               points=[true_params],
                               points_offdiag={'markersize': 5},
                               points_colors=[col['GT']],
        );

        plt.savefig('svg/posterior_supp_{i}.svg'.format(i=i), facecolor='None', transparent=True)
        plt.show()

In [None]:
# For supplement
fig_inches = (12, 12)

prior_lims = np.vstack((p.lower, p.upper)).T

MPL_RC = '../.matplotlibrc'
with mpl.rc_context(fname=MPL_RC):
    fig, axes = samples_nd([posterior_ls[i][-1].gen(10000) for i in range(3)],
                           limits=prior_lims,
                           ticks=prior_lims,
                           labels=LABELS_HH,
                           fig_size=fig_inches,
                           title='',
                           diag='kde',
                           upper='contour',
                           samples_colors = sns.light_palette(col['SNPE'], n_colors=3, reverse=False).as_hex(),
                           hist_diag={'bins': 50},
                           hist_offdiag={'bins': 50},
                           kde_diag={'bins': 50, 'color': col['SNPE']},
                           kde_offdiag={'bins': 200},
                           points=[true_params],
                           points_offdiag={'markersize': 5},
                           points_colors=[col['GT']],
    );

    plt.savefig('svg/posterior_supp_combined.svg', facecolor='None', transparent=True)
    plt.show()
    
!cp svg/posterior_supp_combined.svg fig/fig5_hh_supp_posteriors.svg

### panel C

In [None]:
# sample from prior and compute respective standard deviation of features
# this step takes around 20 seconds when run on 5 parallel processes

from multiprocessing import Pool

num_rep_prior = 1000

# sample from prior
params_prior = p.gen(num_rep_prior)

# simulate
n_processes = 5
seeds_model = np.arange(1,num_rep_prior+1,1) # as many seeds as simulations
def sim_f(param):
    m = HodgkinHuxley(I, dt, V0=obs['data'][0],seed=int(param[0]), cython=cython)
    return m.gen_single(param[1:])

params_seed = np.concatenate((seeds_model.reshape(-1,1),params_prior),axis=1)

pool = Pool(n_processes)
data_prior = pool.map(sim_f, params_seed)
pool.close()
pool.join()

# compute summary statistics (takes a few secs)
std_stats_prior = np.nanstd(s_ls[-1].calc(data_prior),axis=0)

In [None]:
# number of simulations for same parameter set
num_rep = 100

x_high_ls = []
x_low_ls = []
sum_stats_high_ls = []
sum_stats_low_ls = []
for rep in range(num_rep):
    x_high = m.gen_single(mn_post)
    x_high_ls.append(x_high)

    sum_stats_high = s_ls[-1].calc([x_high])[0]
    sum_stats_high_ls.append(sum_stats_high)

    x_low = m.gen_single(post_low)
    x_low_ls.append(x_low)

    sum_stats_low = s_ls[-1].calc([x_low])[0]
    sum_stats_low_ls.append(sum_stats_low)

mn_stats_high = np.nanmean(sum_stats_high_ls,axis=0)
std_stats_high = np.nanstd(sum_stats_high_ls,axis=0)

mn_stats_low = np.nanmean(sum_stats_low_ls,axis=0)
std_stats_low = np.nanstd(sum_stats_low_ls,axis=0)

In [None]:
n_summary_stats = n_summary_ls[-1]

obs_stats_norm_mat = obs_stats[0]/std_stats_prior
arg_sort_stats = np.linspace(0,n_summary_stats-1,n_summary_stats).astype('int')
LABELS_HH_SUMSTATS1 = np.array(LABELS_HH_SUMSTATS)

mn_stats_high_norm = mn_stats_high/std_stats_prior
mn_stats_low_norm = mn_stats_low/std_stats_prior

std_stats_high_norm = std_stats_high/std_stats_prior
std_stats_low_norm = std_stats_low/std_stats_prior

# matplotlib takes figsize specified as inches
fig_inches = (5.7, 1.5)

with mpl.rc_context(fname=MPL_RC):
    fig = plt.figure(figsize=fig_inches)

    ax = plt.subplot(121)
    plt.plot(t, obs['data'], color = col['GT'], lw=2, label='observation')
    plt.plot(t, x_high['data'], color = col['CONSISTENT1'], lw=2, label='mode')
    plt.plot(t, x_low['data'], color = col['INCONSISTENT'], lw=2, label='low prob')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')
    sns.despine(offset=10, trim=True)

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -20, 40])

    width = 0.3
    ax = plt.subplot(122)
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),obs_stats_norm_mat[arg_sort_stats],
            width,color=col['GT'],label='observation', edgecolor='w')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,mn_stats_high_norm[arg_sort_stats],
            width, color=col['CONSISTENT1'],yerr=std_stats_high_norm[arg_sort_stats],label='mode', edgecolor='w')
    plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,mn_stats_low_norm[arg_sort_stats],
            width, color=col['INCONSISTENT'],yerr=std_stats_low_norm[arg_sort_stats],label='low probability', edgecolor='w')
    ax.set_xlim(-1.5*width,n_summary_stats+width/2)
    ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
    ax.set_xticklabels(LABELS_HH_SUMSTATS1[arg_sort_stats])
    plt.ylabel(r'$\frac{f}{\sigma_{f \ PRIOR}}$')
    plt.legend(bbox_to_anchor=(1.75, 0.5))

    ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%g'))
    #ax.set_yticks([0, -5])
    plt.ylim([-5, 1])
    ax.set_yticks([-5, -2, 1])
    sns.despine(offset=10, trim=True)

    plt.savefig(PANEL_C, facecolor='None', transparent=True)  # the figure is saved as svg
    plt.show()

### panel D

In [None]:
# mode for each of the 3 posteriors

mode_feat1 = mode_sample(posterior_ls[0][-1])
mode_feat4 = mode_sample(posterior_ls[1][-1])
mode_feat7 = mode_sample(posterior_ls[-1][-1])

post_modes = np.concatenate((mode_feat1.reshape(-1,1),mode_feat4.reshape(-1,1),mode_feat7.reshape(-1,1)), axis=1)

x_feat1 = m.gen_single(mode_feat1)
x_feat4 = m.gen_single(mode_feat4)
x_feat7 = m.gen_single(mode_feat7)

modes = [mode_feat1, mode_feat4, mode_feat7]

In [None]:
fig_inches = (1.9, 1.9)

label_feature = [' feature']+[' features']*(n_post-1)

PANEL_D_ls = [PANEL_D1,PANEL_D2,PANEL_D3]

for i in range(n_post):
    with mpl.rc_context(fname='../.matplotlibrc'):

        idx = partial_ls = [0,1,2,5]

        labels_params_select = np.array(LABELS_HH)[idx]

        plot_post = posterior_ls[i][-1]


        fig, axes = samples_nd(plot_post.gen(100000),
                               limits=prior_lims,
                               ticks=prior_lims,
                               labels=LABELS_HH,
                               fig_size=fig_inches,
                               diag='kde',
                               upper='kde',
                               hist_diag={'bins': 50},
                               hist_offdiag={'bins': 50},
                               kde_diag={'bins': 50, 'color': col['SNPE']},
                               kde_offdiag={'bins': 50},
                               points=[true_params],
                               points_offdiag={'markersize': 5},
                               points_colors=[col['GT']],
                               subset=partial_ls,
        );
        
        x0, xmax = plt.xlim()
        y0, ymax = plt.ylim()
        data_width = xmax - x0
        data_height = ymax - y0
        
        if i == 0:
            title = '1 feature (spike count)'
            plt.text(x0+data_width*-2.9, y0+data_height*4.8,
                 title,fontsize=8)
        else:
            title = str(n_summary_ls[i])+str(label_feature[i])
            plt.text(x0+data_width*-1.4, y0+data_height*4.8,
                     title,fontsize=8)

        plt.savefig(PANEL_D_ls[i], facecolor='None', transparent=True)  # the figure is saved as svg
        plt.show()

In [None]:
# matplotlib takes figsize specified as inches
fig_inches = (2.6, 1.3)

col_min = 1
num_colors = 2+col_min
cm1 = mpl.cm.Reds
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

with mpl.rc_context(fname='../.matplotlibrc'):
    fig = plt.figure(figsize=fig_inches)

    plt.plot(t, obs['data'], color=col['GT'], lw=2, label='observation')
    plt.plot(t, x_feat7['data'], color=col['CONSISTENT1'], alpha=1, lw=1.1, label='7 features')
    plt.plot(t, x_feat4['data'], '--', color=col['CONSISTENT1'], alpha=1, lw=1.1, label='4 features')
    plt.plot(t, x_feat1['data'], ':', color=col['CONSISTENT1'], alpha=1, lw=1.1, label='1 feature')
    plt.xlabel('time (ms)')
    plt.ylabel('voltage (mV)')
    plt.legend(bbox_to_anchor=(1.1, -0.5), ncol=2)

    ax = plt.gca()
    ax.set_xticks([0, duration/2, duration])
    ax.set_yticks([-80, -10, 60])

    sns.despine(offset=10, trim=True)

    plt.savefig(PANEL_E, facecolor='None', transparent=True)  # the figure is saved as svg
    plt.show()

### compose figure

In [None]:
from svgutils.compose import *

# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964
svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise

# Panel letters in Helvetica Neue, 12pt, Medium
kwargs_text = {'size': '12pt', 'font': 'Arial', 'weight': '800'}

f = Figure("20.3cm", "18.9cm",

    Panel(
          SVG(PANEL_A).scale(svg_scale).scale(0.58).move(0, 15),
          Text("A", 0, 13, **kwargs_text),
    ).move(0, 0),

    Panel(
          SVG(PANEL_B).scale(svg_scale).move(-10,0),
          Text("B", -20, 13, **kwargs_text),
    ).move(480, 0),

    Panel(
          SVG(PANEL_C).scale(svg_scale).move(0, 0),
          Text("C", 0, 13, **kwargs_text),
    ).move(0, 160),

    Panel(
        SVG(PANEL_D1).scale(svg_scale).move(0, 0),
        SVG(PANEL_D2).scale(svg_scale).move(170, 0),
        SVG(PANEL_D3).scale(svg_scale).move(340, 0),
        Text("D", 0, 13, **kwargs_text),
    ).move(0, 325),

    Panel(
        SVG(PANEL_E).scale(svg_scale).move(-8, 5),
        Text("E", 0, 13, **kwargs_text),
    ).move(520, 325),

    Panel(
        SVG('../6_allen/svg/fig_allen_a.svg').scale(svg_scale).move(0, 0),
        Text("F", 0, 15, **kwargs_text),
    ).move(0, 525),
           
    #Grid(50, 50),
)

!mkdir -p fig
f.save("fig/fig5_hh.svg")
svg('fig/fig5_hh.svg')