# two Moons model (figs 1, 6)
- fits various algorithms (SNPE-A,SNPE-B,APT-MDN, APT-MAF, SNL) to two moons example
- originally used only for the multi-seed runs of figure 6 (fig 1 was run from a 'playground' version of this, i.e. with only one fixed seed)
- this notebook was written in haste (day before submission) and never fully uploaded. Local copies on different machines with cells for different algorithms
- for evaluation (i.e. computation of MMDs), see twoMoons_allAlgs_eval.ipynb
- for plotting, see ICML_figure_1.ipynb and ICML_figure_supp.ipynb


- toDo: recover cells for algorithms SNPE-A, APT-MDN, APT-MAF from local notebook copies on different machines

In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import os
import sys
import pickle

import snl.inference.nde as nde
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
from delfi.utils.delfi2snl import SNLprior, SNLmodel

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

seed=42
g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True
        
setup_dict = {}
setup_dict['seed'] = seed
setup_dict['obs_stats'] = obs_stats

# training schedule
setup_dict['n_rounds'] = 10
setup_dict['n_train'] = 1000

# fitting setup
setup_dict['minibatch'] = 100
setup_dict['reg_lambda'] = 0.001
setup_dict['pilot_samples'] = 0
setup_dict['prior_norm'] = False
setup_dict['init_norm'] = False

exp_id = 'seed' + str(setup_dict['seed'])
save_path = 'results/two_moons_runs/'    
    


# SNPE-B fits

In [None]:
seeds = np.arange(42,52)
    
for seed in seeds:

    
    print('\n')
    print('\n')
    print('seed #' + str(seed))
    print('\n')
    
    # fit SNPE
            
    exp_id = 'seed'+str(seed)
    
    setup_dict_ = setup_dict.copy()
    setup_dict_['n_components'] = 20
    setup_dict_['n_hiddens'] = [50,50]
    setup_dict_['svi'] = True
    setup_dict_['epochs'] = 500
    setup_dict_['round_cl'] = 1

    # generator
    g = init_g(seed=seed)

    # inference object
    res_B = infer.SNPE(g, 
                     obs=obs_stats, 
                     n_hiddens=setup_dict_['n_hiddens'], 
                     n_components=setup_dict_['n_components'],
                     seed=seed, 
                     reg_lambda=setup_dict_['reg_lambda'],
                     pilot_samples=setup_dict_['pilot_samples'],
                     svi=setup_dict_['svi'],
                     verbose=verbose,
                     init_norm=setup_dict_['init_norm'],
                     prior_norm=setup_dict_['prior_norm'])

    # train
    t = timeit.time.time()

    logs_B, tds_B, posteriors_B = res_B.run(
                        n_train=setup_dict_['n_train'], 
                        n_rounds=setup_dict_['n_rounds'], 
                        minibatch=setup_dict_['minibatch'], 
                        round_cl=setup_dict_['round_cl'], 
                        epochs=setup_dict_['epochs'])

    print(timeit.time.time() -  t)

    try:
        save_results(logs=logs_B, tds=tds_B, posteriors=posteriors_B, 
                     setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MDN_SNPEB')
    except:
        save_results(logs=logs_B, tds=[], posteriors=posteriors_B, 
                     setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MDN_SNPEB')
        print('could not save datasets')

# SNL fits

In [None]:
seeds = np.arange(42,52)
    
for seed in seeds:

    
    print('\n')
    print('\n')
    print('seed #' + str(seed))
    print('\n')
    
    # fit SNPE
            
    exp_id = 'seed'+str(seed)

    # MAF parameters
    setup_dict_  = setup_dict.copy()
    setup_dict_['mode'] = 'random'
    setup_dict_['n_hiddens'] = [50,50]
    setup_dict_['n_mades'] = 5
    setup_dict_['act_fun'] = 'tanh'
    setup_dict_['batch_norm'] = False # batch-normalization currently not supported
    setup_dict_['train_on_all'] = True
    setup_dict_['thin'] = 10
    # control MAF seed
    rng = np.random
    rng.seed(seed)

    # explicit call to MAF constructor
    theta, x = g.gen(1)
    n_inputs, n_outputs  = x.size, theta.size
    model = ConditionalMaskedAutoregressiveFlow(
                    n_inputs=n_inputs,
                    n_outputs=n_outputs,
                    n_hiddens=setup_dict_['n_hiddens'],
                    act_fun=setup_dict_['act_fun'],
                    n_mades=setup_dict_['n_mades'],
                    mode=setup_dict_['mode'],
                    rng=rng)

    # generator
    g = init_g(seed=seed)

    # inference object
    inf = nde.SequentialNeuralLikelihood(SNLprior(g.prior),               # method to draw parameters  
                                         SNLmodel(g.model, g.summary).gen # method to draw summary stats
                                        )

    # train
    t = timeit.time.time()

    rng = np.random # control  
    rng.seed(seed)  # MCMC seed
    model = inf.learn_likelihood(obs_stats.flatten(), model, n_samples=setup_dict_['n_train'], 
                                 n_rounds=setup_dict_['n_rounds'], 
                                 train_on_all=setup_dict_['train_on_all'], thin=setup_dict_['thin'], save_models=True, 
                                 logger=sys.stdout, rng=rng)

    print(timeit.time.time() -  t)

    tds = (inf.all_ps, inf.all_xs)

    #save_results(logs=[], tds=tds, posteriors=[model], 
    #             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MAF_SNL')

    print(timeit.time.time() -  t)    
    
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)

    file = os.path.join(save_path, exp_id, 'SNL_MAF')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(model, f)
        
    file = os.path.join(save_path, exp_id, 'SNL_posteriors')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(inf.all_models, f)               

    file = os.path.join(save_path, exp_id, 'SNL_posteriors')
    with open(file + '.pkl', 'wb') as f:
        pickle.dump(inf.all_models, f)        
        
    vars = {
        'ps' : inf.all_ps,
        'xs' : inf.all_xs
    }

    for varname in vars.keys():
        fn = os.path.join(save_path, exp_id, varname)
        np.save(fn, vars[varname])    
