In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function

In [None]:
def build_sir_model(
    config: dict,
    stratify: bool,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = config["compartments"]
    analysis_times = (0., config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    # Transitions
    model.add_infection_density_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"),
        source="susceptible", 
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    model.request_output_for_compartments(
        "n_infectious",
        "infectious",
    )
    
    return model

In [None]:
def build_simple_strat(compartments):
                
    mix_strat = Stratification(
        "activity",
        ["urban", "rural"],
        compartments,
    )

    urban_prop = Parameter("urban_prop")
    rural_prop = 1. - urban_prop

    mix_strat.set_population_split(
        {
            "urban": urban_prop,
            "rural": rural_prop,
        }
    )

    mixing_matrix = jnp.array(
        [
            [1., 1.],
            [1., 1.],
        ]
    )

    mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
model_config = {
    "end_time": 20.,
    "population": 1.,
    "seed": 0.01,
    "compartments": ("susceptible", "infectious", "recovered"),
}
parameters = {
    "contact_rate": 1.,
    "infectious_period": 2.,
    "urban_prop": 0.1,
}

outputs = pd.DataFrame(columns=(True, False))
for stratify in (True, False):
    heterogeneous_mixing_model = build_sir_model(model_config, stratify)
    if stratify:
        mix_strat = build_simple_strat(model_config["compartments"])
        heterogeneous_mixing_model.stratify_with(mix_strat)
    
    heterogeneous_mixing_model.run(parameters=parameters)
    outputs[stratify] = heterogeneous_mixing_model.get_derived_outputs_df()["n_infectious"]

In [None]:
list(zip(outputs[True], outputs[False]))

In [None]:
[i - j < 0.0001 for i, j in zip(outputs[True], outputs[False])]

In [None]:
outputs.plot()