In [None]:
import multiprocessing as mp

mp.set_start_method("spawn")  # previously "forkserver"

import pytensor
pytensor.config.cxx = '/usr/bin/clang++'

In [None]:
from tbh import runner_tools as rt 
from tbh import plotting as pl
from tbh import model as tbm

from estival import targets as est
from estival.model import BayesianCompartmentalModel
import pandas as pd

In [None]:
targets = [
    est.NormalTarget(
        name='tb_prevalence_per100k', 
        data=pd.Series(data=[600,], index=[2020]), 
        stdev=100.
    ),
    est.NormalTarget(
        name='tbi_prevalence_perc', 
        data=pd.Series(data=[40,], index=[2020]), 
        stdev=5.
    ),
    est.NormalTarget(
        name='perc_prev_subclinical', 
        data=pd.Series(data=[50], index=[2020]), 
        stdev=5.
    ),
]

In [None]:
from tbh import model as tbm 

params, priors, tv_params = rt.get_parameters_and_priors()
model_config = rt.DEFAULT_MODEL_CONFIG
model = tbm.get_tb_model(model_config, tv_params)

In [None]:
bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [None]:
# Metropolis config
n_cores = 1 # Requesting multiple cores won't work on a mac
tune = 500
draws = 2000

# Full runs config
burn_in = int(draws / 2.) # 10000
full_runs_samples = 1000

In [None]:
idata = rt.run_metropolis_calibration(bcm, draws=draws, tune=tune, cores=n_cores)


In [None]:
import arviz as az

rhats = az.rhat(idata)
burnt_rhats = az.rhat(idata.sel(draw=range(burn_in, idata.sample_stats.sizes['draw'])))
print(f"Max R_hat for full chains: {rhats.to_array().max().item()}")
print(f"Max R_hat for burnt chains: {burnt_rhats.to_array().max().item()}")

full_runs, unc_df = rt.run_full_runs(bcm, idata, burn_in, full_runs_samples)

from pathlib import Path 
output_folder_path = Path.cwd() / "test_outputs"

pl.plot_traces(idata, burn_in, output_folder_path)

pl.plot_post_prior_comparison(idata, burn_in, list(bcm.priors.keys()), list(bcm.priors.values()), n_col=4, output_folder_path=output_folder_path)

import matplotlib.pyplot as plt

selected_outputs = bcm.targets.keys()

for output in selected_outputs:
    fig, ax = plt.subplots()
    pl.plot_model_fit_with_uncertainty(ax, unc_df, output, bcm, x_min=2010)

    if output_folder_path:
        plt.savefig(output_folder_path / f"{output}.jpg", facecolor="white", bbox_inches='tight')
        plt.close()