# Figures

Note that for figure 2 of the manuscript, we manually merged panels from `2_glm` and `3_glm` after revisions. Results of the benchmark comparison between SNPE, Rejection ABC, and SMC-ABC (panel d) are in `benchmark/benchmark_results.zip`.

In [None]:
%matplotlib inline

import delfi.distribution as dd
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pickle
import seaborn as sns

from model.GLM import GLM
from mpl_toolkits.axes_grid1 import make_axes_locatable

import sys; sys.path.append('../')
from common import col, svg, plot_pdf

!mkdir -p svg/

tmp = np.load('results/single_round_lfs.npy',
              allow_pickle=True)[()]
posterior, prior = tmp['posterior'], tmp['prior']

tmp = np.load('results/ground_truth_data_lfs.npy',
              allow_pickle=True)[()]
obs_stats, pars_true, labels_params = tmp['obs_stats'],  tmp['pars_true'], tmp['labels_params']

samples = np.load('results/sam_lfs.npz')['arr_0'].T

plot_post = posterior
plot_prior = prior
pars_raw   = pars_true
labels_params = params_ls = ['bias', r'$f_0$', r'$f_1$', r'$f_2$', r'$f_3$', r'$f_4$', r'$f_5$', r'$f_6$', r'$f_7$', r'$f_8$']

lims_samples = np.array([[-3, -3, -3, -3, -3, -3, -3, -3, -3, -3],
                         [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3]]).T

lims_post    = np.array([[-3., -2, -1, -1, -1, -3, -3, -3, -3, -3],
                         [ 3.,  2,  3,  3,  2,  3,  3,  3,  3,  3]]).T

#obs_stats, pars_true, params_ls

## Model illustration

The following plot is used as stimulus in the model illustration.

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    plt.figure(figsize=(3,2))

    m = GLM(seed=42)
    obs = m.gen_single(pars_true)
    idx = np.arange(0, 100)
    plt.plot(idx, obs['I'][idx], color=[0.2, 0.2, 0.2])
    plt.axis('off')

    plt.savefig('illustration/stimulus.svg', facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.show()

The following plot is used as ground truth filter in the model illustration:

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    plt.figure(figsize=(3,2))
    plt.plot(np.arange(pars_true.size-1)+1, pars_true.flatten()[1:], 'o-', color=col['GT'])
    plt.xticks([1, 5, 9], np.array(params_ls)[np.array([1,5,9])])
    plt.yticks([0, 1])
    plt.axis([0.8, 9.2, -0.5, 2])
    plt.axis('off')

    plt.savefig('illustration/filter.svg', facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.show()

In [None]:
PANEL_A = 'illustration/model.svg'
svg(PANEL_A)

## Summary statistics

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    plt.figure(figsize=(2.5, 1.6))
    plt.plot(np.arange(obs_stats.size-1), obs_stats.flatten()[1:] / obs_stats.flatten()[0], 'o-', color=col['GT'])

    plt.title('spike-triggered average')
    plt.xticks([0, 4, 8])
    plt.yticks([-0.2, 0.5])

    plt.xlabel(r'$\Delta{}t$')
    plt.ylabel('value')

    sns.despine(offset=10, trim=True)
    plt.tight_layout()

    PANEL_B = 'svg/panel_b.svg'
    plt.savefig(PANEL_B, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_B)

## Comparison on marginals

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    fig = plt.figure(figsize=(3.7, 2.1))

    m, S = plot_post.calc_mean_and_cov()
    m_samp = np.mean(samples, axis=0)
    cov_samp = np.cov(samples.T)
    gt = pars_true.copy()

    m, m_samp, gt = m[1:], m_samp[1:], gt[1:]
    S, cov_samp = S[:,1:][1:], cov_samp[:,1:][1:]
    num_param_inf = len(gt)

    ax = plt.subplot(1,1,1)
    ax.fill_between(np.linspace(1,num_param_inf,num_param_inf),
                     m_samp-2*np.sqrt(np.diag(cov_samp)),
                     m_samp+2*np.sqrt(np.diag(cov_samp)),
                     facecolor=col['MCMC'],
                     alpha=0.3)
    ax.fill_between(np.linspace(1, num_param_inf, num_param_inf),
                     m-2*np.sqrt(np.diag(S)),
                     m+2*np.sqrt(np.diag(S)),
                     facecolor=col['SNPE'],
                     alpha=0.3)
    ax.plot(np.linspace(1,num_param_inf,num_param_inf),
             gt, '-o', color=col['GT'], label='true value')
    ax.plot(np.linspace(1,num_param_inf,num_param_inf),
             m_samp, '-o', color = col['MCMC'],
             label='MCMC')
    ax.plot(np.linspace(1,num_param_inf,num_param_inf),
             m, '-o', color = col['SNPE'], label='SNPE')

    #ax.ylim([-2,5])
    plt.yticks([-2, -1, 0, 1, 2, 3])
    plt.xticks([1, 5, 9], np.array(params_ls)[np.array([1,5,9])])
    plt.xlabel('filter parameter')
    plt.ylabel('value')
    ax.axis([0.9, num_param_inf+.1, -2.2, 3.3])
    ax.legend()

    sns.despine(offset=10, trim=True)
    plt.tight_layout()

    PANEL_C = 'svg/panel_c.svg'
    plt.savefig(PANEL_C, facecolor=plt.gcf().get_facecolor(), transparent=True)
    plt.close()

svg(PANEL_C)

## Covariance matrices

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):
    fig = plt.figure(figsize=(4.9, 2.3))

    m, S = plot_post.calc_mean_and_cov()
    m_samp = np.mean(samples, axis=0)
    cov_samp = np.cov(samples.T)
    gt = pars_true.copy()

    m, m_samp, gt = m[1:], m_samp[1:], gt[1:]
    S, cov_samp = S[:,1:][1:], cov_samp[:,1:][1:]
    num_param_inf = len(gt)

    gs = gridspec.GridSpec(1, 3, width_ratios=[1,0.2,1], height_ratios=[1])
    min_cov = np.min([np.min(cov_samp), np.min(S)])
    max_cov = np.max([np.max(cov_samp), np.max(S)])

    ax = plt.subplot(gs[0,0])
    im = plt.imshow(S, clim=(min_cov, max_cov))
    plt.axis('off')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cb = plt.colorbar(im, cax=cax)
    cb.set_label(' SNPE covariance', rotation=90)
    cb.set_clim([min_cov, max_cov])
    cb.set_ticks([0, max_cov])
    cb.set_ticklabels([0, max_cov])
    cb.outline.set_visible(False)
    cb.set_ticklabels(['0', str(np.round(max_cov, 2))])
    
    ax = plt.subplot(gs[0,2])
    im = plt.imshow(cov_samp, clim=(min_cov, max_cov))
    plt.axis('off')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cb = plt.colorbar(im, cax=cax)
    cb.set_label(' MCMC covariance', rotation=90)
    cb.set_clim([min_cov, max_cov])
    cb.set_ticks([0, max_cov])
    cb.outline.set_visible(False)
    cb.set_ticklabels(['0', str(np.round(max_cov, 2))])
    
    PANEL_D = 'svg/panel_d.svg'
    plt.savefig(PANEL_D, facecolor=plt.gcf().get_facecolor(), transparent=True)
    #plt.close()
    plt.show()

#svg(PANEL_D)

## Partial posterior

## Compose figure

In [None]:
from svgutils.compose import *

# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964
svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise

# Panel letters in Helvetica Neue, 12pt, Medium
kwargs_text = {'size': '12pt', 'font': 'Arial', 'weight': '800'}

pxw = 720
pxh = 760


f = Figure("20.3cm", "4.8cm",


    Panel(
          SVG(PANEL_C).scale(svg_scale).move(0,0),
          Text("A", 0, 13, **kwargs_text),
    ).move(0, 0),

    Panel(
          SVG(PANEL_D).scale(svg_scale).move(0,0),
          Text("B", -10, 13, **kwargs_text),
    ).move(pxw*0.5, 0),

    #Grid(10,10)
           
)

f.save("fig/fig2_glm_supp_comparison.svg")
svg('fig/fig2_glm_supp_comparison.svg')

### Supplementary figure

In [None]:
with mpl.rc_context(fname='../.matplotlibrc'):

    fig, axes = plot_pdf(plot_post,
                         lims=lims_post,
                         gt=pars_raw.reshape(-1),
                         figsize=(12, 12),
                         resolution=100,
                         samples=samples.T,
                         contours=True,
                         levels=np.asarray([0.95]),
                         col1=col['MCMC'],
                         col2=col['SNPE'],
                         col4=col['GT'],
                         ticks=True,
                         labels_params=labels_params);

    for i in range(plot_post.ndim):
        axes[i,i].set_xticks([lims_post[i, 0], lims_post[i, 1]])
        axes[i,i].set_yticks([])

    sns.despine(offset=5, left=True)

    SUPP_1 = 'fig/fig2_glm_supp_posterior.svg'
    plt.savefig(SUPP_1, transparent=True)
    plt.close()

svg(SUPP_1)