# Implementing empiric survey data in a model
Next let's feed these empiric data into a model.
The easiest way to do this is to use a model of a population 
that is the same or similar to the one that we're modelling.
If we're doing this, then it's reasonable to just feed the mixing data straight in.
The POLYMOD surveys were performed from late 2005 through much of 2006,
so if we're using the matrix for Great Britain and simulating
an infectious disease in that population, then we should be fine,
provided we also use the age structure for that population.

Let's get started by using the same data as we used in the previous notebook.
Then we'll also use some population age structure data for the same setting,
i.e. the United Kingdom in 2006.

In [None]:
try:
    import google.colab
    %pip install summerepi2
except:
    pass

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
import plotly.graph_objects as go
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, DerivedOutput

In [None]:
def build_polymod_britain_matrix():
    matrix = [
        [1.92, 0.65, 0.41, 0.24, 0.46, 0.73, 0.67, 0.83, 0.24, 0.22, 0.36, 0.20, 0.20, 0.26, 0.13],
        [0.95, 6.64, 1.09, 0.73, 0.61, 0.75, 0.95, 1.39, 0.90, 0.16, 0.30, 0.22, 0.50, 0.48, 0.20],
        [0.48, 1.31, 6.85, 1.52, 0.27, 0.31, 0.48, 0.76, 1.00, 0.69, 0.32, 0.44, 0.27, 0.41, 0.33],
        [0.33, 0.34, 1.03, 6.71, 1.58, 0.73, 0.42, 0.56, 0.85, 1.16, 0.70, 0.30, 0.20, 0.48, 0.63],
        [0.45, 0.30, 0.22, 0.93, 2.59, 1.49, 0.75, 0.63, 0.77, 0.87, 0.88, 0.61, 0.53, 0.37, 0.33],
        [0.79, 0.66, 0.44, 0.74, 1.29, 1.83, 0.97, 0.71, 0.74, 0.85, 0.88, 0.87, 0.67, 0.74, 0.33],
        [0.97, 1.07, 0.62, 0.50, 0.88, 1.19, 1.67, 0.89, 1.02, 0.91, 0.92, 0.61, 0.76, 0.63, 0.27],
        [1.02, 0.98, 1.26, 1.09, 0.76, 0.95, 1.53, 1.50, 1.32, 1.09, 0.83, 0.69, 1.02, 0.96, 0.20],
        [0.55, 1.00, 1.14, 0.94, 0.73, 0.88, 0.82, 1.23, 1.35, 1.27, 0.89, 0.67, 0.94, 0.81, 0.80],
        [0.29, 0.54, 0.57, 0.77, 0.97, 0.93, 0.57, 0.80, 1.32, 1.87, 0.61, 0.80, 0.61, 0.59, 0.57],
        [0.33, 0.38, 0.40, 0.41, 0.44, 0.85, 0.60, 0.61, 0.71, 0.95, 0.74, 1.06, 0.59, 0.56, 0.57],
        [0.31, 0.21, 0.25, 0.33, 0.39, 0.53, 0.68, 0.53, 0.55, 0.51, 0.82, 1.17, 0.85, 0.85, 0.33],
        [0.26, 0.25, 0.19, 0.24, 0.19, 0.34, 0.40, 0.39, 0.47, 0.55, 0.41, 0.78, 0.65, 0.85, 0.57],
        [0.09, 0.11, 0.12, 0.20, 0.19, 0.22, 0.13, 0.30, 0.23, 0.13, 0.21, 0.28, 0.36, 0.70, 0.60],
        [0.14, 0.15, 0.21, 0.10, 0.24, 0.17, 0.15, 0.41, 0.50, 0.71, 0.53, 0.76, 0.47, 0.74, 1.47],
    ]
    return jnp.array(matrix).T

In [None]:
age_groups = [i for i in range(0, 75, 5)]
age_pops_list = [
    3458060, 3556024, 3824317, 3960916, 3911291, 3762213, 4174675, 4695853, 
    4653082, 3986098, 3620216, 3892985, 3124676, 2706365, 6961183,
]
age_pops = pd.Series(age_pops_list, index=age_groups)
print(f"Total simulated population is {round(age_pops.sum() / 1e6, 1)} million")
age_pops.plot.area(labels={"index": "age", "value": "population"})

In [None]:
def build_sir_model(
    config: dict,
) -> CompartmentalModel:
    """
    This function is similar to the model builder in several previous notebooks, 
    except that it implements frequency-dependent mixing.
    
    Args:
        config: User requests to define model construction
    Returns:
        The model object
    """

    compartments = config["compartments"]
    analysis_times = (0.0, config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        },
    )

    model.add_infection_frequency_flow(
        name="infection",
        contact_rate=Parameter("risk_per_contact"),
        source="susceptible",
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    return model

In [None]:
def build_age_strat(
    compartments: list,
    age_pops: pd.Series,
    mixing_matrix: jnp.array,
) -> Stratification:
    """
    Get a stratification that just divides the population into age groups
    as per the user request and implements a mixing matrix provided from outside this function.
    
    Args:
        compartments: The compartments to be stratified (all of them)
        age_pops: Series with indexes the age groups to be implemented and values the size of each age-specific population group
        mixing_matrix: The mixing matrix for this stratification
    Returns:
        The completed Stratification object    
    """
    
    mix_strat = Stratification(
        "age",
        (str(age) for age in age_pops.index),
        compartments,
    )
    
    split = {str(age): pop / age_pops.sum() for age, pop in zip(age_pops.index, age_pops)}
    mix_strat.set_population_split(split)        

    mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
def add_group_strat_prevalence_request(
    model: CompartmentalModel, 
):
    """
    Similar function to that in the assortative-mixing notebook.
    """
    
    for age_start in model.get_stratification("age").strata:
        group = age_start if len(age_start) == 2 else "0" + age_start
        model.request_output_for_compartments(
            f"total_{group}",
            model_config["compartments"],
            strata={"age": age_start},
            save_results=False,
        )
        model.request_output_for_compartments(
            f"n_infectious_{group}",
            "infectious",
            strata={"age": age_start},
            save_results=False,
        )
        model.request_function_output(
            f"{group}",
            func=DerivedOutput(f"n_infectious_{group}") / DerivedOutput(f"total_{group}"),
        )

... and let's set some arbitrary parameters
for an imaginary infectious disease
outbreak in this population,
and run the model.

In [None]:
model_config = {
    "end_time": 40.0,
    "population": 60e6,
    "seed": 100.0,
    "compartments": ("susceptible", "infectious", "recovered"),
}
parameters = {
    "risk_per_contact": 0.1,
    "infectious_period": 4.0,
}

In [None]:
uk_model = build_sir_model(model_config)
mixing_matrix = build_polymod_britain_matrix()
mix_strat = build_age_strat(model_config["compartments"], age_pops, mixing_matrix)
uk_model.stratify_with(mix_strat)
add_group_strat_prevalence_request(uk_model)
uk_model.run(parameters)
output_data = uk_model.get_derived_outputs_df()

Let's see what the results look like.
We probably don't expect them to show any particularly dramatic age effects,
because the only difference between the age groups is the contact rates,
and we haven't implemented any other epidemiological differences between
the age groups.
Nevertheless, we should see some small differences between the groups
because of the heterogeneous mixing.

In [None]:
plotting_data = output_data.iloc[10: 30]
transposed_data = plotting_data.transpose()
fig = go.Figure(
    data=[
        go.Contour(
            x=transposed_data.columns, 
            y=transposed_data.index, 
            z=transposed_data.values
        )
    ]
)
fig.update_xaxes(title="time")
fig.update_yaxes(title="age group")
fig.update_traces(colorbar_title_text="prevalence")
fig.show()

In [None]:
fig = go.Figure(
    data=[
        go.Surface(
            x=transposed_data.columns, 
            y=transposed_data.index, 
            z=transposed_data.values,
        )
    ],
)
fig.update_layout(
    width=800, 
    height=600, 
    scene=dict(
        xaxis_title="time",
        yaxis_title="age group",
        zaxis_title="prevalence",
    ),
)
fig.show()

So transmission is a little more intense and occurs a little earlier in the younger age groups.