In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import calc_all_lprob_errs

seed = 42

model_id = 'gauss'
save_path = 'results/' + model_id + '_box_validationset'
exp_id = 'seed'+str(seed)


# simulation setup
setup_dict = load_setup()

pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
print('pars_true : ', pars_true)
print('obs_stats : ', obs_stats)

from delfi.simulator import NoiseDataDimensions
import delfi.distribution as dd
from delfi.summarystats.BaseSummaryStats import BaseSummaryStats

class NoiseStats(BaseSummaryStats):
    
    def __init__(self, noise_source, n_noise=None, seed=None):
        
        self.noise_source = noise_source
        self.n_noise = self.noise_source.ndim
        super().__init__(seed=seed)
        
    def calc(self, repetition_list):
                
        # get the number of samples contained
        n_reps = len(repetition_list)

        # get the size of the data inside a sample
        self.n_summary = repetition_list[0].size + self.n_noise

        # build a matrix of n_reps x n_summary
        data_matrix = np.zeros((n_reps, self.n_summary))
        noise_matrix = self.noise_source.gen(n_reps)
        for rep_idx, rep_val in enumerate(repetition_list):
            data_matrix[rep_idx, :] =  np.hstack( (rep_val, noise_matrix[rep_idx,:]) )

        return data_matrix
        
noise_dim = 10
n_noise_comps = 20

noise_means_prior = dd.Gaussian(m = np.zeros(noise_dim), S=np.eye(noise_dim), seed=seed+1)
noise_ms = [noise_means_prior.gen(1).reshape(-1) for i in range(n_noise_comps)]

cholesky_factors = [np.tril(np.random.normal(size=(noise_dim, noise_dim))) + np.diag(np.exp(np.random.normal(size=noise_dim)))
                    for i in range(n_noise_comps)]
noise_Ss = [3 * np.dot(ch, ch.T) / noise_dim for ch in cholesky_factors]
noise_ms = [15 * np.random.normal(size=noise_dim) for i in range(n_noise_comps)]
noise_dofs = [2 for i in range(n_noise_comps)]

noise_distribution = dd.MoT(a=np.ones(n_noise_comps) / n_noise_comps, ms=noise_ms, Ss=noise_Ss, dofs=noise_dofs)
noise_distribution.ndim = noise_dim

# generator
g = init_g(seed=seed)
g.summary = NoiseStats(noise_source=noise_distribution)
g.summary.n_summary = noise_dim + 8
g.model.dim_param = 5

obs_stats = np.hstack((obs_stats, noise_distribution.gen(1).reshape(1,-1)))
assert obs_stats.size ==  g.summary.n_summary

In [None]:
from snl.util.plot import plot_hist_marginals
fig = plot_hist_marginals(noise_distribution.gen(1000))
fig.set_figwidth(16)
fig.set_figheight(16)
fig.show()

In [None]:
setup_dict['n_rounds'] = 20

In [None]:
setup_dict['pilot_samples'] = 1000

In [None]:
if setup_dict['train_on_all']:
    epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
else:
    epochs=setup_dict['epochs']

# control MAF seed
rng = np.random
rng.seed(seed)
    
res_C = infer.SNPEC(g,
                    obs=obs_stats,
                    n_hiddens=setup_dict['n_hiddens'],
                    seed=seed,
                    reg_lambda=setup_dict['reg_lambda'],
                    pilot_samples=setup_dict['pilot_samples'],
                    svi=setup_dict['svi'],
                    n_mades=setup_dict['n_mades'],
                    act_fun=setup_dict['act_fun'],
                    mode=setup_dict['mode'],
                    rng=rng,
                    batch_norm=setup_dict['batch_norm'],
                    verbose=setup_dict['verbose'],
                    #upper=setup_dict['upper'], # box-constraints 
                    #lower=setup_dict['lower'], # for MAF support
                    prior_norm=setup_dict['prior_norm'])


# train
t = timeit.time.time()

print('fitting model with SNPC-C')
logs_C, tds_C, posteriors_C = res_C.run(
                    n_train=setup_dict['n_train'],
                    proposal=setup_dict['proposal'],
                    moo=setup_dict['moo'],
                    n_null = setup_dict['n_null'],
                    n_rounds=setup_dict['n_rounds'],
                    train_on_all=setup_dict['train_on_all'],
                    minibatch=setup_dict['minibatch'],
                    epochs=epochs, 
                    silent_fail=False)

print('fitting time : ', timeit.time.time() - t)

In [None]:
g.gen(1)

In [None]:
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33

#for r in range(setup_dict['n_rounds']):
#    plt.plot(logs_C[r]['loss'])
#    plt.show()

for r in np.arange(0, len(logs_C), 1):
    
    posterior_C = posteriors_C[r]
    #posterior_C.ndim = posterior_A.ndim
    
    samples = draw_sample_uniform_prior_33(posterior_C, 5000)
    
    fig,_ = plot_pdf(dd.Gaussian(m=0.00000123*np.ones(pars_true.size), S=1e-30*np.eye(pars_true.size)), 
                   samples=samples.T,
                   gt=pars_true, 
                   lims=[[-3,3],[-3,3],[-3,3],[-3,3],[-3,3]],
                   #lims=[0,10],
                   resolution=100,
                   ticks=True,
                   figsize=(16,16));
    
    fig.suptitle('SNPE-C posterior estimates, round r = '+str(r+1), fontsize=14)
    print('negative log-probability of ground-truth pars \n', -posterior_C.eval(pars_true, log=True))

In [None]:
for r in range(len(res_C.network.parms)):
    print(res_C.network.parms[r].get_value().shape)
    if res_C.network.parms[r].get_value().shape[0] == 18:
        print(res_C.network.parms[r].get_value())

# SNL fit

In [None]:
import snl.simulators.gaussian as sim_gauss

model = sim_gauss.Model()
prior = sim_gauss.Prior()
stats = NoiseStats(noise_source=noise_distribution)

sim_model = lambda ps, rng: stats.calc(model.sim(ps, rng=rng)) 
# to run without summary stats: sim_model = model.sim !

In [None]:
setup_dict['n_rounds'] = 20

In [None]:
from snl.inference.nde import SequentialNeuralLikelihood
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
import sys

infer = SequentialNeuralLikelihood(prior=prior, sim_model=sim_model)

# control MAF seed
rng = np.random
rng.seed(seed)

maf = ConditionalMaskedAutoregressiveFlow(n_inputs=prior.n_dims, 
                                          n_outputs=obs_stats.size, 
                                          n_hiddens=setup_dict['n_hiddens'], 
                                          act_fun=setup_dict['act_fun'], 
                                          n_mades=setup_dict['n_mades'], 
                                          
                                          batch_norm=False,           # these differ for 
                                          output_order='sequential', # the usage of our 
                                          mode='sequential',         # MAFs...
                                          
                                          input=None, 
                                          output=None, rng=rng)


# control sampler seed
rng = np.random
rng.seed(seed+1)

t = timeit.time.time()

learned_model = infer.learn_likelihood(obs_xs=obs_stats.flatten(), 
                       model=maf, 
                       n_samples=setup_dict['n_train'], 
                       n_rounds=setup_dict['n_rounds'],
                       train_on_all=setup_dict['train_on_all'],
                       thin=10, 
                       save_models=False, 
                       logger=sys.stdout, 
                       rng=rng)

print(timeit.time.time() -  t)


In [None]:
from snl.util.plot import plot_hist_marginals
import snl.inference.mcmc as mcmc

thin = 10

log_posterior = lambda t: learned_model.eval([t, obs_stats.flatten()]) + prior.eval(t)
sampler = mcmc.SliceSampler(prior.gen(), log_posterior, thin=thin)

sampler.gen(max(200 // thin, 1), rng=rng)  # burn in

samples = sampler.gen(1000)

fig = plot_hist_marginals(data=samples, lims=[-3,3])
fig.set_figwidth(16)
fig.set_figheight(16)
fig.show()

In [None]:
for r in range(20):
    fig=plot_hist_marginals(infer.all_ps[r], lims=[-3,3])
    fig.set_figwidth(16)
    fig.set_figheight(16)
    plt.show()