# IBM microbiome simulation demo (standalone, non-reservoir)

This notebook demonstrates the individual-based microbiome (IBM) model used in `computingMicrobiome`, **only as a dynamical system**, not yet as a reservoir.

We will:
- Build and inspect small IBM grids (`GridState`).
- Run simulations over time using the low-level `tick` function.
- Visualize species and resource dynamics (time series and simple animations).
- Use interactive widgets to explore how changing parameters (number of species, resources, diffusion, dilution, etc.) affects the dynamics.

The notebook is designed to run in **Google Colab** from a fresh runtime.

## Environment setup (Colab)

If you are running this in **Google Colab** from a fresh runtime, run the following cell first to clone the repository and install the package.

> Note: If you are running locally inside the `computingMicrobiome` repo with the package already installed, you can skip this cell.

In [None]:
# Uncomment this cell if running in Google Colab from a fresh runtime.
# It will clone the repo and install the package (including extras).

# !git clone https://github.com/danielriosgarza/computingMicrobiome.git
# %cd computingMicrobiome
# !uv pip install .[all]

## Imports and basic setup

We import the IBM core components (`GridState`, `EnvParams`, `SpeciesParams`, `load_params`, `make_zero_state`, `tick`) and the plotting / widget libraries.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from numpy.random import default_rng

import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

from computingMicrobiome.ibm import (
    CROSS_FEED_6_SPECIES,
    EnvParams,
    SpeciesParams,
    GridState,
    load_params,
    make_ibm_config_from_species,
    make_center_column_state,
)
from computingMicrobiome.ibm.state import make_zero_state
from computingMicrobiome.ibm.stepper import tick

plt.style.use("seaborn-v0_8-darkgrid")
rng = default_rng(0)

In [None]:
# Shared Plotly styles for IBM grids

OCC_EMPTY_COLOR = "#f5f5f5"
OCC_SPECIES_COLORS = px.colors.qualitative.Set2
ENERGY_COLORSCALE = "Viridis"


def get_species_colors(n_species: int) -> list[str]:
    colors = [OCC_EMPTY_COLOR] + OCC_SPECIES_COLORS
    if len(colors) < n_species + 1:
        reps = (n_species + 1 + len(colors) - 1) // len(colors)
        colors = (colors * reps)[: n_species + 1]
    return colors


def plot_occupancy_and_energy(
    env: EnvParams,
    state: GridState,
    title_suffix: str = "",
    *,
    left_source_species: np.ndarray | None = None,
) -> None:
    """Two-panel Plotly figure: species occupancy + energy, with toggleable species."""
    n_species = env.n_species
    H, W = state.occ.shape
    colors = get_species_colors(n_species)

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=(f"Species occupancy{title_suffix}", f"Energy per cell{title_suffix}"),
        horizontal_spacing=0.08,
    )

    for idx, val in enumerate([-1] + list(range(n_species))):
        mask = np.where(state.occ == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        fig.add_trace(
            go.Heatmap(
                z=mask,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
                showscale=False,
                name=label,
                hoverinfo="skip",
                opacity=0.95,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=1,
        )

    fig.add_trace(
        go.Heatmap(
            z=state.E,
            x=list(range(W)),
            y=list(range(H)),
            colorscale=ENERGY_COLORSCALE,
            colorbar=dict(title="energy"),
            showscale=True,
            name="energy",
            xgap=1,
            ygap=1,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    for c in (1, 2):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    if left_source_species is not None:
        for c in (1, 2):
            fig.add_vrect(
                x0=-0.5,
                x1=0.5,
                row=1,
                col=c,
                fillcolor="rgba(255, 140, 0, 0.12)",
                line_color="darkorange",
                line_width=1,
            )
        fig.add_annotation(
            x=0.01,
            y=1.08,
            xref="paper",
            yref="paper",
            text="left source column (fixed, migration-only)",
            showarrow=False,
            font=dict(color="darkorange"),
        )

    fig.update_layout(
        height=500,
        width=1000,
        legend_title_text="species",
        legend=dict(
            orientation="h",
            yanchor="top",
            y=-0.1,
            xanchor="left",
            x=0.0,
        ),
        template="plotly_white",
        legend_itemclick="toggle",
        legend_itemdoubleclick="toggleothers",
    )
    fig.show()


In [None]:
def plot_resource_grid(env: EnvParams, state: GridState, max_resources: int = 3) -> None:
    """Plot one or a few resource fields as heatmaps (Plotly).

    Shows up to ``max_resources`` resource planes R[j, :, :] as separate
    panels with the same square-grid styling used for occupancy/energy.
    """
    H, W = state.occ.shape
    M = state.R.shape[0]
    m = min(max_resources, M)
    if m == 0:
        raise ValueError("No resources to plot (env.n_resources == 0)")

    fig = make_subplots(
        rows=1,
        cols=m,
        subplot_titles=[f"R{j}" for j in range(m)],
        horizontal_spacing=0.08,
    )

    for j in range(m):
        Rm = state.R[j]
        fig.add_trace(
            go.Heatmap(
                z=Rm,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=ENERGY_COLORSCALE,
                showscale=(j == m - 1),
                colorbar=dict(title="amount") if j == m - 1 else None,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=j + 1,
        )

    for c in range(1, m + 1):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    fig.update_layout(
        height=350,
        width=350 * m,
        template="plotly_white",
    )
    fig.show()

## IBM state initialization helpers

We define helper functions to:
- Build default/environment parameters for a small grid.
- Initialize a `GridState` in different ways (empty, basal pattern, random).

These mirror the logic used in the IBM reservoir backend but keep everything focused on the raw IBM dynamics.

In [None]:
def build_env_species(
    config_overrides: dict | None = None,
    species_indices: list[int] | None = None,
) -> tuple[EnvParams, SpeciesParams]:
    """Create `EnvParams` and `SpeciesParams` from the IBM universe.

    By default this selects the first few species from the global IBM
    universe (with 50 species and 100 resources), but you can override both
    the list of species indices and lattice-level parameters via
    `config_overrides`.
    """
    if species_indices is None:
        species_indices = [0, 1, 2, 3]

    base_cfg = make_ibm_config_from_species(
        species_indices=species_indices,
        height=16,
        width_grid=32,
    )
    if config_overrides is not None:
        base_cfg.update(config_overrides)
    env, species = load_params(base_cfg)
    return env, species


def init_state(env: EnvParams, mode: str = "basal", rng: np.random.Generator | None = None) -> GridState:
    """Initialize a `GridState` for the given environment.

    Modes:
    - "empty": all cells empty, resources zero.
    - "basal": deterministic pattern based on `basal_*` settings.
    - "random": random occupancy, energies, and resources.
    """
    if rng is None:
        rng = default_rng()

    if mode == "empty":
        return make_zero_state(
            height=env.height,
            width_grid=env.width_grid,
            n_resources=env.n_resources,
        )

    if mode == "basal":
        rr, cc = np.indices((env.height, env.width_grid))
        if env.basal_pattern == "stripes":
            sid = (rr % env.n_species).astype(np.int16)
        else:
            sid = ((rr + cc) % env.n_species).astype(np.int16)

        if env.basal_occupancy >= 1.0:
            occupied = np.ones((env.height, env.width_grid), dtype=bool)
        elif env.basal_occupancy <= 0.0:
            occupied = np.zeros((env.height, env.width_grid), dtype=bool)
        else:
            # Deterministic occupancy mask for reproducible runs.
            key = (rr * 73856093 + cc * 19349663) % 1000
            occupied = key < int(env.basal_occupancy * 1000.0)

        occ = np.full((env.height, env.width_grid), -1, dtype=np.int16)
        occ[occupied] = sid[occupied]

        E = np.zeros((env.height, env.width_grid), dtype=np.uint8)
        if env.basal_energy > 0:
            E[occupied] = np.uint8(env.basal_energy)

        # Initialize resources using per-resource basal levels if available,
        # otherwise fall back to the scalar basal_resource.
        br_vec = getattr(env, "basal_resource_vec", None)
        if br_vec is not None:
            br = np.asarray(br_vec, dtype=np.uint8).reshape(env.n_resources)
            R = np.broadcast_to(
                br[:, None, None],
                (env.n_resources, env.height, env.width_grid),
            ).copy()
        else:
            R = np.full(
                (env.n_resources, env.height, env.width_grid),
                np.uint8(env.basal_resource),
                dtype=np.uint8,
            )
        return GridState(occ=occ, E=E, R=R)

    if mode == "random":
        occ = np.full((env.height, env.width_grid), -1, dtype=np.int16)
        occupied = rng.random((env.height, env.width_grid)) < 0.5
        sid = rng.integers(
            0,
            env.n_species,
            size=(env.height, env.width_grid),
            dtype=np.int16,
        )
        occ[occupied] = sid[occupied]

        E = np.zeros((env.height, env.width_grid), dtype=np.uint8)
        if np.any(occupied):
            rand_e = rng.integers(
                1,
                env.Emax + 1,
                size=int(np.count_nonzero(occupied)),
                dtype=np.uint16,
            ).astype(np.uint8)
            E[occupied] = rand_e

        R = rng.integers(
            0,
            env.Rmax + 1,
            size=(env.n_resources, env.height, env.width_grid),
            dtype=np.uint16,
        ).astype(np.uint8)
        return GridState(occ=occ, E=E, R=R)

    raise ValueError("mode must be one of {'empty', 'basal', 'random'}")

### Initial condition used in the examples

In the examples below we initialize the lattice with:

- **Energy around 5 units per occupied cell**, by setting `basal_energy=5` and drawing cell energies around this level.
- **A single central column seeded with multiple species** (cycling through a handful of species IDs down the column), so interactions are easier to see than with just one species.
- **Resource fields initialized from the per-resource basal vector** implied by the IBM universe, and visualized with a dedicated `plot_resource_grid` helper that shows a few representative resources.

This gives a compact but non-trivial starting state that is reused consistently for the static snapshot, the basic run, and the spatial animation.

## Inspect a small IBM state

We now create a small grid, initialize it, and look at the arrays that represent:
- `occ`: occupancy (species id or -1 for empty)
- `E`: per-cell energy
- `R`: per-resource field over the grid

In [None]:
env, species = build_env_species({"height": 8, "width_grid": 16, "basal_energy": 5})
state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng)

print("Grid shape (H, W):", state.occ.shape)
print("Energy shape (H, W):", state.E.shape)
print("Resources shape (M, H, W):", state.R.shape)

# Enrich the initial condition: use multiple species in the central column.
H, W = state.occ.shape
center = W // 2
max_species_band = min(env.n_species, 5)
for r in range(H):
    state.occ[r, center] = r % max_species_band

# Plot occupancy + energy with the shared helper, and then a few resource grids.
plot_occupancy_and_energy(env, state, title_suffix=" (initial)")
plot_resource_grid(env, state, max_resources=3)

## Running the IBM over time

We define a helper to run the IBM for a number of ticks and record summary observables:
- Species counts over time.
- Total resources per resource type over time.

In [None]:
def _prepare_left_source_config(
    env: EnvParams,
    species: SpeciesParams,
    left_source_species=None,
    left_source_competition=None,
    left_source_settle_energy=None,
):
    if left_source_species is None:
        return None, None, None

    source = np.asarray(left_source_species, dtype=np.int16).reshape(-1)
    if source.size != env.height:
        raise ValueError("left_source_species must have length env.height")
    if np.any((source < -1) | (source >= env.n_species)):
        raise ValueError("left_source_species values must be -1 or in [0, n_species)")

    if left_source_competition is None:
        base = (
            species.yield_energy.astype(np.int32)
            + species.birth_energy.astype(np.int32)
            + species.uptake_rate.astype(np.int32)
            - species.maint_cost.astype(np.int32)
        )
        comp = np.maximum(base, 0)
    elif np.isscalar(left_source_competition):
        comp = np.full(env.n_species, int(left_source_competition), dtype=np.int32)
    else:
        comp = np.asarray(left_source_competition, dtype=np.int32).reshape(-1)
        if comp.size != env.n_species:
            raise ValueError("left_source_competition must be scalar or length n_species")
    if np.any(comp < 0):
        raise ValueError("left_source_competition must be >= 0")

    if left_source_settle_energy is None:
        settle = species.birth_energy.astype(np.int32)
    elif np.isscalar(left_source_settle_energy):
        settle = np.full(env.n_species, int(left_source_settle_energy), dtype=np.int32)
    else:
        settle = np.asarray(left_source_settle_energy, dtype=np.int32).reshape(-1)
        if settle.size != env.n_species:
            raise ValueError("left_source_settle_energy must be scalar or length n_species")
    if np.any(settle < 0):
        raise ValueError("left_source_settle_energy must be >= 0")
    settle = np.clip(settle, 0, int(env.Emax)).astype(np.int32)

    return source, comp.astype(np.int32), settle


def _enforce_left_source_column(state: GridState, source: np.ndarray) -> None:
    state.occ[:, 0] = source
    state.E[:, 0] = 0
    state.R[:, :, 0] = 0


def _apply_left_source_migration(
    state: GridState,
    env: EnvParams,
    source: np.ndarray,
    source_competition: np.ndarray,
    source_settle_energy: np.ndarray,
    *,
    outcompete_margin: int = 0,
    colonize_empty: bool = True,
) -> None:
    _enforce_left_source_column(state, source)

    if state.occ.shape[1] < 2:
        return

    tgt_col = 1
    occ = state.occ
    E_work = state.E.astype(np.int32, copy=True)

    if colonize_empty:
        empty = occ[:, tgt_col] < 0
        can_seed = (source >= 0) & empty
        if np.any(can_seed):
            rows = np.where(can_seed)[0]
            src_species = source[rows]
            occ[rows, tgt_col] = src_species
            E_work[rows, tgt_col] = source_settle_energy[src_species]

    tgt_species = occ[:, tgt_col]
    can_compete = (source >= 0) & (tgt_species >= 0) & (source != tgt_species)
    if np.any(can_compete):
        rows = np.where(can_compete)[0]
        src = source[rows]
        tgt = tgt_species[rows]
        src_score = source_competition[src]
        tgt_score = source_competition[tgt]
        wins = src_score >= (tgt_score + int(outcompete_margin))
        if np.any(wins):
            rows_win = rows[wins]
            src_win = source[rows_win]
            occ[rows_win, tgt_col] = src_win
            E_work[rows_win, tgt_col] = source_settle_energy[src_win]

    E_work[occ < 0] = 0
    state.E = np.clip(E_work, 0, int(env.Emax)).astype(np.uint8)
    _enforce_left_source_column(state, source)


def run_simulation(
    env: EnvParams,
    species: SpeciesParams,
    state: GridState,
    *,
    n_steps: int = 200,
    rng: np.random.Generator | None = None,
    left_source_species=None,
    left_source_competition=None,
    left_source_settle_energy=None,
    left_source_outcompete_margin: int = 0,
    left_source_colonize_empty: bool = True,
):
    """Run the IBM for `n_steps`, returning species and resource trajectories.

    Returns
    -------
    species_counts : array, shape (n_steps+1, n_species)
        Number of occupied cells per species at each time.
    resource_totals : array, shape (n_steps+1, n_resources)
        Total resource amount per resource type at each time.
    """
    if rng is None:
        rng = default_rng()

    source, source_comp, source_settle = _prepare_left_source_config(
        env,
        species,
        left_source_species=left_source_species,
        left_source_competition=left_source_competition,
        left_source_settle_energy=left_source_settle_energy,
    )
    if source is not None:
        _enforce_left_source_column(state, source)

    species_counts = np.zeros((n_steps + 1, env.n_species), dtype=np.int32)
    resource_totals = np.zeros((n_steps + 1, env.n_resources), dtype=np.int32)

    def measure(t_idx: int) -> None:
        occ = state.occ
        for s in range(env.n_species):
            species_counts[t_idx, s] = np.count_nonzero(occ == s)
        R = state.R.reshape(env.n_resources, -1)
        resource_totals[t_idx] = R.sum(axis=1)

    measure(0)
    for t in range(1, n_steps + 1):
        if source is None:
            tick(state, env, species, rng)
        else:
            # Source column is fixed and excluded from IBM dynamics.
            state.occ[:, 0] = -2
            state.E[:, 0] = 0
            state.R[:, :, 0] = 0
            tick(state, env, species, rng)
            _apply_left_source_migration(
                state,
                env,
                source,
                source_comp,
                source_settle,
                outcompete_margin=left_source_outcompete_margin,
                colonize_empty=left_source_colonize_empty,
            )
        measure(t)

    return species_counts, resource_totals


def plot_time_series(species_counts: np.ndarray, resource_totals: np.ndarray) -> None:
    """Plot species and resource trajectories over time using Plotly."""
    T = species_counts.shape[0] - 1
    t = np.arange(T + 1)

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Species counts", "Resource totals"),
        horizontal_spacing=0.12,
    )

    # Species counts
    for s in range(species_counts.shape[1]):
        fig.add_trace(
            go.Scatter(
                x=t,
                y=species_counts[:, s],
                mode="lines",
                name=f"s{s}",
            ),
            row=1,
            col=1,
        )

    # Resource totals
    for m in range(resource_totals.shape[1]):
        fig.add_trace(
            go.Scatter(
                x=t,
                y=resource_totals[:, m],
                mode="lines",
                name=f"R{m}",
            ),
            row=1,
            col=2,
        )

    fig.update_xaxes(title_text="time step", row=1, col=1)
    fig.update_yaxes(title_text="# occupied cells", row=1, col=1)
    fig.update_xaxes(title_text="time step", row=1, col=2)
    fig.update_yaxes(title_text="total resource", row=1, col=2)

    fig.update_layout(
        height=400,
        width=900,
        template="plotly_white",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0.0),
    )
    fig.show()


### Example: basic IBM run

We now run the IBM for a few hundred steps and plot the resulting trajectories.

In [None]:
env, species = build_env_species(
    {"height": 8, "width_grid": 32, "basal_energy": 5, "dilution_p": 0.1}
)

state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=default_rng(1))

# Enrich the initial condition in the same way: multi-species central column.
H, W = state.occ.shape
center = W // 2
max_species_band = min(env.n_species, 5)
for r in range(H):
    state.occ[r, center] = r % max_species_band

# Fixed left source column: these species are outside the local dynamics
# but can migrate into column 1 if they outcompete the adjacent cell.
left_source_species = (np.arange(H) % max_species_band).astype(np.int16)

species_counts, resource_totals = run_simulation(
    env,
    species,
    state,
    n_steps=200,
    rng=default_rng(2),
    left_source_species=left_source_species,
    left_source_outcompete_margin=1,
    left_source_colonize_empty=True,
)
plot_time_series(species_counts, resource_totals)
# Optional: final snapshot with source-column overlay.
plot_occupancy_and_energy(
    env,
    state,
    title_suffix=" (final, with left-source migration)",
    left_source_species=left_source_species,
)


## Animation of spatial dynamics

Next we build a simple animation that shows how occupancy and energy fields evolve over time on the lattice.

In [None]:
def make_spatial_animation_plotly(
    env: EnvParams,
    species: SpeciesParams,
    *,
    n_steps: int = 100,
    seed: int = 0,
    left_source_species=None,
    left_source_competition=None,
    left_source_settle_energy=None,
    left_source_outcompete_margin: int = 0,
    left_source_colonize_empty: bool = True,
):
    """Create a Plotly animation of IBM spatial dynamics.

    Uses the same colors and layout as `plot_occupancy_and_energy`.
    """
    rng_local = default_rng(seed)
    state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng_local)

    # Enrich the initial condition: multi-species central column.
    H, W = env.height, env.width_grid
    max_species_band = min(env.n_species, 5)
    center = W // 2
    for r in range(H):
        state.occ[r, center] = r % max_species_band

    source, source_comp, source_settle = _prepare_left_source_config(
        env,
        species,
        left_source_species=left_source_species,
        left_source_competition=left_source_competition,
        left_source_settle_energy=left_source_settle_energy,
    )
    if source is not None:
        _enforce_left_source_column(state, source)

    n_species = env.n_species
    colors = get_species_colors(n_species)

    # Pre-compute frames
    occ_frames = np.zeros((n_steps + 1, H, W), dtype=np.int16)
    E_frames = np.zeros((n_steps + 1, H, W), dtype=np.uint8)

    def snapshot(t_idx: int) -> None:
        occ_frames[t_idx] = state.occ
        E_frames[t_idx] = state.E

    snapshot(0)
    for t in range(1, n_steps + 1):
        if source is None:
            tick(state, env, species, rng_local)
        else:
            state.occ[:, 0] = -2
            state.E[:, 0] = 0
            state.R[:, :, 0] = 0
            tick(state, env, species, rng_local)
            _apply_left_source_migration(
                state,
                env,
                source,
                source_comp,
                source_settle,
                outcompete_margin=left_source_outcompete_margin,
                colonize_empty=left_source_colonize_empty,
            )
        snapshot(t)

    # Base figure for frame 0
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Species occupancy", "Energy per cell"),
        horizontal_spacing=0.08,
    )

    # Occupancy traces (one per species + empty) at t=0
    for idx, val in enumerate([-1] + list(range(n_species))):
        mask0 = np.where(occ_frames[0] == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        fig.add_trace(
            go.Heatmap(
                z=mask0,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
                showscale=False,
                name=label,
                hoverinfo="skip",
                opacity=0.95,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=1,
        )

    # Energy trace at t=0
    # We fix zmin/zmax so the colorscale doesn't jump during animation
    fig.add_trace(
        go.Heatmap(
            z=E_frames[0],
            x=list(range(W)),
            y=list(range(H)),
            zmin=0,
            zmax=getattr(env, "Emax", 20),
            colorscale=ENERGY_COLORSCALE,
            colorbar=dict(title="energy"),
            showscale=True,
            name="energy",
            xgap=1,
            ygap=1,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    # Axes styling
    for c in (1, 2):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    if source is not None:
        for c in (1, 2):
            fig.add_vrect(
                x0=-0.5,
                x1=0.5,
                row=1,
                col=c,
                fillcolor="rgba(255, 140, 0, 0.12)",
                line_color="darkorange",
                line_width=1,
            )
        fig.add_annotation(
            x=0.01,
            y=1.08,
            xref="paper",
            yref="paper",
            text="left source column (fixed, migration-only)",
            showarrow=False,
            font=dict(color="darkorange"),
        )

    # Build frames
    frames = []
    for t_idx in range(n_steps + 1):
        frame_traces = []
        # Update occupancy traces
        for val in [-1] + list(range(n_species)):
            mask = np.where(occ_frames[t_idx] == val, 1.0, np.nan)
            frame_traces.append(go.Heatmap(z=mask))

        # Update energy trace
        frame_traces.append(go.Heatmap(z=E_frames[t_idx]))

        frames.append(go.Frame(data=frame_traces, name=str(t_idx)))

    fig.frames = frames

    fig.update_layout(
        height=500,
        width=1000,
        legend_title_text="species",
        legend=dict(
            orientation="h",
            yanchor="top",
            y=-0.1,
            xanchor="left",
            x=0.0,
        ),
        template="plotly_white",
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        # args[0] should be None (not [None]) to play all frames
                        args=[None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}],
                    ),
                ],
                x=0.0,
                y=-0.16,
            )
        ],
    )

    return fig


In [None]:
# Example animation for a small grid (may take a couple of seconds to render).
env_anim, species_anim = build_env_species(
    {"height": 8, "width_grid": 32, "basal_energy": 4, "dilution_p": 0.05}
)
left_source_species_anim = (
    np.arange(env_anim.height) % min(env_anim.n_species, 5)
).astype(np.int16)

fig_anim = make_spatial_animation_plotly(
    env_anim,
    species_anim,
    n_steps=60,
    seed=3,
    left_source_species=left_source_species_anim,
    left_source_outcompete_margin=1,
    left_source_colonize_empty=True,
)
fig_anim.show()


## Bit pulse injection protocol

To test different **inputs as a reservoir**, we inject a **bit** (0 or 1) as a spatial pulse:

- **Bit 0**: clear a square region (empty cells, zero all resources), then set **toxin** to a given concentration. Cells that later colonize the region may die if toxin exceeds their tolerance.
- **Bit 1**: same clear region, then set the **popular metabolite** to a given concentration. Species prefer this resource first (metabolic switch); it drives growth before they use their secondary uptake.

The pulse is defined by a **center cell** \((r_0, c_0)\), a **radius** (Chebyshev distance: square of side \(2 \cdot \text{radius} + 1\)), and the toxin or popular concentration. All resources in the square are zeroed first so the pulse is clean and the metabolic switch behaves as intended.

In [None]:
def square_mask(height: int, width: int, center_r: int, center_c: int, radius: int) -> np.ndarray:
    """Boolean mask for the square region (Chebyshev distance <= radius)."""
    rr = np.arange(height)[:, None]
    cc = np.arange(width)[None, :]
    return (np.abs(rr - center_r) <= radius) & (np.abs(cc - center_c) <= radius)


def get_popular_resource_index(species: SpeciesParams) -> int:
    """First resource index that appears in any species' popular_uptake_list (compact index)."""
    seen = set()
    for arr in species.popular_uptake_list:
        for m in arr.tolist():
            seen.add(int(m))
    if not seen:
        raise ValueError("No popular_uptake_list resources found; config may omit popular metabolite.")
    return min(seen)


def inject_bit_into_state(
    state: GridState,
    env: EnvParams,
    species: SpeciesParams,
    bit: int,
    center_r: int,
    center_c: int,
    radius: int,
    toxin_conc: int = 200,
    popular_conc: int = 200,
) -> None:
    """Apply a clean bit pulse: clear square (occ, E, all R), then set toxin (bit 0) or popular metabolite (bit 1)."""
    H, W = env.height, env.width_grid
    mask = square_mask(H, W, center_r, center_c, radius)

    # Clear region: empty cells and zero all resources for a clean pulse
    state.occ[mask] = -1
    state.E[mask] = 0
    state.R[:, mask] = 0

    if env.toxin_resource_index is None:
        raise ValueError("Config has no toxin_resource_index; required for bit-0 pulse.")

    if bit == 0:
        state.R[env.toxin_resource_index, mask] = np.clip(toxin_conc, 0, env.Rmax).astype(np.uint8)
    else:
        pop_idx = get_popular_resource_index(species)
        state.R[pop_idx, mask] = np.clip(popular_conc, 0, env.Rmax).astype(np.uint8)

## Animation with bit pulses

The next animation runs the same spatial dynamics but **injects a sequence of bits** at scheduled time steps. Each injection clears a square, zeros all resources there, then adds either toxin (bit 0) or the popular metabolite (bit 1). You can compare how toxin pulses vs nutrient pulses propagate and affect occupancy and energy.

In [None]:
def make_pulse_animation_plotly(
    env: EnvParams,
    species: SpeciesParams,
    *,
    bit_sequence: list[int],
    injection_ticks: list[int] | None = None,
    inject_every_n: int = 25,
    center_r: int | None = None,
    center_c: int | None = None,
    radius: int = 2,
    toxin_conc: int = 200,
    popular_conc: int = 200,
    n_steps: int = 100,
    seed: int = 0,
    left_source_species=None,
    left_source_competition=None,
    left_source_settle_energy=None,
    left_source_outcompete_margin: int = 0,
    left_source_colonize_empty: bool = True,
):
    """Create a Plotly animation of IBM dynamics with bit pulses (toxin vs popular metabolite).

    injection_ticks: step indices at which to inject (if None, use inject_every_n to build them).
    bit_sequence[i] is the bit (0 or 1) injected at injection_ticks[i]; cycles if sequence is shorter.
    """
    H, W = env.height, env.width_grid
    if center_r is None:
        center_r = H // 2
    if center_c is None:
        center_c = W // 2

    if injection_ticks is None:
        injection_ticks = [inject_every_n * k for k in range(len(bit_sequence))]
    injection_ticks = [t for t in injection_ticks if t <= n_steps]

    rng_local = default_rng(seed)
    state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng_local)
    center = W // 2
    max_species_band = min(env.n_species, 5)
    for r in range(H):
        state.occ[r, center] = r % max_species_band

    source, source_comp, source_settle = _prepare_left_source_config(
        env,
        species,
        left_source_species=left_source_species,
        left_source_competition=left_source_competition,
        left_source_settle_energy=left_source_settle_energy,
    )
    if source is not None:
        _enforce_left_source_column(state, source)

    n_species = env.n_species
    colors = get_species_colors(n_species)

    occ_frames = np.zeros((n_steps + 1, H, W), dtype=np.int16)
    E_frames = np.zeros((n_steps + 1, H, W), dtype=np.uint8)

    def snapshot(t_idx: int) -> None:
        occ_frames[t_idx] = state.occ.copy()
        E_frames[t_idx] = state.E.copy()

    snapshot(0)
    for t in range(1, n_steps + 1):
        if t in injection_ticks:
            idx = injection_ticks.index(t) % len(bit_sequence)
            inject_bit_into_state(
                state, env, species,
                bit=bit_sequence[idx],
                center_r=center_r, center_c=center_c, radius=radius,
                toxin_conc=toxin_conc, popular_conc=popular_conc,
            )
        snapshot(t)
        if source is None:
            tick(state, env, species, rng_local)
        else:
            state.occ[:, 0] = -2
            state.E[:, 0] = 0
            state.R[:, :, 0] = 0
            tick(state, env, species, rng_local)
            _apply_left_source_migration(
                state, env, source, source_comp, source_settle,
                outcompete_margin=left_source_outcompete_margin,
                colonize_empty=left_source_colonize_empty,
            )

    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Species occupancy (bit pulses)", "Energy per cell (bit pulses)"),
        horizontal_spacing=0.08,
    )
    for idx, val in enumerate([-1] + list(range(n_species))):
        mask0 = np.where(occ_frames[0] == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        fig.add_trace(
            go.Heatmap(
                z=mask0, x=list(range(W)), y=list(range(H)),
                colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
                showscale=False, name=label, hoverinfo="skip", opacity=0.95, xgap=1, ygap=1,
            ),
            row=1, col=1,
        )
    fig.add_trace(
        go.Heatmap(
            z=E_frames[0], x=list(range(W)), y=list(range(H)),
            zmin=0, zmax=getattr(env, "Emax", 255), colorscale=ENERGY_COLORSCALE,
            colorbar=dict(title="energy"), showscale=True, name="energy", xgap=1, ygap=1, showlegend=False,
        ),
        row=1, col=2,
    )
    for c in (1, 2):
        fig.update_xaxes(title_text="column", row=1, col=c, dtick=1, range=[-0.5, W - 0.5], showgrid=True, gridcolor="white", gridwidth=1)
        fig.update_yaxes(title_text="row", row=1, col=c, dtick=1, range=[H - 0.5, -0.5], showgrid=True, gridcolor="white", gridwidth=1)
    if source is not None:
        for c in (1, 2):
            fig.add_vrect(x0=-0.5, x1=0.5, row=1, col=c, fillcolor="rgba(255, 140, 0, 0.12)", line_color="darkorange", line_width=1)
    frames = []
    for t_idx in range(n_steps + 1):
        frame_traces = []
        for val in [-1] + list(range(n_species)):
            mask = np.where(occ_frames[t_idx] == val, 1.0, np.nan)
            frame_traces.append(go.Heatmap(z=mask))
        frame_traces.append(go.Heatmap(z=E_frames[t_idx]))
        frames.append(go.Frame(data=frame_traces, name=str(t_idx)))
    fig.frames = frames
    fig.update_layout(
        height=500, width=1000, legend_title_text="species",
        legend=dict(orientation="h", yanchor="top", y=-0.1, xanchor="left", x=0.0),
        template="plotly_white",
        updatemenus=[
            dict(
                type="buttons", showactive=False,
                buttons=[
                    dict(label="Play", method="animate", args=[None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}]),
                    dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
                ],
                x=0.0, y=-0.16,
            )
        ],
    )
    return fig

In [None]:
# Example: inject bit sequence [0, 1, 0, 1] at steps 20, 40, 60, 80 (toxin, popular, toxin, popular).
env_pulse, species_pulse = build_env_species(
    species_indices=[0, 1, 20, 21, 40, 41],  # CROSS_FEED_6_SPECIES for toxin + popular
    config_overrides={"height": 8, "width_grid": 32, "basal_energy": 4, "dilution_p": 0.05},
)
fig_pulse = make_pulse_animation_plotly(
    env_pulse,
    species_pulse,
    bit_sequence=[0, 1, 0, 1],
    injection_ticks=[20, 40, 60, 80],
    center_r=4,
    center_c=16,
    radius=2,
    toxin_conc=180,
    popular_conc=200,
    n_steps=100,
    seed=5,
)
fig_pulse.show()

## Task 4 (8-bit memory) with pulse injection

Bit injection is **spatially and temporally the same** as the benchmark (task_4_learn_8_bit_with_ibm.py). The only difference is **what happens at the injection site**: instead of writing a resource amount via `channel_to_resource`, we apply a **pulse** (clear a square, then add toxin for bit 0 or popular metabolite for bit 1).

- **Temporal:** Same schedule — inject at steps 0, 13, 26, …, 91 (first 8 ticks; one tick every `iter_between` steps).
- **Spatial:** Same sites — we use `create_input_locations(width, recurrence, N_CHANNELS, rng)` as in the benchmark; the **bit** (channel 0) is written to the same cell indices that would receive channel 0 there. At each of those locations we apply the pulse (square around the cell, toxin or popular).
- **At the site:** Pulse = clear square (occ, E, all R) then set toxin (bit 0) or popular metabolite (bit 1).

Below we run one episode and animate so you can see the pulses at the benchmark-aligned sites and times.

In [None]:
# Task 4 parameters (match example_tasks/task_4_learn_8_bit_with_ibm.py)
BITS_T4 = 8
D_PERIOD_T4 = 8
ITR_T4 = 12
RECURRENCE_T4 = 4
N_CHANNELS_T4 = 4
ITER_BETWEEN_T4 = ITR_T4 + 1
L_T4 = D_PERIOD_T4 + 2 * BITS_T4   # 24 ticks
T_T4 = L_T4 * ITER_BETWEEN_T4      # 312 steps

# Injection: first 8 ticks only; step index = tick_index * iter_between (same as benchmark)
injection_steps_t4 = [k * ITER_BETWEEN_T4 for k in range(BITS_T4)]
# Output window = last 8 ticks (for readout training)
output_window_ticks_t4 = np.arange(L_T4 - BITS_T4, L_T4)

# Example 8-bit pattern
bits_arr_t4 = np.array([0, 1, 0, 0, 1, 0, 1, 1], dtype=np.int8)

env_t4, species_t4 = build_env_species(
    species_indices=CROSS_FEED_6_SPECIES,
    config_overrides={"height": 8, "width_grid": 32, "basal_energy": 4, "dilution_p": 0.05, "basal_init": True},
)
H_t4, W_t4 = env_t4.height, env_t4.width_grid
width_t4 = H_t4 * W_t4
rng_t4 = default_rng(7)

# Same spatial schedule as benchmark: input_locations and channel 0 = bit injection sites
from computingMicrobiome.utils import create_input_locations
input_locations_t4 = create_input_locations(width_t4, RECURRENCE_T4, N_CHANNELS_T4, rng_t4)
channel_idx_t4 = np.arange(input_locations_t4.size) % N_CHANNELS_T4
bit_site_linear = input_locations_t4[channel_idx_t4 == 0]
bit_site_rc = [(int(loc // W_t4), int(loc % W_t4)) for loc in bit_site_linear]

radius_t4 = 2

# Fixed migrating species at the left edge (column 0); one species per row cycling through cross-feed set.
left_source_species_t4 = (np.arange(H_t4) % env_t4.n_species).astype(np.int16)
source_t4, source_comp_t4, source_settle_t4 = _prepare_left_source_config(
    env_t4, species_t4, left_source_species=left_source_species_t4,
)

state_t4 = make_center_column_state(env_t4, species_id=0, energy_mean=5.0, rng=rng_t4)
center = W_t4 // 2
for r in range(H_t4):
    state_t4.occ[r, center] = r % min(env_t4.n_species, 5)
_enforce_left_source_column(state_t4, source_t4)

occ_frames_t4 = np.zeros((T_T4 + 1, H_t4, W_t4), dtype=np.int16)

# Print bit injection sites (same as benchmark channel-0 locations)
print("Bit injection sites (channel 0, linear -> (r,c)):", [(int(l), (int(l // W_t4), int(l % W_t4))) for l in bit_site_linear])
E_frames_t4 = np.zeros((T_T4 + 1, H_t4, W_t4), dtype=np.uint8)

def snapshot_t4(step: int) -> None:
    occ_frames_t4[step] = state_t4.occ.copy()
    E_frames_t4[step] = state_t4.E.copy()

# Run episode: at steps 0, 13, 26, ..., 91 inject bit 0..7; then snapshot; then step (match run_reservoir_episode cadence).
# Left column is fixed: excluded from dynamics, then migration applied into column 1 after each tick.
for step in range(0, T_T4 + 1):
    if step in injection_steps_t4:
        tick_idx = step // ITER_BETWEEN_T4
        bit = int(bits_arr_t4[tick_idx])
        for (cr, cc) in bit_site_rc:
            inject_bit_into_state(
                state_t4, env_t4, species_t4,
                bit=bit,
                center_r=cr, center_c=cc, radius=radius_t4,
                toxin_conc=180, popular_conc=200,
            )
    snapshot_t4(step)
    if step < T_T4:
        state_t4.occ[:, 0] = -2
        state_t4.E[:, 0] = 0
        state_t4.R[:, :, 0] = 0
        tick(state_t4, env_t4, species_t4, rng_t4)
        _apply_left_source_migration(
            state_t4, env_t4, source_t4, source_comp_t4, source_settle_t4,
            outcompete_margin=1, colonize_empty=True,
        )

# Timeline table: tick, step, bit (for first 8 ticks)
print("Task 4 pulse schedule (first 8 ticks)")
print("tick  step   bit  pulse_type")
for k in range(BITS_T4):
    step = injection_steps_t4[k]
    b = int(bits_arr_t4[k])
    print(f"  {k}    {step:3d}    {b}   {'toxin' if b == 0 else 'popular'}")
print(f"\nOutput window: ticks {output_window_ticks_t4[0]}..{output_window_ticks_t4[-1]} (steps {output_window_ticks_t4[0] * ITER_BETWEEN_T4}..{output_window_ticks_t4[-1] * ITER_BETWEEN_T4 + ITER_BETWEEN_T4 - 1})")
print("bits to recall:", bits_arr_t4.tolist())

In [None]:
# Task 4 figures (histogram and heatmap) are displayed by the cell above.

## Interactive exploration with widgets

Finally, we expose a few key IBM parameters through widgets so you can quickly explore how they affect the dynamics:
- Number of species and resources.
- Diffusion and dilution parameters.
- Grid size and number of steps.

For responsiveness, we keep default runs relatively small.

In [None]:
def simulate_and_plot(
    height: int = 8,
    width_grid: int = 24,
    n_species: int = 4,
    diff_numer: int = 1,
    diff_denom: int = 8,
    dilution_p: float = 0.01,
    n_steps: int = 150,
    init_mode: str = "basal",
    seed: int = 0,
):
    # Select the first `n_species` from the global IBM universe and build a
    # config suitable for `load_params`.
    indices = list(range(int(n_species)))
    cfg = make_ibm_config_from_species(
        species_indices=indices,
        height=height,
        width_grid=width_grid,
        diff_numer=diff_numer,
        diff_denom=diff_denom,
        dilution_p=dilution_p,
    )
    env, species = load_params(cfg)
    rng_local = default_rng(seed)
    state = init_state(env, mode=init_mode, rng=rng_local)
    species_counts, resource_totals = run_simulation(env, species, state, n_steps=n_steps, rng=rng_local)
    plot_time_series(species_counts, resource_totals)


widgets.interact(
    simulate_and_plot,
    height=widgets.IntSlider(min=4, max=32, step=2, value=8, description="height"),
    width_grid=widgets.IntSlider(min=8, max=64, step=4, value=24, description="width"),
    n_species=widgets.IntSlider(min=1, max=10, step=1, value=4, description="#species"),
    diff_numer=widgets.IntSlider(min=0, max=4, step=1, value=1, description="diff_numer"),
    diff_denom=widgets.IntSlider(min=1, max=16, step=1, value=8, description="diff_denom"),
    dilution_p=widgets.FloatSlider(min=0.0, max=0.2, step=0.005, value=0.02, description="dilution_p"),
    n_steps=widgets.IntSlider(min=20, max=400, step=10, value=150, description="steps"),
    init_mode=widgets.Dropdown(options=["basal", "random", "empty"], value="basal", description="init"),
    seed=widgets.IntSlider(min=0, max=1000, step=1, value=0, description="seed"),
);

## Sanity check: total resources over time

As a simple sanity check on the default IBM universe and parameterization,
we can track the *total* amount of resources in the system over time. With a
chemostat-like picture (constant feed + dilution acting as a global outlet),
we expect total resources to grow initially and then roughly stabilize,
rather than diverging without bound.

In [None]:
env_check, species_check = build_env_species({"height": 8, "width_grid": 32})
state_check = init_state(env_check, mode="basal", rng=default_rng(123))
species_counts_check, resource_totals_check = run_simulation(
    env_check,
    species_check,
    state_check,
    n_steps=400,
    rng=default_rng(321),
)
total_resources = resource_totals_check.sum(axis=1)
t = np.arange(total_resources.size)

plt.figure(figsize=(6, 4))
plt.plot(t, total_resources)
plt.xlabel("time step")
plt.ylabel("total resources (all metabolites)")
plt.title("Total resources vs time (default universe)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Notes on initial conditions and resource grids

In the examples above we:
- Initialize cell energies around **5 units** (via `basal_energy` in the environment and `energy_mean=5.0` in the center-column helper).
- Seed a **central band with multiple species** (a cycling pattern over the species ids in the central column) to better highlight inter-species interactions.
- Visualize a few representative **resource fields** `R[j, :, :]` using `plot_resource_grid`, alongside the occupancy and energy plots, so spatial resource structure is visible on the same lattice.

In [None]:
# Legacy Plotly animation helper (disabled).
# Kept as a placeholder so the notebook can Run All without errors.

pass


In [None]:
# Legacy Plotly animation example (disabled).
# This cell is intentionally a no-op so Run All succeeds.

pass


In [None]:
# Second legacy Plotly animation helper + example (disabled).
# Left here only for reference; does nothing on Run All.

pass


## IBM internals recap

This notebook has treated the IBM as a dynamical system in its own right. The key components are:
- `GridState(occ, E, R)`: the discrete lattice state (occupancy, energy, resources).
- `EnvParams` / `SpeciesParams`: lattice and per-species parameters parsed from a config.
- `tick(state, env, species, rng)`: one synchronous IBM update (diffusion, dilution, metabolism, reproduction).

In the full `computingMicrobiome` project, the IBM is wrapped by a reservoir backend that encodes `GridState` into feature vectors for readouts. Here we focused purely on its **ecological dynamics**, which can be used as a starting point for new experiments or teaching demos.

## Manual tick-through viewer (matplotlib)

The following cells precompute a short IBM run using the same initial
conditions as above (energy around 5, multi-species central column) and
let you scrub through individual ticks with a simple slider.

This avoids any reliance on Plotly's animation machinery and should work
reliably in most notebook environments.

In [None]:
# Precompute snapshots for a simple tick-through viewer

from typing import Tuple


def make_spatial_snapshots_for_viewer(
    env: EnvParams,
    species: SpeciesParams,
    *,
    n_steps: int = 60,
    seed: int = 0,
) -> Tuple[np.ndarray, np.ndarray]:
    """Run the IBM and return (occ_frames, E_frames).

    Uses the same initial condition as the other examples:
    - energy_mean ≈ 5
    - multi-species central column band.
    """
    rng_local = default_rng(seed)
    state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng_local)

    H, W = env.height, env.width_grid

    # Multi-species central column
    center = W // 2
    max_species_band = min(env.n_species, 5)
    for r in range(H):
        state.occ[r, center] = r % max_species_band

    occ_frames = np.zeros((n_steps + 1, H, W), dtype=np.int16)
    E_frames = np.zeros((n_steps + 1, H, W), dtype=np.uint8)

    def snapshot(t_idx: int) -> None:
        occ_frames[t_idx] = state.occ
        E_frames[t_idx] = state.E

    snapshot(0)
    for t in range(1, n_steps + 1):
        tick(state, env, species, rng_local)
        snapshot(t)

    return occ_frames, E_frames


# Build a small viewer run with consistent parameters
n_steps_viewer = 60
env_view, species_view = build_env_species({"height": 8, "width_grid": 32, "basal_energy": 5})
occ_frames_view, E_frames_view = make_spatial_snapshots_for_viewer(
    env_view, species_view, n_steps=n_steps_viewer, seed=3
)

In [None]:
# Simple matplotlib + slider viewer for individual ticks


def show_tick(t: int = 0) -> None:
    t = int(np.clip(t, 0, n_steps_viewer))

    fig, axes = plt.subplots(1, 2, figsize=(8, 3))

    # Occupancy
    im0 = axes[0].imshow(occ_frames_view[t], interpolation="nearest")
    axes[0].set_title(f"Species occupancy (t={t})")
    axes[0].set_xlabel("column")
    axes[0].set_ylabel("row")

    # Energy
    im1 = axes[1].imshow(E_frames_view[t], interpolation="nearest")
    axes[1].set_title(f"Energy per cell (t={t})")
    axes[1].set_xlabel("column")
    axes[1].set_ylabel("row")

    plt.tight_layout()
    plt.show()


widgets.interact(
    show_tick,
    t=widgets.IntSlider(
        min=0,
        max=n_steps_viewer,
        step=1,
        value=0,
        description="tick",
    ),
);