In [None]:
from importlib import reload
import model as m

reload(m)

In [None]:
model_config = {
    "start_time": 1850,
    "end_time": 2050,
    "seed": 100,   
}

In [None]:
studies_dict = {
    "majuro": {
        "pop_size": 27797,
    },
    "study_2": {
        "pop_size": 50000,
    }    
}

In [None]:
model = m.get_tb_model(model_config, studies_dict)

In [None]:
params = {
    # Study-specific parameters
    'transmission_rateXmajuro': 10,
    'transmission_rateXstudy_2': 10,

    'lifelong_activation_riskXmajuro': .15,
    'lifelong_activation_riskXstudy_2': .10,
    'prop_early_among_activatorsXmajuro': .90,
    'prop_early_among_activatorsXstudy_2': .90,

    'current_passive_detection_rate': 1.,

    # Universal parameters
    'mean_duration_early_latent': .5,
    'rr_reinfection_latent_late': .2,
    'rr_reinfection_recovered': 1.,
    'self_recovery_rate': .2,
    'tb_death_rate': .2,
    'tx_duration': .5,
    'tx_prop_death': .04
}
model.run(params)

In [None]:
do_df = model.get_derived_outputs_df()
# do_df[['ltbi_propXstudy_1', 'ltbi_propXstudy_2']].plot()
# do_df[['populationXmajuro', 'populationXstudy_2']].plot()

In [None]:
do_df[['ltbi_propXmajuro']].plot()

# Calibration

In [None]:
from estival import priors as esp
from estival import targets as est
from estival.model import BayesianCompartmentalModel

In [None]:
def get_priors(studies_dict):

    # Define hyper-prior distributions
    hyper_mean_lifelong = esp.UniformPrior("hyper_mean_lifelong", [0., 1.])
    # hyper_sd_lifelong = esp.UniformPrior("hyper_sd_lifelong", [0., 10.])
    # hyper_mean_early = esp.UniformPrior("hyper_mean_early", [0., 1.])
    # hyper_sd_early = esp.UniformPrior("hyper_sd_early", [0., 10.])
    
    # Initialise the list of priors with "universal" priors and hyper-priors
    priors = [
        esp.UniformPrior("current_passive_detection_rate", [.1, 10.]),
        hyper_mean_lifelong,
        # hyper_sd_lifelong,
        # hyper_mean_early,
        # hyper_sd_early
    ]
    
    # Complete the list of priors using study-specific priors
    for study in studies_dict:
        priors.extend(
            [
                esp.UniformPrior(f"transmission_rateX{study}", [1., 15.]),

                # the two priors below linked through the previously defined hyper-prior distributions 

                # esp.TruncNormalPrior(f"lifelong_activation_riskX{study}", hyper_mean_lifelong, hyper_sd_lifelong, [0., 1.]),
                esp.TruncNormalPrior(f"lifelong_activation_riskX{study}", hyper_mean_lifelong, .01, [0., 1.]),


                # esp.TruncNormalPrior(f"prop_early_among_activatorsX{study}", hyper_mean_early, hyper_sd_early, [0., 1.]),
            ]
        )
    return priors

In [None]:
priors = get_priors(studies_dict)
prior_list = [p.name for p in priors]

In [None]:
prior_list

In [None]:
import pandas as pd
targets = [
    est.NormalTarget("ltbi_propXmajuro", data=pd.Series(data=[.38], index=[2018]), stdev=esp.UniformPrior("std_ltbi", [.001, .1])),
    est.NormalTarget("tb_prevalence_per100kXmajuro", data=pd.Series(data=[1366], index=[2018]), stdev=esp.UniformPrior("std_tb", [10., 250.])),
    est.NormalTarget("raw_notificationsXmajuro", data=pd.Series(data=[100], index=[2015]), stdev=esp.UniformPrior("std_not", [1., 25.])),
]

In [None]:
import nevergrad as ng
from estival.wrappers.nevergrad import optimize_model

def calibrate_with_opti(bcm, n_iter, opt_class=ng.optimizers.NGOpt):
    
    orunner = optimize_model(bcm, opt_class=opt_class)
    rec = orunner.minimize(n_iter)
    mle_params = rec.value[1]

    return mle_params

import matplotlib.pyplot as plt

def visualise_mle_fit(bcm, mle_params):
    print("Running model with MLE parameters...")
    res = bcm.run(mle_params)
    print("... run completed.")
    for target in bcm.targets:
        plt.figure()
        bcm.targets[target].data.plot(style='.')
        res.derived_outputs[target].plot()
        plt.title(target)

In [None]:
model = m.get_tb_model(model_config, studies_dict)
bcm = BayesianCompartmentalModel(model, params, priors, targets)

In [None]:
# mle_params = calibrate_with_opti(bcm, 200)

In [None]:
# visualise_mle_fit(bcm, mle_params)

In [None]:

from estival.wrappers import pymc as epm

import pymc as pm

In [None]:
with pm.Model() as model:
    
    variables = epm.use_model(bcm)
    idata = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=20000, tune=2000,cores=4,chains=4)

# Visualise traces and posteriors

In [None]:
burn_in = 10000
full_runs_samples = 1000

In [None]:
from pathlib import Path
import arviz as az


def make_post_mc_plots(idata, burn_in, output_folder=None):
    az.rcParams["plot.max_subplots"] = 60 # to make sure all parameters are included in trace plots

    if output_folder:
        output_folder_path = Path(output_folder) / "mc_outputs"
        output_folder_path.mkdir(exist_ok=True, parents=True)

    chain_length = idata.sample_stats.sizes['draw']

    # Traces (including burn-in)
    # az.plot_trace(idata, figsize=(16, 3.0 * len(idata.posterior)), compact=False);
    # plt.subplots_adjust(hspace=.7)
    # if output_folder:
    #     plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight')
    #     plt.close()

    # burn data
    burnt_idata = idata.sel(draw=range(burn_in, chain_length))  # Discard burn-in

    # Traces (after burn-in)
    az.plot_trace(burnt_idata, figsize=(16, 3.0 * len(idata.posterior)), compact=False);
    plt.subplots_adjust(hspace=.7)
    if output_folder:
        plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight')
        plt.close()

    # Posteriors (excluding burn-in)
    az.plot_posterior(burnt_idata);
    if output_folder:
        plt.savefig(output_folder_path / "mc_posteriors_postburnin.png", facecolor="white", bbox_inches='tight')
        plt.close()

    # ESS (excluding burn-in)
    # raw_ess_df = az.ess(burnt_idata).to_dataframe()
    # ess_df = raw_ess_df.drop(columns="random_process.delta_values").loc[0]
    # for i in range(len(raw_ess_df)):
    #     ess_df[f"random_process.delta_values[{i}]"] = raw_ess_df['random_process.delta_values'][i]
    # if output_folder:
    #     ess_df.to_csv(output_folder_path / "mc_ess.csv", header=["ESS"])

    # R_hat plot (excluding burn-in)
    # raw_rhat_df = az.rhat(burnt_idata).to_dataframe()
    # rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0]
    # for i in range(len(raw_rhat_df)):
    #     rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i]
    # axis = rhat_df.plot.barh(xlim=(1.,1.105))
    # axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
    # axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')    
    # if output_folder:
    #     plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight')
    #     plt.close()

In [None]:
make_post_mc_plots(idata, burn_in)

# Plot outputs with uncertainty

In [None]:
from estival.sampling import tools as esamp

def extract_sample_subset(idata, n_samples, burn_in, chain_filter: list = None):
    chain_length = idata.sample_stats.sizes['draw']
    burnt_idata = idata.sel(draw=range(burn_in, chain_length))  # Discard burn-in
    
    return az.extract(burnt_idata, num_samples=n_samples)

In [None]:


full_run_params = extract_sample_subset(idata, full_runs_samples, burn_in)  
full_runs = esamp.model_results_for_samples(full_run_params, bcm, include_extras=False)

unc_df = esamp.quantiles_for_results(full_runs.results, [.025, .25, .5, .75, .975])


In [None]:
from copy import copy

def plot_model_fit_with_uncertainty(axis, uncertainty_df, output_name, bcm, include_legend=True):

    # update_rcparams() 
   
    df = uncertainty_df[output_name]

    if output_name in bcm.targets:
        t = copy(bcm.targets[output_name].data)
        axis.scatter(list(t.index), t, marker=".", color='black', label='observations', zorder=11, s=5.)

    colour = (0.2, 0.2, 0.8)   

    time = df.index
    axis.plot(time, df[0.5], color=colour, zorder=10, label="model (median)")

    axis.fill_between(
        time, 
        df[0.25], df[0.75], 
        color=colour, 
        alpha=0.5, 
        edgecolor=None,
        label="model (IQR)"
    )
    axis.fill_between(
        time, 
        df[0.025], df[0.975],
        color=colour, 
        alpha=0.3,
        edgecolor=None,
        label="model (95% CI)",
    )

    if output_name == "transformed_random_process":
        axis.set_ylim((0., axis.get_ylim()[1]))

    
    # x_min = bcm.targets["population"].data.index.min()
    # axis.set_xlim((x_min, axis.get_xlim()[1]))

    # axis.tick_params(axis="x", labelrotation=45)
    title = output_name # if output_name not in title_lookup else title_lookup[output_name]

    axis.set_ylabel(title)
    # plt.tight_layout()

    if include_legend:
        plt.legend(markerscale=2.)
    # axis.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))

    # return x_min

In [None]:
import matplotlib.pyplot as plt

selected_outputs = [t.name for t in targets]

for output in selected_outputs:

    fig, ax = plt.subplots()
    plot_model_fit_with_uncertainty(ax, unc_df, output, bcm)


In [None]:
import numpy as np


def plot_post_prior_comparison(
    idata: az.InferenceData,
    req_vars: list, #List[str],
    priors: list, #List[dist.Distribution],
    req_grid=None,
    req_size=None,
) -> plt.figure:
    """Plot comparison of calibration posterior estimates
    for parameters against their prior distributions.

    Args:
        idata: Calibration inference data
        req_vars: Names of the parameters to plot
        priors: Prior distributions for the parameters
        req_grid: Dimensions of the subplot
        req_size: Figure size request

    Returns:
        The figure
    """
    grid = req_grid if req_grid else [1, len(req_vars)]
    size = req_size if req_size else None
    fig = az.plot_density(idata, var_names=req_vars, shade=0.3, grid=grid, figsize=size, hdi_prob=1.)
    for i_ax, ax in enumerate(fig.ravel()):
        ax_limits = ax.get_xlim()
        param = ax.title.get_text().split("\n")[0]
        if param:
            x_vals = np.linspace(*ax_limits, 50)
            distri = priors[i_ax]
            # if len(distri.batch_shape) == 0:

            if type(distri) != esp.TruncNormalPrior:
                y_vals = np.exp(distri.logpdf(x_vals))
                # else:
                #     y_vals = np.exp(distri.log_prob(x_vals[:, None])[:, 0])
                # y_vals *= ax.get_ylim()[1] # / max(y_vals)
                ax.fill_between(x_vals, y_vals, color="k", alpha=0.2, linewidth=2)
    # ax.figure.suptitle(country, fontsize=30, y=1.0)
    return ax.figure.tight_layout()

In [None]:
plot_post_prior_comparison(idata, list(bcm.priors.keys()), list(bcm.priors.values()), req_grid=[3, 4])