In [None]:
import multiprocessing as mp
import platform

if platform.system() != "Windows":
    mp.set_start_method('forkserver')
    
import numpy as np
import pandas as pd

from tb_incubator.constants import set_project_base_path
from tb_incubator.constants import compartments, infectious_compartments, model_times, age_strata
from tb_incubator.model import build_model
from tb_incubator.input import load_targets, load_param_info

from estival import targets as est
from estival import priors as esp
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
import pymc as pm

import xarray as xr
import arviz as az
from estival.sampling.tools import likelihood_extras_for_idata
from estival.utils.parallel import map_parallel

#pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")


In [2]:
param_info = load_param_info()
params = param_info["value"]

model, desc = build_model(
    compartments,
    infectious_compartments,
    age_strata,
    params,
    model_times
)

model.run(params)

In [3]:
all_targets = load_targets()

In [4]:
targets = [
    est.TruncatedNormalTarget("prevalence", all_targets["prevalence"], (0.0, np.inf),
        esp.UniformPrior("prevalence_dispersion", (0.1, all_targets["prevalence"].max() * 0.1))
    ),
    est.TruncatedNormalTarget("notification", all_targets["notif"], (0.0, np.inf),
        esp.UniformPrior("notification_dispersion", (0.1, all_targets["notif"].max() * 0.1))
    )
]

priors = [
    esp.UniformPrior("contact_rate", (0.01, 3.0)),
    esp.UniformPrior("self_recovery_rate", (0.05, 0.30)),
    esp.UniformPrior("screening_scaleup_shape", (0.01, 0.3)),
    esp.UniformPrior("screening_inflection_time", (1990.0, 2018.0)),
    esp.UniformPrior("time_to_screening_end_asymp", (0.1, 20.0)),
]

bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [None]:
with pm.Model() as model:
    variables = epm.use_model(bcm)
    idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=4000, tune=0,cores=8,chains=8)

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(idata, figsize=(16,3.2*len(idata.posterior)),compact=False);#, lines=[("m", {}, mtrue), ("c", {}, ctrue)]);

In [None]:
az.plot_posterior(idata)

In [None]:
likelihood_df = likelihood_extras_for_idata(idata, bcm)
likelihood_df

In [None]:
# Examine the performance of chains over time
ldf_pivot = likelihood_df.reset_index(level="chain").pivot(columns=["chain"])
ldf_pivot["logposterior"].plot()

In [None]:
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()

map_params

In [None]:
bcm.loglikelihood(**map_params), ldf_sorted.iloc[0]["loglikelihood"]

In [32]:
map_res = bcm.run(map_params)

In [None]:
variable = "prevalence"

pd.Series(map_res.derived_outputs[variable]).plot(title = f"{variable} (MLE)")
bcm.targets[variable].data.plot(style='.');

In [None]:
variable = "notification"

pd.Series(map_res.derived_outputs[variable]).plot(title = f"{variable} (MLE)")
bcm.targets[variable].data.plot(style='.');

In [17]:

sample_idata = az.extract(idata, num_samples = 400)
samples_df = sample_idata.to_dataframe().drop(columns=["chain","draw"])

In [19]:
def run_sample(idx_sample):
    idx, params = idx_sample
    return idx, bcm.run(params)

In [None]:
sample_res = map_parallel(run_sample, samples_df.iterrows(), n_workers=4)

# Build a DataArray out of our results, then assign coords for indexing
xres = xr.DataArray(np.stack([r.derived_outputs for idx, r in sample_res]), 
                    dims=["sample","time","variable"])
xres = xres.assign_coords(sample=sample_idata.coords["sample"], 
                          time=map_res.derived_outputs.index, variable=map_res.derived_outputs.columns)

quantiles = (0.01,0.05,0.25,0.5,0.75,0.95,0.99)
xquantiles = xres.quantile(quantiles,dim=["sample"])
uncertainty_df = xquantiles.to_dataframe(name="value").reset_index().set_index("time").pivot(columns=("variable","quantile"))["value"]

In [None]:
variable = "notification"

fig = uncertainty_df[variable].plot(title=variable,alpha=0.7)
pd.Series(map_res.derived_outputs[variable]).plot(style='--')
bcm.targets[variable].data.plot(style='.',color="black", ms=3, alpha=0.8);