In [108]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import stats as spstats
import pickle
from simulate_diffusion import simulate_diffusion2c_p

In [166]:
from delfi.simulator.BaseSimulator import BaseSimulator
import delfi.distribution as dd
from delfi.summarystats.BaseSummaryStats import BaseSummaryStats
import delfi.generator as dg
import delfi.inference as infer

# Define prior

In [162]:
seed_p = 2
prior_min = np.array([0.0, -6.0, 0.3, 0.6, 0.3, 1.0])
prior_max = np.array([6.0, 0.0, 0.7, 3.0, 0.7, 2.0])
prior = dd.Uniform(lower=prior_min, upper=prior_max,seed=seed_p)

# Define simulator

In [163]:
class LevyFlight(BaseSimulator):
    def __init__(self, sim_fun=None, dim_param=6, n_points=500):
        """Levy Flight Simulator"""

        super().__init__(dim_param=dim_param)
        self.sim_fun = sim_fun
        self.n_points = n_points
        

    def gen_single(self, params):
        """Forward model for simulator for single parameter set

        Parameters
        ----------
        params : list or np.array, 1d of length dim_param
            Parameter vector

        Returns
        -------
        dict : dictionary with data
            The dictionary must contain a key data that contains the results of
            the forward run. Additional entries can be present.
        """
        
        params = np.asarray(params)
        assert params.ndim == 1, 'params.ndim must be 1'
        rt_data = self.sim_fun(params, self.n_points)

        return {'data': rt_data}

# Summary stats - MMD

In [174]:
class LevyFlightStats(BaseSummaryStats):
    """Moment based SummaryStats class for the Hodgkin-Huxley model

    Calculates summary statistics
    """
    def __init__(self):
        """See SummaryStats.py for docstring"""
        
        super(LevyFlightStats, self).__init__()
        self.n_summary = 28

    def calc(self, repetition_list):
        """Calculate summary statistics

        Parameters
        ----------
        repetition_list : list of dictionaries, one per repetition
            data list, returned by `gen` method of Simulator instance

        Returns
        -------
        np.array, 2d with n_reps x n_summary
        """
        
        stats = np.zeros((len(repetition_list), self.n_summary))
        for i, r in enumerate(repetition_list):
            
            
            # Compute means
            rt00 = r['data'][:, 0][r['data'][:, 0] > 0]
            rt01 = -r['data'][:, 0][r['data'][:, 0] < 0]
            rt10 = r['data'][:, 1][r['data'][:, 1] > 0]
            rt11 = -r['data'][:, 1][r['data'][:, 1] < 0]
            
            # Compute all four means and indicate empty slice with -1
            means = np.array([np.mean(rt) if rt.size != 0 else -1. for rt in [rt00, rt01, rt10, rt11]])
            moments = []
            
            # Compute second, third, fourth, fifth and sixth moments for each condition
            for rt in [rt00, rt01, rt10, rt11]:
                if rt.size == 0:
                    moms = np.zeros(5)
                else:
                    moms = np.array([spstats.moment(rt, moment=i) for i in range(2, 7)])
                moments.append(moms)
            
            # Compute accuracies and error rates for each condition
            acc_1 = rt00.shape[0] / 500
            acc_2 = rt11.shape[0] / 500
            err_1 = rt01.shape[0] / 500
            err_2 =  rt10.shape[0] / 500
            
            # Stack all summaries
            sum_vec = np.concatenate([means, np.array(moments).flatten(), [acc_1, acc_2, err_1, err_2]])
            stats[i, :] = sum_vec
        return stats

# Create objects

In [187]:
m = LevyFlight(sim_fun=simulate_diffusion2c_p)
s = LevyFlightStats()
g = dg.Default(model=m, prior=prior, summary=s)

# Load data

In [176]:
# true parameters and respective labels
data = pickle.load(open('levy_true.pkl', 'rb+'))
theta_true = data['theta']
X_true = data['X']

In [177]:
obs_stats = s.calc([{'data': X_true}])

In [186]:
seed_inf = 1

pilot_samples = 5000

# training schedule
n_train = 5000
n_rounds = 1

# fitting setup
minibatch = 100
epochs = 100
val_frac = 0.05

# network setup
n_hiddens = [128, 128, 128]

# convenience
prior_norm = True

# MAF parameters
density = 'maf'
n_mades = 5 

# Train SNPEC

In [188]:
%%time
# inference object
res = infer.SNPEC(g,
                obs=obs_stats,
                n_hiddens=n_hiddens,
                seed=seed_inf,
                pilot_samples=pilot_samples,
                n_mades=n_mades,
                prior_norm=prior_norm,
                density=density)

# train
log, _, posterior_c = res.run(
                    n_train=n_train,
                    n_rounds=n_rounds,
                    minibatch=minibatch,
                    epochs=epochs,
                    silent_fail=False,
                    proposal='prior',
                    val_frac=val_frac,
                    verbose=True)

HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))


CPU times: user 1h 14min 13s, sys: 1h 12min 23s, total: 2h 26min 37s
Wall time: 20min 48s


In [189]:
posterior_samples_c = posterior_c[0].gen(2000)

In [193]:
pickle.dump(posterior_samples_c, open('snpec_posteriors.pkl', 'wb+'))

# Train SNPEA

In [201]:
seed_inf = 1

pilot_samples = 5000

# training schedule
n_train = 5000
n_rounds = 1

# fitting setup
minibatch = 100
epochs = 100
val_frac = 0.05

# network setup
n_hiddens = [128, 128, 128]

# convenience
prior_norm = True

In [202]:
%%time
# inference object
res = infer.SNPEA(g,
                obs=obs_stats,
                n_hiddens=n_hiddens,
                seed=seed_inf,
                pilot_samples=pilot_samples,
                prior_norm=prior_norm)

# train
log, _, posterior_a = res.run(
                    n_train=n_train,
                    n_rounds=n_rounds,
                    minibatch=minibatch,
                    epochs=epochs)


HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))


CPU times: user 2h 22min, sys: 2h 26min 20s, total: 4h 48min 21s
Wall time: 39min 31s


In [203]:
posterior_samples_a = posterior_a[0].gen(2000)

In [204]:
posterior_samples_a.mean(axis=0)

array([ 1.01568621, -2.14283793,  0.6138257 ,  1.45096817,  0.59045946,
        1.14927037])

In [205]:
pickle.dump(posterior_samples_a, open('snpea_posteriors.pkl', 'wb+'))

# Train SNPEB

In [209]:
seed_inf = 1

pilot_samples = 5000

# training schedule
n_train = 5000
n_rounds = 1

# fitting setup
minibatch = 100
epochs = 100
val_frac = 0.05

# network setup
n_hiddens = [128, 128, 128]

# convenience
prior_norm = True

In [210]:
%%time
# inference object
res = infer.SNPEB(g,
                obs=obs_stats,
                n_hiddens=n_hiddens,
                seed=seed_inf,
                pilot_samples=pilot_samples,
                prior_mixin=0.1,
                prior_norm=prior_norm)

# train
log, _, posterior_b = res.run(
                    n_train=n_train,
                    n_rounds=n_rounds,
                    minibatch=minibatch,
                    epochs=epochs)


HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))


CPU times: user 2h 22min 45s, sys: 2h 27min 34s, total: 4h 50min 20s
Wall time: 39min 44s


In [211]:
posterior_samples_b = posterior_b[0].gen(2000)

In [212]:
pickle.dump(posterior_samples_b, open('snpeb_posteriors.pkl', 'wb+'))