# figures 2 (SLCP) and 3 (SLCP with added noise)
- loads an APT example fit as produced by SLCP.ipynb, and an SNL fit from the original SNL package (python-2) 
- loads evaluation metrics as computed by APT_eval.ipynb and SNL_eval.ipynb on model fits from SLCP_fit.ipynb and SLCP_addedNoise_fit.ipynb
- loads a panel (fig 2c) produced by the SNL package (python 2) 


In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import timeit
import os

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33, load_gt_gauss, init_g_gauss
from snl.util.plot import plot_hist_marginals

from inspect import getmembers, isclass

fig_path = 'results/figs/'
fontsize = 16

def rasterize_and_save(fname, rasterize_list=None, fig=None, dpi=None,
                       savefig_kw={}):
    """Save a figure with raster and vector components
    This function lets you specify which objects to rasterize at the export
    stage, rather than within each plotting call. Rasterizing certain
    components of a complex figure can significantly reduce file size.
    Inputs
    ------
    fname : str
        Output filename with extension
    rasterize_list : list (or object)
        List of objects to rasterize (or a single object to rasterize)
    fig : matplotlib figure object
        Defaults to current figure
    dpi : int
        Resolution (dots per inch) for rasterizing
    savefig_kw : dict
        Extra keywords to pass to matplotlib.pyplot.savefig
    If rasterize_list is not specified, then all contour, pcolor, and
    collects objects (e.g., ``scatter, fill_between`` etc) will be
    rasterized
    Note: does not work correctly with round=True in Basemap
    Example
    -------
    Rasterize the contour, pcolor, and scatter plots, but not the line
    >>> import matplotlib.pyplot as plt
    >>> from numpy.random import random
    >>> X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9))
    >>> fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
    >>> cax1 = ax1.contourf(Z)
    >>> cax2 = ax2.scatter(X, Y, s=Z)
    >>> cax3 = ax3.pcolormesh(Z)
    >>> cax4 = ax4.plot(Z[:, 0])
    >>> rasterize_list = [cax1, cax2, cax3]
    >>> rasterize_and_save('out.svg', rasterize_list, fig=fig, dpi=300)
    """

    # Behave like pyplot and act on current figure if no figure is specified
    fig = plt.gcf() if fig is None else fig

    # Need to set_rasterization_zorder in order for rasterizing to work
    zorder = -5  # Somewhat arbitrary, just ensuring less than 0

    if rasterize_list is None:
        # Have a guess at stuff that should be rasterised
        types_to_raster = ['QuadMesh', 'Contour', 'collections']
        rasterize_list = []

        print("""
        No rasterize_list specified, so the following objects will
        be rasterized: """)
        # Get all axes, and then get objects within axes
        for ax in fig.get_axes():
            for item in ax.get_children():
                if any(x in str(item) for x in types_to_raster):
                    rasterize_list.append(item)
        print('\n'.join([str(x) for x in rasterize_list]))
    else:
        # Allow rasterize_list to be input as an object to rasterize
        if type(rasterize_list) != list:
            rasterize_list = [rasterize_list]

    for item in rasterize_list:

        # Whether or not plot is a contour plot is important
        is_contour = (isinstance(item, matplotlib.contour.QuadContourSet) or
                      isinstance(item, matplotlib.tri.TriContourSet))

        # Whether or not collection of lines
        # This is commented as we seldom want to rasterize lines
        # is_lines = isinstance(item, matplotlib.collections.LineCollection)

        # Whether or not current item is list of patches
        all_patch_types = tuple(
            x[1] for x in getmembers(matplotlib.patches, isclass))
        try:
            is_patch_list = isinstance(item[0], all_patch_types)
        except TypeError:
            is_patch_list = False

        # Convert to rasterized mode and then change zorder properties
        if is_contour:
            curr_ax = item.ax.axes
            curr_ax.set_rasterization_zorder(zorder)
            # For contour plots, need to set each part of the contour
            # collection individually
            for contour_level in item.collections:
                contour_level.set_zorder(zorder - 1)
                contour_level.set_rasterized(True)
        elif is_patch_list:
            # For list of patches, need to set zorder for each patch
            for patch in item:
                curr_ax = patch.axes
                curr_ax.set_rasterization_zorder(zorder)
                patch.set_zorder(zorder - 1)
                patch.set_rasterized(True)
        else:
            # For all other objects, we can just do it all at once
            curr_ax = item.axes
            curr_ax.set_rasterization_zorder(zorder)
            item.set_rasterized(True)
            item.set_zorder(zorder - 1)

    # dpi is a savefig keyword argument, but treat it as special since it is
    # important to this function
    if dpi is not None:
        savefig_kw['dpi'] = dpi

    # Save resulting figure
    fig.savefig(fname, **savefig_kw)


In [None]:
%run -i ./common.ipynb

# panel (a)

In [None]:
load_gt_gauss?

In [None]:
save_path = 'results/gauss_validationset/'

plot_seed = 42
exp_id = 'seed' + str(plot_seed)
pars_true, obs_stats = load_gt_gauss(init_g_gauss(seed=plot_seed))
log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
pnl_a = plot_hist_marginals(draw_sample_uniform_prior_33(posteriors[-1],1000), lims=[-5,5], 
                            gt=pars_true.flatten(), rasterized=False)
pnl_a.set_figwidth(6)
pnl_a.set_figheight(6)
for ax in pnl_a.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_a.axes[-5].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_a.axes[-4].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_a.axes[-3].set_xlabel(r'$\theta_3$', fontsize=fontsize)
pnl_a.axes[-2].set_xlabel(r'$\theta_4$', fontsize=fontsize)    
pnl_a.axes[-1].set_xlabel(r'$\theta_5$', fontsize=fontsize)

PANEL_2A = fig_path +'fig2_a.svg'
#plt.savefig(PANEL_S1A, facecolor=plt.gcf().get_facecolor(), transparent=True, bbox_inches='tight')
    
savefig_kw = {
'facecolor' : plt.gcf().get_facecolor(), 
'transparent' : True,
'bbox_inches' : 'tight' 
}

rasterize_and_save(PANEL_2A, rasterize_list=pnl_a.axes, fig=pnl_a, dpi=600, savefig_kw=savefig_kw)    
    
pnl_a.show()

# panel (b)

In [None]:
save_path = '../../../snl/data/results/seed_42/gauss/'

samples = np.load(save_path + 'final_snl_post_samples_N5000.npy')

pnl_b = plot_hist_marginals(samples[-1000:], lims=[-5,5], gt=pars_true.flatten(), rasterized=False, upper=True)
pnl_b.set_figwidth(6)
pnl_b.set_figheight(6)
for ax in pnl_b.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_b.axes[0].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_b.axes[5].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_b.axes[9].set_xlabel(r'$\theta_3$', fontsize=fontsize)
pnl_b.axes[12].set_xlabel(r'$\theta_4$', fontsize=fontsize)    
pnl_b.axes[-1].set_xlabel(r'$\theta_5$', fontsize=fontsize)     

PANEL_2B = fig_path +'fig2_b.svg'

savefig_kw = {
'facecolor' : plt.gcf().get_facecolor(), 
'transparent' : True,
'bbox_inches' : 'tight' 
}

rasterize_and_save(PANEL_2B, rasterize_list=pnl_b.axes, fig=pnl_b, dpi=600, savefig_kw=savefig_kw)    
    
#plt.savefig(PANEL_2B, facecolor=plt.gcf().get_facecolor(), transparent=True, bbox_inches='tight')
pnl_b.show()



# panel (c)

In [None]:
PANEL_2C = fig_path +'gauss_mmd_validationset_large.svg'

# assemble figure

In [None]:
# FIGURE and GRID
FIG_HEIGHT_MM = 60
FIG_WIDTH_MM = 160  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_2A, 
              3, 
              0)
fig = add_label(fig, 
                'a)', 
                0, 
                5)
fig = add_svg(fig, PANEL_2B, 
              28,
              0) 
fig = add_label(fig, 
                'b)', 
                25, 
                5)

fig = add_svg(fig, PANEL_2C, 
            90,
              -2) 
fig = add_label(fig, 
                'c)', 
                89, 
                5)

PATH_SVG = PATH_DROPBOX_FIGS + 'fig2.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig2.pdf $PATH_SVG


# NOISE-dims figure (fig 3)

somewhat confusing (historically grounded...) naming of experiments:
- 'v1' : d =  20 (m=12) aka 'this should kill SNL'
- 'v2' : d = 100 (m=92) aka 'this *has to* kill SNL'
- 'v3' : d =  60 (m=52) 
- 'v4' : d =  40 (m=32)

In [None]:
# prep 
save_path = 'results/gauss_noisedims_v1/'
seeds = np.arange(43,62)

prior_mmds = np.zeros(len(seeds))
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    prior_mmds[i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))[0]
    
prior_mmd = np.mean(np.sqrt(prior_mmds))**2
    
posterior_mmd = 0.0009931369858651173 # computed from MMD of N=1000 against 5000 ground-truth samples
posterior_mmd = 0.0017898981617274767 # computed from MMD of N=1000 against 4000 *other* ground-truth samples

In [None]:
save_path = 'results/gauss_noisedims_v1/'

seeds = np.arange(52,57)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        sq_mmds_snl_all.append(sq_mmds_snl[1:])
    except:
        pass
    

pnl = plt.figure(figsize=(3,3))

plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all)[0,:], 'rd-', linewidth=1.5)
plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all[0]), 'kd-', linewidth=1.5)

plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all).T, 'rd-', linewidth=1.5,
        label='SNL')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all).T, 'kd-', linewidth=1.5,
        label='APT')

#for i in range(len(seeds)):
#    plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all)[i,:], 'rd-', linewidth=1.5)
#    plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all[i]), 'kd-', linewidth=1.5)


plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(prior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k--', linewidth=0.8)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(posterior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k:', linewidth=0.8)


plt.legend(['SNL', 'APT'], fontsize=12)
plt.ylabel('Maximum meand discrepancy', fontsize=12)
plt.ylim([0., 1.1])
plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.], fontsize=12)
plt.xticks([1000, 10000], fontsize=12)
plt.xlabel('Number of simulations (log scale)', fontsize=12)
plt.title('d = 20', fontsize=13)

PANEL_2nA = fig_path +'fig2n_a.svg'
plt.savefig(PANEL_2nA, facecolor=plt.gcf().get_facecolor(), transparent=True)

pnl.show()

In [None]:
save_path = 'results/gauss_noisedims_v4/'

seeds = np.arange(52,57)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []


pnl = plt.figure(figsize=(3,3))
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(1000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        sq_mmds_snl_all.append(sq_mmds_snl[1:])
    except:
        pass


plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all).T, 'rd-', linewidth=1.5,
        label='SNL')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all).T, 'kd-', linewidth=1.5,
        label='APT')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(prior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k--', linewidth=0.8)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(posterior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k:', linewidth=0.8)

#plt.legend()

plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.], fontsize=12)
plt.xticks([1000, 10000], fontsize=12)
plt.xlabel('Number of simulations (log scale)', fontsize=12)
plt.title('d = 40', fontsize=13)
plt.ylim([0., 1.1])

PANEL_2nB = fig_path +'fig2n_b.svg'
plt.savefig(PANEL_2nB, facecolor=plt.gcf().get_facecolor(), transparent=True)

pnl.show()

In [None]:
save_path = 'results/gauss_noisedims_v3/'

seeds = np.arange(52,57)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(1000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        sq_mmds_snl_all.append(sq_mmds_snl[1:])
    except:
        pass

pnl = plt.figure(figsize=(3,3))

plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all).T, 'rd-', linewidth=1.5,
        label='SNL')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all).T, 'kd-', linewidth=1.5,
        label='APT')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(prior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k--', linewidth=0.8)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(posterior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k:', linewidth=0.8)

plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.], fontsize=12)
plt.xticks([1000, 10000], fontsize=12)
plt.xlabel('Number of simulations (log scale)', fontsize=12)
plt.title('d = 60', fontsize=13)
plt.ylim([0., 1.1])
PANEL_2nC = fig_path +'fig2n_c.svg'
plt.savefig(PANEL_2nC, facecolor=plt.gcf().get_facecolor(), transparent=True)

pnl.show()

In [None]:
save_path = 'results/gauss_noisedims_v2/'

seeds = np.arange(52,57)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(1000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        sq_mmds_snl_all.append(sq_mmds_snl[1:])
    except:
        pass

pnl = plt.figure(figsize=(3,3))

plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.sqrt(sq_mmds_snl_all).T, 'rd-', linewidth=1.5,
        label='SNL')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe_all).T, 'kd-', linewidth=1.5,
        label='APT')

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(prior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k--', linewidth=0.8)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(posterior_mmd) *np.ones(sq_mmds_snpe.size), 
            'k:', linewidth=0.8)

plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.], fontsize=12)
plt.xticks([1000, 10000], fontsize=12)
plt.xlabel('Number of simulations (log scale)', fontsize=12)
plt.title('d = 100', fontsize=13)
plt.ylim([0., 1.1])

PANEL_2nD = fig_path +'fig2n_d.svg'
plt.savefig(PANEL_2nD, facecolor=plt.gcf().get_facecolor(), transparent=True)

pnl.show()

In [None]:
%run -i ./common.ipynb

In [None]:
# FIGURE and GRID
FIG_HEIGHT_MM = 38
FIG_WIDTH_MM = 160  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_2nA, 
              3, 
              0)
#fig = add_label(fig, 
#                'a)', 
#                0, 
#                3)

fig = add_svg(fig, PANEL_2nB, 
              43, 
              0)
#fig = add_label(fig, 
#                'b)', 
#                40, 
#                3)

fig = add_svg(fig, PANEL_2nC, 
              83, 
              0)
#fig = add_label(fig, 
#                'c)', 
#                80, 
#                3)

fig = add_svg(fig, PANEL_2nD, 
              123,
              0) 
#fig = add_label(fig, 
#                'd)', 
#                120, 
#                3)

#fig = add_label(fig, 
#                'b, 
#                60, 
#                10)


if False:
    fig = add_grid(fig, 2, 2)
    fig = add_grid(fig, 160/3, 10, font_size_px=0.0001)

PATH_SVG = PATH_DROPBOX_FIGS + 'fig3.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/fig3.pdf $PATH_SVG


# extra panel on MMDs at given rounds (currently unused)

In [None]:

save_paths = ['results/gauss_validationset',
              'results/gauss_noisedims_v1/',
              'results/gauss_noisedims_v4/',
              'results/gauss_noisedims_v3/',
              'results/gauss_noisedims_v2/']
dims = np.array([12, 20,40,60,100])


plot_idx = np.array([29])


seeds = np.arange(52,57)

sq_mmds_snpe_all = np.zeros((len(save_paths), len(seeds), len(plot_idx)))
sq_mmds_snl_all = np.zeros((len(save_paths), len(seeds), len(plot_idx)))


save_path = save_paths[0]
for i in range(len(seeds)):
    exp_id = 'seed' + str(seed-10) # 42-52 
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000)+'.npy'))
    sq_mmds_snpe_all[0,i,:] = sq_mmds_snpe.flatten()[plot_idx]
    
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000)+'.npy'))
    sq_mmds_snpe_all[0,i,:] = sq_mmds_snpe.flatten()[plot_idx]


for j in range(1,len(save_paths)):
    for i in range(len(seeds)):
        
        save_path = save_paths[j]
    
        seed = seeds[i]
        
        exp_id = 'seed' + str(seed)
        
        N = 5000 if j<2 else 1000
        sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(N)+'.npy'))
        sq_mmds_snpe_all[j,i,:] = sq_mmds_snpe.flatten()[plot_idx]

        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        sq_mmds_snl_all[j,i,:] = sq_mmds_snl.flatten()[plot_idx+1]

        
#plt.plot(1000*(plot_idx+1), np.sqrt(sq_mmds_snpe_all), 'k', label='APT')
#plt.plot(1000*(plot_idx+1), np.sqrt(sq_mmds_snl_all), 'r', label='SNL')
#plt.show()


plt.figure(figsize=(3,3))
for i in range(len(plot_idx)):
    plt.subplot(1,len(plot_idx),i+1)

    plt.plot(dims, np.sqrt(sq_mmds_snpe_all[:,0,i]), 'k.')
    plt.plot(dims+0.5, np.sqrt(sq_mmds_snl_all[:,0,i]), 'r.')
    
    plt.plot(dims, np.sqrt(sq_mmds_snpe_all[:,:,i]), 'k.')
    plt.plot(dims+0.5, np.sqrt(sq_mmds_snl_all[:,:,i]), 'r.')
    for k in range(len(dims)):
        plt.plot([dims[k]-3, dims[k]+3], np.ones(2)*np.mean(np.sqrt(sq_mmds_snpe_all[k,:,i])), 'k-')    
    for k in range(len(dims)):
        plt.plot([dims[k]-3, dims[k]+3], np.ones(2)*np.mean(np.sqrt(sq_mmds_snl_all[k,:,i])), 'r-')
    plt.title('round ' + str(plot_idx[i]+1))
    plt.ylim([0.01, 1.1])
    plt.xlim([17, 105])
    plt.ylabel('Maximum mean discrepancy')
    plt.xlabel('dimensionality d')    
    plt.xticks(dims[1:])
    
    plt.legend(['APT', 'SNL'], loc=1)
plt.show()

# mean +/- std's panel (currently unused)

In [None]:
save_path = 'results/gauss_noisedims_v1/'

seeds = np.arange(52, 62)
sq_mmds_snpe_all = np.zeros((seeds.size, 40))
sq_mmds_snl_all = []
for i in range(len(seeds)):

    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    sq_mmds_snpe = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000)+'.npy'))
    sq_mmds_snpe_all[i,:] = sq_mmds_snpe

    #plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snpe), 'kd-', linewidth=1.5)
    
    try:
        sq_mmds_snl = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000)+'.npy'))
        #plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.sqrt(sq_mmds_snl), 'rd-', linewidth=1.5)
        sq_mmds_snl_all.append(sq_mmds_snl[1:])
    except:
        pass
    

plt.figure(figsize=(4,4))

plt.semilogx(1000*np.arange(1,sq_mmds_snl.size), np.mean(np.sqrt(sq_mmds_snl_all),axis=0), 'rd-', linewidth=1.5,
        label='SNL')

plt.fill_between(1000*np.arange(1,sq_mmds_snl.size),
                 np.mean(np.sqrt(sq_mmds_snl_all),axis=0) - np.std(np.sqrt(sq_mmds_snl_all),axis=0),
                 np.mean(np.sqrt(sq_mmds_snl_all),axis=0) + np.std(np.sqrt(sq_mmds_snl_all),axis=0),
                 color='r', alpha=0.2)

plt.semilogx(1000*np.arange(1,sq_mmds_snpe.size+1), np.mean(np.sqrt(sq_mmds_snpe_all),axis=0), 'kd-', linewidth=1.5,
        label='APT')

plt.fill_between(1000*np.arange(1,sq_mmds_snpe.size+1),
                 np.mean(np.sqrt(sq_mmds_snpe_all),axis=0) - np.std(np.sqrt(sq_mmds_snpe_all),axis=0),
                 np.mean(np.sqrt(sq_mmds_snpe_all),axis=0) + np.std(np.sqrt(sq_mmds_snpe_all),axis=0),
                 color='k', alpha=0.5)

plt.legend()
plt.ylabel('Maximum meand discrepancy')
plt.xlabel('Number of simulations (log scale)')

PANEL_2D = fig_path +'fig2_d.svg'
plt.savefig(PANEL_2D, facecolor=plt.gcf().get_facecolor(), transparent=True)

plt.show()

# Sanity-check Gallery

In [None]:
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33 as fast_sampler    
save_path = 'results/gauss_noisedims_v1//'

for plot_seed in range(43, 47):
    exp_id = 'seed' + str(plot_seed)
    log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
    pnl_a = plot_hist_marginals(fast_sampler(posteriors[1], 1000), lims=[-5,5])
    pnl_a.set_figwidth(12)
    pnl_a.set_figheight(12)
    for ax in pnl_a.axes:
        ax.axis('off')

    pnl_a.show()

In [None]:
import pickle 
import os

save_path = 'results/gauss_noisedims_v1/'

for plot_seed in range(43,62):
    
    print('seed ' + str(plot_seed))
    exp_id = 'seed' + str(plot_seed)
    
    try:
        file = os.path.join(save_path, exp_id, 'SNL_MAF')
        with open(file + '.pkl', 'rb') as f:
            learned_model = pickle.load(f)        

        file = os.path.join(save_path, exp_id, 'SNL_posteriors')
        with open(file + '.pkl', 'rb') as f:
            all_models = pickle.load(f)                   
    except:
        print('failed to load MAFs')
        
    xs = np.load(os.path.join(save_path, exp_id, 'xs.npy'))    
    ps = np.load(os.path.join(save_path, exp_id, 'ps.npy'))    
    
    pnl_a = plot_hist_marginals(ps[-1], lims=[-5,5])
    pnl_a.set_figwidth(12)
    pnl_a.set_figheight(12)
    for ax in pnl_a.axes:
        ax.axis('off')

    pnl_a.show()