In [None]:
import numpy as np

def build_contact_matrix(age_groups, pop_sizes:dict, a_spread, pc_strength, pc_gap, pc_spread):
    
    # Convert age group labels to integers
    age_lb = np.array([int(a) for a in age_groups])
    n_groups = len(age_lb)

    # Representative ages (midpoints); last group is open-ended "A+": represented by A + 10
    mid = [0.5 * (age_lb[i] + age_lb[i + 1] - 1) for i in range(n_groups - 1)] + [age_lb[-1] + 10.] 

    # Build symmetric per-pair contact matrix S, representing contact rate per pair of individuals
    S = np.zeros((n_groups, n_groups))
    for i in range(n_groups):
        for j in range(i, n_groups):
            age_diff = abs(mid[i] - mid[j])

            # Assortative mixing
            assortative = np.exp(-age_diff / a_spread)

            # Parentâ€“child mixing
            parent_child = pc_strength * np.exp(-abs(age_diff - pc_gap) / pc_spread)

            S[i, j] = assortative + parent_child
            S[j, i] = S[i, j]

    # Convert to asymmetric matrix C, representing number of contacts per individual
    C = S * np.array([pop_sizes[age] for age in age_groups])[:, None]

    # Rescale C so its spectral radius is 1
    eigvals = np.linalg.eigvals(C)
    spectral_radius = max(abs(eigvals))
    normalised_C = C / spectral_radius

    return S, normalised_C

In [None]:
import matplotlib.pyplot as plt
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=20, ha="right")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

In [None]:
age_groups = ["0", "3", "5", "10", "15", "18", "40", "65"]
pop_sizes = {
    "0": 1000,
    "3": 1200,
    "5": 1500,
    "10": 2000,
    "15": 1800,
    "18": 2200,
    "40": 2500,
    "65": 1600
}
S, C = build_contact_matrix(age_groups, pop_sizes, a_spread=5, pc_strength=1.0, pc_gap=25, pc_spread=5)

plot_contact_matrix(
    S,
    age_groups,
    title="Symmetric per-pair contact matrix S"
)

plot_contact_matrix(
    C,
    age_groups,
    title="Asymmetric 'who contacts whom' matrix C"
)