In [1]:
from estival.model import BayesianCompartmentalModel
from estival import targets as est

import tbh.runner_tools as rt
from tbh.model import get_tb_model
from tbh.plotting import visualise_mle_params, plot_single_fit, title_lookup

In [2]:
params, priors, tv_params = rt.get_parameters_and_priors()


In [None]:
priors = priors[1:4]

[UniformPrior bg_mixing {bounds: (0.01, 0.05)},
 UniformPrior a_spread {bounds: (5.0, 15.0)},
 UniformPrior pc_strength {bounds: (0.1, 2.0)}]

In [None]:
model = get_tb_model(rt.DEFAULT_MODEL_CONFIG, tv_params)
bcm = BayesianCompartmentalModel(model, params, priors, rt.targets)

In [None]:
import nevergrad as ng
from estival.wrappers.nevergrad import optimize_model

In [None]:
opt_class = ng.optimizers.NGOpt
orunner = optimize_model(bcm, opt_class=opt_class)
rec = orunner.minimize(2000)

In [None]:
mle_params = rec.value[1]
mle_params
res = bcm.run(mle_params)

In [None]:
plot_single_fit(bcm, mle_params)

In [None]:
mle_params

In [None]:
visualise_mle_params(bcm.priors, mle_params)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

def plot_age_spec_tbi_prev(derived_outputs, bcm):
    """
    Plot age-specific TST positivity fraction using derived_outputs (single-run)
    and compare with observed targets from bcm.
    """
    # agegroups = ["3_9", "10", "15", "65"]
    
    agegroups = ["3_9", "10", "15+"]

    model_values = []
    targets = []

    # Collect modelled value per age group (single-run) and targets
    for age in agegroups:
        output_name = f"tst_posXage_{age}_perc"

        year = bcm.targets[output_name].data.index[0]

        # Extract modelled value from derived_outputs
        model_val = derived_outputs[output_name].loc[year]
        target = bcm.targets[output_name].data.iloc[0]

        model_values.append(model_val)
        targets.append(target)

    # --- Plot ---
    fig, ax = plt.subplots(figsize=(8, 5))

    # Plot modelled values as bars
    bars = ax.scatter(range(len(agegroups)), model_values, color='lightblue', edgecolor='navy', alpha=0.6, label='Modelled')

    # Overlay target points
    ax.scatter(range(len(agegroups)), targets, color='red', marker='x', s=80, label='Observed')

    # Create proxy artists for legend
    model_patch = mpatches.Patch(facecolor='lightblue', edgecolor='navy', alpha=0.6, label='Modelled')
    obs_marker = mlines.Line2D([], [], color='red', marker='x', linestyle='None', markersize=8, label='Observed')

    # Labels and formatting
    ax.set_xticks(range(len(agegroups)))
    x_tick_labels = []
    for i_age, age in enumerate(agegroups):
        if age == "3_9":
            x_tick_labels.append("3-9")
        elif age == "15+":
            x_tick_labels.append("15+")
        else:
            if i_age < (len(agegroups) - 1):
                next_age = agegroups[i_age + 1]
                next_age = next_age.replace("+", "")
                x_tick_labels.append(f"{age}-{int(next_age) - 1}")
            else:
                x_tick_labels.append(f"{age}+")        
    ax.set_xticklabels(x_tick_labels)

    ax.set_xlabel("Age group (years)")
    ax.set_ylabel(title_lookup["tst_pos_perc"])
    ax.set_title(f"Observed vs modelled TST positivity fraction by age group in {year}")
    ax.legend(handles=[model_patch, obs_marker], loc='best')
    ax.grid(alpha=0.3)

    ax.set_ylim(bottom=0)  # ensures y-axis starts at 0

    plt.tight_layout()
    # plt.show()

    return fig

fig =plot_age_spec_tbi_prev(res.derived_outputs, bcm)

In [None]:
res.derived_outputs["tst_posXage_3_9_perc"]