In [None]:
%matplotlib inline

import sys
sys.path.append('/home/mackelab/Desktop/Projects/Biophysicality/code/snl_py3/') # SNL does not have a setup.py

In [None]:
cd '/home/mackelab/Desktop/Projects/Biophysicality/code/snl_py3'

In [None]:
import numpy as np
from matplotlib import pyplot as plt

import simulators.lotka_volterra as sim

import delfi.generator as dg
import delfi.distribution as dd
from delfi.utils.viz import plot_pdf
import delfi.inference as infer

seed = 42

# SNPE parameters
    
# training schedule
n_train=3000
n_rounds=5

# fitting setup
minibatch=100
epochs=50

# network setup
n_hiddens=[50]
reg_lambda=0.01

# convenience
pilot_samples=0
svi=False
verbose=True
prior_norm=False


In [None]:
from delfi.simulator import BaseSimulator

# prior 
#prior = sim.Prior()
prior = dd.Uniform(lower= [], upper = [])

# model 
model_snl = sim.Model()
class Lotka_volterra(BaseSimulator):
    """Lotka Volterra simulator

    Parameters
    ----------
    dim : int
        Number of dimensions of parameters
    seed : int or None
        If set, randomness is seeded
    """        

    def gen_single(self, params):
        """ params = (predator births, predator deaths, 
                      prey births, predator-prey interactions)
        
        """
        return model_snl.sim(params)
        
model = Lotka_volterra(dim_param=4)

# summary statistics
summary = sim.Stats()

# rejection of bad simulations - Lotka-Volterra not guaranteed to behave well
# (rejection based on inspection of simulator code)
def rej(x): 
    if x is None:
        return False
    elif np.any([u is None for u in x]):
        return False
    else:
        return True

# generator
g = dg.RejKernel(prior=prior, model=model, summary=summary, rej=rej, seed=seed+41)

In [None]:
pars_true = np.log([0.01, 0.5, 1.0, 0.01])  # taken from SNL paper
obs = g.model.gen_single(pars_true)  # should also recover
obs_stats = g.summary.calc([obs])    # xo from SNL paper !

In [None]:
inf = infer.SNPEC(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=False,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=1, n_hiddens=n_hiddens, verbose=True)

log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   moo='p_tilda')

In [None]:
for r in range(n_rounds):
    print(trn_data[r][0].shape)

In [None]:
fig,_=plot_pdf(posteriors[0], 
               lims=[-2,2],
               gt=pars_true, 
               resolution=100,
               figsize=(16,16));