In [None]:
import numpy as np
import matplotlib.pyplot as plt

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput

In [None]:
def build_hiv_model(
    config: dict,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = (
        "Susceptible", 
        "Infectious", 
        "AIDS"
    )
    model = CompartmentalModel(
        times=(0., config["end_time"]),
        compartments=compartments,
        infectious_compartments=("Infectious",),
    )
    model.set_initial_population(
        distribution={
            "Susceptible": config["total_population"] - config["infectious_seed"],
            "Infectious": config["infectious_seed"],
        }
    )
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"),
        source="Susceptible",
        dest="Infectious"
    )
    model.add_transition_flow(
        name="progression", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="Infectious", 
        dest="AIDS"
    )
    model.add_universal_death_flows(
        "non_aids_mortality",
        1. / Parameter("expectancy_at_debut"),
    )
    model.add_replacement_birth_flow(
        "recruitment",
        "Susceptible",
    )
    model.add_death_flow(
        "aids_mortality",
        1. / Parameter("aids_period"),
        "AIDS",
    )

    # Activity rate stratification
    activity_strat = Stratification(
        "activity",
        ["High", "Low"],
        compartments,
    )
    high_prop = config["high_prop"]
    low_prop = 1. - high_prop
    activity_strat.set_population_split(
        {
            "High": high_prop,
            "Low": low_prop,
        }
    )
    activity_strat.set_flow_adjustments(
        "recruitment",
        adjustments={
            "High": high_prop,
            "Low": low_prop,
        },
    )
    
    mixing_matrix = np.array(
        [
            [
                config["high_partner_change_prop"], 
                config["low_partner_change_prop"]
            ]
        ]
    )
    mixing_matrix = np.repeat(mixing_matrix, 2, axis=0)
    
    activity_strat.set_flow_adjustments(
        "infection",
        {
            "High": Multiply(config["high_partner_change_rate"]),  # Or multiply top row of matrix by this
            "Low": Multiply(config["low_partner_change_rate"]),  # Or multiply bottom row of matrix by this
        },
    )
    
    activity_strat.set_mixing_matrix(
        mixing_matrix,
    )
    model.stratify_with(activity_strat)
    
    # Outputs
    model.request_output_for_compartments(
        "infectious",
        ["Infectious"],
        save_results=False,
    )
    model.request_output_for_compartments(
        "total",
        compartments,
        save_results=False,
    )
    model.request_function_output(
        "Prevalence",
        DerivedOutput("infectious") / DerivedOutput("total") * 100.,
    )
    model.request_output_for_flow(
        "abs_incidence",
        "progression",
    )
    model.request_function_output(
        "Incidence",
        DerivedOutput("abs_incidence") / DerivedOutput("total") * 100.,
    )
    model.request_output_for_flow(
        "mortality",
        "aids_mortality",
    )
    model.request_cumulative_output(
        "Cumulative mortality",
        "mortality",
    )
    
    return model

In [None]:
model_config = {
    "high_prop": 0.15,
    "high_partner_change_rate": 8.,
    "low_partner_change_rate": 0.2,
    "total_population": 1e4,
    "infectious_seed": 100.,
    "end_time": 100.,
}

# From equation 8.20
model_config["high_partner_change_prop"] = \
    model_config["high_partner_change_rate"] * model_config["high_prop"] \
    / (model_config["high_partner_change_rate"] * model_config["high_prop"] + model_config["low_partner_change_rate"] * (1. - model_config["high_prop"]))
model_config["low_partner_change_prop"] = 1. - model_config["high_partner_change_prop"]

parameters = {
    "infectious_period": 9.,
    "expectancy_at_debut": 35.,
    "aids_period": 1.,
    "contact_rate": 0.05,
}

In [None]:
hiv_model = build_hiv_model(model_config)
hiv_model.run(parameters=parameters)
outputs = hiv_model.get_derived_outputs_df()

In [None]:
fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(14, 5))
for output in ["Prevalence", "Incidence"]:
    left_ax.plot(
        hiv_model.times,
        outputs[output],
        label=output,
        linestyle=":" if output == "Incidence" else "--",
        color="k",
    )
left_ax.set_xlim(0., 100.)
left_ax.set_xlabel("Time (years)")
left_ax.set_xticks(range(0, 110, 10))
left_ax.legend(loc=2)
left_ax.set_ylim(0., 14.)
left_ax.set_ylabel("Prevalence (%) and incidence (%/year) of HIV infection")
left_ax.spines.top.set_visible(False)
twin_ax = left_ax.twinx()
twin_ax.plot(
    hiv_model.times,
    outputs["Cumulative mortality"] / 1e3,
    label="Cumulative AIDS deaths",
    color="k",
)
twin_ax.legend(loc=1)
twin_ax.set_ylabel("Cumulative AIDS deaths (thousands)")
twin_ax.set_ylim(bottom=0.)
twin_ax.set_yticks(range(0, 8))
twin_ax.spines.top.set_visible(False)

image = plt.imread("./figures/fig_8_20a.jpg")
right_ax.axis("off")
right_ax.imshow(image, aspect="auto");