# SNPE & RF

running this notebook will requre the following repositories to be locally installed and on the specified branches:
- delfi 
- lfimodels / maprf_elife 
- maprf_mn  / elife      


In [None]:
#%%capture
%matplotlib inline

import delfi.distribution as dd
from delfi.utils.viz import plot_pdf, plot_marg_axes
from lfimodels.maprf.utils import setup_sim, get_data_o

import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches

%run -i common.ipynb

#import matplotlib.pyplot as plt
#import numpy as np

root_path = PATH_DROPBOX + 'materials/fig2/mapRF'

In [None]:
# FIGURE AND GRID
FIG_HEIGHT_MM = 65
FIG_WIDTH_MM = 1.15*FIG_WIDTH_MM  # set in common notebook to a default value for all figures
FIG_N_ROWS = 1
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
ROW_1_WIDTH_COL_1_MM = 1.1*FIG_WIDTH_MM / ROW_1_NCOLS
ROW_1_WIDTH_COL_2_MM = 0.9*FIG_WIDTH_MM / ROW_1_NCOLS
# ROW_2_NCOLS = 1
# ROW_2_HEIGHT_MM = FIG_HEIGHT_MM / FIG_N_ROWS
# ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS

PATH_DROPBOX_FIGS = PATH_DROPBOX + 'figs/'

# PATHS
PANEL_A = PATH_DROPBOX_FIGS + 'fig2_a.svg'


In [None]:
svg(PANEL_A)

# temporal GLM results

In [None]:
PANEL_B_1 = PATH_DROPBOX_FIGS + 'fig2_b_1.svg'
PANEL_B_2 = PATH_DROPBOX_FIGS + 'fig2_b_2.svg'
PANEL_B_3 = PATH_DROPBOX_FIGS + 'fig2_b_3.svg'
PANEL_B_4 = PATH_DROPBOX_FIGS + 'fig2_b_4.svg'


# mapRF results

In [None]:
## training data and true parameters, data, statistics
seed = 42
idx_cell = 6 # load toy cell number i 

tmp = np.load(root_path+'/results/SNPE/toycell_6/ground_truth_data.npy')[()]
obs_stats, pars_true, rf = tmp['obs_stats'],  tmp['pars_true'], tmp['rf']

sim_info = np.load(root_path +'/results/sim_info.npy')[()]
d, params_ls = sim_info['d'], sim_info['params_ls']

assert obs_stats[0,-1] == 307 # the cell we want to work with should have this number of spikes

labels_params = ['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']


# load SNPE results

In [None]:
tmp = np.load(root_path + '/results/SNPE/toycell_6/maprf_100k_elife_prior01_run_9_round4_param9_nosvi_CDELFI_posterior.npy')[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

plot_prior = dd.TransformedNormal(m=prior.m, S = prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

lims_post = np.array([[-1.5, -1.1, .001,         0,          .001, 0, 0, -.999, -.999], 
                 [ 1.5,  1.1, .999*np.pi, 2.5,   1.999*np.pi, 2, 4., .999,   .999]]).T

# load comparison against MCMC sampler

In [None]:
n_samples=1000000
path = root_path + '/results/MCMC/'
savefile = path + 'toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min.npy'
tmp = np.load(savefile)[()]

T, params_dict_true = tmp['T'], tmp['params_dict_true']

params_ls = ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in params_ls])

def symmetrize_sample_modes(samples):

    assert samples.ndim==2 and samples.shape[1] == 9 

    # assumes phase in [0, pi]
    assert np.min(samples[:,2]) >= 0. and np.max(samples[:,2] <= np.pi)
    # assumes angle in [0, 2*pi]
    assert np.min(samples[:,4]) >= 0. and np.max(samples[:,4] <= 2*np.pi)
    # assumes freq, ratio and width > 0
    assert np.all(np.min(samples[:,np.array([3,5,6])], axis=0) >= 0.)

    samples1 = samples.copy()
    idx = np.where( samples[:,4] > np.pi )[0]
    samples1[idx,4] = samples1[idx,4] - np.pi
    idx = np.where( samples[:,4] < np.pi )[0]
    samples1[idx,4] = samples1[idx,4] + np.pi
    #samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples, samples1))[::2, :]

    samples1 = samples_all.copy()
    samples1[:,1] = - samples1[:,1] 
    samples1[:,2] = np.pi - samples1[:,2] 
    samples_all = np.vstack((samples_all, samples1))[::2, :]

    return samples_all

samples = symmetrize_sample_modes(samples)

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'] + 0.05, # remove phase a bit from left interval border
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo'] ])

lims_samples = np.array([[-1.5, -1.1, .00001*np.pi, 0, 0.301*np.pi, 0, 0, -0.5, -0.5], 
                         [ 1.5,  1.1, .99999*np.pi, 3, 1.699*np.pi, 3, 5, 0.5, 0.5]]).T

## panel for summary statistics

In [None]:
plt.figure(figsize=(2, 2))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.title('STA')
plt.tight_layout()
plt.axis('off')

# option to add contours of ground-truth RF
add_gt = False
if add_gt:    
    rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

PANEL_C_1 = PATH_DROPBOX_FIGS + 'fig2_c_1.svg'
plt.savefig(PANEL_C_1, facecolor=plt.gcf().get_facecolor(), transparent=True)


## panel for parameters

In [None]:
plt.figure(figsize=(2, 2))
plt.imshow(rf, interpolation='None')
plt.title('ground-truth')
plt.tight_layout()
plt.axis('off')

PANEL_C_2 = PATH_DROPBOX_FIGS + 'fig2_c_2.svg'
plt.savefig(PANEL_C_2, facecolor=plt.gcf().get_facecolor(), transparent=True)
plt.show()


## panel for (partial) posterior

In [None]:
idx = np.array([1,2,4]) # bias, gain, frequency and angle

labels_params_select = np.array(labels_params)[idx]

plot_post_select = dd.mixture.MoTG(ms=[x.m[idx] for x in plot_post.xs],
                                  Ss=[x.S[idx][:,idx] for x in plot_post.xs],
                                  a=plot_post.a,
                                  flags=plot_post.flags[idx],
                                  lower=plot_post.lower[idx],
                                  upper=plot_post.upper[idx]                                  
                                 )

plot_prior_select = dd.TransformedNormal(m=plot_prior.m[idx], S = plot_prior.S[idx][:,idx],
                            flags=plot_prior.flags[idx],
                            lower=plot_prior.lower[idx], 
                            upper=plot_prior.upper[idx]) 

fig, _ = plot_pdf(plot_post_select,  #pdf2=plot_prior_select, 
                  lims=lims_samples[idx], gt=pars_raw.reshape(-1)[idx], 
                  figsize=(5,5), resolution=100, 
                  samples=samples[:,idx],
                  contour_levels=[0.68, 0.95],
                  contour_colors=('w', 'y'),
                  hist_color='orange',
                  pdf1_color=COL['SNPE'],
                  labels_params=labels_params_select);
fig.tight_layout()

PANEL_C_3 = PATH_DROPBOX_FIGS + 'fig2_c_3.svg'
fig.savefig(PANEL_C_3, facecolor=plt.gcf().get_facecolor(), transparent=True)

## version with all marginals + selected pair-wise marginals

In [None]:
plt.figure(figsize=(7.5,4))


dx, MM, nx = 10, len(lims_samples), 70

fontsize = 14

idx = np.array([0,1,6,3,5,4,2,8,7])
pairs = [[0,1], # bias  vs gain
         [1,2], # gain  vs phase
         [2,4], # phase vs angle
         [7,8]] # xo vs yo
         #[5,6]] # ratio vs width


gs = gridspec.GridSpec(2, 3, width_ratios=[2.5, 1, 1], height_ratios=[1,1])
for j in range(4):

    ax = plt.subplot(gs[(j%2),1+(j>1)])
    plot_marg_axes(ax=ax, 
                   i=pairs[j][0], j=pairs[j][1],
                   pdf=plot_post, 
                   samples=samples, 
                   lims=lims_samples, 
                   gt=pars_raw, 
                   bins=100, 
                   resolution=100,
                   cmap=None,
                   contours=True,
                   contour_levels=(0.68, 0.95),
                   contour_colors=('w','y'),
                   scatter=False,
                   scatter_color='gray',
                   scatter_alpha=0.2)
    plt.xlabel(labels_params[pairs[j][1]], fontsize=fontsize)
    plt.ylabel(labels_params[pairs[j][0]], fontsize=fontsize)
plt.tight_layout()    



gs = gridspec.GridSpec(1, 3, width_ratios=[3, 1, 1], height_ratios=[1])
ax = plt.subplot(gs[0,0])

xh = np.linspace(0, 1, nx) 
for i in idx:
    
    S = samples[:,idx[i]].copy()
    mm, MM = lims_post[idx[i]] #S.min(), S.max()
    yy = np.linspace(mm, MM, nxp)
    S = (S - mm) / (MM - mm)
    xx = np.linspace(0, 1, nxp) 
    
    h, xh_ = np.histogram(S, bins=xh, normed=True)
    Mh = h.max() / (0.9*dx)
    h = h / Mh
    ax.barh(bottom=xh[:-1], width=h, left=(M-i)*dx, color='orange', height=1.1*np.diff(xh)[0])
    
    ff = plot_post.eval(yy.reshape(-1,1), ii=[idx[i]], log=False)*(MM - mm)
    ax.plot(ff/Mh + (M-i)*dx, xx, color=COL['SNPE'], linewidth=2.5)
    ff = plot_prior.eval(yy.reshape(-1,1), ii=[idx[i]], log=False)*(MM - mm)
    ax.plot(ff/Mh + (M-i)*dx, xx, color=COL['IBEA'], linewidth=1.5)

    ax.text(dx*(M-i), -0.05, params_ls[idx[i]], fontsize=fontsize, rotation=45)
    
    ax.plot(dx*(M-i), (pars_raw[idx[i]]-mm)/(MM-mm), 'r*', markersize=9)
    
ax.axis('off')    

"""
# in case variables are grouped by transformation ( idx = [1,0,6,3,5,4,2,7,8] )
plt.plot([0.9*dx, 4.5*dx], [-0.1, -0.1], 'k')
plt.text(2.4*dx, -0.17, 'logit-normal', fontsize=fontsize)
plt.plot([4.9*dx, 7.5*dx], [-0.12, -0.12], 'k')
plt.text(5.9*dx, -0.17, 'log-normal', fontsize=fontsize)
plt.plot([7.9*dx, 9.5*dx], [-0.11, -0.11], 'k')
plt.text(8.4*dx, -0.17, 'normal', fontsize=fontsize)
"""

PANEL_C_3 = PATH_DROPBOX_FIGS + 'fig2_c_3.svg'
plt.savefig(PANEL_C_3, facecolor=plt.gcf().get_facecolor(), transparent=True)

plt.show()


### (vertical version)

In [None]:

plt.figure(figsize=(8,5))

def to_texstring(s):
    s = s.replace("<", r"$<$")
    s = s.replace(">", r"$>$")
    s = s.replace("|", r"$|$")
    return s

styles = mpatches.ArrowStyle.get_styles()


dx, MM, nxh, nxp = 10, len(lims_samples), 70, 200
M_ = M * 1.15

fontsize = 14

idx = np.array([0,1,2,4,3,5,6,7,8])
pairs = [[0,1], # bias  vs gain
         [1,2], # gain  vs phase
         [2,3], # phase vs angle
         [5,6]] # ratio vs width



gs = gridspec.GridSpec(1, 2, width_ratios=[1.5, 1], height_ratios=[1])
ax = plt.subplot(gs[0,0])

xh = np.linspace(0, 1, nx) 
for i in idx:
    
    S = samples[:,idx[i]].copy()
    mm, MM = lims_post[idx[i]] #S.min(), S.max()
    yy = np.linspace(mm, MM, nxp)
    S = (S - mm) / (MM - mm)
    xx = np.linspace(0, 1, nxp) 
    
    h, xh_ = np.histogram(S, bins=xh, normed=True)
    Mh = h.max() / (0.9*dx)
    h = h / Mh
    ax.bar(left=xh[:-1], height=h, bottom=(M-i)*dx, color='orange', width=np.diff(xh)[0], orientation='vertical')
    ff = plot_post.eval(yy.reshape(-1,1), ii=[idx[i]], log=False)*(MM - mm)
    ax.plot(xx, ff/Mh + (M-i)*dx, color=COL['SNPE'], linewidth=2.5)
    ff = plot_prior.eval(yy.reshape(-1,1), ii=[idx[i]], log=False)*(MM - mm)
    ax.plot(xx, ff/Mh + (M-i)*dx, color=COL['IBEA'], linewidth=1.5)

    ax.text(-0.25, dx*(M-i+0.35) , params_ls[idx[i]], fontsize=fontsize, rotation=45)
    
    ax.plot((pars_raw[idx[i]]-mm)/(MM-mm), dx*(M-i), 'r*', markersize=9)    
    #plt.plot([i*dx, i*dx], [0, 1], 'k')
    
"""
# in case variables are grouped by transformation ( idx = [1,0,6,3,5,4,2,7,8] )
ax.plot([-0.15, -0.15], [0.9*dx, 4.5*dx], 'k')
ax.text(-0.23, 3.4*dx, 'logit-normal', fontsize=fontsize, rotation=90)
ax.plot([-0.17, -0.17], [4.9*dx, 7.5*dx], 'k')
ax.text(-0.23, 6.9*dx, 'log-normal', fontsize=fontsize, rotation=90)
ax.plot([-0.16, -0.16], [7.9*dx, 9.5*dx], 'k')
ax.text(-0.23, 8.9*dx, 'normal', fontsize=fontsize, rotation=90)
"""
ax.axis('off')
plt.title('marginal distributions (prior vs. SNPE posterior estimate vs. MCMC)', fontsize=fontsize)


gs = gridspec.GridSpec(4, 2, width_ratios=[1.5, 1], height_ratios=[1,1,1,1])
for j in range(4):
    ax = plt.subplot(gs[j,1])
    plot_marg_axes(ax=ax, 
                   i=pairs[j][0], j=pairs[j][1],
                   pdf=plot_post, 
                   samples=samples, 
                   lims=lims_samples, 
                   gt=pars_raw, 
                   bins=100, 
                   resolution=100,
                   cmap=None,
                   contours=True,
                   contour_levels=(0.68, 0.95),
                   contour_colors=('w','y'),
                   scatter=False,
                   scatter_color='gray',
                   scatter_alpha=0.2)
    for k in range(2):
        plt.annotate('', (0.65,  (M-pairs[j][k])/M_), (0.64, (M-pairs[j][k])/M_),
                xycoords="figure fraction", textcoords="figure fraction",            
                arrowprops=dict(arrowstyle=stylename, fc="k", ec="k", shrinkA=0, shrinkB=0),)
    plt.annotate('', (0.65, (M-pairs[j][0])/M_), (0.65,  (M-pairs[j][1])/M_),
            xycoords="figure fraction", textcoords="figure fraction",            
            arrowprops=dict(arrowstyle=stylename, fc="k", ec="k", shrinkA=0, shrinkB=0),)
    plt.annotate('', (0.65, (M-np.mean(pairs[j]))/M_), (0.7,  (M-np.mean(pairs[j]))/M_),
            xycoords="figure fraction", textcoords="figure fraction",            
            arrowprops=dict(arrowstyle=stylename, fc="k", ec="k", shrinkA=0, shrinkB=0),)
    plt.annotate('', (0.7, (M-np.mean(pairs[j]))/M_), (0.82,  (4-j)/4*(M/M_)),
            xycoords="figure fraction", textcoords="figure fraction",            
            arrowprops=dict(arrowstyle=stylename, fc="k", ec="k", shrinkA=0, shrinkB=0),)

        
stylename, styleclass = sorted(styles.items())[0]            


plt.tight_layout()
#plt.savefig('marginal_plots_v0.pdf')
plt.show()

## panel for posterior samples 

In [None]:

# this snippet of code requires the mapRF repository (to instantiate g.model)
g, prior, d = setup_sim(seed, path=root_path)
filename = root_path + '/results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)
rf = g.model.params_to_rf(pars_true)[0]

lvls, n_draws=[0.5, 0.5], 10 
plt.figure(figsize=(4,4))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
for i in range(n_draws):
    rfm = g.model.params_to_rf(posterior.gen().reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors=[COL['SNPE']], linewidth=2)
    plt.hold(True)
plt.title('RF posterior draws')

rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r', linewidth=2.5)
plt.tight_layout()
plt.axis('off')

PANEL_C_4 = PATH_DROPBOX_FIGS + 'fig2_c_4.svg'
plt.savefig(PANEL_C_4, facecolor=plt.gcf().get_facecolor(), transparent=True)
plt.show()


# compose figure

In [None]:
# FIGURE and GRID
FIG_HEIGHT_MM = 200
FIG_WIDTH_MM = FIG_WIDTH_MM  # set in NIPS2017 notebook to a default value for all figures
FIG_N_ROWS = 4
ROW_1_NCOLS = 1
ROW_1_HEIGHT_MM =     0.6 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  3 *(FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 3
ROW_2_HEIGHT_MM = 1.1 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = 0.27*FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = 1.63*FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_3_MM = 1.1 *FIG_WIDTH_MM / ROW_2_NCOLS
ROW_3_NCOLS = 1
ROW_3_HEIGHT_MM = 1.1 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_3_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_3_NCOLS
ROW_3_WIDTH_COL_3_MM = FIG_WIDTH_MM / ROW_3_NCOLS

In [None]:
fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)

yoffset = -2.2
xoffset = -1.
fig = add_svg(fig, PANEL_A, 
              5 + xoffset, 
              0 + yoffset)


yoffset = 30
fig = add_svg(fig, PANEL_B_1, 
              0 + xoffset, 
              4 + yoffset)
fig = add_svg(fig, PANEL_B_2, 
              0 + xoffset , 
              ROW_1_HEIGHT_MM + yoffset) 
fig = add_svg(fig, PANEL_B_3, 
              ROW_2_WIDTH_COL_1_MM + xoffset + 10,                        
              yoffset + 2.1)  
fig = add_svg(fig, PANEL_B_4, 
              ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM + xoffset - 5, 
              yoffset + 10)

yoffset = 2.5
fig = add_label(fig, 'A', 
                0, 
                0 + yoffset)
yoffset = 5
fig = add_label(fig, 'B', 
                0, 
                ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'C', 
                ROW_2_WIDTH_COL_1_MM + 5, 
                ROW_1_HEIGHT_MM + yoffset)
fig = add_label(fig, 'D', 
                ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM, 
                ROW_1_HEIGHT_MM + yoffset)

fig = add_svg(fig, PANEL_C_2, 
              0 + xoffset, 
              0 + yoffset + ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + 2)
fig = add_svg(fig, PANEL_C_1, 
              0 + xoffset, 
              ROW_1_HEIGHT_MM + ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset - 2.5) 
fig = add_svg(fig, PANEL_C_3, 
              ROW_2_WIDTH_COL_1_MM + xoffset + 10,                   
              yoffset + ROW_2_HEIGHT_MM+ ROW_1_HEIGHT_MM + 2.1)  
fig = add_svg(fig, PANEL_C_4, 
              ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM + xoffset, 
              yoffset + ROW_2_HEIGHT_MM + ROW_1_HEIGHT_MM + 1.8)

yoffset = 11
fig = add_label(fig, 'E', 
                0,                        
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_label(fig, 'F', 
                ROW_2_WIDTH_COL_1_MM + 5, 
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)
fig = add_label(fig, 'G', 
                ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM, 
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + yoffset)



yoffset = 21
fig = add_label(fig, 'H', 
                0,                        
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset)
fig = add_label(fig, 'I', 
                ROW_2_WIDTH_COL_1_MM + 5, 
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset)
fig = add_label(fig, 'J', 
                ROW_2_WIDTH_COL_1_MM + ROW_2_WIDTH_COL_2_MM, 
                ROW_1_HEIGHT_MM + ROW_2_HEIGHT_MM + ROW_3_HEIGHT_MM + yoffset)


if False:
    fig = add_grid(fig, 2, 2)
    fig = add_grid(fig, 160/3, 10, font_size_px=0.0001)

PATH_SVG = PATH_DROPBOX_FIGS + 'fig2.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig2.pdf $PATH_SVG

# supplementary figure: full posterior

In [None]:
fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims_post, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100,
                  contour_levels=[0.68, 0.95],
                  contour_colors=('w', 'y'),
                  hist_color='orange',
                  pdf1_color=COL['SNPE'],                  
                  pdf2_color=COL['GT'],                  
                  labels_params=labels_params)


In [None]:
fig, _ = plot_pdf(plot_post,  pdf2=plot_prior, lims=lims_samples, gt=pars_raw.reshape(-1), figsize=(16,16), 
                  resolution=100, samples=samples,
                  contour_levels=[0.68, 0.95],
                  contour_colors=('w', 'y'),
                  hist_color='orange',
                  pdf1_color=COL['SNPE'],                  
                  pdf2_color=COL['GT'],                        
                  labels_params=labels_params);
fig.savefig('posterior_final_full.pdf')