In [None]:
%matplotlib inline

import numpy as np
import scipy.stats as stats

from scipy.special import gammaln
from matplotlib import pyplot as plt

import delfi.generator as dg
import delfi.distribution as dd
from delfi.utils.viz import plot_pdf
import delfi.inference as infer
from delfi.summarystats.Identity import Identity

from lfimodels.rockpaperscissors.rps_sde import rps_sde

seed = 42

L = 100  # image rows/columns
duration = 100.0
dt = 1.0

In [None]:
# simulation setup
setup_opts = {
    'n_components': 1,
    'n_bypass': 0,
    'filter_sizes': [3,3,3,3,2,2],
    'n_filters': (16,16,16,32,32,32),
    'pool_sizes': [1,3,2,2,2,1],
    'n_hiddens': [50, 50],
    'reg_lambda': 0.01,
    'pilot_samples': 1000,
    'verbose': True,
    'prior_norm': False,
    'init_norm': False,
    'svi': False,
    'seed': seed + 5,
    'input_shape': (3,L,L),
    'verbose': True,
}
                 

run_opts = {
    'n_train': 1000,
    'n_rounds': 2,
    'minibatch': 100,
    'epochs': 2000,
    'moo': 'resample',
    'proposal': 'gaussian',
    'n_null': None,
    'train_on_all': True,
    'max_norm': 0.1,
    'val_frac': 0.1,
    'silent_fail': False,
    'reuse_prior_samples': False,
}

In [None]:
# define a function for showing simulation results as images
def showsim(s, **kwargs):
    if 'interpolation' not in kwargs.keys():
        kwargs['interpolation'] = 'None'
    return plt.imshow(np.moveaxis(s.reshape(3, L, L), 0, -1), **kwargs)

In [None]:
p = dd.Uniform(lower=np.array([-1, -1, -6], dtype=float), 
               upper=np.array([1, 1, -5], dtype=float), seed=seed)
m = rps_sde(dt=dt, duration=duration, L=L, seed=seed+1)
g = dg.Default(model=m, prior=p, summary=Identity(), seed=seed+2)

In [None]:
pars_true = np.array([-0.5, 0.5, -5.25])
obs = g.model.gen_single(pars_true)
obs_stats = g.summary.calc([obs])

showsim(obs_stats)

In [None]:
inf = infer.SNPEC(generator=g, obs=obs_stats, **setup_opts)

In [None]:
# print the network structure. "None" indicates the batch dimension
for s in inf.network.layer:
    if s.startswith('mixture'):
        continue
    print('{0}: {1}'.format(s, inf.network.layer[s].output_shape))

In [None]:
log, trn_data, posteriors = inf.run(**run_opts)

In [None]:
for r in range(run_opts['n_rounds']):
    plt.figure(figsize=(16,3))
    plt.subplot(1,2,1)
    plt.plot(log[r]['loss'])
    plt.subplot(1,2,2)
    plt.semilogx(log[r]['loss'])
    plt.show()

In [None]:
labels_params = ['$\log_{10}(\mu)$', '$\log_{10}(\sigma)$', '$\log_{10}(D)$'] 
# all pairwise marginals of fitted posterior
fig_posterior, _ = plot_pdf(posteriors[-1], lims=[[-1, 1], [-1, 1], [-7, -5]], gt=pars_true.reshape(-1), figsize=(8, 8), resolution=100,
                  labels_params=labels_params, ticks=True)

In [None]:
n_samples = 5

g.proposal = None
x_prior = g.gen(n_samples)[1]
x_posterior = np.stack([y[0]['data'] for y in g.model.gen(posteriors[-1].gen(n_samples))], axis=0)

In [None]:
fig_samples = plt.figure(figsize=(8, 8))
plt.subplot(3, n_samples, 1)
showsim(obs_stats)
plt.ylabel('Observed data')
for i in range(n_samples):
    plt.subplot(3, n_samples, n_samples + i + 1)
    showsim(x_posterior[i])
    if i == 0:
        plt.ylabel('Posterior samples')
    plt.subplot(3, n_samples, 2 * n_samples + i + 1)
    showsim(x_prior[i])
    if i == 0:
        plt.ylabel('Prior samples')
    

In [None]:
fig_posterior.savefig('rps_posterior.pdf')
fig_samples.savefig('rps_samples.pdf')
!pdftk rps_posterior.pdf rps_samples.pdf cat output rps.pdf
!rm rps_posterior.pdf
!rm rps_samples.pdf