# SLCP model with added noise dimensions (fig 3)

- fitting APT and SNL for figure 3 of APT paper
- for evaluation of fits, see APT_eval.ipynb and SNL_eval.ipynb 
- for plotting, see ICML_figure3.ipynb

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd
from delfi.summarystats.BaseSummaryStats import BaseSummaryStats

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import calc_all_lprob_errs

import snl.simulators.gaussian as sim_gauss
from snl.inference.nde import SequentialNeuralLikelihood
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
import sys
import os
import pickle

model_id = 'gauss'
save_path = 'results/' + model_id + '_noisedims_v3'

noise_dim = 52
n_noise_comps = 20

scale = 3.

"""
STABILITY ISSUES WITH NOISE GENERATION FOR LARGE noise_dim VALUES! 
ADDING scale * np.eye(noise_dim) TO COVARIANCE MATRICES, WITH SCALE DEPENDING ON noise_dim !
scales manually mapped out: 
noise_dim = 92 : scale = 7.
noise_dim = 52 : scale = 3.
noise_dim = 32 : scale = 3.
noise_dim = 12 : scale = 1.

"""

"""
Version numbers (savepath suffix):  
noise_dim = 92 : _v2.
noise_dim = 52 : _v3.
noise_dim = 32 : _v4.
noise_dim = 12 : _v1.

"""



class NoiseStats(BaseSummaryStats):

        
    def __init__(self, noise_source, n_signal = None, seed=None):

        rng = np.random
        rng.seed(seed)
                
        self.noise_source = noise_source
        
        self.n_signal = n_signal
        self.n_noise = self.noise_source.ndim

        super().__init__(seed=seed)
        
        self.n_summary = n_signal + self.n_noise
        

        self.idx = np.arange(self.n_summary)
        self.idx = rng.permutation(self.n_summary)
        
    def calc(self, repetition_list):
                
        # get the number of samples contained
        n_reps = len(repetition_list)

        # get the size of the data inside a sample
        assert self.n_summary == repetition_list[0].size + self.n_noise

        # build a matrix of n_reps x n_summary
        data_matrix = np.zeros((n_reps, self.n_summary))
        noise_matrix = self.noise_source.gen(n_reps)
        for rep_idx, rep_val in enumerate(repetition_list):
            data_matrix[rep_idx, :] =  np.hstack( (rep_val, noise_matrix[rep_idx,:]) )

        return data_matrix if self.idx is None else data_matrix[:, self.idx]


def init_noise_g(seed, obs_stats):

    rng = np.random
    rng.seed(seed)    
    
    noise_means_prior = dd.Gaussian(m = np.zeros(noise_dim), S=np.eye(noise_dim), seed=seed+1)
    noise_ms = [noise_means_prior.gen(1).reshape(-1) for i in range(n_noise_comps)]

    cholesky_factors = [np.tril(rng.normal(size=(noise_dim, noise_dim))) + np.diag(np.exp(rng.normal(size=noise_dim)))
                        + scale * np.eye(noise_dim) for i in range(n_noise_comps) ]
    noise_Ss = [3 * np.dot(ch, ch.T) / noise_dim for ch in cholesky_factors]
    noise_ms = [15 * rng.normal(size=noise_dim) for i in range(n_noise_comps)]
    noise_dofs = [2 for i in range(n_noise_comps)]

    noise_distribution = dd.MoT(a=np.ones(n_noise_comps) / n_noise_comps, ms=noise_ms, Ss=noise_Ss, dofs=noise_dofs, seed=seed)
    noise_distribution.ndim = noise_dim

    # generator
    g = init_g(seed=seed)
    g.summary = NoiseStats(noise_source=noise_distribution, n_signal=8, seed=seed)
    g.model.dim_param = 5

    obs_stats = np.hstack((obs_stats, noise_distribution.gen(1).reshape(1,-1)))[0,g.summary.idx]
    assert obs_stats.size ==  g.summary.n_summary

    print('permutation indices', g.summary.idx)    
    
    return g, obs_stats


# fit SNL

In [None]:
seeds = np.arange(52,62)

for seed in seeds:

    
    print('\n')
    print('\n')
    print('seed #' + str(seed))
    print('\n')
    
    setup_dict = load_setup()    
    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)    
    
    # fit SNPE
            
    exp_id = 'seed'+str(seed)
    
    g, obs_stats_noise = init_noise_g(seed, obs_stats)

    if setup_dict['train_on_all']:
        epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
    else:
        epochs=setup_dict['epochs']

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    obs_stats_noise_snpe = obs_stats_noise.copy()

    
    # fit SNL


    g, obs_stats_noise = init_noise_g(seed, obs_stats)
    
    model = sim_gauss.Model()
    prior = sim_gauss.Prior()
    stats = NoiseStats(noise_source=g.summary.noise_source, n_signal=8, seed=seed)
    stats.idx = g.summary.idx

    print('permutation indices', stats.idx)

    # control model seed seed
    rng = np.random
    rng.seed(seed)
    
    sim_model = lambda ps, rng: stats.calc(model.sim(ps, rng=rng)) 
    # to run without summary stats: sim_model = model.sim !

    obs_stats_noise_snl = obs_stats_noise.copy()
        
    inf = SequentialNeuralLikelihood(prior=prior, sim_model=sim_model)

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    maf = ConditionalMaskedAutoregressiveFlow(n_inputs=prior.n_dims, 
                                              n_outputs=obs_stats_noise.size, 
                                              n_hiddens=setup_dict['n_hiddens'], 
                                              act_fun=setup_dict['act_fun'], 
                                              n_mades=setup_dict['n_mades'], 

                                              batch_norm=False,           # these differ for 
                                              output_order='sequential', # the usage of our 
                                              mode=setup_dict['mode'],         # MAFs...

                                              input=None, 
                                              output=None, rng=rng)


    # control sampler seed
    rng = np.random
    rng.seed(seed+1)

    t = timeit.time.time()

    learned_model = inf.learn_likelihood(obs_xs=obs_stats_noise.flatten(), 
                           model=maf, 
                           n_samples=setup_dict['n_train'], 
                           n_rounds=setup_dict['n_rounds'],
                           train_on_all=setup_dict['train_on_all'],
                           thin=10, 
                           save_models=True, 
                           logger=sys.stdout, 
                           rng=rng)

    print(timeit.time.time() -  t)    
    
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)

    file = os.path.join(save_path, exp_id, 'SNL_MAF')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(learned_model, f)
        
    file = os.path.join(save_path, exp_id, 'SNL_posteriors')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(inf.all_models, f)               

    file = os.path.join(save_path, exp_id, 'SNL_posteriors')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(inf.all_models, f)        
        
    vars = {
        'ps' : inf.all_ps,
        'xs' : inf.all_xs
    }

    for varname in vars.keys():
        fn = os.path.join(save_path, exp_id, varname)
        np.save(fn, vars[varname])    
            
    fn = os.path.join(save_path, exp_id, 'obs_stats_noise')
    np.save(fn, obs_stats_noise)    

# fit APT

In [None]:
seeds = np.arange(52,62)

for seed in seeds:

    
    print('\n')
    print('\n')
    print('seed #' + str(seed))
    print('\n')
    
    setup_dict = load_setup()    
    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)    
    
    # fit SNPE
            
    exp_id = 'seed'+str(seed)
    
    g, obs_stats_noise = init_noise_g(seed, obs_stats)

    if setup_dict['train_on_all']:
        epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
    else:
        epochs=setup_dict['epochs']

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    obs_stats_noise_snpe = obs_stats_noise.copy()
    

    res_C = infer.SNPEC(g,
                        obs=obs_stats_noise,
                        n_hiddens=setup_dict['n_hiddens'],
                        seed=seed,
                        reg_lambda=setup_dict['reg_lambda'],
                        pilot_samples=setup_dict['pilot_samples'],
                        svi=setup_dict['svi'],
                        n_mades=setup_dict['n_mades'],
                        act_fun=setup_dict['act_fun'],
                        mode=setup_dict['mode'],
                        rng=rng,
                        batch_norm=setup_dict['batch_norm'],
                        verbose=setup_dict['verbose'],
                        #upper=setup_dict['upper'], # box-constraints 
                        #lower=setup_dict['lower'], # for MAF support
                        prior_norm=setup_dict['prior_norm'])


    
    # train
    t = timeit.time.time()

    print('fitting model with SNPC-C')
    logs_C, tds_C, posteriors_C = res_C.run(
                        n_train=setup_dict['n_train'],
                        proposal=setup_dict['proposal'],
                        moo=setup_dict['moo'],
                        n_null = setup_dict['n_null'],
                        n_rounds=setup_dict['n_rounds'],
                        train_on_all=setup_dict['train_on_all'],
                        minibatch=setup_dict['minibatch'],
                        epochs=epochs, 
                        verbose=False,
                        val_frac=0.1,
                        silent_fail=True,
                        stop_on_nan=True)

    print('fitting time : ', timeit.time.time() - t)

    save_results(logs=logs_C, tds=[], posteriors=posteriors_C, 
                 setup_dict=setup_dict, exp_id=exp_id, path=save_path)

    fn = os.path.join(save_path, exp_id, 'obs_stats_noise')
    np.save(fn, obs_stats_noise)    
        

# fits SNPE-B

In [None]:
seeds = np.arange(52,62)

for seed in seeds:

    try:
        
        print('\n')
        print('\n')
        print('seed #' + str(seed))
        print('\n')

        setup_dict = load_setup()    
        pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
        print('pars_true : ', pars_true)
        print('obs_stats : ', obs_stats)    

        # fit SNPE

        exp_id = 'seed'+str(seed) + '_SNPEB'

        g, obs_stats_noise = init_noise_g(seed, obs_stats)

        # control MAF seed
        rng = np.random
        rng.seed(seed)

        obs_stats_noise_snpe = obs_stats_noise.copy()

        # SNPE-B settings
        setup_dict['n_components'] = 8
        setup_dict['svi'] = True
        setup_dict['epochs'] = 500 # SNPE-B gradient descent becomes unstable after ~700 epochs ! (-> NaN weights)


        res_B = infer.SNPE(g,
                            obs=obs_stats_noise.reshape(1,-1),
                            n_hiddens=setup_dict['n_hiddens'],
                            seed=seed,
                            reg_lambda=setup_dict['reg_lambda'],
                            pilot_samples=setup_dict['pilot_samples'],
                            svi=setup_dict['svi'],
                            n_components=setup_dict['n_components'],
                            prior_norm=setup_dict['prior_norm'])



        # train
        t = timeit.time.time()

        print('fitting model with SNPC-B')
        logs_B, tds_B, posteriors_B = res_B.run(
                            n_train=setup_dict['n_train'],
                            n_rounds=setup_dict['n_rounds'],
                            minibatch=setup_dict['minibatch'],
                            epochs=setup_dict['epochs'],
                            stop_on_nan=True
                            #verbose=True,
                            #val_frac=0.1,
                            #silent_fail=True,
                            )

        print('fitting time : ', timeit.time.time() - t)

        save_results(logs=logs_B, tds=[], posteriors=posteriors_B, 
                     setup_dict=setup_dict, exp_id=exp_id, path=save_path)

        fn = os.path.join(save_path, exp_id, 'obs_stats_noise')
        np.save(fn, obs_stats_noise)    
        
    except:
        
        print('\n')
        print('\n')
        print('\n')
        print('SEED FAILED !')
        print('\n')
        print('\n')
        print('\n')
        

# fit SNPE-A

In [None]:
seeds = np.arange(52,62)

for seed in seeds:

    
    print('\n')
    print('\n')
    print('seed #' + str(seed))
    print('\n')
    
    setup_dict = load_setup()    
    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)    
    
    # fit SNPE
            
    exp_id = 'seed'+str(seed) + '_SNPEA'
    
    g, obs_stats_noise = init_noise_g(seed, obs_stats)

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    obs_stats_noise_snpe = obs_stats_noise.copy()
    
    # SNPE-A settings
    setup_dict['n_components'] = 8
    setup_dict['svi'] = True
    setup_dict['epochs'] = 1000
    
    
    res_A = infer.CDELFI(g,
                        obs=obs_stats_noise.reshape(1,-1),
                        n_hiddens=setup_dict['n_hiddens'],
                        seed=seed,
                        reg_lambda=setup_dict['reg_lambda'],
                        pilot_samples=setup_dict['pilot_samples'],
                        svi=setup_dict['svi'],
                        n_components=setup_dict['n_components'],
                        prior_norm=setup_dict['prior_norm'])


    
    # train
    t = timeit.time.time()

    print('fitting model with SNPC-A')
    logs_A, tds_A, posteriors_A = res_A.run(
                        n_train=setup_dict['n_train'],
                        n_rounds=setup_dict['n_rounds'],
                        minibatch=setup_dict['minibatch'],
                        epochs=setup_dict['epochs']
                        #verbose=True,
                        #val_frac=0.1,
                        #silent_fail=True,
                        )

    print('fitting time : ', timeit.time.time() - t)

    save_results(logs=logs_A, tds=[], posteriors=posteriors_A, 
                 setup_dict=setup_dict, exp_id=exp_id, path=save_path)

    fn = os.path.join(save_path, exp_id, 'obs_stats_noise')
    np.save(fn, obs_stats_noise)    
        