In [None]:
from jax import numpy as jnp
import pandas as pd

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function, Time

In [None]:
def build_hiv_model(
    config: dict,
    stratify: bool,
) -> CompartmentalModel:
    
    compartments = ("Susceptible",)
    model = CompartmentalModel(
        times=(0., config["end_time"]),
        compartments=compartments,
        infectious_compartments=(),
    )
    model.set_initial_population(distribution={"Susceptible": config["total_population"]})
    
    # This makes no difference
    # model.add_universal_death_flows(
    #     "non_aids_mortality",
    #     1. / Parameter("expectancy_at_debut"),
    # )
    model.add_replacement_birth_flow(
        "recruitment",
        "Susceptible",
    )

    # This makes the difference
    if stratify:
        activity_strata = ["High", "Low"]
        activity_strat = Stratification(
            "activity",
            activity_strata,
            compartments,
        )
        model.stratify_with(activity_strat)
    
    model.request_output_for_compartments(
        "total",
        compartments,
    )
    return model

In [None]:
model_config = {
    "total_population": 1e4,
    "end_time": 100.,
}

parameters = {
    "expectancy_at_debut": 35.,
}

In [None]:
outputs = {}
for stratify in [True, False]:
    hiv_model = build_hiv_model(model_config, stratify=stratify)
    hiv_model.run(parameters=parameters)
    outputs[stratify] = hiv_model.get_derived_outputs_df()["total"]

In [None]:
pd.DataFrame.from_dict(outputs)