In [None]:
from jax import numpy as jnp

from tbh.demographic_tools import get_population_over_time
from tbh.age_mixing import get_model_ready_age_mixing_matrix

iso3 = 'KIR'

# AGE_GROUPS = [str(a) for a in range(0, 65, 5)]
AGE_GROUPS = ["0", "3", "5", "10", "15", "18", "40", "65"]

In [None]:
single_age_pop_df, grouped_pop_df = get_population_over_time(iso3, age_groups=AGE_GROUPS, scaling_factor=1.0)

age_mixing_matrix = get_model_ready_age_mixing_matrix(iso3, AGE_GROUPS, grouped_pop_df, single_age_pop_df)
age_mixing_matrix_func = age_mixing_matrix.func

# Mixing matrix visualisation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_contact_matrix(M, age_groups, title, cmap="viridis"):
    """
    Plot a contact matrix as a heatmap with ticks aligned to cell centres.
    """
    n = len(age_groups)

    fig, ax = plt.subplots(figsize=(7, 6))

    im = ax.imshow(
        M,
        origin="upper",
        cmap=cmap,
        aspect="auto",
        interpolation="none"
    )

    # Major ticks at cell centres
    age_lb = [int(a) for a in age_groups]
    labels = (
        [f"{age_lb[i]}-{age_lb[i+1] - 1}" for i in range(len(age_lb) - 1)]
        + [f"{age_lb[-1]}+"]
    )

    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Move x-axis to the top
    ax.xaxis.set_ticks_position("top")
    ax.xaxis.set_label_position("top")
    ax.tick_params(axis="x", top=True, bottom=False)

    # Set axis limits to match matrix extent exactly
    ax.set_xlim(-0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw gridlines on cell boundaries
    ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
    ax.grid(which="minor", color="white", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_xlabel("Contacting individual age group (j)")
    ax.set_ylabel("Contacted individual age group (i)")
    ax.set_title(title)

    plt.setp(ax.get_xticklabels(), rotation=20, ha="center")

    cbar = fig.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Contact rate", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()

# Example implementation

In [None]:
bg_mixing=0. #0.1
a_spread=10.
pc_strength=0.

time=jnp.array(2025.)

M = age_mixing_matrix_func(bg_mixing, a_spread, pc_strength, time)

cmap = 'viridis'
plot_contact_matrix(
    M,
    AGE_GROUPS,
    title="Asymmetric matrix C",
    cmap=cmap
)

# Animation

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# from matplotlib.animation import FuncAnimation
# import jax.numpy as jnp
# from IPython.display import HTML

# years = np.arange(1950, 2100, 5)

# # --- Initial matrix ---
# time0 = jnp.array(float(years[0]))
# M0 = np.asarray(age_mixing_matrix_func(bg_mixing, a_spread, pc_strength, time0))

# fig, ax = plt.subplots(figsize=(6, 5))

# im = ax.imshow(M0, cmap=cmap, origin="upper")
# cbar = plt.colorbar(im, ax=ax)
# title = ax.set_title(f"Asymmetric contact matrix C — {years[0]}")

# ax.set_xticks(range(len(AGE_GROUPS)))
# ax.set_yticks(range(len(AGE_GROUPS)))
# ax.set_xticklabels(AGE_GROUPS, rotation=90)
# ax.set_yticklabels(AGE_GROUPS)
# ax.set_xlabel("Age of index individual")
# ax.set_ylabel("Age of contacted individuals")


#     # Move x-axis to the top
# ax.xaxis.set_ticks_position("top")
# ax.xaxis.set_label_position("top")
# ax.tick_params(axis="x", top=True, bottom=False)


# # --- Update function ---
# def update(year):
#     time = jnp.array(float(year))
#     M = np.asarray(age_mixing_matrix_func(bg_mixing, a_spread, pc_strength, time))
#     im.set_data(M)
#     title.set_text(f"Asymmetric 'who contacts whom' matrix C — {year}")
#     return im, title

# ani = FuncAnimation(
#     fig,
#     update,
#     frames=years,
#     interval=100,   # ms between frames
#     blit=False
# )
# HTML(ani.to_jshtml())


In [None]:
# ani.save(
#     "contact_matrix_1850_2050.gif",
#     writer="pillow",
#     fps=10
# )