In [None]:
from IPython.display import display, Markdown, HTML
from tbh.paths import REPO_ROOT_PATH, DATA_FOLDER
# analysis_path = REPO_ROOT_PATH / "remote_cluster" / "outputs" / "49574599_25sc_revised_se" / "task_1"
analysis_path = REPO_ROOT_PATH / "remote_cluster" / "outputs" / "49954154_heterogenous_mixing" / "task_1"


In [None]:
import tbh.plotting as pl
import tbh.runner_tools as rt

import pandas as pd
import arviz as az
from matplotlib import pyplot as plt 
plt.style.use("ggplot")


In [None]:
idata = az.from_netcdf(analysis_path / "idata.nc")

In [None]:
import yaml

with open(analysis_path / "details.yaml" , "r") as f:
    docs = list(yaml.safe_load_all(f))

model_config = docs[1]
analysis_config = docs[2]

In [None]:
chain_length = idata.posterior.sizes["draw"]
burnt_idata = idata.sel(draw=slice(analysis_config['burn_in'], chain_length))

child_socialising = burnt_idata.posterior['child_socialising'].stack(sample=("chain", "draw")).values
elderly_socialising = burnt_idata.posterior['elderly_socialising'].stack(sample=("chain", "draw")).values

In [None]:
import numpy as np
from numpy.linalg import eigvals


def build_mixing_matrix(child_socialising, elderly_socialising):
        
        # Assign socialising parameters per age group
        socialising = np.array([
            child_socialising if int(age) < 15 else
            elderly_socialising if int(age) >= 65 else
            1.0
            for age in model_config['age_groups']
        ])

        # Construct the mixing matrix: outer product
        M = np.outer(socialising, socialising)
        # Compute spectral radius (largest absolute eigenvalue)
        rho = np.max(np.abs(eigvals(M)))

        # Rescale so spectral radius = 1
        M = M / rho
        return M

In [None]:
# M = build_mixing_matrix(child_socialising[0], elderly_socialising[0])

In [None]:
def plot_mixing_matrix_uncertainty(
    child_samples,
    elderly_samples,
    ci=(0.025, 0.975),
    cmap="viridis"
):
    """
    Plot median mixing matrix and uncertainty from posterior samples.
    """

    n_agegroups = len(model_config['age_groups'])

    n_samples = len(child_samples)
    matrices = np.zeros((n_samples, n_agegroups, n_agegroups))

    for i in range(n_samples):
        matrices[i] = build_mixing_matrix(
            child_samples[i],
            elderly_samples[i]
        )

    # Posterior summaries
    M_median = np.median(matrices, axis=0)
    M_low = np.quantile(matrices, ci[0], axis=0)
    M_high = np.quantile(matrices, ci[1], axis=0)
    M_width = M_high - M_low

    print(M_width)

    labels = model_config['age_groups']

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # Median
    im0 = axes[0].imshow(M_median, cmap=cmap)
    axes[0].set_title("Median mixing matrix")
    # axes[0].set_xticks(range(n_agegroups))
    # axes[0].set_yticks(range(n_agegroups))
    # axes[0].set_xticklabels(labels)
    # axes[0].set_yticklabels(labels)
    plt.colorbar(im0, ax=axes[0], fraction=0.046)

    # Uncertainty
    im1 = axes[1].imshow(M_width, cmap="magma")
    axes[1].set_title("Credible interval width")
    # axes[1].set_xticks(range(n_agegroups))
    # axes[1].set_yticks(range(n_agegroups))
    # axes[1].set_xticklabels(labels)
    # axes[1].set_yticklabels(labels)
    plt.colorbar(im1, ax=axes[1], fraction=0.046)

    plt.tight_layout()
    return fig

In [None]:
fig = plot_mixing_matrix_uncertainty(
    child_socialising,
    elderly_socialising
)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt





def plot_random_posterior_mixing_matrices(
    child_samples,
    elderly_samples,
    age_groups,
    n_draws=9,
    cmap="viridis",
    seed=123
):
    """
    Plot mixing matrices from randomly sampled posterior draws.
    """

    rng = np.random.default_rng(seed)
    idx = rng.choice(len(child_samples), size=n_draws, replace=False)

    labels = age_groups
    n_groups = len(labels)

    ncols = int(np.ceil(np.sqrt(n_draws)))
    nrows = int(np.ceil(n_draws / ncols))

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(3.5 * ncols, 3.5 * nrows),
        squeeze=False
    )

    # ---- First pass: compute all matrices ----
    matrices = []
    for i in idx:

        M_agg = build_mixing_matrix(
            child_samples[i],
            elderly_samples[i]
        )
        matrices.append(M_agg)

    # ---- Global colour scale ----
    vmin = min(M.min() for M in matrices)
    vmax = max(M.max() for M in matrices)

    # ---- Plot ----
    for ax, M, i in zip(axes.flat, matrices, idx):
        im = ax.imshow(M, cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_xticks(range(n_groups))
        ax.set_yticks(range(n_groups))
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_yticklabels(labels)
        ax.set_title(f"Draw {i}")

    # Hide unused axes
    for ax in axes.flat[len(matrices):]:
        ax.axis("off")

    # Single shared colorbar
    cbar = fig.colorbar(
        im,
        ax=axes,
        orientation="horizontal",
        fraction=0.05,
        pad=0.08
    )
    cbar.set_label("Relative mixing intensity")
    
    plt.tight_layout()
    return fig

fig = plot_random_posterior_mixing_matrices(
    child_samples=child_socialising,
    elderly_samples=elderly_socialising,
    age_groups=model_config['age_groups'],
    n_draws=400
)