In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function

In [None]:
def build_sir_model(
    config: dict,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = config["compartments"]
    analysis_times = (0., config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    # Unchanged transitions
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    # Outputs
    model.request_output_for_compartments(
        "prevalence",
        "infectious",
    )
    
    return model

In [None]:
def build_location_strat(compartments, mixing_value):
                
    mix_strat = Stratification(
        "groups",
        ["pop1", "pop2"],
        compartments,
    )

    prop1 = Parameter("prop1")
    prop2 = 1. - prop1

    mix_strat.set_population_split(
        {
            "pop1": prop1,
            "pop2": prop2,
        }
    )

    mixing_matrix = jnp.array(
        [
            [mixing_value, mixing_value],
            [mixing_value, mixing_value],
        ]
    )

    mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
model_config = {
    "end_time": 20.,
    "population": 1.0,
    "seed": 0.01,
    "compartments": ("susceptible", "infectious", "recovered"),
}
parameters = {
    "risk_per_contact": 1.0,
    "infectious_period": 2.0,
    "prop1": 0.5,
}

bool_options = (True, False)
outputs = pd.DataFrame(columns=pd.MultiIndex.from_product((bool_options, bool_options)))

for freq_dependent in bool_options:
    for stratify in bool_options:
        heterogeneous_mixing_model = build_sir_model(model_config)
        
        # Apply appropriate transition rate
        if freq_dependent:
            heterogeneous_mixing_model.add_infection_frequency_flow(
                name="infection", 
                contact_rate=Parameter("risk_per_contact"),
                source="susceptible", 
                dest="infectious",
            )
        else:
            heterogeneous_mixing_model.add_infection_density_flow(
                name="infection", 
                contact_rate=Parameter("risk_per_contact"),
                source="susceptible", 
                dest="infectious",
            )
        
        # Stratify if we need to
        if stratify:
            mixing_value = 0.5 if freq_dependent else 1.0
            mix_strat = build_location_strat(model_config["compartments"], mixing_value)
            heterogeneous_mixing_model.stratify_with(mix_strat)

        # Get the results
        heterogeneous_mixing_model.run(parameters=parameters)
        outputs[(freq_dependent, stratify)] = heterogeneous_mixing_model.get_derived_outputs_df()["prevalence"]
    
# Flatten out anyway
outputs.columns = [f"{str(freq)}_{str(strat)}" for freq, strat in outputs.columns.values]

In [None]:
outputs.plot()