# ClonalOrigin model and simulation-based inference

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sbi.utils.torchutils import BoxUniform
from sbi.inference import NPE_C, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
import sys
sys.path.append('../pysimARG')
from clonal_genealogy import ClonalTree
from ClonalOrigin_simulator import ClonalOrigin_simulator

torch_device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm


## Observation data, simulator and prior

**Simulate observed data**

Set the true parameters as
* $\rho_s = 0.02$,
* $\delta = 300$,
* $\theta_s = 0.05$.

In [None]:
np.random.seed(100)
tree = ClonalTree(n=10)

In [None]:
rho_site = 0.02
theta_site = 0.05
L = 100000
delta = 300

x_o = ClonalOrigin_simulator(tree, rho_site, theta_site, L, delta, N=2000, k_vec=[50, 200, 2000])
x_o = torch.tensor(x_o, device=torch_device)
x_o = x_o.flatten()

x_o_numpy = x_o.cpu().numpy()

In [None]:
x_o

**Define prior distribution**

We set the prior to be uniform with range:
* $[0, 0.2]$ for $\rho_s$,
* $[1, 500]$ for $\delta$,
* $[0, 0.2]$ for $\theta_s$.

In [None]:
prior = BoxUniform(
    low=torch.tensor([0.0, 1.0, 0.0], device=torch_device),
    high=torch.tensor([0.2, 500.0, 0.2], device=torch_device), 
    device=torch_device
)

**Set the simulator with tensor output**

In [None]:
def simulator(theta):
    theta = theta.reshape(-1)
    summary_stats = ClonalOrigin_simulator(tree,
                                           theta[0].item(),
                                           theta[2].item(),
                                           L,
                                           theta[1].item(),
                                           N=2000, k_vec=[50, 200, 2000])
    summary_stats = torch.tensor(summary_stats, device=torch_device)
    return summary_stats

In [None]:
prior, num_parameters, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

In [None]:
check_sbi_inputs(simulator, prior)

## NPE-C

In [None]:
simulation_budget = 5000
seed = 100
num_posterior_samples=1000
learning_rate = 0.0005

inference = NPE_C(prior=prior, density_estimator="nsf", device=torch_device)
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
theta, x = simulate_for_sbi(
    simulator=simulator, proposal=prior, num_simulations=simulation_budget, num_workers=10
)

In [None]:
density_estimator = inference.append_simulations(theta, x).train(
    max_num_epochs=100, learning_rate=learning_rate
)
posterior = inference.build_posterior(density_estimator).set_default_x(x_o)

In [None]:
theta_trained = posterior.sample((num_posterior_samples,), x=x_o)
theta_trained = theta_trained.reshape((num_posterior_samples, 2))