In [125]:
import joblib
import math
from math import ceil
import random
import peakutils
import pandas as pd
import contextlib
import torch
import sys
import sbi
import numpy as np
from numpy import fft, ndarray
from scipy.integrate import odeint
from scipy.stats import norm
import matplotlib.pyplot as plt
import os
import seaborn as sns
from pyro.infer.mcmc.api import MCMC
from warnings import warn
from torch import Tensor, split, randint, cat
from typing import Any, Callable, Optional, Tuple, Union, Dict
from joblib import Parallel, delayed
from tqdm import tqdm
from tqdm.auto import tqdm, trange
from pyro.infer.mcmc import HMC, NUTS
from sbi.inference import prepare_for_sbi, SNLE
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.types import Shape, TorchTransform
from sbi.utils.get_nn_models import (likelihood_nn,)
from sbi.samplers.mcmc import SliceSamplerVectorized
from sbi.samplers.mcmc.slice_numpy import MCMCSampler
from sbi.utils import tensor2numpy

In [126]:
def seed_all_backends(seed: Optional[Union[int, Tensor]] = None) -> None:
    if seed is None:
        seed = int(torch.randint(10_000_000, size=(1,)))
    else:
        # Cast Tensor to int (required by math.random since Python 3.11)
        seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # type: ignore
    torch.backends.cudnn.benchmark = False # type: ignore

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    def tqdm_print_progress(self):
        if self.n_completed_tasks > tqdm_object.n:
            n_completed = self.n_completed_tasks - tqdm_object.n
            tqdm_object.update(n=n_completed)
    
    original_print_progress = joblib.parallel.Parallel.print_progress
    joblib.parallel.Parallel.print_progress = tqdm_print_progress
    
    try:
        yield tqdm_object
    finally:
        joblib.parallel.Parallel.print_progress = original_print_progress
        tqdm_object.close()

class SliceSampler(MCMCSampler):
    def __init__(self, x, lp_f, max_width=float("inf"), init_width: Union[float, np.ndarray] = 0.05, thin=None, tuning: int = 50, verbose: bool = False,):
        MCMCSampler.__init__(self, x, lp_f, thin, verbose=verbose)
        self.max_width = max_width
        self.init_width = init_width
        self.width = None
        self.tuning = tuning
        
    def _tune_bracket_width(self, rng):
        order = list(range(self.n_dims))
        x = self.x.copy()

        self.width = np.full(self.n_dims, self.init_width)

        tbar = trange(self.tuning, miniters=2, disable=not self.verbose)
        tbar.set_description("Tuning bracket width...")
        for n in tbar:
            # for n in range(int(self.tuning)):
            rng.shuffle(order)
            for i in range(self.n_dims):
                x[i], wi = self._sample_from_conditional(i, x[i], rng)
                self.width[i] += (wi - self.width[i]) / (n + 1)

    def _sample_from_conditional(self, i: int, cxi, rng):
        assert self.width is not None, "Chain not initialized."

        # conditional log prob
        Li = lambda t: self.lp_f(np.concatenate([self.x[:i], [t], self.x[i + 1 :]]))
        wi = self.width[i]

        # sample a slice uniformly
        logu = Li(cxi) + np.log(1.0 - rng.rand())

        # position the bracket randomly around the current sample
        lx = cxi - wi * rng.rand()
        ux = lx + wi
        
        # find lower bracket end
        while Li(lx) >= logu and cxi - lx < self.max_width:
            lx -= wi

        # find upper bracket end
        while Li(ux) >= logu and ux - cxi < self.max_width:
            ux += wi

        # sample uniformly from bracket
        xi = (ux - lx) * rng.rand() + lx

        # if outside slice, reject sample and shrink bracket
        while Li(xi) < logu:
            if xi < cxi:
                lx = xi
            else:
                ux = xi
            xi = (ux - lx) * rng.rand() + lx
       
        return xi, ux - lx
      
def run_fun(SliceSamplerSerial, num_samples, inits, seed, log_prob_fn: Callable, thin: Optional[int] = None, tuning: int = 50, verbose: bool = True, init_width: Union[float, np.ndarray] = 0.01,
            max_width: float = float("inf"), num_workers: int = 1, rng=np.random, show_info: bool = False, logger=sys.stdout) -> np.ndarray:
    np.random.seed(seed)
    posterior_sampler = SliceSampler(inits, lp_f=log_prob_fn, max_width=max_width, init_width=init_width, thin=thin, tuning=tuning, verbose=num_workers == 1 and verbose,)
    
    assert num_samples >= 0, "number of samples can't be negative"

    order = list(range(posterior_sampler.n_dims))
    L_trace = []
    samples = np.empty([int(num_samples), int(posterior_sampler.n_dims)])
    logger = open(os.devnull, "w") if logger is None else logger

    if posterior_sampler.width is None:
        # logger.write('tuning bracket width...\n')
        posterior_sampler._tune_bracket_width(rng)

    tbar = trange(int(num_samples), miniters=10, disable=not posterior_sampler.verbose)
    tbar.set_description("Generating samples")
    for n in tbar:
        # for n in range(int(n_samples)):
        for _ in range(posterior_sampler.thin):
            rng.shuffle(order)

            for i in order:
                posterior_sampler.x[i], _ = posterior_sampler._sample_from_conditional(i, posterior_sampler.x[i], rng)

        samples[n] = posterior_sampler.x.copy()

        posterior_sampler.L = posterior_sampler.lp_f(posterior_sampler.x)
        # logger.write('sample = {0}, log prob = {1:.2}\n'.format(n+1, self.L))

        if show_info:
            L_trace.append(posterior_sampler.L)

    # show trace plot
    if show_info:
        fig, ax = plt.subplots(1, 1)
        ax.plot(L_trace)
        ax.set_ylabel("log probability")
        ax.set_xlabel("samples")
        plt.show(block=False)

    return samples

def run(SliceSamplerSerial, log_prob_fn: Callable, num_samples: int, init_params: np.ndarray, num_chains: int = 1, thin: Optional[int] = None, verbose: bool = True, num_workers: int = 1,) -> np.ndarray:
    num_chains , dim_samples = init_params.shape
    # Generate seeds for workers from current random state.
    seeds = torch.randint(high=1_000_000, size=(num_chains,))
    for seed in seeds:
        seed_all_backends(seed)
    with tqdm_joblib(tqdm(range(num_chains), disable=not verbose, desc=f"""Running {num_chains} MCMC chains with {num_workers} worker{"s" if num_workers>1 else ""}.""", total=num_chains,)):
        all_samples = Parallel(n_jobs=num_workers)(delayed(run_fun)(SliceSamplerSerial, num_samples, initial_params_batch, seed, log_prob_fn)for initial_params_batch, seed in zip(init_params, seeds))
    samples = np.stack(all_samples).astype(np.float32)
    samples = samples.reshape(num_chains, -1, dim_samples)  # chains, samples, dim
    samples = samples[:, :: thin, :]  # thin chains

    # save samples
    return samples

class SliceSamplerSerial:
    def __init__(self, log_prob_fn: Callable, init_params: np.ndarray, num_chains: int = 1, thin: Optional[int] = None, tuning: int = 50, verbose: bool = True, init_width: Union[float, np.ndarray] = 0.01, max_width: float = float("inf"), num_workers: int = 1,):
        self._log_prob_fn = log_prob_fn
        self.x = init_params
        self.num_chains = num_chains
        self.thin = thin
        self.tuning = tuning
        self.verbose = verbose
        self.init_width = init_width
        self.max_width = max_width
        self.n_dims = self.x.size
        self.num_workers = num_workers
        self._samples = None

    def get_samples(self, num_samples: Optional[int] = None, group_by_chain: bool = True) -> np.ndarray:
        if self._samples is None:
            raise ValueError("No samples found from MCMC run.")
        # if not grouped by chain, flatten samples into (all_samples, dim_params)
        if not group_by_chain:
            samples = self._samples.reshape(-1, self._samples.shape[2])
        else:
            samples = self._samples

        # if not specified return all samples
        if num_samples is None:
            return samples
        # otherwise return last num_samples (for each chain when grouped).
        elif group_by_chain:
            return samples[:, -num_samples:, :]
        else:
            return samples[-num_samples:, :]

##############################################################################################################################
        
def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
    attribute = default if key not in dict_to_check.keys() else dict_to_check[key]
    return attribute

def _get_initial_params(proposal, init_strategy: str, num_chains: int, num_workers: int, show_progress_bars: bool, **kwargs,) -> Tensor: 
    # Build init function
    init_fn = proposal._build_mcmc_init_fn(proposal.proposal, proposal.potential_fn, transform=proposal.theta_transform, init_strategy=init_strategy, **kwargs,)

    # Parallelize inits for resampling only.
    if num_workers > 1 and (init_strategy == "resample" or init_strategy == "sir"):
        def seeded_init_fn(seed):
            torch.manual_seed(seed)
            return init_fn()

        seeds = torch.randint(high=10_000_000, size=(num_chains,))

        # Generate initial params parallelized over num_workers.
        with tqdm_joblib(tqdm(range(num_chains), disable=not show_progress_bars, desc=f"""Generating {num_chains} MCMC inits with {num_workers} workers.""", total=num_chains,)):
            initial_params = torch.cat(Parallel(n_jobs=num_workers)(delayed(seeded_init_fn)(seed) for seed in seeds))
    else:
        initial_params = torch.cat([init_fn() for _ in range(num_chains)])
    return initial_params
    
def _slice_np_mcmc(proposal, num_samples: int, potential_function: Callable, initial_params: Tensor, thin: int, warmup_steps: int, vectorized: bool = False, num_workers: int = 1, init_width: Union[float, ndarray] = 0.01, show_progress_bars: bool = True,) -> Tensor:
    num_chains, dim_samples = initial_params.shape
        
    if not vectorized:
        SliceSamplerMultiChain = SliceSamplerSerial
    else:
        SliceSamplerMultiChain = SliceSamplerVectorized

    posterior_sampler = SliceSamplerMultiChain(init_params=tensor2numpy(initial_params), log_prob_fn=potential_function, num_chains=num_chains, thin=thin, verbose=show_progress_bars, num_workers=num_workers, init_width=init_width,)
    warmup_ = warmup_steps * thin
    num_samples_ = ceil((num_samples * thin) / num_chains)
    # Run mcmc including warmup
    samples = run(posterior_sampler, log_prob_fn=potential_function, num_samples = (warmup_ + num_samples_), init_params = tensor2numpy(initial_params))
    samples = samples[:, warmup_steps:, :]  # discard warmup steps
    samples = torch.from_numpy(samples)  # chains x samples x dim

    # Save posterior sampler.
    proposal._posterior_sampler = posterior_sampler

    # Save sample as potential next init (if init_strategy == 'latest_sample').
    proposal._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

    # Collect samples from all chains.
    samples = samples.reshape(-1, dim_samples)[:num_samples, :]
    assert samples.shape[0] == num_samples
    return samples.type(torch.float32).to(proposal._device)

def sample_my_fun(proposal, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, method: Optional[str] = None, thin: Optional[int] = None, warmup_steps: Optional[int] = None, num_chains: Optional[int] = None, init_strategy: Optional[str] = None, init_strategy_parameters: Optional[Dict[str, Any]] = None,
                   init_strategy_num_candidates: Optional[int] = None, mcmc_parameters: Dict = {}, mcmc_method: Optional[str] = None, sample_with: Optional[str] = None, num_workers: Optional[int] = None, show_progress_bars: bool = True,) -> Tensor:
    
    proposal.potential_fn.set_x(proposal._x_else_default_x(x))

    # Replace arguments that were not passed with their default.
    method = proposal.method if method is None else method
    thin = proposal.thin if thin is None else thin
    warmup_steps = proposal.warmup_steps if warmup_steps is None else warmup_steps
    num_chains = proposal.num_chains if num_chains is None else num_chains
    init_strategy = proposal.init_strategy if init_strategy is None else init_strategy
    num_workers = proposal.num_workers if num_workers is None else num_workers
    init_strategy_parameters = (proposal.init_strategy_parameters if init_strategy_parameters is None else init_strategy_parameters)

    if init_strategy_num_candidates is not None:
        warn("""Passing `init_strategy_num_candidates` is deprecated as of sbi v0.19.0. Instead, use e.g.,`init_strategy_parameters={"num_candidate_samples": 1000}`""")
        proposal.init_strategy_parameters["num_candidate_samples"] = (init_strategy_num_candidates)
    if sample_with is not None:
        raise ValueError(f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting `sample_with` is no longer supported. You have to rerun `.build_posterior(sample_with={sample_with}).`")
    if mcmc_method is not None:
        warn("You passed `mcmc_method` to `.sample()`. As of sbi v0.18.0, this is deprecated and will be removed in a future release. Use `method` instead of `mcmc_method`.")
        method = mcmc_method
    if mcmc_parameters:
        warn("You passed `mcmc_parameters` to `.sample()`. As of sbi v0.18.0, this is deprecated and will be removed in a future release. Instead, pass the variable to `.sample()` directly, e.g. `posterior.sample((1,), num_chains=5)`.")
    # The following lines are only for backwards compatibility with sbi v0.17.2 or older.
    m_p = mcmc_parameters  # define to shorten the variable name
    method = _maybe_use_dict_entry(method, "mcmc_method", m_p)
    thin = _maybe_use_dict_entry(thin, "thin", m_p)
    warmup_steps = _maybe_use_dict_entry(warmup_steps, "warmup_steps", m_p)
    num_chains = _maybe_use_dict_entry(num_chains, "num_chains", m_p)
    init_strategy = _maybe_use_dict_entry(init_strategy, "init_strategy", m_p)
    proposal.potential_ = proposal._prepare_potential(method)  # type: ignore

    initial_params = _get_initial_params(proposal, init_strategy, num_chains, num_workers, show_progress_bars, **init_strategy_parameters,)
    num_samples = torch.Size(sample_shape).numel()

    track_gradients = method in ("hmc", "nuts")
    with torch.set_grad_enabled(track_gradients):
        if method in ("slice_np", "slice_np_vectorized"):
            transformed_samples = _slice_np_mcmc(proposal, num_samples=num_samples, potential_function=proposal.potential_, initial_params=initial_params, thin=thin, warmup_steps=warmup_steps, vectorized=(method == "slice_np_vectorized"), num_workers=num_workers, show_progress_bars=show_progress_bars,)
        elif method in ("hmc", "nuts", "slice"):
            transformed_samples = _pyro_mcmc(proposal, num_samples=num_samples, potential_function=proposal.potential_, initial_params=initial_params, mcmc_method=method, thin=thin, warmup_steps=warmup_steps, num_chains=num_chains, show_progress_bars=show_progress_bars,)
        else:
            raise NameError

    samples = proposal.theta_transform.inv(transformed_samples)

    return samples.reshape((*sample_shape, -1))  # type: ignore

#######################################################################################################################################

def simulator_seeded(simulator: Callable, theta: Tensor, seed: int) -> Tensor:
    import torch
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    with torch.random.fork_rng(devices=[]):
        torch.manual_seed(seed)
        return simulator(theta)

def simulate_in_batches(simulator: Callable, theta: Tensor, sim_batch_size: int = 1, num_workers: int = 1 , seed: Optional[int] = None, show_progress_bars: bool = True, ) -> Tensor:
    num_sims, *_ = theta.shape
    seed_all_backends(seed)
    if num_sims == 0:
        x = torch.tensor([])
    elif sim_batch_size is not None and sim_batch_size < num_sims:
        batches = split(theta, sim_batch_size, dim=0)
        
        if num_workers != 1:
            batch_seeds = randint(high=1_000_000, size=(len(batches),))
            with tqdm_joblib(tqdm(batches, disable=not show_progress_bars, total = len(batches), desc=f"Running {num_sims} simulations in {len(batches)} batches ({num_workers} cores)",)) as _:
                simulation_outputs = Parallel(n_jobs=num_workers)(delayed(simulator_seeded)(simulator, batch, batch_seed) for batch, batch_seed in zip(batches, batch_seeds))
        else:
            pbar = tqdm(total=num_sims, disable=not show_progress_bars, desc=f"Running {num_sims} simulations.", )
            with pbar:
                simulation_outputs = []
                for batch in batches:
                    simulation_outputs.append(simulator_seeded(simulator, batch, seed))
                    pbar.update(sim_batch_size)
        x = cat(simulation_outputs, dim=0)
    else:
        x = simulator(theta)
    return x

def simulate_for_sbi(round_idx: int, simulator: Callable, proposal: Any, num_simulations: int, num_workers: int = 1, simulation_batch_size: int = 1, seed: Optional[int] = None, show_progress_bar: bool = True)-> Tuple[Tensor, Tensor]:
    if round_idx == 0:
        theta = proposal.sample((num_simulations,))
    else:
        theta = sample_my_fun(proposal, (num_simulations,), num_workers = num_workers, num_chains = 4) # because only in first round proposal is boxuniform, then it is mcmcposterior object
    
    x = simulate_in_batches(simulator=simulator, theta=theta, sim_batch_size=simulation_batch_size, num_workers=num_workers, seed=seed, show_progress_bars=show_progress_bar)
    
    return theta, x

In [127]:
headers = ["k1", "k2", "k3"]          # parameters to be inferred
num_timesteps = 100

# FOR SNLE
prior_min = 0.01                        # same for all parameters
prior_max = 250                        # same for all parameters
# num_rounds = 3                        # how many rounds of SNLE
# num_simulations = 3000               # in each round

# 1 hyperparameter_set = [num_rounds, num_simulations per round, simulation_batch_size, CPUs_to_use]
hyperparameters =[[10, 1500, 20, 75], [5, 3000, 40, 75], [3, 5000, 100, 50]]

# To simulate in batches, simulation_batch_size must not be None and simulation_batch_size < num_simulations
# To parallelise, set number of CPUs to be used.
# simulation_batch_size = 8
#  # run os.cpu_count() to see number of available CPUs
# CPUs_to_use = 50

#################################################################################################################################

# # FOR MCMC
# num_iterations = 10_000 # total in all chains
# interval_to_calculate_acceptance_rate = 100
# burn_in_fraction = 0.3
# num_chains = 4

####################################################################################################################################


t = np.linspace(0, 100, num_timesteps)

param_dict = {'k1': 246.96291990024542, 'k2': 246.96291990024542, 'k3': 246.96291990024542, 'n1': 5, 'n2': 5, 'n3': 5, 'dm1': 1.143402097500176, 'dm2': 1.143402097500176, 'dm3': 1.143402097500176,
              'dp1': 0.7833664565550977, 'dp2': 0.7833664565550977, 'dp3': 0.7833664565550977, 'a1': 24.78485282457379, 'a2': 24.78485282457379, 'a3': 24.78485282457379,
              'g1': 0.024884149937163258, 'g2': 0.024884149937163258, 'g3': 0.024884149937163258, 'b1': 33.82307682700831, 'b2': 33.82307682700831, 'b3': 33.82307682700831}

all_params = 'a1', 'a2', 'a3', 'g1', 'g2', 'g3', 'dm1', 'dm2', 'dm3', 'dp1', 'dp2', 'dp3', 'b1', 'b2', 'b3', 'n1', 'n2', 'n3', 'k1', 'k2', 'k3'
new_param_dictx ={}
for param in all_params:
    if param not in headers:
        new_param_dictx[param] = param_dict[param]
    elif param in headers:
        new_param_dictx[param] = param

def my_simulator(theta):
    def model(variables, t, new_param_dict):
        m1, p1, m2, p2, m3, p3 = variables
        
        dm1dt = -new_param_dict['dm1']*m1 + (new_param_dict['a1'] / (1 + (p2/new_param_dict['k1'])**new_param_dict['n1'])) + new_param_dict['g1']
        dp1dt = (new_param_dict['b1']*m1) - (new_param_dict['dp1']*p1)
        dm2dt = -new_param_dict['dm2']*m2 + (new_param_dict['a2'] / (1 + (p3/new_param_dict['k2'])**new_param_dict['n2'])) + new_param_dict['g2']
        dp2dt = (new_param_dict['b2']*m2) - (new_param_dict['dp2']*p2)
        dm3dt = -new_param_dict['dm3']*m3 + (new_param_dict['a3'] / (1 + (p1/new_param_dict['k3'])**new_param_dict['n3'])) + new_param_dict['g3']
        dp3dt = (new_param_dict['b3']*m3) - (new_param_dict['dp3']*p3)

        return [dm1dt, dp1dt, dm2dt, dp2dt, dm3dt, dp3dt]

    def solve_ode(theta, t, new_param_dict = new_param_dictx):
        for i in range(len(headers)):
            new_param_dict[headers[i]] = theta[i]

        initial_conditions = np.array([0, 2, 0, 1, 0, 3], dtype=np.float32)
        solution = odeint(model, initial_conditions, t, args=(new_param_dict,))
        return torch.tensor(solution, dtype=torch.float32).flatten() # Flatten tensor to size [600]
    return solve_ode(theta, t)

true_params = tuple(param_dict[parameter] for parameter in headers)
true_solutions = my_simulator(true_params)

#####################################################################################################################

num_dim = len(true_params)
prior = utils.BoxUniform(low=prior_min * torch.ones(num_dim), high=prior_max * torch.ones(num_dim))
simulator, prior = prepare_for_sbi(my_simulator, prior)

my_density_estimator = likelihood_nn(model="maf", hidden_features=50, num_transforms=3)

inference = SNLE(prior = prior, density_estimator = my_density_estimator) # Initialise inference
posteriors = [] # Empty list to contain posterior after each round
proposal = prior # For the first round proposal = prior, then updated (sequentiality)

for hyperparameter_set in hyperparameters:
    num_rounds, num_simulations, simulation_batch_size, num_workers = hyperparameter_set # num_workers = CPUs_to_use
    print(f"SNLE FOR {num_dim} PARAMETER{'S' if len(headers)>1 else ''} ({', '.join(headers)}), {num_rounds} ROUND{'S' if num_rounds>1 else ''}{', EACH' if num_rounds>1 else ''} OF {num_simulations} SIMULATIONS (using {num_workers} cores)")
    for _ in range(num_rounds):
        print(f"Round {_+1}")
        theta, x = simulate_for_sbi(_, simulator, proposal, num_simulations = num_simulations, simulation_batch_size = simulation_batch_size, num_workers = num_workers)
        density_estimator = inference.append_simulations(theta, x).train()
        posterior = inference.build_posterior(density_estimator)
        posteriors.append(posterior)
        proposal = posterior.set_default_x(true_solutions)
        print("\n")

    posterior_samples = sample_my_fun(posterior, (9975,), num_chains = 1) # sample to plot the posteriors

    snle_data = pd.DataFrame(data=posterior_samples, columns=headers)
    snle_data.to_csv(f'{num_dim}p-{num_rounds}*{num_simulations}.csv')

    # Calculate quantiles of posterior samples
    posterior_quantiles = np.percentile(posterior_samples, [1, 99], axis=0)

    # Define custom limits slightly larger than the range of the central 98% of the posterior samples
    custom_limits = [(posterior_quantiles[0][i] - 0.2 * (posterior_quantiles[1][i] - posterior_quantiles[0][i]), posterior_quantiles[1][i] + 0.2 * (posterior_quantiles[1][i] - posterior_quantiles[0][i])) for i in range(num_dim)]

    # Plot pair plots with custom limits
    _ = analysis.pairplot(posterior_samples, limits=custom_limits, figsize=(8, 8), labels=headers)

    plt.savefig(f'SNLE-{num_dim}p-{num_rounds}*{num_simulations}-P_customlimits.png')

    _ = analysis.pairplot(posterior_samples, limits=[[prior_min, prior_max]]*num_dim, figsize=(8, 8), labels=headers)

    plt.savefig(f'SNLE-{num_dim}p-{num_rounds}*{num_simulations}-P_priorlimits.png')

    variables = ['m1', 'p1', 'm2', 'p2', 'm3', 'p3']

    raw_trajectories=np.zeros([len(posterior_samples), num_timesteps, len(variables)])

    def simulate_sample(batch, seed):
        np.random.seed(seed)
        result = []
        for i in range(len(batch)):
            result.append(my_simulator(batch[i]).reshape(num_timesteps, len(variables)))
        return result

    # Use joblib to parallelize the simulation
    batches = split(posterior_samples, int(len(posterior_samples)//num_workers), dim =0)

    seeds = randint(high=1_000_000, size=(len(batches),))
    with tqdm_joblib(tqdm(batches, total = len(batches), desc=f"Running {len(posterior_samples)} simulations in {len(batches)} batches ({num_workers} cores)",)) as _: 
        results = Parallel(n_jobs=num_workers)(delayed(simulate_sample)(batch, seed) for batch, seed in zip(batches, seeds))

    index = 0
    for i in range(len(results)):
        for j in range(len(results[i])):
            raw_trajectories[index] = results[i][j]
            index += 1

    tr = np.percentile(raw_trajectories, [2.5, 97.5], axis=0)

    fig, ax=plt.subplots(3, 2, figsize=(30,18))
    ax = ax.ravel()
    col=["blue","blue"]
    for i in range(6):
        for j in range(2):
            ax[i].plot(tr[j,:,i],alpha=0.4,linestyle='dotted',linewidth=1, color='black')
        ax[i].plot(true_solutions.reshape(100, 6)[:,i],linewidth=1,color='black', label = 'true')
        ax[i].fill_between(t, tr[0, :, i],tr[1, :, i], alpha=0.4, color='skyblue')
        ax[i].legend()

    plt.savefig(f'SNLE-{num_dim}p-{num_rounds}*{num_simulations}-T.png')

SNLE FOR 3 PARAMETERS (k1, k2, k3), 10 ROUNDS, EACH OF 1500 SIMULATIONS
Round 1


Running 1500 simulations in 188 batches (8 cores):   4%|▍         | 8/188 [00:53<20:06,  6.70s/it]  


KeyboardInterrupt: 

In [129]:
for i in range(1, 81):
    if 5000%i==0:
        print(i)

1
2
4
5
8
10
20
25
40
50


From ChatGPT:
1. The neural network (self._neural_net) is set to training mode (self._neural_net.train()). This ensures that layers like dropout and batch normalization behave differently during training compared to evaluation.

2. The loop iterates over the training data in batches (for batch in train_loader). Each batch typically consists of input samples (x) and their corresponding parameters (theta).

3. The gradients of the model parameters are reset to zero (self.optimizer.zero_grad()) before computing the gradients for the current batch. This prevents gradient accumulation across batches.

4. The model is then applied to the input samples to compute the losses (train_losses) using the loss function (self._loss). This loss function typically compares the model's predictions for x given theta with the ground truth. The loss function returns the -log probability i.e. log likelihood of samples
                
    i. Separates the input data x into continuous and discrete components using a helper function _separate_x(x)
    
    ii. The log probability for the each part is computed using a neural network with theta  as context (continuous part may be transformed into log-space if needed)
    
    iii. The log probabilities for the discrete and continuous parts are combined into a joint log probability (log_probs_combined)
    
    iv. If the continuous data is transformed to log-space, the log absolute determinant Jacobian of the transformation is subtracted from the joint log probability
    
    v. Returns the joint log probability of p(x∣θ) for each sample in the batch

5. The losses are averaged across the batch to obtain a single scalar value representing the overall loss for the batch (train_loss).

6. The gradients of the loss with respect to the model parameters are computed using backpropagation (train_loss.backward()).

7. Optionally, the gradients are clipped to prevent them from growing too large and causing instability during training (clip_grad_norm_()).

8. An optimizer (self.optimizer) updates the model parameters based on the computed gradients (self.optimizer.step()).

In [None]:
# def getDif(indexes, arrayData):	
#     arrLen = len(indexes)
#     sum = 0
#     for i, ind in enumerate(indexes):
#         if i == arrLen - 1:
#             break
#         sum += arrayData[ind] - arrayData[indexes[i + 1]]
        
#     #add last peak - same as substracting it from zero 
#     sum += arrayData[indexes[-1:]]  
#     return sum   
    
# def getSTD(indexes, arrayData, window):
#     numPeaks = len(indexes)
#     arrLen = len(arrayData)
#     sum = 0
#     for ind in indexes:
#         minInd = max(0, ind - window)
#         maxInd = min(arrLen, ind + window)
#         sum += np.std(arrayData[minInd:maxInd])  
        
#     sum = sum/numPeaks 	#The 1/P factor
#     return sum
    
# def getFrequencies(y):
#     res = abs(fft.rfft(y))  #Real FT
#     #normalize the amplitudes 
#     #res = res/math.ceil(1/2) #Normalise with a factor of 1/2
#     return res

# def costTwo(Y, getAmplitude = False): #Yes
#     p1 = Y[:,1]  #Get the first column
#     fftData = getFrequencies(p1)    #Get frequencies of FFT of the first column  
#     fftData = np.array(fftData) 
#     #find peaks using very low threshold and minimum distance
#     indexes = peakutils.indexes(fftData, thres=0.02/max(fftData), min_dist=1)  #Just find peaks
#     #in case of no oscillations return 0 
#     if len(indexes) == 0:     
#         return 0
#     #if amplitude is greater than 400nM
#     #global amp
#     #amp = np.max(fftData[indexes])
#     #if amp > 400: #If bigger than 400, then cost is 0, not viable
#       #  return 0, 
#     fitSamples = fftData[indexes]
#     std = getSTD(indexes, fftData, 1)  #get sd of peaks at a window of 1 (previous peak)
#     diff = getDif(indexes, fftData)  #Get differences in peaks
#     cost = std + diff #Sum them
#     #print(cost)   
#     if getAmplitude:
#         return cost, amp
#     return int(cost)

# def euclidean_distance_multiple_trajectories(truth, simulations):
#     num_trajectories = len(simulations)
#     total_distance = 0.0

#     for i in range(num_trajectories):
#         observed_data = truth[i]
#         simulated_data = simulations[i]

#         # Calculate the Euclidean distance between observed and simulated data
#         euclidean_distance = np.linalg.norm(observed_data - simulated_data)

#         # Accumulate the distances
#         total_distance += euclidean_distance

#     # Average the distances over all trajectories
#     average_distance = total_distance / num_trajectories

#     return average_distance

# def get_distance(truth, simulation):
#     timepoints = int(len(truth))
#     third = int(timepoints / 3)
#     observed = truth[third:timepoints]
#     simulated = simulation[third:timepoints] # Discard the first third
#     euclidean_distance = euclidean_distance_multiple_trajectories(observed, simulated)
#     penalising_factor = np.abs(np.abs(costTwo(simulation)) - 200)
#     if costTwo(simulation) >= 200:
#         return euclidean_distance
#     else:
#         if penalising_factor < 1:
#             penalising_factor = 1
#         return euclidean_distance * penalising_factor

# ##################################################################################################################
    
# # create priors from posterior of SNLE
    
# posterior_samples_np = posterior_samples.numpy()
# parameter_means = np.mean(posterior_samples_np, axis = 0)
# parameter_stdevs = np.std(posterior_samples_np, axis = 0)

# prior_samples = np.zeros((num_chains, posterior_samples.shape[1]))
# for param_idx in range(posterior_samples.shape[1]):
#     prior_samples[:, param_idx] = np.random.normal(loc = parameter_means[param_idx], scale = parameter_stdevs[param_idx], size = num_chains)

# #####################################################################################################################
    
# def abc_mcmc_single_chain_seeded(chain, true_params, seed, chain_idx):
#     random.seed(seed)
#     stored_accepted_data = []
#     params = []
#     distances = []
#     current_simulation = [] # initialised as list but converted to torch.tensor([]) once populated with simulator resuts using current_params, allowing reshaping,
#                             # could also be initialised as torch.tensor([])
#     starting_params = prior_samples[chain_idx]
#     current_params = tuple(starting_params) # required manipulation of the type of object (np.ndarray --> tuple)
#     parameter_traces = [current_params]  # Collect parameter traces to plot
#     accepted_count = 0
#     acceptance_rates = []
#     # with tqdm(total=len(chain), desc=f'Chain {chain_idx+1}', position=chain_idx) as pbar:
#     for _ in chain:
#         proposal_scale = 2.38*(len(true_params) ** -1/2)

# # For a Gaussian target distribution of unit variance in each dimension, the optimal σ of the proposal distribution is given as σ_{d}≈2.38d^{−1/2} by Gelman, Roberts,
# # and Gilks 1996, where  d is the dimension of the parameter space. For targeting other varieties of Gaussian, we need to scale it appropriately; multiply its
# # covariance matrix by the square of σ_{d} given above.
        
#         while True:
#             proposed_params = current_params + np.random.normal(scale=proposal_scale, size=len(true_params))
#             if max(proposed_params) > np.max(np.max(posterior_samples_np, axis = 0), axis=0) or min(proposed_params) < np.min(np.min(posterior_samples_np, axis = 0), axis=0):
#                 continue  # repeat the current iteration
#             break  # exit the while loop if conditions are satisfied
#         proposed_simulation = my_simulator_euler(proposed_params)

#         proposed_simulation_unflattened = unflatten(proposed_simulation)           
#         true_solutions_unflattened = unflatten(true_solutions)

#         p_proposed = get_distance(true_solutions_unflattened, proposed_simulation_unflattened)

#         if len(current_simulation) == 0:  # if the list is not empty i.e. first iterations of each chain
#             current_simulation = my_simulator_euler(current_params)

#         current_simulation_unflattened = unflatten(current_simulation)
#         p_current = get_distance(true_solutions_unflattened, current_simulation_unflattened)

#         if p_proposed < 300:
#             acceptance_prob = min(1, -np.exp(-p_current / p_proposed))
                                                
#             if np.random.rand() < acceptance_prob:
#                 accepted_count += 1
#                 current_params = proposed_params
#                 p_current = p_proposed
#                 current_simulation = proposed_simulation
#                 stored_accepted_data.append(current_simulation_unflattened)
        
#         distances.append(p_current)
#         params.append(current_params)
#         parameter_traces.append(current_params)

#         if _ % interval_to_calculate_acceptance_rate == 0:
#             acceptance_rates.append(accepted_count / _ * 100)
            
#         # pbar.update(1)  # Update progress bar
    
#     return params, stored_accepted_data, parameter_traces, accepted_count, distances, acceptance_rates

# def abc_mcmc_multiple_chains(true_params, num_iterations, num_chains, show_progress_bars: bool = True):
#     chain_seeds = random.sample(range(0, 1_000_000), num_chains)
#     iteration_index = np.arange(1, num_iterations + 1)
#     chains = np.array_split(iteration_index, num_chains)
#     with tqdm(total=num_iterations, desc=f'Running {num_chains} chains', disable = not show_progress_bars) as pbar: results = Parallel(n_jobs=num_workers)(delayed(abc_mcmc_single_chain_seeded)(chain, true_params, chain_seed, chain_idx) 
#         for chain_idx, (chain, chain_seed) in enumerate(zip(chains, chain_seeds)))
    
#     all_params, all_stored_accepted_data, all_parameter_traces, all_accepted_counts, all_distances, all_acceptance_rates = zip(*results)

#     return (list(all_params), list(all_stored_accepted_data), list(all_parameter_traces), list(all_accepted_counts), list(all_distances), list(all_acceptance_rates))

# all_accepted_parameters, all_accepted_data, all_parameter_traces, all_accepted_counts, all_distances, all_acceptance_rates= abc_mcmc_multiple_chains(true_params, num_iterations, num_chains)

# ##################################################################################################################
# ##################################################################### TRAJECTORIES ################################

# def plot_trajectories(data):
#     variables = ['m1', 'p1', 'm2', 'p2', 'm3', 'p3']

#     # Initialize lists to store quartiles for each variable
#     quartiles = [[] for _ in range(len(variables))]

#     # Iterate over each variable
#     for variable_idx in range(len(variables)):
#         variable_list = []
#         # Iterate over each timestep
#         for timestep in range(num_timesteps):
#             timestep_list = []
#             # Iterate over each chain and simulation
#             for chain in data:
#                 for simulation in chain:
#                     timestep_list.append(simulation[timestep][variable_idx])
#             variable_list.append(timestep_list)

#         # Calculate quartiles for the current variable
#         q1_list = []
#         q3_list = []
#         for timestep_data in variable_list:
#             q3, q1 = np.percentile(timestep_data, [75, 25])
#             q1_list.append(q1)
#             q3_list.append(q3)

#         # Append quartile lists for the current variable
#         quartiles[variable_idx].extend([q1_list, q3_list])

#     # Output quartiles for each variable
#     for idx, variable in enumerate(variables):
#         q1_list = quartiles[idx][0]
#         q3_list = quartiles[idx][1]

#     fig, axes = plt.subplots(2, 3, figsize=(15, 10))

#     for idx, ax in enumerate(axes.flat):
#         q1_list = quartiles[idx][0]
#         q3_list = quartiles[idx][1]
#         true_variable_data = true_solutions_unflattened[:, idx]  # Extract true data for the current variable
#         ax.plot(range(num_timesteps), q1_list, label='Q1', color='blue')
#         ax.plot(range(num_timesteps), q3_list, label='Q3', color='red')
#         ax.fill_between(range(num_timesteps), q1_list, q3_list, alpha=0.3)
#         ax.plot(range(num_timesteps), true_variable_data, label='True Data', color='green')
#         ax.set_xlabel('t')
#         ax.set_ylabel(variables[idx])
#         ax.legend()

#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#     plt.show()

# plot_trajectories(all_accepted_data)

# ############################################ PARAMETER TRACES ########################################

# fig, axs = plt.subplots(len(true_params), 1, figsize=(15, 3 * len(true_params)), sharex=True)

# for i in range(len(true_params)):
#     for chain in range(num_chains):
#         parameter_trace = np.array(all_parameter_traces[chain])  # Convert to numpy array
#         axs[i].plot(parameter_trace[:, i], alpha=1, label=f"Chain {chain+1}")
#         axs[i].legend()  # Add legend for each subplot
    
#     axs[i].set_ylabel(headers[i])

# axs[-1].set_xlabel('Iteration')
# plt.tight_layout()
# plt.show()

# ####################################################### PAIRPLOTS ###################################

# def plot_pairplots(data):
#     burn_in_iterations = int(num_iterations / num_chains * burn_in_fraction)
#     parameter_accepts = [[] for _ in range(len(headers))]  # List of lists for parameter acceptances
#     chain_labels = []

#     # Iterate over each chain
#     for chain in range(num_chains):
#         num_samples = len(data[chain])
#         for i in range(burn_in_iterations, num_samples):  # Start from burn-in samples
#             accepted_params = data[chain][i]
#             for param_idx, param_accept_list in enumerate(parameter_accepts):
#                 param_accept_list.append(accepted_params[param_idx])
#             chain_labels.append(f"Chain {chain + 1}")

#     # Construct DataFrame using parameter acceptances and chain labels
#     data_dict = {header: accept_list for header, accept_list in zip(headers, parameter_accepts)}
#     data_dict["Chain"] = chain_labels
#     plot_data = pd.DataFrame(data_dict)

#     g = sns.pairplot(plot_data, kind="kde", hue="Chain", palette="Set1")

#     # Add vertical/horizontal lines for true parameter values
#     for i, j in zip(*np.tril_indices_from(g.axes, -1)):
#         for idx, header in enumerate(headers):
#             true_value = true_params[idx]
#             ax = g.axes[i, j]
#             ax.axvline(x=true_value, color='k', linestyle='--')
#             if i != j:  # For off-diagonal plots
#                 ax.axhline(y=true_value, color='k', linestyle='--')

#     for i in range(len(headers)):
#         ax = g.axes[i, i]
#         for idx, header in enumerate(headers):
#             true_value = true_params[idx]
#             ax.axvline(x=true_value, color='k', linestyle='--')

#     plt.show()

# plot_pairplots(all_accepted_parameters)

# ############################################################# DISTANCES OVER ITERATION ##############################

# fig, axs = plt.subplots(num_chains, 1, figsize=(15, 3 * num_chains), sharex=True)

# for chain in range(num_chains):
#     axs[chain].plot(all_distances[chain], label=f"Chain {chain+1}")
#     axs[chain].set_ylabel('Distance')
#     axs[chain].set_xlim([0, int(num_iterations/num_chains)])

# axs[-1].set_xlabel('Iteration')
# plt.tight_layout()
# plt.show()
# for i in num_chains:
#     print(f"Final distance of Chain {i+1} = {all_distances[i][-1]}")

# ####################################################### ACCEPTANCE RATE #####################################################

# plt.figure(figsize=(10, 6))
# iterations = range(0, int(num_iterations/num_chains), interval_to_calculate_acceptance_rate)
# for chain in range(num_chains):
#     plt.plot(iterations, (all_acceptance_rates[chain]))

# plt.xlabel('Iteration')
# plt.ylabel('Acceptance Rate (%)')
# plt.grid(True)
# plt.show()