## Environment Setup


In [None]:
import copy
import os
from typing import Dict, List, Tuple

from IPython.display import Video
import matplotlib.animation as animation
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import math
import numpy as np
from tqdm.auto import tqdm


In [None]:
%load_ext watermark
%watermark -diwmuv -iv


## Simulation Implementation


In [None]:
# State Constants
S: int = 0
I: int = 1  # noqa: E741
R: int = 2


def simulate_with_frames(
    N_SITES: int = 2,
    POP_SIZE: int = 100_000,
    BASE_B: float = 0.3,
    CONTACT_RATE: float = 0.5,
    RECOVERY_RATE: float = 0.1,
    MUTATION_RATE: float = 1e-3,
    WANING_RATE: float = 0.016,
    IMMUNE_STRENGTH: float = 0.7,
    WANED_STRENGTH: float = 0.05,
    N_STEPS: int = 300,
    seed: int = 1,
) -> List[Dict]:
    """Run simulation and collect frame data for animation."""
    np.random.seed(seed)

    num_strains = 2**N_SITES
    num_alleles = 2 * N_SITES

    def initialize_pop() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        pathogen_genomes = np.zeros(shape=POP_SIZE, dtype=np.uint8)
        host_immunities = np.full(
            shape=(POP_SIZE, num_alleles), fill_value=3, dtype=np.int8
        )
        host_statuses = np.full(shape=POP_SIZE, fill_value=S, dtype=np.uint8)
        return host_statuses, pathogen_genomes, host_immunities

    def infect_initial(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        seed_count: int = 100,
    ) -> Tuple[np.ndarray, np.ndarray]:
        host_statuses[:seed_count] = I
        pathogen_genomes[:seed_count] = 0
        return host_statuses, pathogen_genomes

    def get_kappa(host_immunities: np.ndarray) -> np.ndarray:
        kappas = np.zeros_like(host_immunities, dtype=np.float16)
        kappas[(host_immunities == 2) | (host_immunities == 1)] = 1.0
        kappas[host_immunities == 0] = WANED_STRENGTH
        return kappas

    def update_waning(host_immunities: np.ndarray) -> np.ndarray:
        for level in [2, 1]:
            mask = (host_immunities == level) & (
                np.random.rand(*host_immunities.shape) < WANING_RATE
            )
            host_immunities[mask] -= 1
        return host_immunities

    def update_recoveries(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        inf_mask = host_statuses == I
        rec_mask = inf_mask & (np.random.rand(POP_SIZE) < RECOVERY_RATE)
        indices = np.where(rec_mask)[0]

        if indices.size > 0:
            g = pathogen_genomes[indices][:, None]
            shifts = np.arange(N_SITES, dtype=np.uint64)
            bits = (g >> shifts) & np.uint8(1)
            allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)

            row_idx = np.repeat(indices, N_SITES)
            col_idx = allele_indices.flatten()
            host_immunities[row_idx, col_idx] = 2
            host_statuses[indices] = R
        return host_statuses, host_immunities

    def update_infections(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        infector_mask = host_statuses == I
        num_infectors = int(np.sum(infector_mask))
        if num_infectors == 0:
            return host_statuses, pathogen_genomes, host_immunities

        targets = np.random.randint(
            low=0, high=POP_SIZE, size=num_infectors, dtype=np.uint32
        )
        inf_genomes = pathogen_genomes[infector_mask]

        bits = (
            inf_genomes[:, None] >> np.arange(N_SITES, dtype=np.uint8)
        ) & np.uint8(1)
        allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)

        kappas = get_kappa(host_immunities[targets])
        target_kappas = np.take_along_axis(kappas, allele_indices, axis=1)
        susc_factor = np.prod(1.0 - (IMMUNE_STRENGTH * target_kappas), axis=1)

        total_b = N_SITES * BASE_B
        prob = total_b * CONTACT_RATE * susc_factor

        success = (np.random.rand(num_infectors) < prob) & (
            host_statuses[targets] != I
        )
        new_inf_idx = targets[success]

        if new_inf_idx.size > 0:
            host_statuses[new_inf_idx] = I
            new_genomes = inf_genomes[success]

            mut_mask = np.random.rand(new_inf_idx.size) < MUTATION_RATE
            if np.any(mut_mask):
                num_mut = int(np.sum(mut_mask))
                flip_pos = np.random.randint(
                    low=0, high=N_SITES, size=num_mut
                ).astype(np.uint64)
                new_genomes[mut_mask] ^= np.uint64(1) << flip_pos

            pathogen_genomes[new_inf_idx] = new_genomes

        return host_statuses, pathogen_genomes, host_immunities

    def collect_frame_data(
        t: int,
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Dict:
        """Collect data for a single animation frame."""
        # Host status distribution (S, I, R)
        status_counts = np.array(
            [
                np.sum(host_statuses == S),
                np.sum(host_statuses == I),
                np.sum(host_statuses == R),
            ]
        ) / POP_SIZE

        # Strain prevalence (all 2^N_SITES strains)
        strain_counts = np.zeros(num_strains)
        inf_mask = host_statuses == I
        if np.any(inf_mask):
            unique_g, counts = np.unique(
                pathogen_genomes[inf_mask], return_counts=True
            )
            for g, c in zip(unique_g, counts):
                strain_counts[g] = c / POP_SIZE

        # Allele susceptibility
        pop_kappas = get_kappa(host_immunities)
        allele_susc = np.mean(1.0 - (IMMUNE_STRENGTH * pop_kappas), axis=0)

        # Immunity level distribution per allele (4 levels x num_alleles)
        immunity_dist = np.zeros((4, num_alleles))
        for level in range(4):
            immunity_dist[level] = np.mean(host_immunities == level, axis=0)

        return {
            "step": t,
            "status_dist": status_counts,
            "strain_prev": strain_counts,
            "allele_susc": allele_susc,
            "immunity_dist": immunity_dist,
        }

    # Run simulation
    host_statuses, pathogen_genomes, host_immunities = initialize_pop()
    host_statuses, pathogen_genomes = infect_initial(
        host_statuses, pathogen_genomes
    )
    frames: List[Dict] = []

    for t in tqdm(range(N_STEPS), desc="Simulating"):
        host_statuses, host_immunities = update_recoveries(
            host_statuses, pathogen_genomes, host_immunities
        )
        host_statuses, pathogen_genomes, host_immunities = update_infections(
            host_statuses, pathogen_genomes, host_immunities
        )
        host_immunities = update_waning(host_immunities)

        frames.append(
            collect_frame_data(
                t, host_statuses, pathogen_genomes, host_immunities
            )
        )

    return frames


## Run Simulation


In [None]:
N_SITES = 2
frames = simulate_with_frames(
    N_SITES=N_SITES,
    POP_SIZE=100_000,
    N_STEPS=300,
    MUTATION_RATE=1e-3,
    seed=42,
)


## Create Animation


In [None]:
def create_animation(
    frames: List[Dict],
    n_sites: int,
    output_path: str,
    fps: int = 15,
) -> str:
    """Create animated visualization of the ABM simulation."""
    num_strains = 2**n_sites
    num_alleles = 2 * n_sites

    # Determine strain grid dimensions
    strain_rows = int(np.ceil(np.sqrt(num_strains)))
    strain_cols = int(np.ceil(num_strains / strain_rows))

    # Create figure with subplots for each matrix
    fig = plt.figure(figsize=(14, 8))
    gs = fig.add_gridspec(
        2, 3, width_ratios=[1.5, 1, 1.5], height_ratios=[1, 1], hspace=0.3, wspace=0.3
    )

    # Create axes for each matrix
    ax_status = fig.add_subplot(gs[0, 0])
    ax_strain = fig.add_subplot(gs[0, 1])
    ax_susc = fig.add_subplot(gs[0, 2])
    ax_immunity = fig.add_subplot(gs[1, :])

    # Initialize plots
    # Host status (discrete: S, I, R)
    status_data = frames[0]["status_dist"].reshape(1, 3)
    im_status = ax_status.imshow(
        status_data, cmap="RdYlGn", vmin=0, vmax=1, aspect="auto"
    )
    ax_status.set_title("Host Status Distribution", fontweight="bold")
    ax_status.set_xticks([0, 1, 2])
    ax_status.set_xticklabels(["S", "I", "R"])
    ax_status.set_yticks([])

    # Strain prevalence (continuous)
    strain_data = np.zeros((strain_rows, strain_cols))
    strain_data.flat[: len(frames[0]["strain_prev"])] = frames[0]["strain_prev"]
    im_strain = ax_strain.imshow(
        strain_data, cmap="viridis", vmin=0, vmax=0.01, aspect="equal"
    )
    ax_strain.set_title("Strain Prevalence", fontweight="bold")
    # Label strains with binary codes
    for i in range(num_strains):
        row, col = divmod(i, strain_cols)
        binary_label = format(i, f"0{n_sites}b")
        ax_strain.text(
            col, row, binary_label, ha="center", va="center",
            fontsize=8, color="white", fontweight="bold"
        )
    ax_strain.set_xticks([])
    ax_strain.set_yticks([])

    # Allele susceptibility (continuous)
    susc_data = frames[0]["allele_susc"].reshape(2, n_sites)
    im_susc = ax_susc.imshow(
        susc_data, cmap="plasma", vmin=0, vmax=1, aspect="auto"
    )
    ax_susc.set_title("Allele Susceptibility", fontweight="bold")
    ax_susc.set_xlabel("Site")
    ax_susc.set_ylabel("Allele")
    ax_susc.set_xticks(range(n_sites))
    ax_susc.set_yticks([0, 1])
    ax_susc.set_yticklabels(["0", "1"])

    # Immunity level distribution (discrete levels 0-3)
    immunity_data = frames[0]["immunity_dist"]
    im_immunity = ax_immunity.imshow(
        immunity_data, cmap="coolwarm", vmin=0, vmax=1, aspect="auto"
    )
    ax_immunity.set_title("Immunity Level Distribution", fontweight="bold")
    ax_immunity.set_xlabel("Allele Index")
    ax_immunity.set_ylabel("Immunity Level")
    ax_immunity.set_xticks(range(num_alleles))
    ax_immunity.set_yticks([0, 1, 2, 3])
    ax_immunity.set_yticklabels(["Waned", "Full-2", "Full-1", "Naive"])

    # Add step counter
    step_text = fig.suptitle("Step: 0", fontsize=14, fontweight="bold")

    def update(frame_idx):
        """Update function for animation."""
        frame = frames[frame_idx]

        # Update host status
        im_status.set_data(frame["status_dist"].reshape(1, 3))

        # Update strain prevalence
        strain_data = np.zeros((strain_rows, strain_cols))
        strain_data.flat[: len(frame["strain_prev"])] = frame["strain_prev"]
        # Dynamic scaling for strain prevalence
        max_prev = max(0.001, np.max(frame["strain_prev"]))
        im_strain.set_data(strain_data)
        im_strain.set_clim(0, max_prev)

        # Update allele susceptibility
        im_susc.set_data(frame["allele_susc"].reshape(2, n_sites))

        # Update immunity distribution
        im_immunity.set_data(frame["immunity_dist"])

        # Update step counter
        step_text.set_text(f"Step: {frame['step']}")

        return [im_status, im_strain, im_susc, im_immunity, step_text]

    # Create animation
    anim = animation.FuncAnimation(
        fig,
        update,
        frames=len(frames),
        interval=1000 // fps,
        blit=False,
    )

    # Save animation
    writer = animation.FFMpegWriter(fps=fps, bitrate=1800)
    anim.save(output_path, writer=writer)
    plt.close(fig)

    return output_path


In [None]:
notebook_name = "2025-02-04-abm-animation"
output_dir = f"teeplots/{notebook_name}"
os.makedirs(output_dir, exist_ok=True)
output_path = f"{output_dir}/combined-abm-animation.mp4"
create_animation(frames, N_SITES, output_path, fps=15)
print(f"Animation saved to: {output_path}")


## Display Animation


In [None]:
Video(output_path, embed=True, width=800)


## Create Granular Animation


In [None]:
def simulate_raw_frames(
    N_SITES: int = 2,
    POP_SIZE: int = 100_000,
    BASE_B: float = 0.3,
    CONTACT_RATE: float = 0.5,
    RECOVERY_RATE: float = 0.1,
    MUTATION_RATE: float = 1e-3,
    WANING_RATE: float = 0.016,
    IMMUNE_STRENGTH: float = 0.7,
    WANED_STRENGTH: float = 0.05,
    N_STEPS: int = 300,
    seed: int = 1,
) -> List[Dict]:
    """
    Run simulation and collect RAW state arrays (deep copies) for every frame.
    Note: This consumes more memory than the aggregate version.
    """
    np.random.seed(seed)

    # -- Copy of the internal helper functions from the previous block --
    # (Simplified here for brevity as they are defined within the scope)
    num_alleles = 2 * N_SITES

    # 1. Initialize
    pathogen_genomes = np.zeros(shape=POP_SIZE, dtype=np.uint8)
    host_immunities = np.full(shape=(POP_SIZE, num_alleles), fill_value=3, dtype=np.int8)
    host_statuses = np.full(shape=POP_SIZE, fill_value=0, dtype=np.uint8) # S=0

    # 2. Infect Initial
    seed_count = 100
    host_statuses[:seed_count] = 1 # I=1
    pathogen_genomes[:seed_count] = 0

    # Re-define logic helpers locally to ensure closure access
    def get_kappa_local(h_imm):
        k = np.zeros_like(h_imm, dtype=np.float16)
        k[(h_imm == 2) | (h_imm == 1)] = 1.0
        k[h_imm == 0] = WANED_STRENGTH
        return k

    frames = []

    for t in tqdm(range(N_STEPS), desc="Simulating Raw Data"):
        # -- Logic matches original simulate_with_frames --

        # Recoveries
        inf_mask = host_statuses == 1 # I
        rec_mask = inf_mask & (np.random.rand(POP_SIZE) < RECOVERY_RATE)
        indices = np.where(rec_mask)[0]
        if indices.size > 0:
            g = pathogen_genomes[indices][:, None]
            shifts = np.arange(N_SITES, dtype=np.uint64)
            bits = (g >> shifts) & np.uint8(1)
            allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)
            row_idx = np.repeat(indices, N_SITES)
            col_idx = allele_indices.flatten()
            host_immunities[row_idx, col_idx] = 2
            host_statuses[indices] = 2 # R

        # Infections
        infector_mask = host_statuses == 1
        num_infectors = int(np.sum(infector_mask))
        if num_infectors > 0:
            targets = np.random.randint(0, POP_SIZE, size=num_infectors, dtype=np.uint32)
            inf_genomes = pathogen_genomes[infector_mask]

            # Susceptibility calc
            bits = (inf_genomes[:, None] >> np.arange(N_SITES, dtype=np.uint8)) & np.uint8(1)
            allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)
            kappas = get_kappa_local(host_immunities[targets])
            target_kappas = np.take_along_axis(kappas, allele_indices, axis=1)
            susc_factor = np.prod(1.0 - (IMMUNE_STRENGTH * target_kappas), axis=1)

            prob = (N_SITES * BASE_B) * CONTACT_RATE * susc_factor
            success = (np.random.rand(num_infectors) < prob) & (host_statuses[targets] != 1)
            new_inf_idx = targets[success]

            if new_inf_idx.size > 0:
                host_statuses[new_inf_idx] = 1
                new_genomes = inf_genomes[success]
                # Mutation
                mut_mask = np.random.rand(new_inf_idx.size) < MUTATION_RATE
                if np.any(mut_mask):
                    num_mut = int(np.sum(mut_mask))
                    flip_pos = np.random.randint(0, N_SITES, size=num_mut).astype(np.uint64)
                    new_genomes[mut_mask] ^= np.uint64(1) << flip_pos
                pathogen_genomes[new_inf_idx] = new_genomes

        # Waning
        for level in [2, 1]:
            mask = (host_immunities == level) & (np.random.rand(*host_immunities.shape) < WANING_RATE)
            host_immunities[mask] -= 1

        # -- Capture Raw Data --
        # We copy to avoid reference issues as the arrays mutate
        frames.append({
            "step": t,
            "host_statuses": host_statuses.copy(),
            "pathogen_genomes": pathogen_genomes.copy(),
            "host_immunities": host_immunities.copy()
        })

    return frames


In [None]:
def reshape_to_square_with_nan(flat_arr: np.ndarray) -> np.ndarray:
    """
    Reshapes 1D array to square, padding empty space with np.nan.
    Converts to float to support NaN.
    """
    n = flat_arr.size
    side = math.ceil(math.sqrt(n))
    # Initialize with NaNs. Must be float to hold NaN.
    padded = np.full(side * side, fill_value=np.nan, dtype=np.float32)
    padded[:n] = flat_arr
    return padded.reshape((side, side))

def create_masked_matrix_animation(
    frames: list,
    n_sites: int,
    output_path: str,
    fps: int = 15
):
    num_alleles = 2 * n_sites

    # --- Figure Setup ---
    fig = plt.figure(figsize=(4 * num_alleles, 12), facecolor='black') # Black bg for consistency
    gs = fig.add_gridspec(3, num_alleles, height_ratios=[1.2, 1, 1], hspace=0.4)

    # --- Axes Setup ---
    # Set axes background to black so gaps look intentional
    mid_point = num_alleles // 2
    ax_status = fig.add_subplot(gs[0, :mid_point])
    ax_genome = fig.add_subplot(gs[0, mid_point:])

    imm_axes = [fig.add_subplot(gs[1, i]) for i in range(num_alleles)]
    path_allele_axes = [fig.add_subplot(gs[2, i]) for i in range(num_alleles)]

    # Helper to style axes text for black background
    def style_ax(ax, title):
        ax.set_title(title, fontsize=10, fontfamily='monospace', color='white')
        ax.axis('off')

    # --- Colormaps with Masking ---

    # 1. Status: 0=S, 1=I, 2=R
    cmap_status = copy.copy(mcolors.ListedColormap(['lightgray', 'crimson', 'mediumseagreen']))
    cmap_status.set_bad(color='black')

    # 2. Genome: Integer IDs
    cmap_genome = copy.copy(plt.cm.get_cmap("tab20", 16))
    cmap_genome.set_bad(color='black')

    # 3. Immunity: 0=Waned, 1=Weak, 2=Strong, 3=Naive
    cmap_imm = copy.copy(mcolors.ListedColormap(['salmon', 'gold', 'cornflowerblue', 'whitesmoke']))
    cmap_imm.set_bad(color='black')

    # 4. Pathogen Bits: 0=Absent, 1=Present
    cmap_bit = copy.copy(mcolors.ListedColormap(['whitesmoke', 'rebeccapurple']))
    cmap_bit.set_bad(color='black')

    # --- Helper Logic ---
    def get_pathogen_allele_presence(genomes, site_idx, allele_variant):
        bits = (genomes >> site_idx) & 1
        return (bits == allele_variant).astype(np.float32) # Convert to float for NaN padding later

    # --- Initialization ---
    f0 = frames[0]

    # Row 1
    im_status = ax_status.imshow(reshape_to_square_with_nan(f0["host_statuses"]), cmap=cmap_status, vmin=0, vmax=2)
    style_ax(ax_status, "Host Statuses\n(S/I/R)")

    im_genome = ax_genome.imshow(reshape_to_square_with_nan(f0["pathogen_genomes"]), cmap=cmap_genome, vmin=0, vmax=15)
    style_ax(ax_genome, "Pathogen Genomes\n(Strain ID)")

    # Rows 2 & 3
    im_imm_list = []
    im_path_list = []

    for i in range(num_alleles):
        site_idx = i // 2
        variant = i % 2

        # Immunity
        imm_data = reshape_to_square_with_nan(f0["host_immunities"][:, i])
        im_imm = imm_axes[i].imshow(imm_data, cmap=cmap_imm, vmin=0, vmax=3)
        style_ax(imm_axes[i], f"Host Imm\nSite {site_idx} Allele {variant}")
        im_imm_list.append(im_imm)

        # Pathogen Allele
        pres_data = get_pathogen_allele_presence(f0["pathogen_genomes"], site_idx, variant)
        path_data = reshape_to_square_with_nan(pres_data)
        im_path = path_allele_axes[i].imshow(path_data, cmap=cmap_bit, vmin=0, vmax=1)
        style_ax(path_allele_axes[i], f"Pathogen Allele\nSite {site_idx} Allele {variant}")
        im_path_list.append(im_path)

    step_text = fig.suptitle(f"Step: {f0['step']}", fontsize=16, fontweight='bold', color='white')

    # --- Update Loop ---
    def update(frame_idx):
        frame = frames[frame_idx]

        im_status.set_data(reshape_to_square_with_nan(frame["host_statuses"]))
        im_genome.set_data(reshape_to_square_with_nan(frame["pathogen_genomes"]))

        for i in range(num_alleles):
            # Immunity
            im_imm_list[i].set_data(reshape_to_square_with_nan(frame["host_immunities"][:, i]))

            # Pathogen
            site_idx = i // 2
            variant = i % 2
            pres_data = get_pathogen_allele_presence(frame["pathogen_genomes"], site_idx, variant)
            im_path_list[i].set_data(reshape_to_square_with_nan(pres_data))

        step_text.set_text(f"Step: {frame['step']}")
        return [im_status, im_genome, step_text] + im_imm_list + im_path_list

    anim = animation.FuncAnimation(
        fig, update, frames=len(frames), interval=1000 // fps, blit=False
    )

    writer = animation.FFMpegWriter(fps=fps, bitrate=3000)
    anim.save(output_path, writer=writer)
    plt.close(fig)
    return output_path


## Run Granular Animation


In [None]:
raw_frames = simulate_raw_frames(N_SITES=2, POP_SIZE=100_000, N_STEPS=300, seed=42)

output_masked_path = f"{output_dir}/masked_matrix_animation.mp4"
print(f"Rendering masked animation to {output_masked_path}...")
create_masked_matrix_animation(raw_frames, N_SITES, output_masked_path, fps=15)


## Display Granular Animation


In [None]:
Video(output_masked_path, embed=True, width=1000)
