In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function

In [None]:
def build_sis_model(
    config: dict,
) -> CompartmentalModel:
    
    # Model characteristics
    compartments = (
        "susceptible",
        "infectious",
        "recovered",
    )
    analysis_times = (0., config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    # Transitions
    model.add_infection_density_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"),
        source="susceptible", 
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=Parameter("recovery"),
        source="infectious", 
        dest="recovered",
    )
    
    # Activity rate stratification
    activity_strat = Stratification(
        "activity",
        ["High", "Low"],
        compartments,
    )
    
    high_prop = Parameter("high_prop")
    low_prop = 1. - high_prop

    activity_strat.set_population_split(
        {
            "High": high_prop,
            "Low": low_prop,
        }
    )

    mixing_matrix = jnp.array(
        [
            [1., 1.],
            [1., 1.],
        ]
    )

    activity_strat.set_mixing_matrix(mixing_matrix)
    
    model.stratify_with(activity_strat)
    
    return model