In [23]:
# %pip install ipywidgets  # if needed; enable widgets in classic notebook: jupyter nbextension enable --py widgetsnbextension
#%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import (
    VBox, HBox, Layout, Dropdown, FloatSlider, Checkbox, Text, HTML, interactive_output
)

# ---------------- Units helpers ----------------
TIME_UNITS = {"s":1.0, "min":60.0, "h":3600.0, "day":86400.0}
MASS_UNITS = {"kg":1.0, "g":1e-3, "mg":1e-6, "µg":1e-9}

def to_base_time(val, tu):           return val * TIME_UNITS[tu]               # -> seconds
def to_base_rate_per_time(k, tu):    return k / TIME_UNITS[tu]                 # (1/time_unit) -> 1/s
def to_base_S(S, mu, tu):            return S * MASS_UNITS[mu] / TIME_UNITS[tu]# (mass_unit/time_unit)->kg/s
def from_base_mass(m_kg, mu):        return m_kg / MASS_UNITS[mu]              # kg -> mass_unit

def fmt(x, nd=3):
    # concise formatter for helper text
    if not np.isfinite(x): return "∞"
    if x == 0: return "0"
    mag = abs(x)
    if (mag >= 1e3) or (mag < 1e-2):
        return f"{x:.{nd}e}"
    return f"{x:.{nd}g}"

# ---------------- Physics ----------------
def m_solution(t_s, S_kg_s, k_s, m0_kg):
    if k_s <= 0.0:
        m = m0_kg + S_kg_s * t_s
        dm = np.full_like(t_s, S_kg_s)
        return m, dm, np.inf, np.nan
    e = np.exp(-k_s * t_s)
    m_star = S_kg_s / k_s
    m = m0_kg * e + m_star * (1.0 - e)
    dm = S_kg_s - k_s * m
    tau = 1.0 / k_s
    return m, dm, tau, m_star

# ---------------- Global controls ----------------
species_count = Dropdown(options=[("One species",1), ("Two species",2), ("Three species",3)],
                         value=1, description="Species:")
time_unit     = Dropdown(options=list(TIME_UNITS.keys()), value="h", description="Time unit:")
mass_unit     = Dropdown(options=list(MASS_UNITS.keys()), value="µg", description="Mass unit:")
tmax          = FloatSlider(value=48.0, min=1.0, max=240.0, step=1.0, description="t_max", readout_format=".0f")

show_tau    = Checkbox(value=True,  description="Show τ")
show_thalf  = Checkbox(value=True,  description="Show t½")
show_3tau   = Checkbox(value=True,  description="Show 3τ")
show_dm     = Checkbox(value=True,  description="Show dm/dt panel")

# ---------------- Species parameter blocks with helper text ----------------
def species_block(default_name):
    name = Text(value=default_name, description="Name:")
    S    = FloatSlider(value=10.0, min=0.0,  max=1e4, step=1.0, description="S")
    k    = FloatSlider(value=0.10, min=0.0,  max=2.0,  step=0.01, description="k")
    m0   = FloatSlider(value=0.0,  min=0.0,  max=1e5,  step=1.0,  description="m0")
    help_html = HTML(value="", layout=Layout(width="100%"))
    row  = VBox([HBox([name, S, k, m0], layout=Layout(gap="10px")), help_html])
    return {"name":name, "S":S, "k":k, "m0":m0, "row":row, "help":help_html}

A = species_block("Species A")
B = species_block("Species B")
C = species_block("Species C")
species_blocks = [A, B, C]

# ---------------- Dynamic unit labels & helper text ----------------
def update_labels_and_help():
    tu = time_unit.value
    mu = mass_unit.value
    # Update slider descriptions with explicit units
    for blk, label in zip([A,B,C], ["A","B","C"]):
        blk["S"].description  = f"S [{mu}/{tu}]"
        blk["k"].description  = f"k [1/{tu}]"
        blk["m0"].description = f"m0 [{mu}]"
    tmax.description = f"t_max [{tu}]"

    # Update helper text per species: show meanings + derived m* and times
    def make_help(Sv, kv, m0v):
        # Convert inputs to base to compute invariants, then back to chosen units
        kps   = to_base_rate_per_time(kv, tu)
        Skgps = to_base_S(Sv, mu, tu)
        m0kg  = m0v * MASS_UNITS[mu]
        _, _, tau, mstar = m_solution(np.array([0.0]), Skgps, kps, m0kg)
        mstar_out = from_base_mass(mstar, mu) if np.isfinite(tau) else np.nan
        thalf = (np.log(2)/kps) if (kps > 0) else np.inf
        # Build concise HTML line
        parts = [
            f"<b>S</b> = {fmt(Sv)} {mu}/{tu}",
            f"<b>k</b> = {fmt(kv)} 1/{tu}",
            f"<b>m₀</b> = {fmt(m0v)} {mu}",
        ]
        if np.isfinite(tau):
            parts += [
                f"<b>m*</b> = S/k = {fmt(mstar_out)} {mu}",
                f"<b>τ</b> = 1/k = {fmt(1.0/kv)} {tu}",
                f"<b>t½</b> = ln2/k = {fmt((np.log(2)/kv))} {tu}",
            ]
        else:
            parts += [
                f"<b>m*</b>: undefined (k=0)",
                f"<b>τ</b>: ∞ (k=0)",
                f"<b>t½</b>: ∞ (k=0)"
            ]
        return " | ".join(parts)

    for blk in [A,B,C]:
        blk["help"].value = make_help(blk["S"].value, blk["k"].value, blk["m0"].value)

# Observe changes to refresh labels/help
def wire_help_observers():
    for w in [time_unit, mass_unit,
              A["S"], A["k"], A["m0"],
              B["S"], B["k"], B["m0"],
              C["S"], C["k"], C["m0"]]:
        w.observe(lambda change: update_labels_and_help(), names="value")

wire_help_observers()
update_labels_and_help()

# ---------------- Rendering ----------------
def render(species_count, time_unit, mass_unit, tmax,
           A_name, A_S, A_k, A_m0,
           B_name, B_S, B_k, B_m0,
           C_name, C_S, C_k, C_m0,
           show_tau, show_thalf, show_3tau, show_dm):

    # Colors locked per species for both panels and markers
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    # Time axis (seconds for calc; plotted in chosen unit)
    t_end_s = to_base_time(tmax, time_unit)
    t_s = np.linspace(0.0, t_end_s, 900)
    t_out = t_s / TIME_UNITS[time_unit]

    # Figure layout
    fig = plt.figure(constrained_layout=True, figsize=(9, 6.6 if show_dm else 4.8))
    if show_dm:
        gs = fig.add_gridspec(2, 1, height_ratios=[2.0, 1.2])
        ax_m  = fig.add_subplot(gs[0,0])
        ax_dm = fig.add_subplot(gs[1,0], sharex=ax_m)
    else:
        ax_m = fig.add_subplot(1,1,1); ax_dm = None

    def plot_species(idx, name, S_val, k_val, m0_val):
        col = colors[idx % len(colors)]
        Skgs = to_base_S(S_val, mass_unit, time_unit)
        kps  = to_base_rate_per_time(k_val, time_unit)
        m0kg = m0_val * MASS_UNITS[mass_unit]

        mkg, dmg, tau, mstar = m_solution(t_s, Skgs, kps, m0kg)
        m_out = from_base_mass(mkg, mass_unit)

        # m(t)
        ax_m.plot(t_out, m_out, lw=2, label=name, color=col)

        if np.isfinite(tau):
            # steady-state line
            mstar_out = from_base_mass(mstar, mass_unit)
            ax_m.axhline(mstar_out, ls="--", lw=1.2, color=col, alpha=0.6)
            # τ marker
            if show_tau and tau <= t_s[-1]:
                t_tau = tau / TIME_UNITS[time_unit]
                m_tau = from_base_mass(m_solution(np.array([tau]), Skgs, kps, m0kg)[0][0], mass_unit)
                ax_m.axvline(t_tau, ls=":", lw=1.0, color=col, alpha=0.8)
                ax_m.plot([t_tau], [m_tau], marker="o", color=col)
            # t½ marker
            th_base = (np.log(2)/kps)
            if show_thalf and th_base <= t_s[-1]:
                th = th_base / TIME_UNITS[time_unit]
                target = mstar_out + 0.5*(m0_val - mstar_out)
                ax_m.axvline(th, ls=":", lw=1.0, color=col, alpha=0.8)
                ax_m.axhline(target, ls=":", lw=1.0, color=col, alpha=0.8)
                ax_m.plot([th], [target], marker="s", color=col)
            # 3τ marker
            if show_3tau and 3*tau <= t_s[-1]:
                ax_m.axvline((3*tau)/TIME_UNITS[time_unit], ls=":", lw=1.0, color=col, alpha=0.8)

        # dm/dt panel
        if ax_dm is not None:
            dmg_out = from_base_mass(dmg, mass_unit) / (1.0 / TIME_UNITS[time_unit])  # mass_unit / time_unit
            ax_dm.plot(t_out, dmg_out, lw=1.7, label=name, color=col)

    # Plot requested species with stable color indices
    plot_species(0, A_name, A_S, A_k, A_m0)
    if species_count >= 2:
        plot_species(1, B_name, B_S, B_k, B_m0)
    if species_count >= 3:
        plot_species(2, C_name, C_S, C_k, C_m0)

    # Axes styling
    ax_m.grid(alpha=0.3)
    ax_m.set_xlabel(f"Time [{time_unit}]")
    ax_m.set_ylabel(f"Mass m(t) [{mass_unit}]")
    ax_m.legend()

    if ax_dm is not None:
        ax_dm.grid(alpha=0.3)
        ax_dm.set_xlabel(f"Time [{time_unit}]")
        ax_dm.set_ylabel(f"dm/dt [{mass_unit}/{time_unit}]")
        ax_dm.legend()

    plt.show()

# ---------------- Layout wiring ----------------
row0 = HBox([species_count, time_unit, mass_unit, tmax], layout=Layout(gap="10px"))
rowA = A["row"]; rowB = B["row"]; rowC = C["row"]
opts = HBox([show_tau, show_thalf, show_3tau, show_dm], layout=Layout(gap="14px"))

ui = VBox([row0, rowA, rowB, rowC, opts])

out = interactive_output(
    render,
    {
        "species_count": species_count, "time_unit": time_unit, "mass_unit": mass_unit, "tmax": tmax,
        "A_name": A["name"], "A_S": A["S"], "A_k": A["k"], "A_m0": A["m0"],
        "B_name": B["name"], "B_S": B["S"], "B_k": B["k"], "B_m0": B["m0"],
        "C_name": C["name"], "C_S": C["S"], "C_k": C["k"], "C_m0": C["m0"],
        "show_tau": show_tau, "show_thalf": show_thalf, "show_3tau": show_3tau,
        "show_dm": show_dm,
    }
)

display(ui, out)


VBox(children=(HBox(children=(Dropdown(description='Species:', options=(('One species', 1), ('Two species', 2)…

Output()