In [None]:
import matplotlib.pyplot as plt
from jax import numpy as jnp
import pandas as pd

from summer2 import CompartmentalModel, Stratification, AgeStratification
from summer2.parameters import Parameter, DerivedOutput

In [None]:
latent_period = 8.
infectious_period = 7.
r0 = 13.
life_expectancy = 70.
population = 1e5
seed = 1.

# Derived parameters
beta = r0 / infectious_period
progression = 1. / latent_period
recovery = 1. / infectious_period
mortality = 1. / life_expectancy / 365.

In [None]:
def build_demog_model(
    vacc_coverage: float,
    pre_vacc_period: float,
    duration: float,
):
    # Create the SEIR model with demographic processes
    compartments = (
        "Susceptible", 
        "Pre-infectious", 
        "Infectious", 
        "Immune"
    )
    model = CompartmentalModel(
        times=(
            -pre_vacc_period * 365.,
            duration * 365.,
        ),
        compartments=compartments,
        infectious_compartments=["Infectious"],
        timestep=0.01,
    )
    model.set_initial_population(
        distribution={
            "Susceptible": population - seed, 
            "Infectious": seed,
        }
    )
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=beta,
        source="Susceptible", 
        dest="Pre-infectious"
    )
    model.add_transition_flow(
        name="progression", 
        fractional_rate=progression,
        source="Pre-infectious", 
        dest="Infectious"
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=recovery, 
        source="Infectious", 
        dest="Immune",
    )
    model.add_universal_death_flows(
        "universal_death",
        death_rate=mortality,
    )
    model.add_crude_birth_flow(
        "births",
        mortality,
        "Susceptible",
    )
    
    model.request_output_for_compartments(
        "total_population",
        compartments,
    )
    
    model.request_output_for_flow(
        name="incidence",
        flow_name="progression",
    )
    
    model.request_function_output(
        name="incidence_rate",
        func=DerivedOutput("incidence") / DerivedOutput("total_population") * 1e5
    )
    
    # Age stratification
    age_strat = AgeStratification(
        "age",
        [0., 15.],
        compartments,
    )
    
    # The following code runs, but produces strange outputs
    # mixing_matrix = jnp.array(
    #     [[1., 1.], [1., 1.]],
    # )
    # mixing_matrix = jnp.array(
    #     [[1.8e-8., 5.02e-9], [5.02e-9, 5.02e-9]],
    # )
    # age_strat.set_mixing_matrix(mixing_matrix)
    
    model.stratify_with(age_strat)    

    # Vaccination stratification
    vacc_strat = Stratification(
        "vaccination",
        ["vaccinated", "unvaccinated"],
        ["Susceptible"],
    )
    
    # Start everyone out unvaccinated
    vacc_strat.set_population_split(
        {
            "vaccinated": 0.,
            "unvaccinated": 1.,
        }
    )
    
    # Vaccinated are completely immune
    vacc_strat.set_flow_adjustments(
        flow_name="infection",
        adjustments={
            "vaccinated": 0.,
            "unvaccinated": 1.,
        },
    )
    
    def step_up(time, values):
        return jnp.where(time > 0., vacc_coverage, 0.)
    
    def step_down(time, values):
        return jnp.where(time > 0., 1. - vacc_coverage, 1.)
    
    vacc_strat.set_flow_adjustments(
        flow_name="births",
        adjustments={
            "vaccinated": step_up,
            "unvaccinated": step_down,
        },
    )
    model.stratify_with(vacc_strat)    

    return model

In [None]:
fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(14, 5))
coverage = 0.72
duration = 10.

vacc_model = build_demog_model(
    vacc_coverage=coverage,
    pre_vacc_period=100.,
    duration=duration,
)
vacc_model.run()  # Should make the vaccination coverage a parameter, but haven't been able to
outputs = vacc_model.get_derived_outputs_df()["incidence_rate"]

left_ax.plot(
    outputs.index / 365.,
    outputs,
    color="k",
)
    
left_ax.set_xlabel("Years since the introduction of vaccination of newborns")
left_ax.set_xlim((-50., duration))
left_ax.set_ylabel("Daily number of new infectious per 100,000")
left_ax.set_ylim((0., 5.))
left_ax.grid(axis="y")