## Environment Setup


In [1]:
import os
from typing import Dict, List, Tuple

from IPython.display import Video
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext watermark
%watermark -diwmuv -iv


Last updated: 2026-02-04T17:50:23.304328+00:00

Python implementation: CPython
Python version       : 3.10.12
IPython version      : 7.31.1

Compiler    : GCC 11.4.0
OS          : Linux
Release     : 6.8.0-1044-azure
Machine     : x86_64
Processor   : x86_64
CPU cores   : 4
Architecture: 64bit

numpy     : 2.1.2
matplotlib: 3.9.2

Watermark: 2.4.3



## Simulation Implementation


In [3]:
# State Constants
S: int = 0
I: int = 1  # noqa: E741
R: int = 2


def simulate_with_frames(
    N_SITES: int = 2,
    POP_SIZE: int = 100_000,
    BASE_B: float = 0.3,
    CONTACT_RATE: float = 0.5,
    RECOVERY_RATE: float = 0.1,
    MUTATION_RATE: float = 1e-3,
    WANING_RATE: float = 0.016,
    IMMUNE_STRENGTH: float = 0.7,
    WANED_STRENGTH: float = 0.05,
    N_STEPS: int = 300,
    seed: int = 1,
) -> List[Dict]:
    """Run simulation and collect frame data for animation."""
    np.random.seed(seed)

    num_strains = 2**N_SITES
    num_alleles = 2 * N_SITES

    def initialize_pop() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        pathogen_genomes = np.zeros(shape=POP_SIZE, dtype=np.uint8)
        host_immunities = np.full(
            shape=(POP_SIZE, num_alleles), fill_value=3, dtype=np.int8
        )
        host_statuses = np.full(shape=POP_SIZE, fill_value=S, dtype=np.uint8)
        return host_statuses, pathogen_genomes, host_immunities

    def infect_initial(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        seed_count: int = 100,
    ) -> Tuple[np.ndarray, np.ndarray]:
        host_statuses[:seed_count] = I
        pathogen_genomes[:seed_count] = 0
        return host_statuses, pathogen_genomes

    def get_kappa(host_immunities: np.ndarray) -> np.ndarray:
        kappas = np.zeros_like(host_immunities, dtype=np.float16)
        kappas[(host_immunities == 2) | (host_immunities == 1)] = 1.0
        kappas[host_immunities == 0] = WANED_STRENGTH
        return kappas

    def update_waning(host_immunities: np.ndarray) -> np.ndarray:
        for level in [2, 1]:
            mask = (host_immunities == level) & (
                np.random.rand(*host_immunities.shape) < WANING_RATE
            )
            host_immunities[mask] -= 1
        return host_immunities

    def update_recoveries(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        inf_mask = host_statuses == I
        rec_mask = inf_mask & (np.random.rand(POP_SIZE) < RECOVERY_RATE)
        indices = np.where(rec_mask)[0]

        if indices.size > 0:
            g = pathogen_genomes[indices][:, None]
            shifts = np.arange(N_SITES, dtype=np.uint64)
            bits = (g >> shifts) & np.uint8(1)
            allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)

            row_idx = np.repeat(indices, N_SITES)
            col_idx = allele_indices.flatten()
            host_immunities[row_idx, col_idx] = 2
            host_statuses[indices] = R
        return host_statuses, host_immunities

    def update_infections(
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        infector_mask = host_statuses == I
        num_infectors = int(np.sum(infector_mask))
        if num_infectors == 0:
            return host_statuses, pathogen_genomes, host_immunities

        targets = np.random.randint(
            low=0, high=POP_SIZE, size=num_infectors, dtype=np.uint32
        )
        inf_genomes = pathogen_genomes[infector_mask]

        bits = (
            inf_genomes[:, None] >> np.arange(N_SITES, dtype=np.uint8)
        ) & np.uint8(1)
        allele_indices = (2 * np.arange(N_SITES) + bits).astype(int)

        kappas = get_kappa(host_immunities[targets])
        target_kappas = np.take_along_axis(kappas, allele_indices, axis=1)
        susc_factor = np.prod(1.0 - (IMMUNE_STRENGTH * target_kappas), axis=1)

        total_b = N_SITES * BASE_B
        prob = total_b * CONTACT_RATE * susc_factor

        success = (np.random.rand(num_infectors) < prob) & (
            host_statuses[targets] != I
        )
        new_inf_idx = targets[success]

        if new_inf_idx.size > 0:
            host_statuses[new_inf_idx] = I
            new_genomes = inf_genomes[success]

            mut_mask = np.random.rand(new_inf_idx.size) < MUTATION_RATE
            if np.any(mut_mask):
                num_mut = int(np.sum(mut_mask))
                flip_pos = np.random.randint(
                    low=0, high=N_SITES, size=num_mut
                ).astype(np.uint64)
                new_genomes[mut_mask] ^= np.uint64(1) << flip_pos

            pathogen_genomes[new_inf_idx] = new_genomes

        return host_statuses, pathogen_genomes, host_immunities

    def collect_frame_data(
        t: int,
        host_statuses: np.ndarray,
        pathogen_genomes: np.ndarray,
        host_immunities: np.ndarray,
    ) -> Dict:
        """Collect data for a single animation frame."""
        # Host status distribution (S, I, R)
        status_counts = np.array(
            [
                np.sum(host_statuses == S),
                np.sum(host_statuses == I),
                np.sum(host_statuses == R),
            ]
        ) / POP_SIZE

        # Strain prevalence (all 2^N_SITES strains)
        strain_counts = np.zeros(num_strains)
        inf_mask = host_statuses == I
        if np.any(inf_mask):
            unique_g, counts = np.unique(
                pathogen_genomes[inf_mask], return_counts=True
            )
            for g, c in zip(unique_g, counts):
                strain_counts[g] = c / POP_SIZE

        # Allele susceptibility
        pop_kappas = get_kappa(host_immunities)
        allele_susc = np.mean(1.0 - (IMMUNE_STRENGTH * pop_kappas), axis=0)

        # Immunity level distribution per allele (4 levels x num_alleles)
        immunity_dist = np.zeros((4, num_alleles))
        for level in range(4):
            immunity_dist[level] = np.mean(host_immunities == level, axis=0)

        return {
            "step": t,
            "status_dist": status_counts,
            "strain_prev": strain_counts,
            "allele_susc": allele_susc,
            "immunity_dist": immunity_dist,
        }

    # Run simulation
    host_statuses, pathogen_genomes, host_immunities = initialize_pop()
    host_statuses, pathogen_genomes = infect_initial(
        host_statuses, pathogen_genomes
    )
    frames: List[Dict] = []

    for t in tqdm(range(N_STEPS), desc="Simulating"):
        host_statuses, host_immunities = update_recoveries(
            host_statuses, pathogen_genomes, host_immunities
        )
        host_statuses, pathogen_genomes, host_immunities = update_infections(
            host_statuses, pathogen_genomes, host_immunities
        )
        host_immunities = update_waning(host_immunities)

        frames.append(
            collect_frame_data(
                t, host_statuses, pathogen_genomes, host_immunities
            )
        )

    return frames


## Run Simulation


In [4]:
N_SITES = 2
frames = simulate_with_frames(
    N_SITES=N_SITES,
    POP_SIZE=100_000,
    N_STEPS=300,
    MUTATION_RATE=1e-3,
    seed=42,
)


Simulating: 100%|██████████| 300/300 [00:09<00:00, 32.20it/s]


## Create Animation


In [5]:
def create_animation(
    frames: List[Dict],
    n_sites: int,
    output_path: str,
    fps: int = 15,
) -> str:
    """Create animated visualization of the ABM simulation."""
    num_strains = 2**n_sites
    num_alleles = 2 * n_sites

    # Determine strain grid dimensions
    strain_rows = int(np.ceil(np.sqrt(num_strains)))
    strain_cols = int(np.ceil(num_strains / strain_rows))

    # Create figure with subplots for each matrix
    fig = plt.figure(figsize=(14, 8))
    gs = fig.add_gridspec(
        2, 3, width_ratios=[1.5, 1, 1.5], height_ratios=[1, 1], hspace=0.3, wspace=0.3
    )

    # Create axes for each matrix
    ax_status = fig.add_subplot(gs[0, 0])
    ax_strain = fig.add_subplot(gs[0, 1])
    ax_susc = fig.add_subplot(gs[0, 2])
    ax_immunity = fig.add_subplot(gs[1, :])

    # Initialize plots
    # Host status (discrete: S, I, R)
    status_data = frames[0]["status_dist"].reshape(1, 3)
    im_status = ax_status.imshow(
        status_data, cmap="RdYlGn", vmin=0, vmax=1, aspect="auto"
    )
    ax_status.set_title("Host Status Distribution", fontweight="bold")
    ax_status.set_xticks([0, 1, 2])
    ax_status.set_xticklabels(["S", "I", "R"])
    ax_status.set_yticks([])

    # Strain prevalence (continuous)
    strain_data = np.zeros((strain_rows, strain_cols))
    strain_data.flat[: len(frames[0]["strain_prev"])] = frames[0]["strain_prev"]
    im_strain = ax_strain.imshow(
        strain_data, cmap="viridis", vmin=0, vmax=0.01, aspect="equal"
    )
    ax_strain.set_title("Strain Prevalence", fontweight="bold")
    # Label strains with binary codes
    for i in range(num_strains):
        row, col = divmod(i, strain_cols)
        binary_label = format(i, f"0{n_sites}b")
        ax_strain.text(
            col, row, binary_label, ha="center", va="center",
            fontsize=8, color="white", fontweight="bold"
        )
    ax_strain.set_xticks([])
    ax_strain.set_yticks([])

    # Allele susceptibility (continuous)
    susc_data = frames[0]["allele_susc"].reshape(2, n_sites)
    im_susc = ax_susc.imshow(
        susc_data, cmap="plasma", vmin=0, vmax=1, aspect="auto"
    )
    ax_susc.set_title("Allele Susceptibility", fontweight="bold")
    ax_susc.set_xlabel("Site")
    ax_susc.set_ylabel("Allele")
    ax_susc.set_xticks(range(n_sites))
    ax_susc.set_yticks([0, 1])
    ax_susc.set_yticklabels(["0", "1"])

    # Immunity level distribution (discrete levels 0-3)
    immunity_data = frames[0]["immunity_dist"]
    im_immunity = ax_immunity.imshow(
        immunity_data, cmap="coolwarm", vmin=0, vmax=1, aspect="auto"
    )
    ax_immunity.set_title("Immunity Level Distribution", fontweight="bold")
    ax_immunity.set_xlabel("Allele Index")
    ax_immunity.set_ylabel("Immunity Level")
    ax_immunity.set_xticks(range(num_alleles))
    ax_immunity.set_yticks([0, 1, 2, 3])
    ax_immunity.set_yticklabels(["Waned", "Full-2", "Full-1", "Naive"])

    # Add step counter
    step_text = fig.suptitle("Step: 0", fontsize=14, fontweight="bold")

    def update(frame_idx):
        """Update function for animation."""
        frame = frames[frame_idx]

        # Update host status
        im_status.set_data(frame["status_dist"].reshape(1, 3))

        # Update strain prevalence
        strain_data = np.zeros((strain_rows, strain_cols))
        strain_data.flat[: len(frame["strain_prev"])] = frame["strain_prev"]
        # Dynamic scaling for strain prevalence
        max_prev = max(0.001, np.max(frame["strain_prev"]))
        im_strain.set_data(strain_data)
        im_strain.set_clim(0, max_prev)

        # Update allele susceptibility
        im_susc.set_data(frame["allele_susc"].reshape(2, n_sites))

        # Update immunity distribution
        im_immunity.set_data(frame["immunity_dist"])

        # Update step counter
        step_text.set_text(f"Step: {frame['step']}")

        return [im_status, im_strain, im_susc, im_immunity, step_text]

    # Create animation
    anim = animation.FuncAnimation(
        fig,
        update,
        frames=len(frames),
        interval=1000 // fps,
        blit=False,
    )

    # Save animation
    writer = animation.FFMpegWriter(fps=fps, bitrate=1800)
    anim.save(output_path, writer=writer)
    plt.close(fig)

    return output_path


In [6]:
notebook_name = "2025-02-04-abm-animation"
output_dir = f"teeplots/{notebook_name}"
os.makedirs(output_dir, exist_ok=True)
output_path = f"{output_dir}/combined-abm-animation.mp4"
create_animation(frames, N_SITES, output_path, fps=15)
print(f"Animation saved to: {output_path}")


Animation saved to: teeplots/2025-02-04-abm-animation/combined-abm-animation.mp4


## Display Animation


In [7]:
Video(output_path, embed=True, width=800)
