# Gaussian model

- simulator taken from https://github.com/mackelab/SNL_py3port, which contains the original https://github.com/gpapamak/snl after 2to3 conversion with minimal edits (deactivating generator-internal summary stats normalization).
- WIP

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

import snl
import snl.simulators.gaussian as sim

import delfi.generator as dg
import delfi.distribution as dd
from delfi.utils.viz import plot_pdf
import delfi.inference as infer

seed = 42

# SNPE parameters

# training schedule
n_train=1000
n_rounds=10

# fitting setup
minibatch=100
epochs=500

# network setup
n_hiddens=[50,50]
reg_lambda=0.01

# convenience
pilot_samples=1000
svi=False
verbose=True
prior_norm=False

# SNPE-C parameters
n_null = 10

# MAF parameters
mode='random' # ordering of variables for MADEs
n_mades = 5 # number of MADES
act_fun = 'tanh'
batch_norm = False # batch-normalization currently not supported
train_on_all = True # now supported feature


In [None]:
from delfi.simulator import BaseSimulator


def init_g(seed):
    prior = dd.Uniform(lower= [-3,-3,-3,-3,-3], upper = [3,3,3,3,3])

    # model 
    model_snl = sim.Model()
    class Gaussian(BaseSimulator):
        """Gaussian simulator

        Parameters
        ----------
        dim : int
            Number of dimensions of parameters
        seed : int or None
            If set, randomness is seeded
        """        

        def gen_single(self, params):
            """ params = (m1, m2, std1, std2, arctanh(corr_scaling))

            """
            return model_snl.sim(params)

    model = Gaussian(dim_param=4)

    # summary statistics
    summary = sim.Stats()

    # generator
    g = dg.Default(prior=prior, model=model, summary=summary, seed=seed+41)
    return g


In [None]:
# setup
g = init_g(seed=seed)

gt = np.load('/home/marcel/Desktop/Projects/Biophysicality/code/snl_snpec/gt_gauss.npy', encoding='latin1')[()]
pars_true, obs_stats = np.array(gt['true_ps']), np.array(gt['obs_xs']).reshape(1,-1)

pars_true, obs_stats

In [None]:

if train_on_all:
    epochs = [epochs//(r+1) for r in range(n_rounds)]

# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)

# inference object
res_C = infer.SNPEC(g,
                 obs=obs_stats,
                 n_hiddens=n_hiddens,
                 seed=seed,
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 n_mades=n_mades, # providing this argument triggers usage of MAFs (vs. MDNs)
                 act_fun=act_fun,
                 mode=mode,
                 rng=rng,
                 batch_norm=batch_norm,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_C, tds_C, posteriors_C = res_C.run(
                    n_train=n_train,
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds,
                    train_on_all=train_on_all,
                    minibatch=minibatch,
                    epochs=epochs)

print(timeit.time.time() - t)


In [None]:
for r in range(n_rounds):
    plt.plot(logs_C[r]['loss'])
    plt.show()

In [None]:
for r in range(len(logs_C)):
    
    posterior_C = posteriors_C[r]
    #posterior_C.ndim = posterior_A.ndim
    
    g = init_g(seed=42)
    g.proposal = posterior_C
    samples = np.array(g.draw_params(5000)) 
    
    fig,_ = plot_pdf(dd.Gaussian(m=0.00000123*np.ones(pars_true.size), S=1e-30*np.eye(pars_true.size)), 
                   samples=samples.T,
                   gt=pars_true, 
                   lims=[[-3,3],[-3,3],[-3,3],[-3,3],[-3,3]],
                   #lims=[0,10],
                   resolution=100,
                   ticks=True,
                   figsize=(16,16));
    
    fig.suptitle('SNPE-C posterior estimates, round r = '+str(r+1), fontsize=14)
    print('negative log-probability of ground-truth pars \n', -posterior_C.eval(pars_true, log=True))

# marginal over summary statistics (plus best-fitting Gaussian approx.)

In [None]:
stats = tds_C[0][1]
fig,_ = plot_pdf(dd.Gaussian(m=stats.mean(axis=0), S=np.cov(stats.T)), 
                   samples=stats.T,
                   gt=((obs_stats-res_C.stats_mean)/res_C.stats_std).flatten(), 
                   ticks=True,
                   resolution=100,
                   figsize=(16,16));
fig.suptitle('(pair-wise) marginal(s) over summary statistics from Gaussian model (already z-scored!)')
#fig.savefig('/home/marcel/Desktop/gauss_summary_stats_marginals.pdf')
fig.show()

# results evaluation

In [None]:
import snl.pdfs as pdfs

# for mcmc
thin = 10
n_mcmc_samples = 5000
burnin = 100

def calc_err(true_ps, samples, weights=None):
    """
    Calculates error (neg log prob of truth) for a set of possibly weighted samples.
    """

    std = n_mcmc_samples ** (-1.0 / (len(true_ps) + 4))

    return -pdfs.gaussian_kde(samples, weights, std).eval(true_ps)

std = n_mcmc_samples ** (-1.0 / (len(pars_true) + 4))

all_prop_errs = []

for proposal in posteriors_C[:-1]:
    g = init_g(seed=42)
    g.proposal = proposal
    samples = np.array(g.draw_params(n_mcmc_samples))
    prop_err = calc_err(pars_true, samples)
    all_prop_errs.append(prop_err)

g = init_g(seed=42)
g.proposal = posteriors_C[-1]    
samples = np.array(g.draw_params(n_mcmc_samples))
post_err = calc_err(pars_true, samples)

plt.figure(figsize=(8,5))
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs + [post_err]), 'kd:')
plt.axis([600, 22000, 0, 5])
plt.xlabel('Number of simulations (log scale)')
plt.ylabel('- log probability of true parameters')
plt.savefig('/home/marcel/Desktop/gauss_snpec_maf_n_null_10__v0.pdf')
plt.show()

In [None]:
all_prop_errs_raw = []

for proposal in posteriors_C[:-1]:
    samples = proposal.gen(n_mcmc_samples)
    prop_err = calc_err(pars_true, samples)
    all_prop_errs_raw.append(prop_err)

samples = posteriors_C[-1].gen(n_mcmc_samples)
post_err_raw = calc_err(pars_true, samples)

plt.figure(figsize=(8,5))
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs + [post_err]), 'bd:')
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs_raw + [post_err_raw]), 'kd:')
plt.legend(['rej. sampling', 'naive sampling'])
plt.axis([600, 22000, 0, 10])
plt.xlabel('Number of simulations (log scale)')
plt.ylabel('- log probability of true parameters')
plt.title('effects of truncation on MAF')
plt.savefig('/home/marcel/Desktop/Gauss_snpec_maf_n_null_10_N5000_MAF_truncation.pdf')
plt.show()