# Susceptibility to infection and mixing matrices
From the previous notebook,
we hopefully have some idea of what the mixing matrices 
we can use in our models represent.
This is really important basic knowledge before we start trying to
manipulate them or play around with them in any way.

Let's start thinking about adapting our mixing matrix,
starting off with a simple density-dependent transmission model
stratified into two arbitrary categories
(similar to one of the models from the previous notebook).

To get started,
let's define a simple framework for a model
that we can add more features to later.
Also, let's define a simple two-stratum stratification object
that will expect a mixing matrix later on.

In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification, Multiply
from summer2.parameters import Parameter, DerivedOutput, Function

In [None]:
def build_sir_model(
    config: dict,
) -> CompartmentalModel:

    # Model characteristics
    compartments = config["compartments"]
    analysis_times = (0.0, config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    # Transitions
    model.add_infection_density_flow(
        name="infection", 
        contact_rate=Parameter("risk_per_contact"),
        source="susceptible", 
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    # Output
    model.request_output_for_compartments(
        "prevalence",
        "infectious",
    )
    
    return model

In [None]:
def build_simple_strat(
    compartments: list,
    mixing_matrix: jnp.array,
) -> Stratification:
                
    mix_strat = Stratification(
        "groups",
        ["group1", "group2"],
        compartments,
    )
    
    prop1 = Parameter("prop1")
    prop2 = 1. - prop1
    mix_strat.set_population_split(
        {
            "group1": prop1,
            "group2": prop2,
        }
    )

    if mixing_matrix is not None:
        mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
model_config = {
    "end_time": 20.0,
    "population": 1.0,
    "seed": 0.01,
    "compartments": ("susceptible", "infectious", "recovered"),
}
parameters = {
    "risk_per_contact": 1.0,
    "infectious_period": 2.0,
    "prop1": np.random.uniform(),
    "susceptibility": 2.0,
}

## Comparison model
As a comparison,
let's quickly build and run an unstratified model
with the base parameters.

In [None]:
base_model = build_sir_model(model_config)
base_model.run(parameters=parameters)

## Increased susceptibility for a sub-population
Next, let's increase the susceptibility of a stratum of the model,
by adjusting the rate of infection for one of our model sub-groups
(using `summer`'s adjustments structures).
This scales the rate of rate of infection for `group1` to be
the product of the `risk_per_contact` and the `susceptibility` parameter.
Increasing the effective parameter for the infection process
for a population stratum means that they will experience a greater
force of infection and so can be thought of as being more susceptible.

Next, let's multiply the first row
(relating to the infection of the `group1` population)
through by a value to represent increased susceptibility for this stratum.
The `susceptibility` parameter again increases the force of infection
for the `group1` population, by multiplying both of the contributions
to the force of infection through by the same value.
Let's check that the these two processes are equivalent.

In [None]:
suscept_matrix_model = build_sir_model(model_config)
mixing_matrix = jnp.array(
    [
        [1.0, 1.0],
        [1.0, 1.0],
    ]
)
suscept_param_strat = build_simple_strat(model_config["compartments"], mixing_matrix)
suscept_param_strat.set_flow_adjustments(
    "infection",
    {
        "group1": Parameter("susceptibility"),  # Increased susceptibility for group1
        "group2": None,  # No change for group2
    },
)
suscept_matrix_model.stratify_with(suscept_param_strat)
suscept_matrix_model.run(parameters=parameters)

In this case, we didn't even need the mixing matrix, of course.
We could achieve the same effect by adjusting susceptibility,
provided we do have the sub-populations implemented.

In [None]:
suscept_model = build_sir_model(model_config)
suscept_param_strat = build_simple_strat(model_config["compartments"], None)
suscept_param_strat.set_flow_adjustments(
    "infection",
    {
        "group1": Parameter("susceptibility"),
        "group2": None,
    },
)
suscept_model.stratify_with(suscept_param_strat)
suscept_model.run(parameters=parameters)

We can achieve the same thing by adjusting our mixing matrix.
To do this, we'll have to multiply the appropriate
row of our matrix through

In [None]:
strat_model = build_sir_model(model_config)

def build_frequency_mixing_matrix(susceptibility):
    return jnp.array(
        [
            [susceptibility, susceptibility],  # Multiply group1 row of matrix by the susceptibility value
            [1.0, 1.0],
        ]
    )

mixing_matrix = Function(
    build_frequency_mixing_matrix, 
    (Parameter("susceptibility"),),
)
suscept_matrix_strat = build_simple_strat(model_config["compartments"], mixing_matrix)
strat_model.stratify_with(suscept_matrix_strat)
strat_model.run(parameters=parameters)

In [None]:
outputs = pd.DataFrame(
    {
        "stratified, adjusted matrix": strat_model.get_derived_outputs_df()["prevalence"],
        "stratified, increased susceptibility": suscept_matrix_model.get_derived_outputs_df()["prevalence"],
        "stratified, increased susceptibility, no matrix": suscept_model.get_derived_outputs_df()["prevalence"],
    }
)
differences = outputs.min(axis=1) - outputs.max(axis=1)
assert all(abs(differences) < 1e-8), "There's a discrepancy"
outputs["base comparison"] = base_model.get_derived_outputs_df()["prevalence"]
outputs.plot()

# Adjusting infectiousness