# M/G/1 model

- simulator taken from https://github.com/mackelab/SNL_py3port, which contains the original https://github.com/gpapamak/snl after 2to3 conversion with minimal edits (deactivating generator-internal summary stats normalization).


In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.utils.viz import plot_pdf
import delfi.inference as infer

from lfimodels.snl_exps.util import init_g_mg1 as init_g

seed = 42

# SNPE parameters

# training schedule
n_train=1000
n_rounds=20

# fitting setup
minibatch=100
epochs=500

# network setup
n_hiddens=[50,50]
reg_lambda=0.01

# convenience
pilot_samples=1000
svi=False
verbose=True
prior_norm=False

# SNPE-C parameters
n_null = 10

# MAF parameters
mode='random' # ordering of variables for MADEs
n_mades = 5 # number of MADES
act_fun = 'tanh'
batch_norm = False # batch-normalization currently not supported
train_on_all = True # now supported feature

# simulation setup
g = init_g(seed=seed)

try: 
    # load ground-truth xo and true parameter values theta* from disc
    gt = np.load('/home/marcel/Desktop/Projects/Biophysicality/code/snl_snpec/gt_mg1.npy', encoding='latin1')[()]
    whiten_params = np.load('/home/marcel/Desktop/Projects/Biophysicality/code/snl_snpec/whiten_params_mg1.npy', encoding='latin1')[()]
    pars_true, obs_stats = np.array(gt['true_ps']), np.array(gt['obs_xs']).reshape(1,-1)

    # un-whiten xo with retrieved whitening params (SNPE-C will apply its own z-scoring, but xo needs to match the x_n)
    obs_stats = (obs_stats.flatten() / whiten_params['istds']).dot(whiten_params['U'].T) + whiten_params['means']
    obs_stats = obs_stats.reshape(1,-1)

except:
    pars_true = np.array([1, 5, 0.2])  # taken from SNL paper
    obs = g.model.gen_single(pars_true)  # should also recover
    obs_stats = g.summary.calc([obs])    # xo from SNL paper !
    
    print('\n WARNING: could not load ground-truth data and parameters from disk! \n Sampling xo instead !')

print('pars_true : ', pars_true)
print('obs_stats : ', obs_stats)

# fit SNPE-C

In [None]:

if train_on_all:
    epochs = [epochs//(r+1) for r in range(n_rounds)]

# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)

# inference object
res_C = infer.SNPEC(g,
                 obs=obs_stats,
                 n_hiddens=n_hiddens,
                 seed=seed,
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 n_mades=n_mades, # providing this argument triggers usage of MAFs (vs. MDNs)
                 act_fun=act_fun,
                 mode=mode,
                 rng=rng,
                 batch_norm=batch_norm,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_C, tds_C, posteriors_C = res_C.run(
                    n_train=n_train,
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds,
                    train_on_all=train_on_all,
                    minibatch=minibatch,
                    epochs=epochs)

print('fitting time : ', timeit.time.time() - t)


# inspect results

In [None]:
for r in range(n_rounds):
    plt.plot(logs_C[r]['loss'])
    plt.show()

for r in range(len(logs_C)):
    
    posterior_C = posteriors_C[r]
    #posterior_C.ndim = posterior_A.ndim
    
    g = init_g(seed=42)
    g.proposal = posterior_C
    samples = np.array(g.draw_params(1000)) 
    
    fig,_ = plot_pdf(dd.Gaussian(m=0.00000123*np.ones(pars_true.size), S=1e-30*np.eye(pars_true.size)), 
                   samples=samples.T,
                   gt=pars_true, 
                   lims=[[0,10],[0,20],[0., 1./3.]],
                   #lims=[0,10],
                   resolution=100,
                   ticks=True,
                   figsize=(16,16));
    
    fig.suptitle('SNPE-C posterior estimates, round r = '+str(r+1), fontsize=14)
    print('negative log-probability of ground-truth pars \n', -posterior_C.eval(pars_true, log=True))

# marginal over summary statistics (plus best-fitting Gaussian approx.)
- note that plot_pdf automatically chooses the axes limits according to the provided samples, i.e. by the outliers
- hence large empty regions indicate that the simulator sometimes produces extreme outliers that may negatively affect the de-facto standard z-scoring done by SNPE and SNL

In [None]:
stats = tds_C[0][1]
fig,_ = plot_pdf(dd.Gaussian(m=stats.mean(axis=0), S=np.cov(stats.T)), 
                   samples=stats.T,
                   gt=((obs_stats-res_C.stats_mean)/res_C.stats_std).flatten(), 
                   ticks=True,
                   resolution=100,
                   figsize=(16,16));
fig.suptitle('(pair-wise) marginal(s) over summary statistics from Gaussian model (already z-scored!)')
#fig.savefig('/home/marcel/Desktop/mg1_summary_stats_marginals.pdf')
fig.show()

# results evaluation
- copy-paste from SNL code

In [None]:
all_prop_errs = calc_all_lprob_errs(pars_true, 
                                    n_samples=5000, 
                                    posteriors=posteriors_C, 
                                    init_g=init_g,
                                    rej=True)

all_prop_errs_raw = calc_all_lprob_errs(pars_true, 
                                    n_samples=5000, 
                                    posteriors=posteriors_C, 
                                    init_g=init_g,
                                    rej=False)

plt.figure(figsize=(8,5))
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs + [post_err]), 'bd:')
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs_raw + [post_err_raw]), 'kd:')
plt.legend(['rej. sampling', 'naive sampling'])
plt.axis([600, 22000, -1, 2])
plt.xlabel('Number of simulations (log scale)')
plt.ylabel('- log probability of true parameters')
plt.title('effects of truncation on MAF')
#plt.savefig('/home/marcel/Desktop/mg1_snpec_maf_n_null_10_N5000_MAF_truncation.pdf')
plt.show()