In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput

In [None]:
def build_sis_model(
    model_config: dict,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = (
        "susceptible",
        "infectious",
    )
    analysis_times = (
        model_config["start_time"], 
        model_config["end_time"],
    )
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=["infectious"],
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": model_config["population"] - model_config["seed"], 
            "infectious": model_config["seed"],
        }
    )
    
    # Transitions
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"),
        source="susceptible", 
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=Parameter("recovery"),
        source="infectious", 
        dest="susceptible",
    )
    
    # Activity rate stratification
    activity_strat = Stratification(
        "activity",
        ["High", "Low"],
        compartments,
    )

    high_prop = config["high_prop"]
    low_prop = 1. - high_prop
    
    activity_strat.set_population_split(
        {
            "High": high_prop,
            "Low": low_prop,
        }
    )
    
    low_partner_change = config["low_partner_change"]
    high_partner_change = (config["average_partner_change"] - low_partner_change * low_prop) / high_prop  # Equation 8.15
    high_change_rate_abs = high_partner_change * high_prop  # Absolute rate of partner changes for high stratum
    low_change_rate_abs = low_partner_change * low_prop  # Absolute rate of partner changes for high stratum
    total_change_rate = high_change_rate_abs + low_change_rate_abs  # Absolute total rate of partner changes in the population

    # The "g" values
    high_change_prop = high_change_rate_abs / total_change_rate  # Equation 8.20
    low_change_prop = low_change_rate_abs / total_change_rate
    mixing_matrix = np.array([[high_change_prop, low_change_prop]])
    mixing_matrix = np.repeat(mixing_matrix, 2, axis=0)  # Double up to a square array
    
    # The "c" values
    activity_strat.set_flow_adjustments(
        "infection",
        {
            "High": Multiply(high_partner_change),  # Or multiply top row of matrix by this
            "Low": Multiply(low_partner_change),  # Or multiply bottom row of matrix by this
        },
    )
    
    activity_strat.set_mixing_matrix(mixing_matrix)
    model.stratify_with(activity_strat)
    
    # Outputs to track
    model.request_output_for_compartments(
        "infectious",
        ["infectious"],
        save_results=False,
    )
    model.request_output_for_compartments(
        "total",
        compartments,
        save_results=False,
    )
    model.request_function_output(
        "Overall",
        DerivedOutput("infectious") / 
        DerivedOutput("total")
    )
    
    for stratum in ["High", "Low"]:
        model.request_output_for_compartments(
            f"infectiousX{stratum}",
            ["infectious"],
            strata={"activity": stratum},
            save_results=False,
        )
        model.request_output_for_compartments(
            f"totalX{stratum}",
            compartments,
            strata={"activity": stratum},
            save_results=False,
        )
        model.request_function_output(
            stratum,
            DerivedOutput(f"infectiousX{stratum}") / 
            DerivedOutput(f"totalX{stratum}")
        )
    
    model.request_output_for_flow(
        "incidence",
        "infection",
    )
    
    return model

In [None]:
parameters = {
    "recovery": 6. / 365.,
    "contact_rate": 0.75 / 365.,
}
config = {
    "start_time": 0.,
    "end_time": 10. * 365.,
    "population": 1.,
    "seed": 1e-6,
    "high_prop": 0.02,
    "average_partner_change": 2.,
    "low_partner_change": 1.4,
}

sis_model = build_sis_model(config)
sis_model.run(parameters=parameters)

In [None]:
derived_outputs = sis_model.get_derived_outputs_df()
derived_outputs.index = derived_outputs.index / 365.
derived_outputs = derived_outputs * 100.  # Make a percentage

In [None]:
fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(14, 5))
line_styles = ("--", ":", "-")
for i_out, output in enumerate(derived_outputs.columns[:3]):
    left_ax.plot(
        derived_outputs.index,
        derived_outputs[output],
        color="k",
        label=output,
        linestyle=line_styles[i_out]
    )
left_ax.legend()
left_ax.set_ylabel("Prevalence of infection (%)")
left_ax.set_ylim(0., 50.)
left_ax.set_xlabel("Time, years")
left_ax.set_xlim(0., 10.)
left_ax.spines.top.set_visible(False)
left_ax.spines.right.set_visible(False)

image = plt.imread("./figures/fig_8_8a.jpg")
right_ax.axis("off")
right_ax.imshow(image, aspect="auto");

In [None]:
fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(14, 5))
left_ax.plot(
    derived_outputs.index,
    derived_outputs["incidence"] * 365.,
    color="k",
)
left_ax.set_ylabel("Incidence (%/year)")
left_ax.set_ylim(0., 50.)
left_ax.set_xlabel("Time, years")
left_ax.set_xlim(0., 10.)
left_ax.spines.top.set_visible(False)
left_ax.spines.right.set_visible(False)

image = plt.imread("./figures/fig_8_8b.jpg")
right_ax.axis("off")
right_ax.imshow(image, aspect="auto");