In [17]:
import numpy as np
import time
import delfi
import L96
from delfi.summarystats import Identity
import delfi.generator

import delfi.inference

import matplotlib.pyplot as plt

from L96_summary import Summary_Schneider2017, Summary_convstats

from util import plot_hist_marginals

seed = 42

# define simulation setup

In [2]:
K, J, dt = 36, 10, 0.001

# observation points k, t
obs_X_grid = np.arange(0, K, 1)  # currently observe ALL X_i !
K_obs = len(obs_X_grid)
obs_times = np.arange(1., 11., 100 * dt)   # simulate [0, 2] and use [1, 2] as summary stats (first sec 'burn-in')
obs_nsteps = len(obs_times)

In [3]:
sim = L96.L96TwoSim(K=K, J=J, dt=dt, obs_X=obs_X_grid, obs_times=obs_times, seed=seed)

# define paramter prior

In [4]:
# define prior over (F, h, b, log c)
prior = delfi.distribution.Gaussian(m=np.array([10.,0.,5.,2,]), S=np.diag([10.,1.,10.,0.1]), seed=seed+1)

# define summary features

In [5]:
summary = Summary_Schneider2017(K=K, J=J)

# define ground truth parameters

In [10]:
# parameters for testing and 'data' simulations
pars_true = np.array([10, 1, 10, np.log(10)]) # (F, h, b, log c)
pars_alt  = np.array([ 5, 1, 10, np.log(10)]) # for comparison
obs_stats = summary.calc(sim.gen([pars_true])[0])

# set up data generation

In [None]:
# generator object (prior, simulator, summary statistics)
def notnan(x): # rejecting simulations with NaN (just as insurance)
    return np.all(np.isfinite(x))    
    
g = delfi.generator.RejKernel(model = sim, prior=prior, 
                              summary = summary, 
                              rej=notnan)

# test run on a range of different parameters
params, stats = g.gen(8)
params.shape, stats.shape

# set SNPE options

In [None]:
# simulation setup
setup_opts = {
    'density': 'maf',
    'n_hiddens': [50, 50],
    'n_mades' : 5,     
    'verbose': True,
    'prior_norm': False,    
    'svi': False,
    'pilot_samples' : 1000 # will be overwritten if useCNN == True
}

run_opts = {
    'n_train': 5000,
    'n_rounds': 2,
    'minibatch': 100,
    'epochs': 2000,
    'proposal': 'atomic',
    'max_norm': 0.1,
    'val_frac': 0.1,
    'silent_fail': False,
}

# Run inference

In [None]:
res = delfi.inference.SNPEC(g, obs=obs_stats, seed=seed+2, **setup_opts)

In [None]:
logs, trn_datasets, posteriors = res.run(**run_opts, verbose=True)

# Plot training curves

In [None]:
plt.figure()
plt.plot(logs[0]['loss'])
plt.show()

plt.figure()
plt.plot(logs[1]['loss'])
plt.show()

# Show inferred posterior

In [None]:
from snl.util.plot import plot_hist_marginals

pars = np.array([10, 1, 10, np.log(10)])           #pars_true.copy()
#obs =  summary.calc(sim.gen([pars])[0]) #obs_stats.copy()
posterior = res.predict(obs_stats)
#posterior = posteriors[-1]

pu, pl = prior.mean + 3*prior.std, prior.mean - 3*prior.std

pu = [20, 1.5, 15, 3]
pl = [ 0, 0.5, -5, 1]

xs = posterior.gen(3000)
idx = np.where( (xs[:,1]<pu[1]))[0]
xs = xs[idx,:]
idx = np.where( (xs[:,3]<pu[3]))[0]
xs = xs[idx,:]


fig=plot_hist_marginals(xs, lims=[ [pl[i], pu[i]] for i in range(len(pu))], 
                        gt=pars, upper=True)


labels=['F','h','b','log c']
for i in range(len(pu)):
    plt.subplot(4,4,5*i+1)
    xgrid = np.linspace(prior.mean[i]-3*prior.std[i], prior.mean[i]+3*prior.std[i], 200)
    plt.plot(xgrid, prior.eval(xgrid,ii=i,log=False), color=[0.4,0.4,0.4], linewidth=2)
    plt.xlabel(labels[i], fontsize=20)
    
plt.subplot(4,4,1)
plt.xticks([5, 10, 15])    
plt.subplot(4,4,6)
plt.xticks([ 0.5,  1, 1.5])    
plt.subplot(4,4,11)
plt.xticks([0, 5, 10])    
plt.subplot(4,4,16)
plt.xticks([1, 2, 3])    

plt.subplot(4,4,4)
plt.yticks([5, 10, 15])    
plt.subplot(4,4,8)
plt.yticks([ 0.5,  1, 1.5])    
plt.subplot(4,4,12)
plt.yticks([0, 5, 10])    
plt.subplot(4,4,16)
plt.yticks([1, 2, 3])    
plt.axis([pl[-1], pu[-1], 0, 1.5])

fig.set_figwidth(12)
fig.set_figheight(12)

In [None]:
# posterior with limits

pars = np.array([10, 1, 10, np.log(10)])           #pars_true.copy()
#obs =  summary.calc(sim.gen([pars])[0]) #obs_stats.copy()
posterior = res.predict(obs_stats)
#posterior = posteriors[-1]

pu, pl = prior.mean + 3*prior.std, prior.mean - 3*prior.std

pu = [20, 1.5, 15, 3]
pl = [ 0, 0.5, -5, 1]

xs = posterior.gen(3000)
idx = np.where( (xs[:,1]<pu[1]))[0]
xs = xs[idx,:]
idx = np.where( (xs[:,3]<pu[3]))[0]
xs = xs[idx,:]


fig=plot_hist_marginals(xs, lims=[ [pl[i], pu[i]] for i in range(len(pu))], 
                        gt=pars, upper=True)


labels=['F','h','b','log c']
for i in range(len(pu)):
    plt.subplot(4,4,5*i+1)
    xgrid = np.linspace(prior.mean[i]-3*prior.std[i], prior.mean[i]+3*prior.std[i], 200)
    plt.plot(xgrid, prior.eval(xgrid,ii=i,log=False), color=[0.4,0.4,0.4], linewidth=2)
    plt.xlabel(labels[i], fontsize=20)
    
plt.subplot(4,4,1)
plt.xticks([5, 10, 15])    
plt.subplot(4,4,6)
plt.xticks([ 0.5,  1, 1.5])    
plt.subplot(4,4,11)
plt.xticks([0, 5, 10])    
plt.subplot(4,4,16)
plt.xticks([1, 2, 3])    

plt.subplot(4,4,4)
plt.yticks([5, 10, 15])    
plt.subplot(4,4,8)
plt.yticks([ 0.5,  1, 1.5])    
plt.subplot(4,4,12)
plt.yticks([0, 5, 10])    
plt.subplot(4,4,16)
plt.yticks([1, 2, 3])    
plt.axis([pl[-1], pu[-1], 0, 1.5])

fig.set_figwidth(12)
fig.set_figheight(12)