# Population interactions

## Rationale
So we've now got to the point that we should have
some idea of what we are meaning by the mixing matrices
that we might like to put into a model.
But what do we really mean epidemiologically by
different rates of contact between population groups?
In reality, people who share certain characteristics
often come into contact with people from other groups
at rates that differ systematically 
from pepole who don't share those characteristics.

For example, suppose we have two population groups,
and we want to consider that people from one population group
are more likely to contact other people from that same group
than they are to contact other people from the other group.
Let's call our two population groups `group1` and `group2`.
We've decided that the rate at which a person from 
`group1` contacts others from `group1` should be greater than the
rate at which someone from from `group1` contacts people from `group2`.
However, to implement this in a model,
we'll need to think about every possible combination of people from one group
coming into contact with people from any other group
(including the group that person is from themself).
So the number of possible group-group interactions we have will
be the square of the number of groups we're modelling.
This is four in the current example,
and we'll need to decide on values for each of the four cells of our $2\times 2$ matrix.

## Starting assumptions
It's probably slightly simpler to start looking at this under the density-dependent framework,
because we're dealing with _per capita per capita_ rates 
(i.e. _per capita_ with respect to both the susceptible and the infecting individual).
In this case, we might like to consider that the rate at which 
the first population group contacts the second is the same as the rate
at which the second contacts the first.
Provided we are thinking of the contents of the mixing matrix as 
rates of contact per individual per unit time, this is fine.

For now, let's still assume that there are no epidemiological
differences between our groups, 
such as in susceptibility or infectiousness
(which we'll explore in the [following notebook](./15-susceptibility-infectiousness-matrices.ipynb).
Sticking to this approach should make things easier,
because we can think of the mixing matrix as just containing
the rates at which people come into contact with one another
regardless of the risk of transmission from that contact.

To summarise, let's start off assuming:
- Two interacting population subgroups
- Density-dependent transmission
- Uniform susceptibility and infectiousness for both groups
- Our mixing matrix contains information on the rate of interaction between specific individuals from each possible combination of the two population subgroups being modelled

## Assortative mixing
The slightly more general assumption
that interactions with your own population group
happen more often than interactions with other population groups
is what is called "assortative mixing".
For example, our two population groups might represent
people living in two neighbouring towns,
and we might want to consider that the rate at which
people come into contact with other people from their own town
is greater than the rate at which they come into contact 
with people from the other town.

Let's define a function that can do this for us,
and build a really simple matrix to represent the idea.

In [None]:
import numpy as np
import pandas as pd
from jax import numpy as jnp
pd.options.plotting.backend = "plotly"
import plotly.express as px

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter, Function, DerivedOutput

In [None]:
def build_assortative_mixing_matrix(
    intergroup_contact_rate: float, 
    jax: bool=True,
) -> np.array:
    """
    Get a basic 2 x 2 matrix with ones on the diagonal
    and the user-submitted request on the off-diagonal.
    
    Args:
        intergroup_contact_rate: Value for the off-diagonal cells
        jax: Whether to return as a numpy or jax numpy array
    """
    
    values = [
        [1.0, intergroup_contact_rate],
        [intergroup_contact_rate, 1.0],
    ]
    
    return jnp.array(values) if jax else np.array(values)

In [None]:
assortative_matrix = build_assortative_mixing_matrix(0.5, jax=False)
px.imshow(assortative_matrix)

OK, so we've demonstrated what a very simple assortative mixing matrix looks like.
Next, let's get this running in a `summer` model.

In [None]:
def build_sir_model(
    config: dict,
) -> CompartmentalModel:
    """
    This is the same model as introduced in the mixing-and-transmission notebook.
    
    Args:
        config: User requests to define model construction
    Returns:
        The very basic model object
    """
    compartments = config["compartments"]
    analysis_times = (0.0, config["end_time"])
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        }
    )
    
    model.add_infection_density_flow(
        name="infection",
        contact_rate=Parameter("risk_per_contact"),
        source="susceptible",
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery", 
        fractional_rate=1. / Parameter("infectious_period"),
        source="infectious", 
        dest="recovered",
    )
    
    return model

In [None]:
def build_simple_strat(
    compartments: list,
) -> Stratification:
    """
    Get a stratification that divides the modelled population into two groups
    and applies a symmetric mixing matrix.
    
    Args:
        compartments: The compartments to be stratified (here all the model's compartments)
        mixing_matrix: The mixing matrix for this stratification
    Returns:
        The completed Stratification object    
    """
                
    mix_strat = Stratification(
        "groups",
        ["group1", "group2"],
        compartments,
    )
    
    prop1 = Parameter("prop1")
    prop2 = 1. - prop1
    mix_strat.set_population_split(
        {
            "group1": prop1,
            "group2": prop2,
        }
    )

    mixing_matrix = Function(
        build_assortative_mixing_matrix, 
        (Parameter("intergroup_interaction"),),
    )
    
    mix_strat.set_mixing_matrix(mixing_matrix)

    return mix_strat

In [None]:
def add_group_strat_prevalence_request(
    model: CompartmentalModel, 
    model_name: str,
):
    """
    Track the prevalence of infection in the population
    of a previously defined model, as well as 
    the prevalence in each stratified sub-group of the model
    (provided these are named in a standard way, using "group").
    
    Args:
        model: The model object to modify in-place
        model_name: A string to add for the name of the model
    """
    
    model.request_output_for_compartments(
        f"{model_name}_prevalence",
        "infectious",
    )
    
    for group in model.get_stratification("groups").strata:
        model.request_output_for_compartments(
            f"total_{group}",
            model_config["compartments"],
            strata={"groups": group},
            save_results=False,
        )
        model.request_output_for_compartments(
            f"n_infectious_{group}",
            "infectious",
            strata={"groups": group},
            save_results=False,
        )
        model.request_function_output(
            f"{model_name}_prevalence_{group}",
            func=DerivedOutput(f"n_infectious_{group}") / DerivedOutput(f"total_{group}"),
        )

In [None]:
model_config = {
    "end_time": 40.0,
    "population": 1.0,
    "seed": 0.01,
    "compartments": ("susceptible", "infectious", "recovered"),
}

parameters = {
    "risk_per_contact": 0.5,
    "infectious_period": 4.0,
    "prop1": 0.5,
    "intergroup_interaction": 1.0,
}

## Unstratified model
Just for reference, let's first build a similar model without heterogeneous mixing.

In [None]:
unstratified_model = build_sir_model(model_config)
unstratified_model.request_output_for_compartments(
    "unstrat_prevalence",
    "infectious",
)
unstratified_model.run(parameters)

### Stratified model with homogeneous mixing

In [None]:
simple_strat_model = build_sir_model(model_config)
mix_strat = build_simple_strat(model_config["compartments"])
simple_strat_model.stratify_with(mix_strat)
add_group_strat_prevalence_request(simple_strat_model, "simple")
simple_strat_model.run(parameters)

At this stage, the model is very simple
and the stratification isn't doing anything 
because the rate of transmission hasn't been adjusted
(because the mixing matrix only contains ones,
because the user request for the off-diagonal is one).

In [None]:
build_assortative_mixing_matrix(parameters["intergroup_interaction"], jax=False)

### Stratified model with heterogeneous mixing
We'll now introduce a matrix that implements assortative mixing,
by setting the off-diagonal elements of the matrix to a value lower than one.
This is intended to represent the idea that the rate of interactions
between people from different population sub-groups is proportionately
lower than the base rate at which people interact with people from their own sub-group.

In [None]:
parameters.update({"intergroup_interaction": 0.5})
build_assortative_mixing_matrix(parameters["intergroup_interaction"], jax=False)

In [None]:
hetero_mix_model = build_sir_model(model_config)
mix_strat = build_simple_strat(model_config["compartments"])
hetero_mix_model.stratify_with(mix_strat)
add_group_strat_prevalence_request(hetero_mix_model, "hetero")

We now have assortative mixing in our model.
However, we've also reduced the total number of daily contacts that
people have per unit time.
In the simple model, the number of contacts was one per day
(or time unit).
We've now reduced the number of daily contacts to 0.75,
because in this model we have half of the population assigned each of the two 
modelled groups and these proportions remain constant
(because there is no entry or exit into or out of the model
or transition between strata).

If we want to have this level of assortativity,
but still the same daily risk of transmission 
for each susceptible individual each day,
we could scale our risk of transmission per contact parameter 
(`risk_per_contact`) up to account for this
(equivalent to multiplying the whole matrix through by $\frac{4}{3}$).

In [None]:
parameters["risk_per_contact"] *= 4.0 / 3.0
hetero_mix_model.run(parameters)

### Assortative model outputs
So now we have implemented assortative mixing in the model.
Let's see what that looks like.

In [None]:
pd.concat(
    (
        unstratified_model.get_derived_outputs_df(), 
        simple_strat_model.get_derived_outputs_df(), 
        hetero_mix_model.get_derived_outputs_df(),
    ), 
    axis=1,
).plot()

Perhaps we're driving you completely crazy now;
the assortative mixing clearly still hasn't done anything at all
(as you can show by selecting/unselecting the outputs
of the various models in the graph just above).
However, this illustrates an important point about heterogeneous mixing,
which is that the mixing itself may not have any important effects
unless the groups that are mixing heterogeneously have some sort 
of different epidemiogical characteristics.
Let's implement that next, so that our assortative mixing model
is actually achieving something!

### Assortative mixing model with adjustment
Let's adjust something about one of our strata and re-run the model.
We could choose just about any epidemiologically important process here,
such as susceptibility, recovery rate or infectiousness,
but let's arbitrarily choose infectiousness for now,
which we'll double for the `group1` sub-population.
Because people from the `group1` population are both more infectious
and also interact preferentially with other `group1` individuals,
the prevalence of infection becomes greater in this group than in `group2`.

In [None]:
hetero_adj_model = build_sir_model(model_config)
mix_strat = build_simple_strat(model_config["compartments"])
mix_strat.add_infectiousness_adjustments(
    "infectious",
    {
        "group1": 2.0,
        "group2": None,
    }
)
hetero_adj_model.stratify_with(mix_strat)
add_group_strat_prevalence_request(hetero_adj_model, "hetero_adj")
hetero_adj_model.run(parameters)
hetero_adj_model.get_derived_outputs_df().plot()

Note that if we implement completely homogeneous mixing,
then the increased infectiousness 
will still have an effect on overall model dynamics,
because part of the population is just more infectious 
and so the force of infection is greater.
However, in this case, the force of infection
will be greater for all groups,
so the prevalence and profile of the epidemic 
remains the same for `group1` and `group2`
(and so for the overall population).
Let's demonstrate that.

In [None]:
parameters.update({"intergroup_interaction": 1.0})
hetero_adj_model.run(parameters)
hetero_adj_model.get_derived_outputs_df().plot()

At the other extreme,
if we set the interaction between groups to zero,
then we have absolutely no interaction between the two populations.
This is equivalent to having two totally independent SIR models,
with no contact between the two groups that we are simulating.
In this case, the basic reproduction number for the epidemic
dynamics in `group1` is double that for the epidemic occurring in `group2`.

In [None]:
parameters.update({"intergroup_interaction": 0.0})
hetero_adj_model.run(parameters)
hetero_adj_model.get_derived_outputs_df().plot()

## Manipulating density-dependent transmission matrices
Under this density-dependent framework for thinking about heterogeneous transmission,
there is some equivalence between the opposite elements of our matrix.
If the elements of the matrix just represent the rates at which someone
from one of the population groups comes into contact with 
someone from another population group,
then we would expect the matrix to be symmetrical.
That is, the rate at which a specific person from `group1` comes into contact with
a specific person from `group2` should be equal to the rate at which
a specific person from `group2` comes into a specific person from `group1`.
So the symmetric matrix we had been using above, is a reasonable choice.

Next, suppose we have a mixing matrix that we want to use in a model which describes
the contact rates in these density-dependent terms, but is not symmetric.
This is a pretty theoretical situation and not really something we're likely to come across,
but let's consider it anyway.
In this case, it would probably be reasonable to assume that the reason that the
opposite elements of the matrix don't match is because there was some problem
with how we collected our data that we used in building the matrix.
If we accept that the rate at which `group1` contacts `group2` should be the same
as the rate at which `group2` contacts `group1`,
then we essentially have two estimates for the same quantity
(in the top-right and bottom-left cells of our $2 \times 2$ matrix).
So it would probably be reasonable to average out these two quantities
to find one value that we should use for both cells of our matrix.

Let's consider how that might look in code
(building a function that can handle square matrices of any size).
We'll get into this sort of thing in more detail in the next notebook.
For now, just note that this process is relatively simple 
while we're considering density-dependent transmission.

It may seem obvious that we could do things like this,
but as we'll see in the next notebook,
we may need to be more careful with frequency-dependent matrices.

In [None]:
def get_averaged_matrix(matrix):
    """
    Average out the corresponding opposite off-diagonal elements of the input matrix,
    keeping the diagonal elements unchanged.
    
    Args:
        matrix: The user-submitted matrix to manipulate
    Returns:
        Matrix with the same dimensions as the input, adapted as describe
    """
    upper_part = np.triu(matrix, k=1)
    lower_part = np.tril(matrix, k=-1)
    upper_result = np.average((upper_part, lower_part.transpose()), axis=0)
    lower_result = upper_result.transpose()
    return upper_result + lower_result + np.diag(np.diag(matrix))

In [None]:
asymmetric_matrix = np.array(
    [
        [1.0, 0.4], 
        [0.6, 1.0],
    ]
)
px.imshow(asymmetric_matrix)

In [None]:
symmetric_matrix = get_averaged_matrix(asymmetric_matrix)
px.imshow(symmetric_matrix)