In [None]:
import torch
import sbi
import numpy as np
from scipy.integrate import odeint
from sbi.inference import SNRE, prepare_for_sbi, simulate_for_sbi
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
import random
import time

In [None]:
true_params = 0., 0.       # if this changes, model must also be adjusted
prior_min = -3
prior_max = 3
num_timesteps = 100
num_rounds = 3          # how many rounds of SNRE (5-10 is good)
num_simulations = 100   # how many simulations in each round

In [None]:
def repressilator(variables, t, theta):
    m1, p1, m2, p2, m3, p3 = variables
    k1, k2 = theta
    return [-m1 + (10 ** 3 / (1 + (10 ** k1 * p2) ** 2)) + 10 ** 0, #return the results if the six odes
            -10 ** 0 * (p1 - m1),
            -m2 + (10 ** 3 / (1 + (10 ** k2 * p3) ** 2)) + 10 ** 0,
            -10 ** 0 * (p2 - m2),
            -m3 + (10 ** 3 / (1 + (10 ** 0 * p1) ** 2)) + 10 ** 0,
            -10 ** 0 * (p3 - m3)]

In [None]:
t = np.linspace(0, 100, num_timesteps)
def my_simulator(theta):
    initial_conditions = np.array([0, 2, 0, 1, 0, 3], dtype=np.float32)
    solution = odeint(repressilator, initial_conditions, t, args=(theta,))
    return torch.tensor(solution, dtype=torch.float32).flatten()  # Flatten tensor to size [600]
x_o = my_simulator(true_params) # Generate observations,with data from 100 timepoints for each of 6 variables
num_dim = len(true_params)
prior = utils.BoxUniform(low=prior_min * torch.ones(num_dim), high=prior_max * torch.ones(num_dim))
simulator, prior = prepare_for_sbi(my_simulator, prior)  
num_rounds = 2
# The specific observation we want to focus the inference on.
posteriors = []
proposal = prior
inference = SNRE(prior=prior)
for _ in range(num_rounds):
    theta, x = simulate_for_sbi(simulator, proposal, num_simulations=100)
    # In `SNLE` and `SNRE`, you should not pass the `proposal` to `.append_simulations()`
    density_estimator = inference.append_simulations(
        theta, x
    ).train()
    posterior = inference.build_posterior(density_estimator)
    posteriors.append(posterior)
    proposal = posterior.set_default_x(x_o) 
sampling_algorithm = "mcmc"
mcmc_method = "slice_np"  # or nuts, or hmc
posterior = inference.build_posterior(sample_with=sampling_algorithm, mcmc_method=mcmc_method)
posterior_samples_sre = posterior.sample((500,), x=x_o) 