In [1]:
%load_ext autoreload
%autoreload 2

"""SNPE: RNN stable amplification. """

from neural_circuits.LRRNN import get_W_eigs_np
import numpy as np
import os
import pickle
#import matplotlib.pyplot as plt

import delfi
from delfi.simulator.BaseSimulator import BaseSimulator
import delfi.distribution as dd
from delfi.summarystats.BaseSummaryStats import BaseSummaryStats
from scipy import stats as spstats
import delfi.generator as dg
import delfi.inference as infer

DTYPE = np.float32

N = 2
n_train = 1000
n_mades = 1
n_atoms = 25
g = 0.
K = 1
rs = 1



In [2]:
print('Running SNPE on RNN conditioned on stable amplification with:')
print('N = %d, n_train = %d, n_mades = %d, n_atoms = %d, seed=%d' \
      % (N, n_train, n_mades, n_atoms, rs))

base_path = os.path.join("data", "snpe")
save_dir = "SNPE_RNN_stab_amp_N=%d_ntrain=%dk_nmades=%d_natoms=%d_rs=%d" \
        % (N, n_train//1000, n_mades, n_atoms, rs)

save_path = os.path.join(base_path, save_dir)
if not os.path.exists(save_path):
    os.makedirs(save_path)

if os.path.exists(os.path.join(base_path, save_dir, "optim.pkl")):
    print("SNPE optimization already run. Exitting.")
    exit()

_W_eigs = get_W_eigs_np(g, K)

def W_eigs(params, seed=None):
    """Calculates Jeigs.

        Parameters
        ----------
        params : np.array, 1d of length dim_param
            Parameter vector
        seed : int
        """

    if seed is not None:
        rng = np.random.RandomState(seed=seed)
    else:
        rng = np.random.RandomState()

    params = params

    U = np.reshape(params[0,:(2*N)], (N,2))
    V = np.reshape(params[0,(2*N):], (N,2))

    x = _W_eigs(U, V)
    return x

class RNN(BaseSimulator):
    def __init__(self, N, seed=None):
        """Hodgkin-Huxley simulator

        Parameters
        ----------
        N : int or None
            Number of neurons.
        seed : int or None
            If set, randomness across runs is disabled
        """
        self.N = N
        self.r = 2
        dim_param = self.N*self.r*2

        super().__init__(dim_param=dim_param, seed=seed)
        self.Jeigs = W_eigs

    def gen_single(self, params):
        """Forward model for simulator for single parameter set

        Parameters
        ----------
        params : list or np.array, 1d of length dim_param
            Parameter vector

        Returns
        -------
        dict : dictionary with data
            The dictionary must contain a key data that contains the results of
            the forward run. Additional entries can be present.
        """
        params = np.asarray(params)

        assert params.ndim == 1, 'params.ndim must be 1'

        Jeig_seed = self.gen_newseed()

        states = self.Jeigs(params.reshape(1, -1), seed=Jeig_seed)

        return {'data': states}

seed_p = 1
prior_min = -np.ones((4*N,))
prior_max = np.ones((4*N,))
prior = dd.Uniform(lower=prior_min, upper=prior_max,seed=seed_p)

class RNNStats(BaseSummaryStats):
    """Moment based SummaryStats class for the Hodgkin-Huxley model

    Calculates summary statistics
    """
    def __init__(self, seed=None):
        """See SummaryStats.py for docstring"""
        super(RNNStats, self).__init__(seed=seed)

    def calc(self, repetition_list):
        """Calculate summary statistics

        Parameters
        ----------
        repetition_list : list of dictionaries, one per repetition
            data list, returned by `gen` method of Simulator instance

        Returns
        -------
        np.array, 2d with n_reps x n_summary
        """
        stats = []
        if len(repetition_list) > 1:
            print(repetition_list)
            raise NotImplementedError()
        for r in range(len(repetition_list)):
            x = repetition_list[r]

            stats.append(x['data'])
        return np.asarray(stats)

seed = 0
# define model, prior, summary statistics and generator classes
m = RNN(N=N)
s = RNNStats()
g = dg.Default(model=m, prior=prior, summary=s)

n_processes = 4

seeds_m = np.arange(1,n_processes+1,1)
m = []
for i in range(n_processes):
    m.append(RNN(N=N, seed=seeds_m[i]))
g = dg.MPGenerator(models=m, prior=prior, summary=s)

# true parameters and respective labels
true_params = np.random.uniform(-1., 1., (4*N,))
labels_params = [r'$T_1$', r'$T_2$']

# observed data: simulation given true parameters
obs = m[0].gen_single(true_params)

obs_stats = np.array([0.5, 1.5])

pilot_samples = n_train

# training schedule
n_rounds = 2

# fitting setup
minibatch = 100
epochs = 100
val_frac = 0.05

# network setup
n_hiddens = [50,50]

# convenience
prior_norm = False

# MAF parameters
density = 'maf'

# inference object
res = infer.SNPEC(g,
                obs=obs_stats,
                n_hiddens=n_hiddens,
                pilot_samples=pilot_samples,
                n_mades=n_mades,
                prior_norm=prior_norm,
                density=density,
                seed=rs)

# train
logs, trn_datasets, posteriors = res.run(
                    n_train=n_train,
                    n_rounds=n_rounds,
                    n_atoms=n_atoms,
                    minibatch=minibatch,
                    epochs=epochs,
                    silent_fail=False,
                    proposal='prior',
                    val_frac=val_frac,
                    verbose=True,)

optim = {'logs':logs,
         'trn_datasets':trn_datasets,
         'times':times}
nets = {'posteriors':posteriors}

base_path = os.path.join("data", "snpe")
save_dir = "SNPE_RNN_stab_amp_N=%d_ntrain=%dk_nmades=%d_natoms=%d_rs=%d" \
        % (N, n_train//1000, n_mades, n_atoms, rs)

save_path = os.path.join(base_path, save_dir)
if not os.path.exists(save_path):
    os.makedirs(save_path)

print('Saving', save_path, '...')
with open(os.path.join(base_path, save_dir, "optim.pkl"), "wb") as f:
    pickle.dump(optim, f)
with open(os.path.join(base_path, save_dir, "networks.pkl"), "wb") as f:
    pickle.dump(nets, f)
print('done.')

Running SNPE on RNN conditioned on stable amplification with:
N = 2, n_train = 1000, n_mades = 1, n_atoms = 25, seed=1


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))




NameError: name 'times' is not defined

In [None]:
print(posteriors)