In [None]:
from tbh.demographic_tools import (
    get_population_over_time, 
    build_agegap_lookup,     
    build_age_weight_lookup, 
)

iso3 = 'KIR'
AGE_GROUPS = [str(a) for a in range(0, 65, 5)]


# Load population size and fertility data

In [None]:
import pandas as pd
from tbh.paths import DATA_FOLDER

fertility_data = pd.read_csv(DATA_FOLDER / f"un_fertility_rates_{iso3}.csv",index_col=0)
normalised_fertility_data = fertility_data.div(fertility_data.sum(axis=1), axis=0)

single_age_pop_df, grouped_pop_df = get_population_over_time(iso3, age_groups=AGE_GROUPS, scaling_factor=1.0)


In [None]:
fert_probs, fert_year0, fert_age0 = build_agegap_lookup(normalised_fertility_data)

In [None]:
age_weights_lookup, ageweights_year0 = build_age_weight_lookup(AGE_GROUPS, single_age_pop_df)

# Main mixing matrix building function

# Mixing matrix visualisation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=0, ha="center")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

# Example implementation

In [None]:
from tbh.age_mixing import gen_mixing_matrix_func
from jax import numpy as jnp

In [None]:
build_mixing_matrix = gen_mixing_matrix_func(grouped_pop_df, fert_probs, fert_year0, fert_age0, age_weights_lookup, ageweights_year0, AGE_GROUPS)

bg_mixing=.01
a_spread=10.
pc_strength=0.8

time=jnp.array(2025.)

M = build_mixing_matrix(bg_mixing, a_spread, pc_strength, time)


cmap = 'viridis'

plot_contact_matrix(
    M,
    AGE_GROUPS,
    title="Asymmetric 'who contacts whom' matrix C",
    cmap=cmap
)

In [None]:
M[0,0]

# Animation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import jax.numpy as jnp
from IPython.display import HTML

years = np.arange(1950, 2100, 5)

# --- Initial matrix ---
time0 = jnp.array(float(years[0]))
M0 = np.asarray(build_mixing_matrix(bg_mixing, a_spread, pc_strength, time0))

fig, ax = plt.subplots(figsize=(6, 5))

im = ax.imshow(M0, cmap=cmap, origin="upper")
cbar = plt.colorbar(im, ax=ax)
title = ax.set_title(f"Asymmetric contact matrix C — {years[0]}")

ax.set_xticks(range(len(AGE_GROUPS)))
ax.set_yticks(range(len(AGE_GROUPS)))
ax.set_xticklabels(AGE_GROUPS, rotation=90)
ax.set_yticklabels(AGE_GROUPS)
ax.set_xlabel("Age of index individual")
ax.set_ylabel("Age of contacted individuals")


    # Move x-axis to the top
ax.xaxis.set_ticks_position("top")
ax.xaxis.set_label_position("top")
ax.tick_params(axis="x", top=True, bottom=False)


# --- Update function ---
def update(year):
    time = jnp.array(float(year))
    M = np.asarray(build_mixing_matrix(bg_mixing, a_spread, pc_strength, time))
    im.set_data(M)
    title.set_text(f"Asymmetric 'who contacts whom' matrix C — {year}")
    return im, title

ani = FuncAnimation(
    fig,
    update,
    frames=years,
    interval=100,   # ms between frames
    blit=False
)
HTML(ani.to_jshtml())
# plt.show()

In [None]:
ani.save(
    "contact_matrix_1850_2050.gif",
    writer="pillow",
    fps=10
)