In [None]:
from tbh.demographic_tools import get_population_over_time

iso3 = 'KIR'
AGE_GROUPS = ["0", "5", "15", "30", "65"]


# Load population size and fertility data

In [None]:
import pandas as pd
from tbh.paths import DATA_FOLDER

fertility_data = pd.read_csv(DATA_FOLDER / f"un_fertility_rates_{iso3}.csv",index_col=0)
normalised_fertility_data = fertility_data.div(fertility_data.sum(axis=1), axis=0)

single_age_pop_df, grouped_pop_df = get_population_over_time(iso3, age_groups=AGE_GROUPS, scaling_factor=1.0)


# Some helper functions

In [None]:
def get_agegap_prob(normalised_fertility_data, age_gap, birth_year):
    """
    Retrieve the probability of a given parent-child age gap for a specific birth year.

    Parameters:
        normalised_fertility_data (pd.DataFrame): Fertility probabilities indexed by birth year, columns are mothers' ages.
        age_gap (int): Age difference between parent and child.
        birth_year (int): Year of birth of the child.

    Returns:
        float: Probability corresponding to the age gap and birth year, or 0.0 if out of range.
    """

    # Ensure birth_year is within the data range, if not, clamp it to the nearest available year
    earliest_year, latest_year = normalised_fertility_data.index.min(), normalised_fertility_data.index.max()
    birth_year = min(max(birth_year, earliest_year), latest_year)

    # Return 0.0 if age_gap is less than youngest age in the data or greater than oldest age
    youngest_age, oldest_age = normalised_fertility_data.columns.astype(int).min(), normalised_fertility_data.columns.astype(int).max()
    if age_gap < youngest_age or age_gap > oldest_age:
        return 0.0
    else:
        return normalised_fertility_data[str(age_gap)].loc[birth_year]

get_agegap_prob(normalised_fertility_data, age_gap=30, birth_year=2000)

In [None]:

def get_relative_age_weight(age, lower_bound, upper_bound, single_age_pop_df, year):
    assert age >= lower_bound and age <= upper_bound       
    earliest_year, latest_year = single_age_pop_df['Time'].min(), single_age_pop_df['Time'].max()
    year = min(max(year, earliest_year), latest_year)

    subset_df = single_age_pop_df[(single_age_pop_df["Time"] == year) & (single_age_pop_df["Age"] >= lower_bound) & (single_age_pop_df["Age"] <= upper_bound)] 
    age_pop = subset_df[subset_df["Age"] == age]["Pop"].iloc[0]
    total_pop = subset_df["Pop"].sum()  

    return age_pop / total_pop

get_relative_age_weight(age=12, lower_bound=11, upper_bound=13, single_age_pop_df=single_age_pop_df, year=2000)


# Main mixing matrix building function

In [None]:
import numpy as np

def build_contact_matrix(normalised_fertility_data, single_age_pop_df, grouped_pop_df, age_groups, a_spread, pc_strength, time, max_age=120):
    
    # Convert age group labels to integers
    age_lb = np.array([int(a) for a in age_groups])
    n_groups = len(age_lb)

    # Build symmetric per-pair contact matrix S, representing contact rate per pair of individuals
    S = np.zeros((n_groups, n_groups))
    for i in range(n_groups):
        lb_i = int(age_groups[i])
        ub_i = int(age_groups[i+1]) - 1 if i < len(age_groups) - 1 else max_age
        for j in range(i, n_groups):
            lb_j = int(age_groups[j])
            ub_j = int(age_groups[j+1]) - 1 if j < len(age_groups) - 1 else max_age
            
            assortative_component, pc_component = 0., 0.
            for age_i in range(lb_i, ub_i + 1):
                age_i_weight = get_relative_age_weight(age_i, lb_i, ub_i + 1, single_age_pop_df, time)
                for age_j in range(lb_j, ub_j + 1):
                    age_j_weight = get_relative_age_weight(age_j, lb_j, ub_j + 1, single_age_pop_df, time)  
                    age_gap = abs(age_i - age_j)
                    
                    assortative_component += age_i_weight * age_j_weight * (1. / a_spread) * np.exp(-age_gap / a_spread)

                    child_age = min(age_i, age_j)
                    pc_component += age_i_weight * age_j_weight * get_agegap_prob(normalised_fertility_data, age_gap, int(time - child_age))           
            
            S[i, j] = assortative_component + pc_strength * pc_component
            if i != j:
                S[j, i] = S[i, j]

    # Convert to asymmetric matrix C, representing number of contacts per individual using pop size by agegroup
    earliest_year, latest_year = grouped_pop_df.index.min(), grouped_pop_df.index.max()
    ref_year = min(max(time, earliest_year), latest_year)
    C = S * np.array(grouped_pop_df.loc[ref_year])[:, None]

    # Rescale C so its spectral radius is 1
    eigvals = np.linalg.eigvals(C)
    spectral_radius = max(abs(eigvals))
    normalised_C = C / spectral_radius

    return S, normalised_C

# Mixing matrix visualisation

In [None]:
import matplotlib.pyplot as plt
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=0, ha="center")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

# Example implementation

In [None]:
S, C = build_contact_matrix(normalised_fertility_data, single_age_pop_df, grouped_pop_df, AGE_GROUPS, a_spread=10., pc_strength=1.5, time=2000)

cmap = 'viridis'
plot_contact_matrix(
    S,
    AGE_GROUPS,
    title="Symmetric per-pair contact matrix S",
    cmap=cmap
)

plot_contact_matrix(
    C,
    AGE_GROUPS,
    title="Asymmetric 'who contacts whom' matrix C",
    cmap=cmap
)