In [None]:
from tbh.demographic_tools import (
    get_population_over_time, 
    build_agegap_lookup, get_agegap_prob_jax,
    build_age_weight_lookup, get_age_weight_jax
)

iso3 = 'KIR'
AGE_GROUPS = ["0", "5", "15", "30", "65"]


# Load population size and fertility data

In [None]:
import pandas as pd
from tbh.paths import DATA_FOLDER

fertility_data = pd.read_csv(DATA_FOLDER / f"un_fertility_rates_{iso3}.csv",index_col=0)
normalised_fertility_data = fertility_data.div(fertility_data.sum(axis=1), axis=0)

single_age_pop_df, grouped_pop_df = get_population_over_time(iso3, age_groups=AGE_GROUPS, scaling_factor=1.0)


In [None]:
fert_probs, fert_year0, fert_age0 = build_agegap_lookup(normalised_fertility_data)

In [None]:
age_weights_lookup, ageweights_year0 = build_age_weight_lookup(AGE_GROUPS, single_age_pop_df)

# Main mixing matrix building function

In [None]:
import numpy as np

def build_contact_matrix(grouped_pop_df, fert_probs, fert_year0, fert_age0, age_weights_lookup, ageweights_year0, age_groups, a_spread, pc_strength, time, max_age=120):
    
    # Convert age group labels to integers
    age_lb = np.array([int(a) for a in age_groups])
    n_groups = len(age_lb)

    # Build symmetric per-pair contact matrix S, representing contact rate per pair of individuals
    S = np.zeros((n_groups, n_groups))
    for i in range(n_groups):
        lb_i = int(age_groups[i])
        ub_i = int(age_groups[i+1]) - 1 if i < len(age_groups) - 1 else max_age
        for j in range(i, n_groups):
            lb_j = int(age_groups[j])
            ub_j = int(age_groups[j+1]) - 1 if j < len(age_groups) - 1 else max_age
            
            assortative_component, pc_component = 0., 0.
            for age_i in range(lb_i, ub_i + 1):
                for age_j in range(lb_j, ub_j + 1):
                    age_gap = abs(age_i - age_j)
                    ageweight_prod = get_age_weight_jax(age_i, time, age_weights_lookup, ageweights_year0) * get_age_weight_jax(age_j, time, age_weights_lookup, ageweights_year0)
                    assortative_component += ageweight_prod * (1. / a_spread) * np.exp(-age_gap / a_spread)

                    child_age = min(age_i, age_j)
                    pc_component += ageweight_prod * get_agegap_prob_jax(fert_probs, fert_year0, fert_age0, time - child_age, age_gap)         
            
            S[i, j] = assortative_component + pc_strength * pc_component
            if i != j:
                S[j, i] = S[i, j]

    # Convert to asymmetric matrix C, representing number of contacts per individual using pop size by agegroup
    earliest_year, latest_year = grouped_pop_df.index.min(), grouped_pop_df.index.max()
    ref_year = min(max(time, earliest_year), latest_year)
    C = S * np.array(grouped_pop_df.loc[ref_year])[:, None]

    # Rescale C so its spectral radius is 1
    eigvals = np.linalg.eigvals(C)
    spectral_radius = max(abs(eigvals))
    normalised_C = C / spectral_radius

    return S, normalised_C

# Mixing matrix visualisation

In [None]:
import matplotlib.pyplot as plt
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=0, ha="center")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

# Example implementation

In [None]:
S, C = build_contact_matrix(grouped_pop_df, fert_probs, fert_year0, fert_age0, age_weights_lookup, ageweights_year0, AGE_GROUPS, a_spread=10., pc_strength=1.5, time=2000)

cmap = 'viridis'
plot_contact_matrix(
    S,
    AGE_GROUPS,
    title="Symmetric per-pair contact matrix S",
    cmap=cmap
)

plot_contact_matrix(
    C,
    AGE_GROUPS,
    title="Asymmetric 'who contacts whom' matrix C",
    cmap=cmap
)