# IBM microbiome simulation demo (standalone, non-reservoir)

This notebook demonstrates the individual-based microbiome (IBM) model used in `computingMicrobiome`, **only as a dynamical system**, not yet as a reservoir.

We will:
- Build and inspect small IBM grids (`GridState`).
- Run simulations over time using the low-level `tick` function.
- Visualize species and resource dynamics (time series and simple animations).
- Use interactive widgets to explore how changing parameters (number of species, resources, diffusion, dilution, etc.) affects the dynamics.

The notebook is designed to run in **Google Colab** from a fresh runtime.

## Environment setup (Colab)

If you are running this in **Google Colab** from a fresh runtime, run the following cell first to clone the repository and install the package.

> Note: If you are running locally inside the `computingMicrobiome` repo with the package already installed, you can skip this cell.

In [None]:
# Uncomment this cell if running in Google Colab from a fresh runtime.
# It will clone the repo and install the package (including extras).

# !git clone https://github.com/danielriosgarza/computingMicrobiome.git
# %cd computingMicrobiome
# !uv pip install .[all]

## Imports and basic setup

We import the IBM core components (`GridState`, `EnvParams`, `SpeciesParams`, `load_params`, `make_zero_state`, `tick`) and the plotting / widget libraries.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from numpy.random import default_rng

import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

from computingMicrobiome.ibm import (
    EnvParams,
    SpeciesParams,
    GridState,
    load_params,
    make_ibm_config_from_species,
    make_center_column_state,
)
from computingMicrobiome.ibm.state import make_zero_state
from computingMicrobiome.ibm.stepper import tick

plt.style.use("seaborn-v0_8-darkgrid")
rng = default_rng(0)

In [None]:
# Shared Plotly styles for IBM grids

OCC_EMPTY_COLOR = "#f5f5f5"
OCC_SPECIES_COLORS = px.colors.qualitative.Set2
ENERGY_COLORSCALE = "Viridis"


def get_species_colors(n_species: int) -> list[str]:
    colors = [OCC_EMPTY_COLOR] + OCC_SPECIES_COLORS
    if len(colors) < n_species + 1:
        reps = (n_species + 1 + len(colors) - 1) // len(colors)
        colors = (colors * reps)[: n_species + 1]
    return colors


def plot_occupancy_and_energy(env: EnvParams, state: GridState, title_suffix: str = "") -> None:
    """Two-panel Plotly figure: species occupancy + energy, with toggleable species."""
    n_species = env.n_species
    H, W = state.occ.shape
    colors = get_species_colors(n_species)

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=(f"Species occupancy{title_suffix}", f"Energy per cell{title_suffix}"),
        horizontal_spacing=0.08,
    )

    for idx, val in enumerate([-1] + list(range(n_species))):
        mask = np.where(state.occ == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        fig.add_trace(
            go.Heatmap(
                z=mask,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
                showscale=False,
                name=label,
                hoverinfo="skip",
                opacity=0.95,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=1,
        )

    fig.add_trace(
        go.Heatmap(
            z=state.E,
            x=list(range(W)),
            y=list(range(H)),
            colorscale=ENERGY_COLORSCALE,
            colorbar=dict(title="energy"),
            showscale=True,
            name="energy",
            xgap=1,
            ygap=1,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    for c in (1, 2):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    fig.update_layout(
        height=500,
        width=1000,
        legend_title_text="species",
        legend=dict(
            orientation="h",
            yanchor="top",
            y=-0.1,
            xanchor="left",
            x=0.0,
        ),
        template="plotly_white",
        legend_itemclick="toggle",
        legend_itemdoubleclick="toggleothers",
    )
    fig.show()


In [None]:
def plot_resource_grid(env: EnvParams, state: GridState, max_resources: int = 3) -> None:
    """Plot one or a few resource fields as heatmaps (Plotly).

    Shows up to ``max_resources`` resource planes R[j, :, :] as separate
    panels with the same square-grid styling used for occupancy/energy.
    """
    H, W = state.occ.shape
    M = state.R.shape[0]
    m = min(max_resources, M)
    if m == 0:
        raise ValueError("No resources to plot (env.n_resources == 0)")

    fig = make_subplots(
        rows=1,
        cols=m,
        subplot_titles=[f"R{j}" for j in range(m)],
        horizontal_spacing=0.08,
    )

    for j in range(m):
        Rm = state.R[j]
        fig.add_trace(
            go.Heatmap(
                z=Rm,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=ENERGY_COLORSCALE,
                showscale=(j == m - 1),
                colorbar=dict(title="amount") if j == m - 1 else None,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=j + 1,
        )

    for c in range(1, m + 1):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    fig.update_layout(
        height=350,
        width=350 * m,
        template="plotly_white",
    )
    fig.show()

## IBM state initialization helpers

We define helper functions to:
- Build default/environment parameters for a small grid.
- Initialize a `GridState` in different ways (empty, basal pattern, random).

These mirror the logic used in the IBM reservoir backend but keep everything focused on the raw IBM dynamics.

In [None]:
def build_env_species(
    config_overrides: dict | None = None,
    species_indices: list[int] | None = None,
) -> tuple[EnvParams, SpeciesParams]:
    """Create `EnvParams` and `SpeciesParams` from the IBM universe.

    By default this selects the first few species from the global IBM
    universe (with 50 species and 100 resources), but you can override both
    the list of species indices and lattice-level parameters via
    `config_overrides`.
    """
    if species_indices is None:
        species_indices = [0, 1, 2, 3]

    base_cfg = make_ibm_config_from_species(
        species_indices=species_indices,
        height=16,
        width_grid=32,
    )
    if config_overrides is not None:
        base_cfg.update(config_overrides)
    env, species = load_params(base_cfg)
    return env, species


def init_state(env: EnvParams, mode: str = "basal", rng: np.random.Generator | None = None) -> GridState:
    """Initialize a `GridState` for the given environment.

    Modes:
    - "empty": all cells empty, resources zero.
    - "basal": deterministic pattern based on `basal_*` settings.
    - "random": random occupancy, energies, and resources.
    """
    if rng is None:
        rng = default_rng()

    if mode == "empty":
        return make_zero_state(
            height=env.height,
            width_grid=env.width_grid,
            n_resources=env.n_resources,
        )

    if mode == "basal":
        rr, cc = np.indices((env.height, env.width_grid))
        if env.basal_pattern == "stripes":
            sid = (rr % env.n_species).astype(np.int16)
        else:
            sid = ((rr + cc) % env.n_species).astype(np.int16)

        if env.basal_occupancy >= 1.0:
            occupied = np.ones((env.height, env.width_grid), dtype=bool)
        elif env.basal_occupancy <= 0.0:
            occupied = np.zeros((env.height, env.width_grid), dtype=bool)
        else:
            # Deterministic occupancy mask for reproducible runs.
            key = (rr * 73856093 + cc * 19349663) % 1000
            occupied = key < int(env.basal_occupancy * 1000.0)

        occ = np.full((env.height, env.width_grid), -1, dtype=np.int16)
        occ[occupied] = sid[occupied]

        E = np.zeros((env.height, env.width_grid), dtype=np.uint8)
        if env.basal_energy > 0:
            E[occupied] = np.uint8(env.basal_energy)

        # Initialize resources using per-resource basal levels if available,
        # otherwise fall back to the scalar basal_resource.
        br_vec = getattr(env, "basal_resource_vec", None)
        if br_vec is not None:
            br = np.asarray(br_vec, dtype=np.uint8).reshape(env.n_resources)
            R = np.broadcast_to(
                br[:, None, None],
                (env.n_resources, env.height, env.width_grid),
            ).copy()
        else:
            R = np.full(
                (env.n_resources, env.height, env.width_grid),
                np.uint8(env.basal_resource),
                dtype=np.uint8,
            )
        return GridState(occ=occ, E=E, R=R)

    if mode == "random":
        occ = np.full((env.height, env.width_grid), -1, dtype=np.int16)
        occupied = rng.random((env.height, env.width_grid)) < 0.5
        sid = rng.integers(
            0,
            env.n_species,
            size=(env.height, env.width_grid),
            dtype=np.int16,
        )
        occ[occupied] = sid[occupied]

        E = np.zeros((env.height, env.width_grid), dtype=np.uint8)
        if np.any(occupied):
            rand_e = rng.integers(
                1,
                env.Emax + 1,
                size=int(np.count_nonzero(occupied)),
                dtype=np.uint16,
            ).astype(np.uint8)
            E[occupied] = rand_e

        R = rng.integers(
            0,
            env.Rmax + 1,
            size=(env.n_resources, env.height, env.width_grid),
            dtype=np.uint16,
        ).astype(np.uint8)
        return GridState(occ=occ, E=E, R=R)

    raise ValueError("mode must be one of {'empty', 'basal', 'random'}")

### Initial condition used in the examples

In the examples below we initialize the lattice with:

- **Energy around 5 units per occupied cell**, by setting `basal_energy=5` and drawing cell energies around this level.
- **A single central column seeded with multiple species** (cycling through a handful of species IDs down the column), so interactions are easier to see than with just one species.
- **Resource fields initialized from the per-resource basal vector** implied by the IBM universe, and visualized with a dedicated `plot_resource_grid` helper that shows a few representative resources.

This gives a compact but non-trivial starting state that is reused consistently for the static snapshot, the basic run, and the spatial animation.

## Inspect a small IBM state

We now create a small grid, initialize it, and look at the arrays that represent:
- `occ`: occupancy (species id or -1 for empty)
- `E`: per-cell energy
- `R`: per-resource field over the grid

In [None]:
env, species = build_env_species({"height": 8, "width_grid": 16, "basal_energy": 5})
state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng)

print("Grid shape (H, W):", state.occ.shape)
print("Energy shape (H, W):", state.E.shape)
print("Resources shape (M, H, W):", state.R.shape)

# Enrich the initial condition: use multiple species in the central column.
H, W = state.occ.shape
center = W // 2
max_species_band = min(env.n_species, 5)
for r in range(H):
    state.occ[r, center] = r % max_species_band

# Plot occupancy + energy with the shared helper, and then a few resource grids.
plot_occupancy_and_energy(env, state, title_suffix=" (initial)")
plot_resource_grid(env, state, max_resources=3)

## Running the IBM over time

We define a helper to run the IBM for a number of ticks and record summary observables:
- Species counts over time.
- Total resources per resource type over time.

In [None]:
def run_simulation(
    env: EnvParams,
    species: SpeciesParams,
    state: GridState,
    *,
    n_steps: int = 200,
    rng: np.random.Generator | None = None,
):
    """Run the IBM for `n_steps`, returning species and resource trajectories.

    Returns
    -------
    species_counts : array, shape (n_steps+1, n_species)
        Number of occupied cells per species at each time.
    resource_totals : array, shape (n_steps+1, n_resources)
        Total resource amount per resource type at each time.
    """
    if rng is None:
        rng = default_rng()

    species_counts = np.zeros((n_steps + 1, env.n_species), dtype=np.int32)
    resource_totals = np.zeros((n_steps + 1, env.n_resources), dtype=np.int32)

    def measure(t_idx: int) -> None:
        occ = state.occ
        for s in range(env.n_species):
            species_counts[t_idx, s] = np.count_nonzero(occ == s)
        R = state.R.reshape(env.n_resources, -1)
        resource_totals[t_idx] = R.sum(axis=1)

    measure(0)
    for t in range(1, n_steps + 1):
        tick(state, env, species, rng)
        measure(t)

    return species_counts, resource_totals


def plot_time_series(species_counts: np.ndarray, resource_totals: np.ndarray) -> None:
    """Plot species and resource trajectories over time using Plotly."""
    T = species_counts.shape[0] - 1
    t = np.arange(T + 1)

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Species counts", "Resource totals"),
        horizontal_spacing=0.12,
    )

    # Species counts
    for s in range(species_counts.shape[1]):
        fig.add_trace(
            go.Scatter(
                x=t,
                y=species_counts[:, s],
                mode="lines",
                name=f"s{s}",
            ),
            row=1,
            col=1,
        )

    # Resource totals
    for m in range(resource_totals.shape[1]):
        fig.add_trace(
            go.Scatter(
                x=t,
                y=resource_totals[:, m],
                mode="lines",
                name=f"R{m}",
            ),
            row=1,
            col=2,
        )

    fig.update_xaxes(title_text="time step", row=1, col=1)
    fig.update_yaxes(title_text="# occupied cells", row=1, col=1)
    fig.update_xaxes(title_text="time step", row=1, col=2)
    fig.update_yaxes(title_text="total resource", row=1, col=2)

    fig.update_layout(
        height=400,
        width=900,
        template="plotly_white",
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0.0),
    )
    fig.show()

### Example: basic IBM run

We now run the IBM for a few hundred steps and plot the resulting trajectories.

In [None]:
env, species = build_env_species({"height": 8, "width_grid": 32, "basal_energy": 5})
state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=default_rng(1))

# Enrich the initial condition in the same way: multi-species central column.
H, W = state.occ.shape
center = W // 2
max_species_band = min(env.n_species, 5)
for r in range(H):
    state.occ[r, center] = r % max_species_band

species_counts, resource_totals = run_simulation(env, species, state, n_steps=200, rng=default_rng(2))
plot_time_series(species_counts, resource_totals)
# Optional: final snapshot with the same styling as the initial example
plot_occupancy_and_energy(env, state, title_suffix=" (final)")

## Animation of spatial dynamics

Next we build a simple animation that shows how occupancy and energy fields evolve over time on the lattice.

In [None]:
def make_spatial_animation_plotly(
    env: EnvParams,
    species: SpeciesParams,
    *,
    n_steps: int = 100,
    seed: int = 0,
):
    """Create a Plotly animation of IBM spatial dynamics.

    Uses the same colors and layout as `plot_occupancy_and_energy`.
    """
    rng_local = default_rng(seed)
    state = make_center_column_state(env, species_id=0, energy_mean=5.0, rng=rng_local)

    # Enrich the initial condition: multi-species central column, as in the other examples.
    H, W = env.height, env.width_grid
    max_species_band = min(env.n_species, 5)
    center = W // 2
    for r in range(H):
        state.occ[r, center] = r % max_species_band

    n_species = env.n_species
    colors = get_species_colors(n_species)

    # Pre-compute frames
    occ_frames = np.zeros((n_steps + 1, H, W), dtype=np.int16)
    E_frames = np.zeros((n_steps + 1, H, W), dtype=np.uint8)

    def snapshot(t_idx: int) -> None:
        occ_frames[t_idx] = state.occ
        E_frames[t_idx] = state.E

    snapshot(0)
    for t in range(1, n_steps + 1):
        tick(state, env, species, rng_local)
        snapshot(t)

    # Base figure for frame 0
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Species occupancy", "Energy per cell"),
        horizontal_spacing=0.08,
    )

    # Occupancy traces (one per species + empty) at t=0
    for idx, val in enumerate([-1] + list(range(n_species))):
        mask0 = np.where(occ_frames[0] == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        fig.add_trace(
            go.Heatmap(
                z=mask0,
                x=list(range(W)),
                y=list(range(H)),
                colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
                showscale=False,
                name=label,
                hoverinfo="skip",
                opacity=0.95,
                xgap=1,
                ygap=1,
            ),
            row=1,
            col=1,
        )

    # Energy trace at t=0
    fig.add_trace(
        go.Heatmap(
            z=E_frames[0],
            x=list(range(W)),
            y=list(range(H)),
            colorscale=ENERGY_COLORSCALE,
            colorbar=dict(title="energy"),
            showscale=True,
            name="energy",
            xgap=1,
            ygap=1,
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    # Axes styling
    for c in (1, 2):
        fig.update_xaxes(
            title_text="column",
            row=1,
            col=c,
            dtick=1,
            range=[-0.5, W - 0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )
        fig.update_yaxes(
            title_text="row",
            row=1,
            col=c,
            dtick=1,
            range=[H - 0.5, -0.5],
            showgrid=True,
            gridcolor="white",
            gridwidth=1,
        )

    # Build frames: update z for all occupancy traces + energy trace
    frames = []
    for t_idx in range(n_steps + 1):
        data = []
        for val in [-1] + list(range(n_species)):
            mask = np.where(occ_frames[t_idx] == val, 1.0, np.nan)
            data.append({"z": [mask]})
        data.append({"z": [E_frames[t_idx]]})
        frames.append(go.Frame(data=[
            go.Heatmap(z=data[i]["z"][0]) for i in range(len(data))
        ]))

    fig.frames = frames

    fig.update_layout(
        height=500,
        width=1000,
        legend_title_text="species",
        legend=dict(
            orientation="h",
            yanchor="top",
            y=-0.1,
            xanchor="left",
            x=0.0,
        ),
        template="plotly_white",
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[[None], {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}],
                    ),
                ],
                x=0.0,
                y=-0.16,
            )
        ],
    )

    return fig

In [None]:
# Example animation for a small grid (may take a couple of seconds to render).
env_anim, species_anim = build_env_species({"height": 8, "width_grid": 32, "basal_energy": 4})
fig_anim = make_spatial_animation_plotly(env_anim, species_anim, n_steps=60, seed=3)
fig_anim.show()

## Interactive exploration with widgets

Finally, we expose a few key IBM parameters through widgets so you can quickly explore how they affect the dynamics:
- Number of species and resources.
- Diffusion and dilution parameters.
- Grid size and number of steps.

For responsiveness, we keep default runs relatively small.

In [None]:
def simulate_and_plot(
    height: int = 8,
    width_grid: int = 24,
    n_species: int = 4,
    diff_numer: int = 1,
    diff_denom: int = 8,
    dilution_p: float = 0.01,
    n_steps: int = 150,
    init_mode: str = "basal",
    seed: int = 0,
):
    # Select the first `n_species` from the global IBM universe and build a
    # config suitable for `load_params`.
    indices = list(range(int(n_species)))
    cfg = make_ibm_config_from_species(
        species_indices=indices,
        height=height,
        width_grid=width_grid,
        diff_numer=diff_numer,
        diff_denom=diff_denom,
        dilution_p=dilution_p,
    )
    env, species = load_params(cfg)
    rng_local = default_rng(seed)
    state = init_state(env, mode=init_mode, rng=rng_local)
    species_counts, resource_totals = run_simulation(env, species, state, n_steps=n_steps, rng=rng_local)
    plot_time_series(species_counts, resource_totals)


widgets.interact(
    simulate_and_plot,
    height=widgets.IntSlider(min=4, max=32, step=2, value=8, description="height"),
    width_grid=widgets.IntSlider(min=8, max=64, step=4, value=24, description="width"),
    n_species=widgets.IntSlider(min=1, max=10, step=1, value=4, description="#species"),
    diff_numer=widgets.IntSlider(min=0, max=4, step=1, value=1, description="diff_numer"),
    diff_denom=widgets.IntSlider(min=1, max=16, step=1, value=8, description="diff_denom"),
    dilution_p=widgets.FloatSlider(min=0.0, max=0.2, step=0.005, value=0.02, description="dilution_p"),
    n_steps=widgets.IntSlider(min=20, max=400, step=10, value=150, description="steps"),
    init_mode=widgets.Dropdown(options=["basal", "random", "empty"], value="basal", description="init"),
    seed=widgets.IntSlider(min=0, max=1000, step=1, value=0, description="seed"),
);

## Sanity check: total resources over time

As a simple sanity check on the default IBM universe and parameterization,
we can track the *total* amount of resources in the system over time. With a
chemostat-like picture (constant feed + dilution acting as a global outlet),
we expect total resources to grow initially and then roughly stabilize,
rather than diverging without bound.

In [None]:
env_check, species_check = build_env_species({"height": 8, "width_grid": 32})
state_check = init_state(env_check, mode="basal", rng=default_rng(123))
species_counts_check, resource_totals_check = run_simulation(
    env_check,
    species_check,
    state_check,
    n_steps=400,
    rng=default_rng(321),
)
total_resources = resource_totals_check.sum(axis=1)
t = np.arange(total_resources.size)

plt.figure(figsize=(6, 4))
plt.plot(t, total_resources)
plt.xlabel("time step")
plt.ylabel("total resources (all metabolites)")
plt.title("Total resources vs time (default universe)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Notes on initial conditions and resource grids

In the examples above we:
- Initialize cell energies around **5 units** (via `basal_energy` in the environment and `energy_mean=5.0` in the center-column helper).
- Seed a **central band with multiple species** (a cycling pattern over the species ids in the central column) to better highlight inter-species interactions.
- Visualize a few representative **resource fields** `R[j, :, :]` using `plot_resource_grid`, alongside the occupancy and energy plots, so spatial resource structure is visible on the same lattice.

## IBM internals recap

This notebook has treated the IBM as a dynamical system in its own right. The key components are:
- `GridState(occ, E, R)`: the discrete lattice state (occupancy, energy, resources).
- `EnvParams` / `SpeciesParams`: lattice and per-species parameters parsed from a config.
- `tick(state, env, species, rng)`: one synchronous IBM update (diffusion, dilution, metabolism, reproduction).

In the full `computingMicrobiome` project, the IBM is wrapped by a reservoir backend that encodes `GridState` into feature vectors for readouts. Here we focused purely on its **ecological dynamics**, which can be used as a starting point for new experiments or teaching demos.