In [None]:
from tbh.demographic_tools import (
    get_population_over_time, 
    build_agegap_lookup, get_agegap_prob_jax,
    build_age_weight_lookup, get_age_weight_jax
)

iso3 = 'KIR'
AGE_GROUPS = ["0", "5", "15", "30", "65"]


# Load population size and fertility data

In [None]:
import pandas as pd
from tbh.paths import DATA_FOLDER

fertility_data = pd.read_csv(DATA_FOLDER / f"un_fertility_rates_{iso3}.csv",index_col=0)
normalised_fertility_data = fertility_data.div(fertility_data.sum(axis=1), axis=0)

single_age_pop_df, grouped_pop_df = get_population_over_time(iso3, age_groups=AGE_GROUPS, scaling_factor=1.0)


In [None]:
fert_probs, fert_year0, fert_age0 = build_agegap_lookup(normalised_fertility_data)

In [None]:
age_weights_lookup, ageweights_year0 = build_age_weight_lookup(AGE_GROUPS, single_age_pop_df)

# Main mixing matrix building function

In [None]:
from jax import numpy as jnp
from jax import vmap

def gen_mixing_matrix_func(grouped_pop_df, fert_probs, fert_year0, fert_age0, age_weights_lookup, ageweights_year0, age_groups):
    # Convert age group labels to integers
    age_lb = jnp.array([int(a) for a in age_groups])
    n_groups = len(age_lb)
    max_age = single_age_pop_df["Age"].max()

    def build_mixing_matrix(a_spread, pc_strength, time):
        S = jnp.zeros((n_groups, n_groups))  # JAX array

        for i in range(n_groups):
            lb_i = int(age_groups[i])
            ub_i = int(age_groups[i + 1]) - 1 if i < len(age_groups) - 1 else max_age
            ages_i = jnp.arange(lb_i, ub_i + 1)
            w_i = vmap(lambda a: get_age_weight_jax(a, time, age_weights_lookup, ageweights_year0))(ages_i)

            for j in range(i, n_groups):
                lb_j = int(age_groups[j])
                ub_j = int(age_groups[j + 1]) - 1 if j < len(age_groups) - 1 else max_age
                ages_j = jnp.arange(lb_j, ub_j + 1)
                w_j = vmap(lambda a: get_age_weight_jax(a, time, age_weights_lookup, ageweights_year0))(ages_j)

                # Outer product of weights
                weight_prod = w_i[:, None] * w_j[None, :]

                # Age gap and child age matrices
                age_gap_mat = jnp.abs(ages_i[:, None] - ages_j[None, :])
                child_age_mat = jnp.minimum(ages_i[:, None], ages_j[None, :])

                # Assortative component
                assortative_component = jnp.sum(weight_prod * (1.0 / a_spread) * jnp.exp(-age_gap_mat / a_spread))

                # Parent-child component
                pc_component = jnp.sum(weight_prod * get_agegap_prob_jax(fert_probs, fert_year0, fert_age0, time - child_age_mat, age_gap_mat))

                # Set S[i,j] and symmetric S[j,i]
                value = assortative_component + pc_strength * pc_component
                S = S.at[i, j].set(value)
                S = S.at[j, i].set(value)


            # Scale by population size of each age group
        earliest_year, latest_year = grouped_pop_df.index.min(), grouped_pop_df.index.max()
        ref_year = min(max(time, earliest_year), latest_year)
        pop_sizes = jnp.asarray(grouped_pop_df.loc[ref_year].values)  # shape (n_groups,)
        C = S * pop_sizes[:, None]  # asymmetric contacts

        # Normalize C by its spectral radius
        eigvals = jnp.linalg.eigvals(C)
        spectral_radius = jnp.max(jnp.abs(eigvals))
        normalised_C = C / spectral_radius

        return normalised_C
    
    return build_mixing_matrix

# Mixing matrix visualisation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=0, ha="center")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

# Example implementation

In [None]:
build_mixing_matrix = gen_mixing_matrix_func(grouped_pop_df, fert_probs, fert_year0, fert_age0, age_weights_lookup, ageweights_year0, AGE_GROUPS)

a_spread=10.
pc_strength=1.5
time=1970
M = build_mixing_matrix(a_spread, pc_strength, time)

cmap = 'viridis'

plot_contact_matrix(
    M,
    AGE_GROUPS,
    title="Asymmetric 'who contacts whom' matrix C",
    cmap=cmap
)