# Task 4 IBM Reservoir Debug Viewer (exact current config)

This notebook visualizes Task 4 reservoir behavior **exactly as configured now** in
`example_tasks/task_4_learn_8_bit_with_ibm.py`.

It includes:
1. Tick-by-tick input stream (memory task channels).
2. Species occupancy + energy over simulation steps.
3. Interactive `show_tick(...)` slider.
4. Plotly animation (occupancy + energy).
5. Injection/extraction timing table for the 8-bit output window.


## Colab setup (if needed)


In [None]:
# If needed on fresh Colab runtime:
# !git clone https://github.com/danielriosgarza/computingMicrobiome.git
# %cd computingMicrobiome
# !uv pip install .[all]


## Imports


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# IBM-core imports (same style as ibm_simulation_demo)
from computingMicrobiome.ibm import (
    EnvParams,
    SpeciesParams,
    GridState,
    load_params,
    make_channel_to_resource_from_config,
    make_ibm_config_from_species,
)
from computingMicrobiome.ibm.state import make_zero_state
from computingMicrobiome.ibm.stepper import tick


## Local memory-task stream definition (inlined from benchmark)


In [None]:
# Channel layout for memory task
BIT = 0
BIT_FLIP = 1
DISTRACTOR = 2
CUE = 3
N_CHANNELS = 4
CHANNEL_NAMES = ["bit", "bit_flip", "distractor", "cue"]


def create_input_streams(bits_arr: np.ndarray, d_period: int) -> np.ndarray:
    B = len(bits_arr)
    L = d_period + 2 * B
    streams = np.zeros((L, 4), dtype=np.int8)

    # ch0: bits
    streams[:B, 0] = bits_arr

    # ch1: flipped bits
    streams[:B, 1] = np.bitwise_xor(bits_arr, 1)

    # ch2: distractor
    cue_idx = L - B - 1
    streams[:, 2] = 1
    streams[:B, 2] = 0
    streams[cue_idx, 2] = 0

    # ch3: cue
    streams[cue_idx, 3] = 1

    return streams


## Task 4 config (exact values from script)


In [None]:
# Exact Task 4 parameters
BITS = 8
BOUNDARY = "periodic"
RECURRENCE = 4
ITR = 12
D_PERIOD = 8
SEED = 0

TRACE_DEPTH = (D_PERIOD + BITS) * (ITR + 1) + 8
ITER_BETWEEN = ITR + 1

IBM_DIFF_NUMER = 1
IBM_DILUTION_P = 0.02
IBM_INJECT_SCALE = 2.0

IBM_CFG = make_ibm_config_from_species(
    species_indices=[0, 1, 2],
    height=8,
    width_grid=8,
    overrides={
        "state_width_mode": "raw",
        "input_trace_depth": TRACE_DEPTH,
        "input_trace_channels": N_CHANNELS,
        "input_trace_decay": 1.0,
        "inject_scale": IBM_INJECT_SCALE,
        "dilution_p": IBM_DILUTION_P,
        "diff_numer": IBM_DIFF_NUMER,
    },
)
IBM_CFG["channel_to_resource"] = make_channel_to_resource_from_config(IBM_CFG, N_CHANNELS)
IBM_CFG["inject_mode"] = "replace"

env, species = load_params(IBM_CFG)
WIDTH = int(env.height) * int(env.width_grid)

print("grid:", env.height, "x", env.width_grid, "cells:", WIDTH)
print("species/resources:", env.n_species, env.n_resources)
print("inject_scale:", env.inject_scale)
print("dilution_p:", env.dilution_p)
print("diff_numer:", env.diff_numer)
print("inject_mode:", IBM_CFG["inject_mode"])
print("channel_to_resource:", IBM_CFG["channel_to_resource"])
print("iter_between:", ITER_BETWEEN)
print("trace_depth:", TRACE_DEPTH)


## Build one sample stream and extraction timing


In [None]:
# Example 8-bit memory input
bits_arr = np.array([1, 0, 1, 0, 1, 1, 0, 0], dtype=np.int8)

input_streams = create_input_streams(bits_arr, D_PERIOD)
L = input_streams.shape[0]
T = L * ITER_BETWEEN

# Task 4 extraction: output window is last BITS ticks
extract_ticks = np.arange(L - BITS, L)
extract_steps = extract_ticks * ITER_BETWEEN

print("L ticks:", L)
print("T steps:", T)
print("extract ticks:", extract_ticks.tolist())
print("extract steps:", extract_steps.tolist())


## Stream heatmap (channels x ticks)


In [None]:
fig, ax = plt.subplots(figsize=(12, 4.5))
ax.imshow(input_streams.T, aspect="auto", interpolation="nearest", cmap="Greys")
ax.set_title("Task 4 input stream (channels x ticks)")
ax.set_xlabel("tick")
ax.set_ylabel("channel")
ax.set_yticks(np.arange(N_CHANNELS))
ax.set_yticklabels(CHANNEL_NAMES)

# Highlight extraction window ticks (last 8)
ax.axvspan(extract_ticks[0] - 0.5, extract_ticks[-1] + 0.5, color="red", alpha=0.12, label="output window")

ax.set_xticks(np.arange(-0.5, input_streams.shape[0], 1), minor=True)
ax.set_yticks(np.arange(-0.5, N_CHANNELS, 1), minor=True)
ax.grid(which="minor", color="white", linewidth=0.6)
ax.tick_params(which="minor", bottom=False, left=False)

ax.legend(loc="upper right")
plt.tight_layout()
plt.show()


## Helpers: input locations, init-state, injection


In [None]:
def create_input_locations(width: int, recurrence: int, input_channels: int, rng: np.random.Generator) -> np.ndarray:
    if width < recurrence:
        raise ValueError("width must be >= recurrence")
    single_min = width // recurrence
    rest = width % recurrence
    if input_channels > single_min:
        raise ValueError("input_channels exceeds minimum segment width")

    r_widths = np.full(recurrence, single_min, dtype=int)
    r_widths[:rest] += 1

    locs = []
    offset = 0
    for i in range(recurrence):
        seg_w = r_widths[i]
        seg_positions = rng.choice(seg_w, size=input_channels, replace=False)
        locs.extend((seg_positions + offset).tolist())
        offset += seg_w

    return np.array(locs, dtype=int)


def init_state_like_backend_zeros(env: EnvParams) -> GridState:
    # Mirrors IBMReservoirBackend.reset(..., x0_mode='zeros')
    if not env.basal_init:
        return make_zero_state(height=env.height, width_grid=env.width_grid, n_resources=env.n_resources)

    rr, cc = np.indices((env.height, env.width_grid))
    if env.basal_pattern == "stripes":
        sid = (rr % env.n_species).astype(np.int16)
    else:
        sid = ((rr + cc) % env.n_species).astype(np.int16)

    if env.basal_occupancy >= 1.0:
        occupied = np.ones((env.height, env.width_grid), dtype=bool)
    elif env.basal_occupancy <= 0.0:
        occupied = np.zeros((env.height, env.width_grid), dtype=bool)
    else:
        key = (rr * 73856093 + cc * 19349663) % 1000
        occupied = key < int(env.basal_occupancy * 1000.0)

    occ = np.full((env.height, env.width_grid), -1, dtype=np.int16)
    occ[occupied] = sid[occupied]

    E = np.zeros((env.height, env.width_grid), dtype=np.uint8)
    if env.basal_energy > 0:
        E[occupied] = np.uint8(env.basal_energy)

    br_vec = getattr(env, "basal_resource_vec", None)
    if br_vec is not None:
        br = np.asarray(br_vec, dtype=np.uint8).reshape(env.n_resources)
        R = np.broadcast_to(br[:, None, None], (env.n_resources, env.height, env.width_grid)).copy()
    else:
        R = np.full((env.n_resources, env.height, env.width_grid), np.uint8(env.basal_resource), dtype=np.uint8)

    return GridState(occ=occ, E=E, R=R)


def inject_packet_into_state(
    state: GridState,
    env: EnvParams,
    input_values: np.ndarray,
    input_locations: np.ndarray,
    channel_idx: np.ndarray,
    *,
    inject_mode: str = "replace",
):
    # Mirrors IBMReservoirBackend.inject
    if input_locations.size == 0 or input_values.size == 0:
        return

    cells = env.height * env.width_grid
    loc = np.mod(input_locations.astype(np.int64, copy=False), cells)
    rr = (loc // env.width_grid).astype(np.int64, copy=False)
    cc = (loc % env.width_grid).astype(np.int64, copy=False)

    ch = channel_idx.astype(np.int64, copy=False)
    vals = input_values.reshape(-1)
    scaled = np.rint(vals[ch % vals.size].astype(np.float32) * env.inject_scale)
    add = np.maximum(scaled, 0.0).astype(np.int32)

    if env.channel_to_resource is None:
        m_idx = np.mod(ch, env.n_resources).astype(np.int64)
    else:
        mapping = env.channel_to_resource
        m_idx = mapping[ch % mapping.size].astype(np.int64, copy=False)

    R_work = state.R.astype(np.int32, copy=True)
    if inject_mode == "replace":
        R_work[m_idx, rr, cc] = add
    else:
        np.add.at(R_work, (m_idx, rr, cc), add)
    state.R = np.clip(R_work, 0, env.Rmax).astype(np.uint8)


## Simulate step-by-step with Task 4 schedule


In [None]:
rng = np.random.default_rng(SEED)
state = init_state_like_backend_zeros(env)

input_locations = create_input_locations(WIDTH, RECURRENCE, N_CHANNELS, rng)
channel_idx = np.arange(input_locations.size) % N_CHANNELS
inject_mode = str(IBM_CFG.get("inject_mode", "replace")).strip().lower()

occ_frames_view = np.zeros((T, env.height, env.width_grid), dtype=np.int16)
E_frames_view = np.zeros((T, env.height, env.width_grid), dtype=np.uint8)
R0_frames_view = np.zeros((T, env.height, env.width_grid), dtype=np.uint8)
resource_total = np.zeros(T, dtype=np.int64)

step_to_tick = np.full(T, -1, dtype=np.int32)
step_packet = np.zeros((T, N_CHANNELS), dtype=np.int8)

injection_delta_total = []

tick_idx = 0
for step in range(T):
    if step % ITER_BETWEEN == 0:
        packet = input_streams[tick_idx]
        before = int(state.R.astype(np.int64).sum())
        inject_packet_into_state(
            state,
            env,
            packet,
            input_locations,
            channel_idx,
            inject_mode=inject_mode,
        )
        after = int(state.R.astype(np.int64).sum())

        step_to_tick[step] = tick_idx
        step_packet[step] = packet
        injection_delta_total.append((step, tick_idx, after - before))
        tick_idx += 1

    # Snapshot before dynamics step
    occ_frames_view[step] = state.occ
    E_frames_view[step] = state.E
    R0_frames_view[step] = state.R[0]
    resource_total[step] = int(state.R.astype(np.int64).sum())

    tick(state, env, species, rng)

n_steps_viewer = T - 1
print("captured frames:", occ_frames_view.shape, E_frames_view.shape, R0_frames_view.shape)
print("")
print("first 12 injection deltas (change in total resources at inject steps):")
for row in injection_delta_total[:12]:
    print(row)


## Injection/extraction timing table (per bit)


In [None]:
print("bit_idx | write_tick | extract_tick | delay_ticks | write_step | extract_step | delay_steps")
print("-" * 95)
for i in range(BITS):
    wt = i
    et = int(extract_ticks[i])
    dt = et - wt
    ws = wt * ITER_BETWEEN
    es = et * ITER_BETWEEN
    ds = es - ws
    print(f"{i:7d} | {wt:10d} | {et:12d} | {dt:11d} | {ws:10d} | {es:12d} | {ds:11d}")


## Interactive step viewer (species + energy + R0)


In [None]:
def _step_note(step: int) -> str:
    tags = []
    if step_to_tick[step] >= 0:
        tick_idx = int(step_to_tick[step])
        active_channels = np.where(step_packet[step] == 1)[0].tolist()
        tags.append(f"inject tick={tick_idx}, channels={active_channels}")

    if step in set(extract_steps.tolist()):
        tags.append("EXTRACTION WINDOW")

    return " | ".join(tags) if tags else "no injection"


def show_tick(t: int = 0) -> None:
    t = int(np.clip(t, 0, n_steps_viewer))

    fig, axes = plt.subplots(1, 3, figsize=(11, 3))

    axes[0].imshow(occ_frames_view[t], interpolation="nearest")
    axes[0].set_title(f"Species occupancy (step={t})")
    axes[0].set_xlabel("column")
    axes[0].set_ylabel("row")

    axes[1].imshow(E_frames_view[t], interpolation="nearest")
    axes[1].set_title(f"Energy per cell (step={t})")
    axes[1].set_xlabel("column")
    axes[1].set_ylabel("row")

    axes[2].imshow(R0_frames_view[t], interpolation="nearest")
    axes[2].set_title(f"Resource R0 (step={t})")
    axes[2].set_xlabel("column")
    axes[2].set_ylabel("row")

    fig.suptitle(_step_note(t), fontsize=10)
    plt.tight_layout()
    plt.show()


widgets.interact(
    show_tick,
    t=widgets.IntSlider(min=0, max=n_steps_viewer, step=1, value=0, description="step"),
)


## Plotly animation (occupancy + energy)


In [None]:
OCC_EMPTY_COLOR = "#f5f5f5"
OCC_SPECIES_COLORS = px.colors.qualitative.Set2
ENERGY_COLORSCALE = "Viridis"


def get_species_colors(n_species: int) -> list[str]:
    colors = [OCC_EMPTY_COLOR] + OCC_SPECIES_COLORS
    if len(colors) < n_species + 1:
        reps = (n_species + 1 + len(colors) - 1) // len(colors)
        colors = (colors * reps)[: n_species + 1]
    return colors


def make_spatial_animation_plotly_from_frames(
    occ_frames: np.ndarray,
    E_frames: np.ndarray,
    *,
    n_species: int,
    inject_steps: np.ndarray,
    extract_steps: np.ndarray,
):
    H, W = occ_frames.shape[1], occ_frames.shape[2]
    n_steps = occ_frames.shape[0] - 1
    colors = get_species_colors(n_species)

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Species occupancy", "Energy per cell"),
        horizontal_spacing=0.08,
    )

    trace_configs = []

    for idx, val in enumerate([-1] + list(range(n_species))):
        mask0 = np.where(occ_frames[0] == val, 1.0, np.nan)
        label = "empty" if val == -1 else f"s{val}"
        conf = dict(
            x=list(range(W)),
            y=list(range(H)),
            colorscale=[[0.0, colors[idx]], [1.0, colors[idx]]],
            showscale=False,
            name=label,
            hoverinfo="skip",
            opacity=0.95,
            xgap=1,
            ygap=1,
        )
        trace_configs.append(conf)
        fig.add_trace(go.Heatmap(z=mask0, **conf), row=1, col=1)

    energy_conf = dict(
        x=list(range(W)),
        y=list(range(H)),
        colorscale=ENERGY_COLORSCALE,
        colorbar=dict(title="energy"),
        showscale=True,
        name="energy",
        xgap=1,
        ygap=1,
        showlegend=False,
    )
    trace_configs.append(energy_conf)
    fig.add_trace(go.Heatmap(z=E_frames[0], **energy_conf), row=1, col=2)

    for c in (1, 2):
        fig.update_xaxes(title_text="column", row=1, col=c, dtick=1, range=[-0.5, W - 0.5], showgrid=True, gridcolor="white", gridwidth=1)
        fig.update_yaxes(title_text="row", row=1, col=c, dtick=1, range=[H - 0.5, -0.5], showgrid=True, gridcolor="white", gridwidth=1)

    extract_set = set(int(x) for x in extract_steps.tolist())

    frames = []
    for step in range(n_steps + 1):
        frame_data = []
        for idx, val in enumerate([-1] + list(range(n_species))):
            mask = np.where(occ_frames[step] == val, 1.0, np.nan)
            frame_data.append(go.Heatmap(z=mask, **trace_configs[idx]))
        frame_data.append(go.Heatmap(z=E_frames[step], **trace_configs[-1]))

        tags = []
        if inject_steps[step]:
            tags.append("inject")
        if step in extract_set:
            tags.append("extract")
        title = f"step={step}" + (" (" + ", ".join(tags) + ")" if tags else "")

        frames.append(go.Frame(data=frame_data, name=str(step), layout=go.Layout(title_text=title)))

    fig.frames = frames

    fig.update_layout(
        title_text="step=0",
        height=500,
        width=1000,
        legend_title_text="species",
        legend=dict(orientation="h", yanchor="top", y=-0.1, xanchor="left", x=0.0),
        template="plotly_white",
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(label="Play", method="animate", args=[None, {"frame": {"duration": 80, "redraw": True}, "fromcurrent": True}]),
                    dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
                ],
                x=0.0,
                y=-0.16,
            )
        ],
    )

    return fig


In [None]:
inject_steps = step_to_tick >= 0
fig = make_spatial_animation_plotly_from_frames(
    occ_frames_view,
    E_frames_view,
    n_species=env.n_species,
    inject_steps=inject_steps,
    extract_steps=extract_steps,
)
fig


## Notes

This notebook intentionally mirrors the current Task 4 IBM config, including:

- `inject_mode="replace"`
- explicit `channel_to_resource` mapping
- non-zero `inject_scale`, `dilution_p`, and `diff_numer`

Use it to inspect exactly how packets are injected and how the state evolves before extraction.
