In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function

In [None]:
def build_sir_model(
    config: dict,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = config["compartments"]
    analysis_times = (0., config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    # Unchanged transition
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    # Outputs
    model.request_output_for_compartments(
        "prevalence",
        "infectious",
    )
    
    return model

In [None]:
def build_location_strat(compartments, mixing_matrix):
                
    mix_strat = Stratification(
        "groups",
        ["pop1", "pop2"],
        compartments,
    )

    prop1 = Parameter("prop1")
    prop2 = 1. - prop1

    mix_strat.set_population_split(
        {
            "pop1": prop1,
            "pop2": prop2,
        }
    )

    mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
model_config = {
    "end_time": 20.,
    "population": 1.0,
    "seed": 0.01,
    "compartments": ("susceptible", "infectious", "recovered"),
}
parameters = {
    "risk_per_contact": 1.0,
    "infectious_period": 2.0,
    "prop1": 0.7,
}

transmission_options = ("freq", "dens")
stratification_options = ("stratified", "unstratified")
outputs = pd.DataFrame(columns=pd.MultiIndex.from_product((transmission_options, stratification_options)))

for trans_opt in transmission_options:
    for stratify_opt in stratification_options:
        model = build_sir_model(model_config)
        
        # Frequency dependence
        if trans_opt == "freq":
            model.add_infection_frequency_flow(
                name="infection", 
                contact_rate=Parameter("risk_per_contact"),
                source="susceptible", 
                dest="infectious",
            )
            if stratify_opt == "stratified":
                mixing_value = 0.3
                mixing_matrix = jnp.array(
                    [
                        [mixing_value, 1.0 - mixing_value],
                        [mixing_value, 1.0 - mixing_value],
                    ]
                )
                mix_strat = build_location_strat(model_config["compartments"], mixing_matrix)
                model.stratify_with(mix_strat)
        
        # Density dependence
        else:
            model.add_infection_density_flow(
                name="infection", 
                contact_rate=Parameter("risk_per_contact"),
                source="susceptible", 
                dest="infectious",
            )
        
            # Stratify if we need to
            if stratify_opt == "stratified":
                mixing_matrix = jnp.array(
                    [
                        [1.0, 1.0],
                        [1.0, 1.0],
                    ]
                )
                mix_strat = build_location_strat(model_config["compartments"], mixing_matrix)
                model.stratify_with(mix_strat)

        # Get the results
        model.run(parameters=parameters)
        outputs[(trans_opt, stratify_opt)] = model.get_derived_outputs_df()["prevalence"]
    
# Flatten out anyway
outputs.columns = [f"{str(freq)}_{str(strat)}" for freq, strat in outputs.columns.values]

In [None]:
outputs.plot()