In [None]:
seed = 57
use_maf = False

# simulation setup
setup_opts = {        
    'n_hiddens': [50, 50],
    'reg_lambda': 0.01,
    'pilot_samples': 1000,
    'verbose': True,
    'prior_norm': False,    
    'svi': False,
    'n_rnn': 100,
    'n_inputs_rnn': 2
}

run_opts = {
    'n_train': 1000,
    'n_rounds': 15,
    'minibatch': 100,
    'epochs': 2000,
    'moo': 'resample',
    'proposal': 'gaussian',
    'n_null': None,
    'train_on_all': True,
    'max_norm': 0.1,
    'val_frac': 0.1,
}
  
if run_opts['train_on_all']:
    run_opts['epochs'] = [run_opts['epochs'] // (r+1) for r in range(run_opts['n_rounds'])]

if use_maf:
    # control MAF seed
    rng_maf = np.random
    
    setup_opts.update({
        'mode': 'random',
        'n_mades': 5,
        'rng': rng_maf,
        'act_fun': 'tanh',
        'batch_norm': False
    })

pars_true = np.array([-4.60517019, -0.69314718,  0.        , -4.60517019])  # from SNL paper
#x0 was generated by simulation using the true parameters from the SNL paper, initial state removed, dt=0.2
raw_data = np.array([ 59., 104.,  66., 117.,  81., 133.,  97., 131., 109., 135., 127.,
       132., 148., 135., 163., 125., 188., 108., 217.,  95., 239.,  75.,
       236.,  60., 235.,  46., 228.,  34., 217.,  28., 211.,  18., 193.,
        12., 175.,   8., 171.,   9., 162.,   5., 151.,   5., 136.,   6.,
       128.,   6., 111.,   8., 110.,   8.,  96.,   8.,  95.,   9.,  93.,
         9.,  81.,  10.,  80.,  11.,  78.,  17.,  74.,  17.,  72.,  19.,
        66.,  20.,  67.,  25.,  64.,  29.,  54.,  30.,  55.,  45.,  54.,
        45.,  53.,  48.,  48.,  50.,  54.,  65.,  61.,  75.,  64.,  77.,
        65.,  95.,  73., 106.,  85., 107.,  88., 114., 104., 113., 130.,
       114., 159., 110., 173.,  95., 196.,  77., 220.,  72., 219.,  49.,
       212.,  33., 209.,  24., 196.,  18., 188.,  18., 176.,  17., 166.,
        15., 158.,  14., 154.,  10., 138.,  13., 128.,  11., 119.,  11.,
       107.,   9.,  99.,   8.,  94.,  11.,  86.,  11.,  81.,  12.,  78.,
        15.,  67.,  19.,  63.,  16.,  58.,  20.,  58.,  28.,  57.,  36.,
        53.,  39.,  55.,  44.,  53.,  44.,  55.,  44.,  54.,  49.,  54.,
        66.,  54.,  77.,  53.,  93.,  64.,  98.,  76., 100.,  94., 110.,
       116., 112., 136., 102., 149., 105., 173., 103., 200.,  95., 208.,
        69., 224.,  54., 232.,  47., 217.,  40., 206.,  29., 202.,  24.,
       196.,  24., 181.,  23., 167.,  13., 158.,  14., 147.,   9., 135.,
         8., 121.,   9., 111.,   8.,  95.,  10.,  84.,  10.,  73.,  11.,
        68.,  14.,  61.,  16.,  57.,  20.,  50.,  19.,  46.,  25.,  43.,
        33.,  41.,  38.,  42.,  40.,  45.,  50.,  48.,  57.,  48.,  60.,
        47.,  63.,  49.,  68.,  48.,  66.,  42.,  71.,  43.,  84.,  46.,
        81.,  40.,  89.,  45., 121.,  55., 134.,  67., 136.,  83., 135.,
       100., 142., 129., 135., 148., 134., 163., 130., 182., 111., 195.,
       106., 206.,  93., 216.,  67., 224.,  47., 207.,  33., 202.,  24.,
       193.,  20., 182.,  14., 174.,   8., 161.,   8., 152.,   4., 140.,
         6., 131.,   5.])

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit
from copy import deepcopy

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd
import delfi.generator
from delfi.summarystats import Identity

from lfimodels.snl_exps.util import save_results, load_results, StubbornGenerator, stubborn_defaultrej, load_results_byname
from lfimodels.snl_exps.util import init_g_lv as init_g
from lfimodels.snl_exps.util import load_setup_lv as load_setup
from lfimodels.snl_exps.util import load_gt_lv as load_gt
from lfimodels.snl_exps.util import calc_all_lprob_errs
from lfimodels.snl_exps.LotkaVolterra import LotkaVolterra, LotkaVolterraStats
import snl.simulators.lotka_volterra as sim_lv

model_id = 'lv'
save_path = 'results/' + model_id

print('pars_true : ', pars_true)

dt = 0.2

In [None]:
# run the RNN with no rejection based on summary stats, as we don't have any
exp_id = 'nostats_rnn_nostatrej_seed'+str(seed)

def init_all(seed, n_timepts_used=raw_data.shape[0] // 2):
    assert n_timepts_used * 2 <= raw_data.shape[0]
    model = LotkaVolterra(dt=dt, duration=dt * n_timepts_used, seed=seed)
    prior = dd.Uniform(lower= [-5,-5,-5,-5], upper = [2,2,2,2], seed=seed)
    summary = Identity(seed=seed)
    g = StubbornGenerator(model=model, prior=prior, summary=summary, seed=seed)
    
    if use_maf:
        setup_opts['rng'].seed(seed)
    # initialize inference object
    res = infer.SNPEC(g, obs=raw_data[:n_timepts_used * 2], seed=seed, **setup_opts)
    return g, res

In [None]:
data_lengths = [5, 10, 25, 60, 100, 150]
L, TD, P = [], [], []

for i, dl in enumerate(data_lengths):
    g, res = init_all(seed=seed + i, n_timepts_used=dl)
    # train
    t = timeit.time.time()
    logs, tds, posteriors = res.run(**run_opts, verbose=False)
    L.append(logs)
    TD.append(tds)
    P.append(posteriors)
    print('fitting time for data length {0}: '.format(dl), timeit.time.time() - t)

L_rnn_norej, TD_rnn_norej, P_rnn_norej = L, TP, P

In [None]:
save_results_byname(logs=L, tds=TD, posteriors=P, 
             setup_opts=setup_opts, run_opts=run_opts, exp_id=exp_id, path=save_path)

In [None]:
# run the RNN on the raw data, but while rejecting any theta x pair for which the LV summary stats are undefined
model = LotkaVolterra(dt=dt, duration=30.0, seed=seed)
prior = dd.Uniform(lower= [-5,-5,-5,-5], upper = [2,2,2,2], seed=seed)
summary = Identity(seed=seed)
g = StubbornGenerator(model=model, prior=prior, summary=summary, seed=seed)

In [103]:
LVstatsobj = LotkaVolterraStats()
def LVstatsok(raw_data):
    LVstatvals = LVstatsobj._calc_one_datapoint(raw_data)
    if np.isnan(LVstatvals).any():
        return 'discard'
    return 'accept'
    
    
g._feedback_summary_stats = LVstatsok

g.gen(1)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

[[52.  0. 49.  0. 45.  0. 40.  0. 37.  0. 36.  0. 33.  0. 33.  0. 29.  0.
  26.  0. 24.  0. 24.  0. 22.  0. 19.  0. 19.  0. 18.  0. 17.  0. 14.  0.
  12.  0. 11.  0. 11.  0. 10.  0.  9.  0.  9.  0.  7.  0.  7.  0.  7.  0.
   7.  0.  6.  0.  6.  0.  5.  0.  4.  0.  4.  0.  2.  0.  1.  0.  1.  0.
   1.  0.  1.  0.  1.  0.  1.  0.  1.  0.  1.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.  0.  0.  0. 

ValueError: response not supported

In [None]:
# show the effect of summary stat-based rejection on the prior
model = LotkaVolterra(dt=dt, duration=dt * n_timepts_used, seed=seed)
prior = dd.Uniform(lower= [-5,-5,-5,-5], upper = [2,2,2,2], seed=seed)
summary = Identity(seed=seed)
g = StubbornGenerator(model=model, prior=prior, summary=summary, seed=seed)

In [None]:
h = []
for i, dl in enumerate(data_lengths):
    h.append(plt.plot(np.arange(len(L[i])) + 1,
                      [-pp.eval(pars_true.reshape(1, -1), log=True) for pp in P[i]],
                     label='T = {0}'.format(dl * dt))[0])
plt.legend(h)

In [None]:
posteriors_C = P[-1]
logs = L[-1]
for r in range(len(logs)):
    
    posterior_C = posteriors[r]
    #posterior_C.ndim = posterior_A.ndim
    
    g2 = deepcopy(g)
    g2.proposal = posterior_C
    samples = np.array(g2.draw_params(5000)) 
    
    fig,_ = plot_pdf(dd.Gaussian(m=0.00000123*np.ones(pars_true.size), S=1e-30*np.eye(pars_true.size)), 
                   samples=samples.T,
                   gt=pars_true, 
                   lims=[[-5,2],[-5,2],[-5,2],[-5,2]],
                   #lims=[0,10],
                   resolution=100,
                   ticks=True,
                   figsize=(16,16));
    
    fig.suptitle('SNPE-C posterior estimates, round r = '+str(r+1), fontsize=14)
        

plt.figure()
plt.plot(np.arange(len(logs)) + 1, [-posterior_C.eval(pars_true.reshape(1, -1), log=True) for posterior_C in posteriors_C])
plt.ylabel('-log p of true params')
plt.xlabel('round')