# M/G/1 model

- simulator taken from https://github.com/mackelab/SNL_py3port, which contains the original https://github.com/gpapamak/snl after 2to3 conversion with minimal edits (deactivating generator-internal summary stats normalization).


In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

import snl
import snl.simulators.mg1 as sim

import delfi.generator as dg
import delfi.distribution as dd
from delfi.utils.viz import plot_pdf
import delfi.inference as infer

seed = 42

# SNPE parameters

# training schedule
n_train=1000
n_rounds=20

# fitting setup
minibatch=100
epochs=500

# network setup
n_hiddens=[50,50]
reg_lambda=0.01

# convenience
pilot_samples=1000
svi=False
verbose=True
prior_norm=False

# SNPE-C parameters
n_null = 10

# MAF parameters
mode='random' # ordering of variables for MADEs
n_mades = 5 # number of MADES
act_fun = 'tanh'
batch_norm = False # batch-normalization currently not supported
train_on_all = True # now supported feature


# implement prior in DELFI
- Papamakarios' prior for this experiment is uniform within a non-rectangular region
- since the MAF has in principle infinite support, we need to reject samples drawn from proposals in later rounds that fall outside the prior support
- DELFI implements rejection through the generator object, but it is hardcoded for uniform (rectangular-support) priors, so we inherit from dd.Uniform and slightly adapt the generator 


In [None]:
class ShiftedUniform(dd.Uniform):
    
    def __init__(self, lower=0., upper=1., seed=None):
        """Shifted uniform distribution from SNL paper, M/G/1 model
        theta1 ~ Unif[0,10]
        theta2-theta1 ~ Unif[0,10]
        theta3 ~ Unif[0,1/3]

        Parameters
        ----------
        lower : list, or np.array, 1d
            Lower bound(s)
        upper : list, or np.array, 1d
            Upper bound(s)
        seed : int or None
            If provided, random number generator will be seeded
        """

        super().__init__(lower=lower, upper=upper, seed=seed)
        assert self.ndim == 3
        

    def gen(self, n_samples=1):
        params = super().gen(n_samples=n_samples)
        params[:,1] += params[:,0]
        return params
    
    def eval(self, x, ii=None, log=True):
        
        x = x.copy()
        x[:,1] -= x[:,0]
        return super().eval(x=x, ii=ii,log=log)
        
        
prior = ShiftedUniform(lower=[ 0, 0,  0  ], 
                       upper=[10,10,1./3.])
params = prior.gen(10000)

plt.plot(params[:,0], params[:,1], '.')
plt.show()

plt.hist( params[:,1] - params[:,0], normed=True)
plt.show()

plt.plot(prior.eval(params, log=True))
plt.show()

print('isinstance(prior, dd.Uniform)', isinstance(prior, dd.Uniform))

### overwrite generator to properly reject proposal draws outside prior bounds

In [None]:
from delfi.generator.Default import Default


class GenMG1(Default):
    
    @copy_ancestor_docstring
    def _feedback_proposed_param(self, param):
        # See BaseGenerator for docstring

        #print(param.shape)
        assert param.size == 3
        assert isinstance(self.prior, ShiftedUniform)
        
        param_ = param.copy()
        param_[0,1] -= param_[0,0]
        
        if np.any(param_ < self.prior.lower) or \
           np.any(param_ > self.prior.upper):
            return 'resample'

        return 'accept'


### control seed

In [None]:
from delfi.simulator import BaseSimulator


def init_g(seed):
    # prior 

    #prior = sim.Prior() # proposal will happily sample from outside the prior unless it's dd.Uniform !
    prior = ShiftedUniform(lower=[ 0, 0,  0  ], 
                           upper=[10,10,1./3.], 
                           seed=seed)

    # model 
    model_snl = sim.Model()
    class MG1(BaseSimulator):
        """M/G/1 simulator

        Parameters
        ----------
        dim : int
            Number of dimensions of parameters
        seed : int or None
            If set, randomness is seeded
        """        

        def gen_single(self, params):
            """ params = (lower bound of server processing time,
                          upper bound of server processing time,
                          rate customer arrivals )

            """
            return model_snl.sim(params, rng=self.rng)

    model = MG1(dim_param=3, seed=seed)

    # summary statistics
    summary = sim.Stats()

    # generator
    g = GenMG1(prior=prior, model=model, summary=summary, seed=seed+41)
    return g

In [None]:
g = init_g(seed=seed)

#pars_true = np.array([1, 5, 0.2])  # taken from SNL paper
#obs = g.model.gen_single(pars_true)  # should also recover
#obs_stats = g.summary.calc([obs])    # xo from SNL paper !

# load ground-truth xo and true parameter values theta* from disc
gt = np.load('/home/marcel/Desktop/Projects/Biophysicality/code/snl_snpec/gt_mg1.npy', encoding='latin1')[()]
whiten_params = np.load('/home/marcel/Desktop/Projects/Biophysicality/code/snl_snpec/whiten_params_mg1.npy', encoding='latin1')[()]
pars_true, obs_stats = np.array(gt['true_ps']), np.array(gt['obs_xs']).reshape(1,-1)

# un-whiten xo with retrieved whitening params (SNPE-C will apply its own z-scoring, but xo needs to match the x_n)
obs_stats = (obs_stats.flatten() / whiten_params['istds']).dot(whiten_params['U'].T) + whiten_params['means']
obs_stats = obs_stats.reshape(1,-1)

pars_true, obs_stats

# fit SNPE-C

In [None]:

if train_on_all:
    epochs = [epochs//(r+1) for r in range(n_rounds)]

# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)

# inference object
res_C = infer.SNPEC(g,
                 obs=obs_stats,
                 n_hiddens=n_hiddens,
                 seed=seed,
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 n_mades=n_mades, # providing this argument triggers usage of MAFs (vs. MDNs)
                 act_fun=act_fun,
                 mode=mode,
                 rng=rng,
                 batch_norm=batch_norm,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_C, tds_C, posteriors_C = res_C.run(
                    n_train=n_train,
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds,
                    train_on_all=train_on_all,
                    minibatch=minibatch,
                    epochs=epochs)

print(timeit.time.time() - t)


## training losses across rounds

In [None]:
for r in range(n_rounds):
    plt.plot(logs_C[r]['loss'])
    plt.show()

# posterior estimates across rounds

In [None]:
for r in range(len(logs_C)):
    
    posterior_C = posteriors_C[r]
    #posterior_C.ndim = posterior_A.ndim
    
    g = init_g(seed=42)
    g.proposal = posterior_C
    samples = np.array(g.draw_params(1000)) 
    
    fig,_ = plot_pdf(dd.Gaussian(m=0.00000123*np.ones(pars_true.size), S=1e-30*np.eye(pars_true.size)), 
                   samples=samples.T,
                   gt=pars_true, 
                   lims=[[0,10],[0,20],[0., 1./3.]],
                   #lims=[0,10],
                   resolution=100,
                   ticks=True,
                   figsize=(16,16));
    
    fig.suptitle('SNPE-C posterior estimates, round r = '+str(r+1), fontsize=14)
    print('negative log-probability of ground-truth pars \n', -posterior_C.eval(pars_true, log=True))

# marginal over summary statistics (plus best-fitting Gaussian approx.)
- note that plot_pdf automatically chooses the axes limits according to the provided samples, i.e. by the outliers
- hence large empty regions indicate that the simulator sometimes produces extreme outliers that may negatively affect the de-facto standard z-scoring done by SNPE and SNL

In [None]:
stats = tds_C[0][1]
fig,_ = plot_pdf(dd.Gaussian(m=stats.mean(axis=0), S=np.cov(stats.T)), 
                   samples=stats.T,
                   gt=((obs_stats-res_C.stats_mean)/res_C.stats_std).flatten(), 
                   ticks=True,
                   resolution=100,
                   figsize=(16,16));
fig.suptitle('(pair-wise) marginal(s) over summary statistics from Gaussian model (already z-scored!)')
#fig.savefig('/home/marcel/Desktop/mg1_summary_stats_marginals.pdf')
fig.show()

# results evaluation
- copy-paste from SNL code

In [None]:
import snl.pdfs as pdfs

# for mcmc
thin = 10
n_mcmc_samples = 5000
burnin = 100

def calc_err(true_ps, samples, weights=None):
    """
    Calculates error (neg log prob of truth) for a set of possibly weighted samples.
    """

    std = n_mcmc_samples ** (-1.0 / (len(true_ps) + 4))

    return -pdfs.gaussian_kde(samples, weights, std).eval(true_ps)

std = n_mcmc_samples ** (-1.0 / (len(pars_true) + 4))

all_prop_errs = []

for proposal in posteriors_C[:-1]:
    g = init_g(seed=42)
    g.proposal = proposal
    samples = np.array(g.draw_params(n_mcmc_samples))
    prop_err = calc_err(pars_true, samples)
    all_prop_errs.append(prop_err)

g = init_g(seed=42)
g.proposal = posteriors_C[-1]    
samples = np.array(g.draw_params(n_mcmc_samples))
post_err = calc_err(pars_true, samples)

plt.figure(figsize=(8,5))
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs + [post_err]), 'kd:')
plt.axis([600, 22000, -1, 5])
plt.xlabel('Number of simulations (log scale)')
plt.ylabel('- log probability of true parameters')
plt.savefig('/home/marcel/Desktop/mg1_snpec_maf_n_null_10__v0.pdf')
plt.show()

In [None]:
all_prop_errs_raw = []

for proposal in posteriors_C[:-1]:
    samples = proposal.gen(n_mcmc_samples)
    prop_err = calc_err(pars_true, samples)
    all_prop_errs_raw.append(prop_err)

samples = posteriors_C[-1].gen(n_mcmc_samples)
post_err_raw = calc_err(pars_true, samples)

plt.figure(figsize=(8,5))
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs + [post_err]), 'bd:')
plt.semilogx(np.arange(1, n_rounds+1) * n_train, np.array(all_prop_errs_raw + [post_err_raw]), 'kd:')
plt.legend(['rej. sampling', 'naive sampling'])
plt.axis([600, 22000, -1, 0.8])
plt.xlabel('Number of simulations (log scale)')
plt.ylabel('- log probability of true parameters')
plt.title('effects of truncation on MAF')
plt.savefig('/home/marcel/Desktop/mg1_snpec_maf_n_null_10_N5000_MAF_truncation.pdf')
plt.show()