In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import timeit
import os

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import draw_sample_uniform_prior_52, load_gt_lv, init_g_lv
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33, load_gt_gauss, init_g_gauss

from snl.util.plot import plot_hist_marginals

from inspect import getmembers, isclass

fig_path = 'results/figs/'
fontsize = 16

def rasterize_and_save(fname, rasterize_list=None, fig=None, dpi=None,
                       savefig_kw={}):
    """Save a figure with raster and vector components
    This function lets you specify which objects to rasterize at the export
    stage, rather than within each plotting call. Rasterizing certain
    components of a complex figure can significantly reduce file size.
    Inputs
    ------
    fname : str
        Output filename with extension
    rasterize_list : list (or object)
        List of objects to rasterize (or a single object to rasterize)
    fig : matplotlib figure object
        Defaults to current figure
    dpi : int
        Resolution (dots per inch) for rasterizing
    savefig_kw : dict
        Extra keywords to pass to matplotlib.pyplot.savefig
    If rasterize_list is not specified, then all contour, pcolor, and
    collects objects (e.g., ``scatter, fill_between`` etc) will be
    rasterized
    Note: does not work correctly with round=True in Basemap
    Example
    -------
    Rasterize the contour, pcolor, and scatter plots, but not the line
    >>> import matplotlib.pyplot as plt
    >>> from numpy.random import random
    >>> X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9))
    >>> fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
    >>> cax1 = ax1.contourf(Z)
    >>> cax2 = ax2.scatter(X, Y, s=Z)
    >>> cax3 = ax3.pcolormesh(Z)
    >>> cax4 = ax4.plot(Z[:, 0])
    >>> rasterize_list = [cax1, cax2, cax3]
    >>> rasterize_and_save('out.svg', rasterize_list, fig=fig, dpi=300)
    """

    # Behave like pyplot and act on current figure if no figure is specified
    fig = plt.gcf() if fig is None else fig

    # Need to set_rasterization_zorder in order for rasterizing to work
    zorder = -5  # Somewhat arbitrary, just ensuring less than 0

    if rasterize_list is None:
        # Have a guess at stuff that should be rasterised
        types_to_raster = ['QuadMesh', 'Contour', 'collections']
        rasterize_list = []

        print("""
        No rasterize_list specified, so the following objects will
        be rasterized: """)
        # Get all axes, and then get objects within axes
        for ax in fig.get_axes():
            for item in ax.get_children():
                if any(x in str(item) for x in types_to_raster):
                    rasterize_list.append(item)
        print('\n'.join([str(x) for x in rasterize_list]))
    else:
        # Allow rasterize_list to be input as an object to rasterize
        if type(rasterize_list) != list:
            rasterize_list = [rasterize_list]

    for item in rasterize_list:

        # Whether or not plot is a contour plot is important
        is_contour = (isinstance(item, matplotlib.contour.QuadContourSet) or
                      isinstance(item, matplotlib.tri.TriContourSet))

        # Whether or not collection of lines
        # This is commented as we seldom want to rasterize lines
        # is_lines = isinstance(item, matplotlib.collections.LineCollection)

        # Whether or not current item is list of patches
        all_patch_types = tuple(
            x[1] for x in getmembers(matplotlib.patches, isclass))
        try:
            is_patch_list = isinstance(item[0], all_patch_types)
        except TypeError:
            is_patch_list = False

        # Convert to rasterized mode and then change zorder properties
        if is_contour:
            curr_ax = item.ax.axes
            curr_ax.set_rasterization_zorder(zorder)
            # For contour plots, need to set each part of the contour
            # collection individually
            for contour_level in item.collections:
                contour_level.set_zorder(zorder - 1)
                contour_level.set_rasterized(True)
        elif is_patch_list:
            # For list of patches, need to set zorder for each patch
            for patch in item:
                curr_ax = patch.axes
                curr_ax.set_rasterization_zorder(zorder)
                patch.set_zorder(zorder - 1)
                patch.set_rasterized(True)
        else:
            # For all other objects, we can just do it all at once
            curr_ax = item.axes
            curr_ax.set_rasterization_zorder(zorder)
            item.set_rasterized(True)
            item.set_zorder(zorder - 1)

    # dpi is a savefig keyword argument, but treat it as special since it is
    # important to this function
    if dpi is not None:
        savefig_kw['dpi'] = dpi

    # Save resulting figure
    fig.savefig(fname, **savefig_kw)

    
def plot_hist_marginals(data, weights=None, lims=None, gt=None, upper=False, rasterized=False):
    """
    Plots marginal histograms and pairwise scatter plots of a dataset.
    """

    data = np.asarray(data)
    n_bins = int(np.sqrt(data.shape[0]))

    if data.ndim == 1:

        fig, ax = plt.subplots(1, 1)
        ax.hist(data, weights=weights, bins=n_bins, normed=True, rasterized=rasterized)
        ax.set_ylim([0.0, ax.get_ylim()[1]])
        ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
        if lims is not None: ax.set_xlim(lims)
        if gt is not None: ax.vlines(gt, 0, ax.get_ylim()[1], color='r')

    else:

        n_dim = data.shape[1]
        fig = plt.figure()

        if weights is None:
            col = 'k'
            vmin, vmax = None, None
        else:
            col = weights
            vmin, vmax = 0., np.max(weights)

        if lims is not None:
            lims = np.asarray(lims)
            lims = np.tile(lims, [n_dim, 1]) if lims.ndim == 1 else lims

        for i in range(n_dim):
            for j in range(i, n_dim) if upper else range(i + 1):

                ax = fig.add_subplot(n_dim, n_dim, i * n_dim + j + 1)

                if i == j:
                    ax.hist(data[:, i], weights=weights, bins=n_bins, normed=True, rasterized=rasterized)
                    ax.set_ylim([0.0, ax.get_ylim()[1]])
                    ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
                    if i < n_dim - 1 and not upper: ax.tick_params(axis='x', which='both', labelbottom=False)
                    if lims is not None: ax.set_xlim(lims[i])
                    if gt is not None: ax.vlines(gt[i], 0, ax.get_ylim()[1], color='r')

                else:
                    ax.scatter(data[:, j], data[:, i], c=col, s=3, marker='o', vmin=vmin, vmax=vmax, cmap='binary', edgecolors='none', rasterized=rasterized)
                    if i < n_dim - 1: ax.tick_params(axis='x', which='both', labelbottom=False)
                    if j > 0: ax.tick_params(axis='y', which='both', labelleft=False)
                    if j == n_dim - 1: ax.tick_params(axis='y', which='both', labelright=True)
                    if lims is not None:
                        ax.set_xlim(lims[j])
                        ax.set_ylim(lims[i])
                    if gt is not None: ax.scatter(gt[j], gt[i], c='r', s=20, marker='o', edgecolors='none')

    return fig
    
    

In [None]:
%run -i ./common.ipynb

# Gauss

In [None]:
save_path = 'results/gauss_validationset/'

plot_seed = 42
exp_id = 'seed' + str(plot_seed)
log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
pnl_a = plot_hist_marginals(draw_sample_uniform_prior_33(posteriors[-1],1000)[-1000:], lims=[-5,5], rasterized=False)
pnl_a.set_figwidth(8)
pnl_a.set_figheight(8)
for ax in pnl_a.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_a.axes[-5].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_a.axes[-4].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_a.axes[-3].set_xlabel(r'$\theta_3$', fontsize=fontsize)
pnl_a.axes[-2].set_xlabel(r'$\theta_4$', fontsize=fontsize)    
pnl_a.axes[-1].set_xlabel(r'$\theta_5$', fontsize=fontsize)

PANEL_S1A = fig_path +'figS1_a.svg'
plt.savefig(PANEL_S1A, facecolor=plt.gcf().get_facecolor(), transparent=True, bbox_inches='tight')
    
pnl_a.show()

In [None]:
save_path = '../../snl/data/results/seed_42/gauss/'

samples = np.load(save_path + 'final_snl_post_samples_N5000.npy')

pnl_b = plot_hist_marginals(samples[-1000:], lims=[-5,5], rasterized=False, upper=True)
pnl_b.set_figwidth(8)
pnl_b.set_figheight(8)
for ax in pnl_b.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_b.axes[0].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_b.axes[5].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_b.axes[9].set_xlabel(r'$\theta_3$', fontsize=fontsize)
pnl_b.axes[12].set_xlabel(r'$\theta_4$', fontsize=fontsize)    
pnl_b.axes[-1].set_xlabel(r'$\theta_5$', fontsize=fontsize)     

PANEL_S1B = fig_path +'figS1_b.svg'
plt.savefig(PANEL_S1B, facecolor=plt.gcf().get_facecolor(), transparent=True, bbox_inches='tight')
pnl_b.show()



In [None]:

# FIGURE and GRID
FIG_HEIGHT_MM = 80
FIG_WIDTH_MM = 160  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_S1A, 
              3, 
              0)
fig = add_label(fig, 
                'a)', 
                0, 
                5)
fig = add_svg(fig, PANEL_S1B, 
              40,
              0) 
fig = add_label(fig, 
                'b)', 
                37, 
                5)

PATH_SVG = PATH_DROPBOX_FIGS + 'figS1v2.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/figS1v2.pdf $PATH_SVG


# LV

In [None]:
lims = [ [-5, -4.4], [-1.0, -0.4], [-0.5, 0.2], [-5, -4.35] ] 

In [None]:
save_path = 'results/lv_validationset/'

plot_seed = 43
exp_id = 'seed' + str(plot_seed)
log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path)

pars_true, obs_stats = load_gt_lv(generator=init_g_lv(seed=plot_seed))
print('pars_true : ', pars_true)

samples= draw_sample_uniform_prior_52(posteriors[-1],5000)
pnl_a = plot_hist_marginals(samples, gt=pars_true.flatten(), lims=lims,rasterized=False)
pnl_a.set_figwidth(8)
pnl_a.set_figheight(8)
for ax in pnl_a.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_a.axes[-4].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_a.axes[-3].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_a.axes[-2].set_xlabel(r'$\theta_3$', fontsize=fontsize)    
pnl_a.axes[-1].set_xlabel(r'$\theta_4$', fontsize=fontsize)

PANEL_S2A = fig_path +'figS2_a.svg'
savefig_kw = {
'facecolor' : plt.gcf().get_facecolor(), 
'transparent' : True,
'bbox_inches' : 'tight' 
}
    
pnl_a.axes[6].set_xticks([-4.9, -4.5])
pnl_a.axes[7].set_xticks([-0.9, -0.5])
pnl_a.axes[8].set_xticks([-0.4, 0.1])
pnl_a.axes[9].set_xticks([-4.9, -4.5])

rasterize_and_save(PANEL_S2A, rasterize_list=[pnl_a.axes[i] for i in [1,3,4,6,7,8]], fig=pnl_a, dpi=600, savefig_kw=savefig_kw)

pnl_a.show()



In [None]:
save_path = '../../snl/data/results/seed_42/lotka_volterra/'

samples = np.load(save_path + 'final_snl_post_samples_N5000.npy')

pnl_b = plot_hist_marginals(samples, gt=pars_true.flatten(), lims=lims,rasterized=False, upper=True)
pnl_b.set_figwidth(8)
pnl_b.set_figheight(8)
for ax in pnl_b.axes:
    ax.set_xticks([])
    ax.set_yticks([])
#    ax.axis('off')
    
pnl_b.axes[0].set_xlabel(r'$\theta_1$', fontsize=fontsize)
pnl_b.axes[4].set_xlabel(r'$\theta_2$', fontsize=fontsize)
pnl_b.axes[7].set_xlabel(r'$\theta_3$', fontsize=fontsize)
pnl_b.axes[-1].set_xlabel(r'$\theta_4$', fontsize=fontsize)     

PANEL_S2B = fig_path +'figS2_b.svg'
savefig_kw = {
'facecolor' : plt.gcf().get_facecolor(), 
'transparent' : True,
'bbox_inches' : 'tight' 
}

pnl_b.axes[0].set_xticks([-4.8, -4.2])
pnl_b.axes[4].set_xticks([-1, -0.6])
pnl_b.axes[7].set_xticks([-0.4, 0])
pnl_b.axes[-1].set_xticks([-4.9, -4.6])


rasterize_and_save(PANEL_S2B, rasterize_list=[pnl_b.axes[i] for i in [1,2,3,5,6,8]], fig=pnl_b, dpi=600, savefig_kw=savefig_kw)

pnl_b.show()



In [None]:

# FIGURE and GRID
FIG_HEIGHT_MM = 82
FIG_WIDTH_MM = 119  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_S2A, 
              3, 
              0)
fig = add_label(fig, 
                'a)', 
                0, 
                5)
fig = add_svg(fig, PANEL_S2B, 
              40,
              0) 
fig = add_label(fig, 
                'b)', 
                37, 
                5)

PATH_SVG = PATH_DROPBOX_FIGS + 'figS2.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/figS2.pdf $PATH_SVG


# M/G/1 figure

In [None]:
PANEL_4A = fig_path +'mg1_lprobs_validationset.svg'
PANEL_4B = fig_path +'mg1_dists_validationset.svg'

# FIGURE and GRID
FIG_HEIGHT_MM = 50
FIG_WIDTH_MM = 110  # set in NIPS2017 notebook to a default value for all figures

FIG_N_ROWS = 2
ROW_1_NCOLS = 2
ROW_1_HEIGHT_MM =      1.2 * (FIG_HEIGHT_MM / FIG_N_ROWS )
ROW_1_WIDTH_COL_1_MM =  1.5* (FIG_WIDTH_MM / ROW_1_NCOLS)

ROW_2_NCOLS = 2
ROW_2_HEIGHT_MM = 0.8 * FIG_HEIGHT_MM / FIG_N_ROWS
ROW_2_WIDTH_COL_1_MM = FIG_WIDTH_MM / ROW_2_NCOLS
ROW_2_WIDTH_COL_2_MM = FIG_WIDTH_MM / ROW_2_NCOLS


fig = create_fig(FIG_WIDTH_MM, FIG_HEIGHT_MM)


fig = add_svg(fig, PANEL_4A, 
              5, 
              0)

fig = add_svg(fig, PANEL_4B, 
              60,
              0) 

PATH_SVG = PATH_DROPBOX_FIGS + 'figS3.svg'
fig.save(PATH_SVG)
svg(PATH_SVG)
!$INKSCAPE --export-pdf $PATH_DROPBOX_FIGS/figS3.pdf $PATH_SVG


In [None]:
algs=['MAF_SNL', '_MDN_SNPEA', '_MDN_SNPEB', '_continuous_MDN_SNPEC','_discrete_MAF_SNPEC']
seeds = range(42,52)

sq_mmds = np.zeros((5,len(seeds)))
model_id = 'two_moons_runs/'
save_path = 'results/' + model_id
for i in range(len(seeds)):
    seed = seeds[i]
    exp_id = 'seed' + str(seed)
    try:
        sq_mmds[0,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(5000) +'.npy'))[-1]
    except:
        sq_mmds[0,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(1000) +'.npy'))[-1]
        
for k in range(1,len(algs)):
    alg=algs[k]
    save_path = 'results/' + model_id + alg

    for i in range(len(seeds)):
        seed = seeds[i]
        exp_id = 'seed' + str(seed)
        
        if k ==1 and i in [0,9]: # for seeds 42 and 51, first round is final round for SNPE-A
            try:
                sq_mmds[k,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000) +'.npy'))[0]
            except:
                sq_mmds[k,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(1000) +'.npy'))[0]
            
        else: 
            try:
                sq_mmds[k,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(5000) +'.npy'))[-1]
            except:
                sq_mmds[k,i] = np.load(os.path.join(save_path, exp_id, 'all_mmds_N' + str(1000) +'.npy'))[-1]
    

mmds = np.sqrt(sq_mmds).T
mmds[np.where(mmds == np.inf)] = np.nan

print(mmds)
clrs =['r', 'c', 'g', 'm', 'k']

plt.figure(figsize=(6,8))
plt.bar(np.arange(5), np.nanmean(mmds,axis=0), edgecolor=clrs, color='w')
plt.ylabel('Maximum mean discrepancy', fontsize=13)
plt.yticks([0, 0.2, 0.4], fontsize=12)
plt.xticks(np.arange(5), ['SNL', 'SNPE-A', 'SNPE-B', 'MDN-APT', 'MAF-APT'], fontsize=12)
for k in range(5):
    plt.plot(k*np.ones(10)+np.random.normal(size=10)*0.05, mmds[:,k], 'o', color=clrs[k], markersize=5)
ax = plt.subplot(1,1,1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

plt.savefig(PATH_DROPBOX_FIGS + 'figS3.pdf', bbox_inches='tight')
plt.show()

# GLM figure

In [None]:
# version 1 (quick for rebuttal with small networks: n_hiddens=[10,10], n_mades=2, n_components=1)

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(4,4))

seeds = range(42, 52)
n = len(seeds) # number of seeds



all_A =  np.zeros((n, 5))
all_B =  np.zeros((n, 5))
all_C =  np.zeros((n, 5))
all_L =  np.zeros((n, 5))

plt.plot(-1,0, '>:', color='c')
plt.plot(-1,0, 'p:', color='g')
plt.plot(-1,0, 'o:', color='r')
plt.plot(-1,0, 'd-', color='k')

for i in range(n):

    seed = seeds[i]

    tmp = np.load('results/glm/seed' + str(seed) + '/all_mmds.npy')[()]

    all_A[i, :] = tmp['all_mmds_A']
    all_B[i, :] = tmp['all_mmds_B']
    all_C[i, :] = tmp['all_mmds_C']
    all_L[i, :] = tmp['all_mmds_L']


plt.plot(np.mean(np.sqrt(all_A.T), axis=1), '>:', color='c')
plt.plot(np.mean(np.sqrt(all_B.T), axis=1), 'p:', color='g')
plt.plot(np.mean(np.sqrt(all_L.T), axis=1), 'o:', color='r')
plt.plot(np.mean(np.sqrt(all_C.T), axis=1), 'd-', color='k')

x = np.arange(5)
m, s = np.mean(np.sqrt(all_A.T), axis=1), np.std(np.sqrt(all_A.T), axis=1)
plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='c', alpha=0.3)
m, s = np.mean(np.sqrt(all_B.T), axis=1), np.std(np.sqrt(all_B.T), axis=1)
plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='g', alpha=0.3)
m, s = np.mean(np.sqrt(all_L.T), axis=1), np.std(np.sqrt(all_L.T), axis=1)
plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='r', alpha=0.3)
m, s = np.mean(np.sqrt(all_C.T), axis=1), np.std(np.sqrt(all_C.T), axis=1)
plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='k', alpha=0.3)

plt.ylabel('Maximum Mean Discrepancy', fontsize=12)
plt.xlabel('Number of simulations', fontsize=12)
plt.legend(['SNPE-A', 'SNPE-B', 'SNL', 'APT'], fontsize=12)

plt.yticks([0.5,0.3,0.1], fontsize=12)
plt.xticks([0,2,4], [5000, 10000, 25000], fontsize=12)
#plt.title('GLM (10-dim, 5 rounds with N=5k each)')
plt.axis([-0.1, 4.1, 0, 0.55])

plt.savefig(PATH_DROPBOX_FIGS + 'figS4.pdf', bbox_inches='tight')
plt.show()


In [None]:
# version 2 (larger networks: n_hiddens=[50,50], n_mades=5, n_components=8)

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(9,6))

seeds = range(42, 52) #[42,43,45,46,48,49,50,51] # seeds done up to rebuttal deadline
n = len(seeds) # number of seeds

for plot_indiv in [1,0]:

    plt.subplot(1,2,2-plot_indiv)
    
    all_A =  np.zeros((n, 5))
    all_B =  np.inf * np.zeros((n, 5))
    all_C =  np.zeros((n, 5))
    all_L =  np.zeros((n, 5))

    plt.plot(-1,0, '>:', color='c')
    plt.plot(-1,0, 'p:', color='g')
    plt.plot(-1,0, 'o:', color='r')
    plt.plot(-1,0, 'd-', color='k')
    
    for i in range(n):
        
        seed = seeds[i]

        tmp = np.load('results/glm/seed' + str(seed) + '/all_mmds_v2.npy')[()]

        all_A[i, :] = tmp['all_mmds_A']
        all_B[i, :len(tmp['all_mmds_B'])] = tmp['all_mmds_B']
        all_C[i, :] = tmp['all_mmds_C']
        all_L[i, :] = tmp['all_mmds_L']


    if plot_indiv:
        plt.plot(np.sqrt(all_A.T), '>:', color='c')
        plt.plot(np.sqrt(all_B.T), 'p:', color='g')
        plt.plot(np.sqrt(all_L.T), 'o:', color='r')
        plt.plot(np.sqrt(all_C.T), 'd-', color='k')

    else: # plot mean +/- sem
        plt.plot(np.mean(np.sqrt(all_A.T), axis=1), '>:', color='c')
        plt.plot(np.mean(np.sqrt(all_B.T), axis=1), 'p:', color='g')
        plt.plot(np.mean(np.sqrt(all_L.T), axis=1), 'o:', color='r')
        plt.plot(np.mean(np.sqrt(all_C.T), axis=1), 'd-', color='k')

        x = np.arange(5)
        m, s = np.mean(np.sqrt(all_A.T), axis=1), np.std(np.sqrt(all_A.T), axis=1)
        plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='c', alpha=0.3)
        m, s = np.mean(np.sqrt(all_B.T), axis=1), np.std(np.sqrt(all_B.T), axis=1)
        plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='g', alpha=0.3)
        m, s = np.mean(np.sqrt(all_L.T), axis=1), np.std(np.sqrt(all_L.T), axis=1)
        plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='r', alpha=0.3)
        m, s = np.mean(np.sqrt(all_C.T), axis=1), np.std(np.sqrt(all_C.T), axis=1)
        plt.fill_between(x, m-s/np.sqrt(n), m+s/np.sqrt(n),color='k', alpha=0.3)

    plt.xlabel('N')
    if plot_indiv:
        plt.ylabel('maximum mean discrepancy')
    else:
        plt.ylabel('averge MMD')
        plt.legend(['SNPE-A', 'SNPE-B', 'SNL', 'APT'])
        
    plt.xticks([0,2,4], [5000, 10000, 25000])
    plt.title('GLM (10-dim, 5 rounds with N=5k each)')
    plt.axis([-0.1, 4.1, 0, 0.65])
plt.show()
