In [None]:
import runner_tools as rt 
from importlib import reload

reload(rt);

# Manual run

In [None]:
studies_dict = {
    "majuro": {
        "pop_size": 27797,
    },
    "study_2": {
        "pop_size": 50000,
    }    
}

params = rt.DEFAULT_PARAMS

In [None]:
model, do_df = rt.model_single_run(rt.DEFAULT_MODEL_CONFIG, studies_dict, params)

In [None]:
do_df[['ltbi_propXmajuro']].plot()

# Calibration and full runs

In [None]:
# Metropolis config
tune = 10
draws = 100

# Full runs config
burn_in = 50
full_runs_samples = 50

In [None]:
bcm = rt.get_bcm_object(rt.DEFAULT_MODEL_CONFIG, studies_dict, params)

In [None]:
idata = rt.run_metropolis_calibration(bcm, draws=draws, tune=tune)

In [None]:
full_runs, unc_df = rt.run_full_runs(bcm, idata, burn_in, full_runs_samples)

# Visualise traces and posteriors

In [None]:
from pathlib import Path
import arviz as az


def make_post_mc_plots(idata, burn_in, output_folder=None):
    az.rcParams["plot.max_subplots"] = 60 # to make sure all parameters are included in trace plots

    if output_folder:
        output_folder_path = Path(output_folder) / "mc_outputs"
        output_folder_path.mkdir(exist_ok=True, parents=True)

    chain_length = idata.sample_stats.sizes['draw']

    # Traces (including burn-in)
    # az.plot_trace(idata, figsize=(16, 3.0 * len(idata.posterior)), compact=False);
    # plt.subplots_adjust(hspace=.7)
    # if output_folder:
    #     plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight')
    #     plt.close()

    # burn data
    burnt_idata = idata.sel(draw=range(burn_in, chain_length))  # Discard burn-in

    # Traces (after burn-in)
    az.plot_trace(burnt_idata, figsize=(16, 3.0 * len(idata.posterior)), compact=False);
    plt.subplots_adjust(hspace=.7)
    if output_folder:
        plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight')
        plt.close()

    # Posteriors (excluding burn-in)
    az.plot_posterior(burnt_idata);
    if output_folder:
        plt.savefig(output_folder_path / "mc_posteriors_postburnin.png", facecolor="white", bbox_inches='tight')
        plt.close()

    # ESS (excluding burn-in)
    # raw_ess_df = az.ess(burnt_idata).to_dataframe()
    # ess_df = raw_ess_df.drop(columns="random_process.delta_values").loc[0]
    # for i in range(len(raw_ess_df)):
    #     ess_df[f"random_process.delta_values[{i}]"] = raw_ess_df['random_process.delta_values'][i]
    # if output_folder:
    #     ess_df.to_csv(output_folder_path / "mc_ess.csv", header=["ESS"])

    # R_hat plot (excluding burn-in)
    # raw_rhat_df = az.rhat(burnt_idata).to_dataframe()
    # rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0]
    # for i in range(len(raw_rhat_df)):
    #     rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i]
    # axis = rhat_df.plot.barh(xlim=(1.,1.105))
    # axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
    # axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')    
    # if output_folder:
    #     plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight')
    #     plt.close()

In [None]:
make_post_mc_plots(idata, burn_in)

# Plot outputs with uncertainty

In [None]:
from copy import copy

def plot_model_fit_with_uncertainty(axis, uncertainty_df, output_name, bcm, include_legend=True):

    # update_rcparams() 
   
    df = uncertainty_df[output_name]

    if output_name in bcm.targets:
        t = copy(bcm.targets[output_name].data)
        axis.scatter(list(t.index), t, marker=".", color='black', label='observations', zorder=11, s=5.)

    colour = (0.2, 0.2, 0.8)   

    time = df.index
    axis.plot(time, df[0.5], color=colour, zorder=10, label="model (median)")

    axis.fill_between(
        time, 
        df[0.25], df[0.75], 
        color=colour, 
        alpha=0.5, 
        edgecolor=None,
        label="model (IQR)"
    )
    axis.fill_between(
        time, 
        df[0.025], df[0.975],
        color=colour, 
        alpha=0.3,
        edgecolor=None,
        label="model (95% CI)",
    )

    if output_name == "transformed_random_process":
        axis.set_ylim((0., axis.get_ylim()[1]))

    
    # x_min = bcm.targets["population"].data.index.min()
    # axis.set_xlim((x_min, axis.get_xlim()[1]))

    # axis.tick_params(axis="x", labelrotation=45)
    title = output_name # if output_name not in title_lookup else title_lookup[output_name]

    axis.set_ylabel(title)
    # plt.tight_layout()

    if include_legend:
        plt.legend(markerscale=2.)
    # axis.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))

    # return x_min

In [None]:
import matplotlib.pyplot as plt

selected_outputs = [t.name for t in targets]

for output in selected_outputs:

    fig, ax = plt.subplots()
    plot_model_fit_with_uncertainty(ax, unc_df, output, bcm)


In [None]:
import numpy as np


def plot_post_prior_comparison(
    idata: az.InferenceData,
    req_vars: list, #List[str],
    priors: list, #List[dist.Distribution],
    req_grid=None,
    req_size=None,
) -> plt.figure:
    """Plot comparison of calibration posterior estimates
    for parameters against their prior distributions.

    Args:
        idata: Calibration inference data
        req_vars: Names of the parameters to plot
        priors: Prior distributions for the parameters
        req_grid: Dimensions of the subplot
        req_size: Figure size request

    Returns:
        The figure
    """
    grid = req_grid if req_grid else [1, len(req_vars)]
    size = req_size if req_size else None
    fig = az.plot_density(idata, var_names=req_vars, shade=0.3, grid=grid, figsize=size, hdi_prob=1.)
    for i_ax, ax in enumerate(fig.ravel()):
        ax_limits = ax.get_xlim()
        param = ax.title.get_text().split("\n")[0]
        if param:
            x_vals = np.linspace(*ax_limits, 50)
            distri = priors[i_ax]
            # if len(distri.batch_shape) == 0:

            if type(distri) != esp.TruncNormalPrior:
                y_vals = np.exp(distri.logpdf(x_vals))
                # else:
                #     y_vals = np.exp(distri.log_prob(x_vals[:, None])[:, 0])
                # y_vals *= ax.get_ylim()[1] # / max(y_vals)
                ax.fill_between(x_vals, y_vals, color="k", alpha=0.2, linewidth=2)
    # ax.figure.suptitle(country, fontsize=30, y=1.0)
    return ax.figure.tight_layout()

In [None]:
plot_post_prior_comparison(idata, list(bcm.priors.keys()), list(bcm.priors.values()), req_grid=[3, 4])