In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit
import os

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results
from snl.util.plot import plot_hist_marginals

fig_path = 'results/figs/'

# panel (a)

In [None]:
save_path = 'results/gauss_validationset/'

plot_seed = 42
exp_id = 'seed' + str(plot_seed)
log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
pnl_a = plot_hist_marginals(tds[-1][0][-300:], lims=[-5,5])
pnl_a.set_figwidth(8)
pnl_a.set_figheight(8)
for ax in pnl_a.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
PANEL_2A = fig_path +'fig2_a.svg'
plt.savefig(PANEL_2A, facecolor=plt.gcf().get_facecolor(), transparent=True)
    
pnl_a.show()

# panel (b)

In [None]:
# load MMD results 

# panel (c)

In [None]:
PANEL_2C = fig_path +'gauss_mmd_validationset.svg'

# panel (d)

In [None]:
save_path = 'results/gauss_noisedims_v1/'

seeds = np.arange(43, 62)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe

    #plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe), 'kd-', linewidth=1.5)
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        #plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snl), 'rd-', linewidth=1.5)
        sq_mmds_snl_all.append(sq_mmds_snl)
    except:
        pass
    

plt.figure(figsize=(4,4))

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.mean(np.sqrt(sq_mmds_snl_all),axis=0), 'rd-', linewidth=1.5,
        label='SNL')

plt.fill_between(1000*np.arange(1,sq_mmds_snl.size+1),
                 np.mean(np.sqrt(sq_mmds_snl_all),axis=0) - np.std(np.sqrt(sq_mmds_snl_all),axis=0),
                 np.mean(np.sqrt(sq_mmds_snl_all),axis=0) + np.std(np.sqrt(sq_mmds_snl_all),axis=0),
                 color='r', alpha=0.2)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.mean(np.sqrt(sq_mmds_snpe_all),axis=0), 'kd-', linewidth=1.5,
        label='APT')

plt.fill_between(1000*np.arange(1,sq_mmds_snpe.size+1),
                 np.mean(np.sqrt(sq_mmds_snpe_all),axis=0) - np.std(np.sqrt(sq_mmds_snpe_all),axis=0),
                 np.mean(np.sqrt(sq_mmds_snpe_all),axis=0) + np.std(np.sqrt(sq_mmds_snpe_all),axis=0),
                 color='k', alpha=0.5)

plt.legend()
plt.ylabel('Maximum meand discrepancy')
plt.xlabel('Number of simulations (log scale)')

PANEL_2D = fig_path +'fig2_d.svg'
plt.savefig(PANEL_2D, facecolor=plt.gcf().get_facecolor(), transparent=True)

plt.show()

# assemble figure

In [None]:
%run -i ./common.ipynb

In [None]:
# FIGURE and GRID
FIG_HEIGHT_MM = 90
FIG_WIDTH_MM = 140  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_2A, 
              0, 
              0)
fig = add_svg(fig, PANEL_2C, 
              85,
              0) 

yoffset = 5
#fig = add_label(fig, 'C', 
#                0, 
#                ROW_1_HEIGHT_MM + yoffset)
fig = add_svg(fig, PANEL_2D, 
                90, 
                45)


if False:
    fig = add_grid(fig, 2, 2)
    fig = add_grid(fig, 160/3, 10, font_size_px=0.0001)


PATH_SVG = PATH_DROPBOX_FIGS + 'fig2.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig2.pdf $PATH_SVG


# Sanity-check Gallery

In [None]:
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33 as fast_sampler    
save_path = 'results/gauss_noisedims_v1//'

for plot_seed in range(43, 47):
    exp_id = 'seed' + str(plot_seed)
    log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
    pnl_a = plot_hist_marginals(fast_sampler(posteriors[1], 1000), lims=[-5,5])
    pnl_a.set_figwidth(12)
    pnl_a.set_figheight(12)
    for ax in pnl_a.axes:
        ax.axis('off')

    pnl_a.show()

In [None]:
import pickle 
import os

save_path = 'results/gauss_noisedims_v1/'

for plot_seed in range(43,62):
    
    print('seed ' + str(plot_seed))
    exp_id = 'seed' + str(plot_seed)
    
    try:
        file = os.path.join(save_path, exp_id, 'SNL_MAF')
        with open(file + '.pkl', 'rb') as f:
            learned_model = pickle.load(f)        

        file = os.path.join(save_path, exp_id, 'SNL_posteriors')
        with open(file + '.pkl', 'rb') as f:
            all_models = pickle.load(f)                   
    except:
        print('failed to load MAFs')
        
    xs = np.load(os.path.join(save_path, exp_id, 'xs.npy'))    
    ps = np.load(os.path.join(save_path, exp_id, 'ps.npy'))    
    
    pnl_a = plot_hist_marginals(ps[-1], lims=[-5,5])
    pnl_a.set_figwidth(12)
    pnl_a.set_figheight(12)
    for ax in pnl_a.axes:
        ax.axis('off')

    pnl_a.show()