# One Box Model : Mass Balance

In [3]:
# @title
# %pip install ipywidgets
#%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import (
    VBox, HBox, Layout, Dropdown, FloatSlider, Checkbox, Text, HTML, interactive_output
)

# ---------------- Units helpers ----------------
TIME_UNITS = {"s":1.0, "min":60.0, "h":3600.0, "day":86400.0}
MASS_UNITS = {"kg":1.0, "g":1e-3, "mg":1e-6, "µg":1e-9}

def to_base_time(val, tu):           return val * TIME_UNITS[tu]               # -> seconds
def to_base_rate_per_time(k, tu):    return k / TIME_UNITS[tu]                 # (1/time_unit) -> 1/s
def to_base_S(S, mu, tu):            return S * MASS_UNITS[mu] / TIME_UNITS[tu]# (mass_unit/time_unit)->kg/s
def from_base_mass(m_kg, mu):        return m_kg / MASS_UNITS[mu]              # kg -> mass_unit

def fmt(x, nd=3):
    if not np.isfinite(x): return "∞"
    if x == 0: return "0"
    m = abs(x)
    return f"{x:.{nd}e}" if (m >= 1e3 or m < 1e-2) else f"{x:.{nd}g}"

# ---------------- Physics ----------------
def m_solution(t_s, S_kg_s, k_s, m0_kg):
    if k_s <= 0.0:
        m = m0_kg + S_kg_s * t_s
        dm = np.full_like(t_s, S_kg_s)
        return m, dm, np.inf, np.nan
    e = np.exp(-k_s * t_s)
    m_star = S_kg_s / k_s
    m = m0_kg * e + m_star * (1.0 - e)
    dm = S_kg_s - k_s * m
    tau = 1.0 / k_s
    return m, dm, tau, m_star

# ---------------- Global controls ----------------
species_count = Dropdown(options=[("One species",1), ("Two species",2), ("Three species",3)],
                         value=1, description="Species:")
time_unit     = Dropdown(options=list(TIME_UNITS.keys()), value="h", description="Time unit:")
mass_unit     = Dropdown(options=list(MASS_UNITS.keys()), value="µg", description="Mass unit:")
tmax          = FloatSlider(value=48.0, min=1.0, max=240.0, step=1.0, description="t_max", readout_format=".0f")

show_tau    = Checkbox(value=True,  description="Show τ")
show_thalf  = Checkbox(value=True,  description="Show t½")
show_3tau   = Checkbox(value=True,  description="Show 3τ")
show_dm     = Checkbox(value=True,  description="Show dm/dt panel")

# ---------------- Species parameter blocks (k OR lifetime τ) ----------------
def species_block(default_name):
    name = Text(value=default_name, description="Name:")
    S    = FloatSlider(value=0.0,  min=0.0,  max=1e4, step=1.0,  description="S")
    mode = Dropdown(options=[("k (1/time)","k"), ("lifetime τ","tau")],
                    value="k", description="Rate input:")
    k    = FloatSlider(value=0.10, min=0.0,  max=2.0,  step=0.01, description="k")
    tau  = FloatSlider(value=0.5, min=0.1, max=1e2, step=0.1, description="τ")
    m0   = FloatSlider(value=0.0,  min=0.0,  max=1e4,  step=1.0,  description="m0")
    help_html = HTML(value="", layout=Layout(width="100%"))

    # toggle visibility between k and τ
    def _toggle(*_):
        k.layout.display   = "flex" if mode.value=="k"   else "none"
        tau.layout.display = "flex" if mode.value=="tau" else "none"
    mode.observe(_toggle, names="value"); _toggle()

    row  = VBox([HBox([name, S, mode, k, tau, m0], layout=Layout(gap="10px")), help_html])
    return {"name":name, "S":S, "mode":mode, "k":k, "tau":tau, "m0":m0, "row":row, "help":help_html}

A = species_block("Species A")
B = species_block("Species B")
C = species_block("Species C")

def get_k_in_user_units(blk):
    """Return k in 1/(time_unit) from either direct k or lifetime τ; allow τ→∞."""
    if blk["mode"].value == "k":
        return blk["k"].value
    # lifetime path
    tau_val = blk["tau"].value
    if tau_val <= 0:
        return np.inf   # instantaneous loss (degenerate); handled in solver via huge k
    return 1.0 / tau_val

# helper for compact text labels
def _label(ax, x, y, txt, color):
    ax.annotate(txt, xy=(x, y), xytext=(4, 4), textcoords="offset points",
                color=color, fontsize=12, fontweight="bold",
                ha="left", va="bottom",
                bbox=dict(boxstyle="round,pad=0.15", fc="none", ec="none"))


# ---------------- Dynamic labels + helper text ----------------
def update_labels_and_help():
    tu = time_unit.value; mu = mass_unit.value
    for blk in (A,B,C):
        blk["S"].description  = f"S [{mu}/{tu}]"
        blk["k"].description  = f"k [1/{tu}]"
        blk["tau"].description= f"τ [{tu}]"
        blk["m0"].description = f"m0 [{mu}]"
    tmax.description = f"t_max [{tu}]"

    def make_help(Sv, mode, kv, tauv, m0v):
        # convert inputs to base units and compute invariants
        # choose k_user (1/tu) from mode, then to 1/s
        k_user = kv if mode=="k" else (0.0 if tauv==np.inf else (1.0/tauv))
        kps   = to_base_rate_per_time(k_user, tu)
        Skgps = to_base_S(Sv, mu, tu)
        m0kg  = m0v * MASS_UNITS[mu]
        _, _, tau, mstar = m_solution(np.array([0.0]), Skgps, kps, m0kg)
        mstar_out = from_base_mass(mstar, mu) if np.isfinite(tau) else np.nan
        thalf = (np.log(2)/k_user) if (k_user > 0) else np.inf

        parts = [
            f"<b>S</b> = {fmt(Sv)} {mu}/{tu}",
            f"<b>mode</b> = {mode}",
            f"<b>k</b> = {fmt(k_user)} 1/{tu}" if mode=="k" else f"<b>τ</b> = {fmt(tauv)} {tu}",
            f"<b>m₀</b> = {fmt(m0v)} {mu}",
        ]
        if np.isfinite(tau):
            parts += [
                f"<b>m*</b> = S/k = {fmt(mstar_out)} {mu}",
                f"<b>τ</b> = 1/k = {fmt(1.0/k_user)} {tu}",
                f"<b>t½</b> = ln2/k = {fmt(thalf)} {tu}",
            ]
        else:
            parts += ["<b>m*</b>: undefined (k=0)", "<b>τ</b>: ∞ (k=0)", "<b>t½</b>: ∞ (k=0)"]
        return " | ".join(parts)

    for blk in (A,B,C):
        blk["help"].value = make_help(
            blk["S"].value, blk["mode"].value, blk["k"].value, blk["tau"].value, blk["m0"].value
        )

def wire_help_observers():
    widgets = [time_unit, mass_unit,
               A["S"], A["mode"], A["k"], A["tau"], A["m0"],
               B["S"], B["mode"], B["k"], B["tau"], B["m0"],
               C["S"], C["mode"], C["k"], C["tau"], C["m0"]]
    for w in widgets:
        w.observe(lambda change: update_labels_and_help(), names="value")

wire_help_observers(); update_labels_and_help()

# ---------------- Rendering ----------------
def render(species_count, time_unit, mass_unit, tmax,
           A_name, A_S, A_mode, A_k, A_tau, A_m0,
           B_name, B_S, B_mode, B_k, B_tau, B_m0,
           C_name, C_S, C_mode, C_k, C_tau, C_m0,
           show_tau, show_thalf, show_3tau, show_dm):

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    t_end_s = to_base_time(tmax, time_unit)
    t_s = np.linspace(0.0, t_end_s, 900)
    t_out = t_s / TIME_UNITS[time_unit]

    fig = plt.figure(constrained_layout=True, figsize=(9, 6 if show_dm else 4.2))
    if show_dm:
        gs = fig.add_gridspec(2, 1, height_ratios=[2.0, 1.2])
        ax_m  = fig.add_subplot(gs[0,0]); ax_dm = fig.add_subplot(gs[1,0], sharex=ax_m)
    else:
        ax_m = fig.add_subplot(1,1,1); ax_dm = None

    def plot_species(idx, name, S_val, mode, k_val, tau_val, m0_val):
        col = colors[idx % len(colors)]
        # derive k in user units then to base 1/s
        k_user = k_val if mode=="k" else (0.0 if tau_val==np.inf else (1.0/tau_val))
        Skgs = to_base_S(S_val, mass_unit, time_unit)
        kps  = to_base_rate_per_time(k_user, time_unit)
        m0kg = m0_val * MASS_UNITS[mass_unit]

        mkg, dmg, tau, mstar = m_solution(t_s, Skgs, kps, m0kg)
        m_out = from_base_mass(mkg, mass_unit)

        ax_m.plot(t_out, m_out, lw=2, label=name, color=col)
        ax_m.axhline(0, ls="--", lw=0.8, color='k', alpha=1)

        if np.isfinite(tau):
            mstar_out = from_base_mass(mstar, mass_unit)
            ax_m.axhline(mstar_out, ls="--", lw=1.2, color=col, alpha=0.6)
            if show_tau and tau <= t_s[-1]:
                t_tau = tau / TIME_UNITS[time_unit]
                m_tau = from_base_mass(m_solution(np.array([tau]), Skgs, kps, m0kg)[0][0], mass_unit)
                ax_m.axvline(t_tau, ls=":", lw=1.0, color=col, alpha=0.8)
                #ax_m.plot([t_tau], [m_tau], marker="o", color=col)
                _label(ax_m, t_tau, m_tau, "τ", col)

            th_base = (np.log(2)/kps) if kps>0 else np.inf
            if show_thalf and np.isfinite(th_base) and th_base <= t_s[-1]:
                th = th_base / TIME_UNITS[time_unit]
                target = mstar_out + 0.5*(m0_val - mstar_out)
                ax_m.axvline(th, ls=":", lw=1.0, color=col, alpha=0.8)
                ax_m.axhline(target, ls=":", lw=1.0, color=col, alpha=0.8)
                #ax_m.plot([th], [target], marker="s", color=col)
                _label(ax_m, th, target, "1/2", col)

            if show_3tau and 3*tau <= t_s[-1]:
                ax_m.axvline((3*tau)/TIME_UNITS[time_unit], ls=":", lw=1.0, color=col, alpha=0.8)

        if ax_dm is not None:
            dmg_out = from_base_mass(dmg, mass_unit) / (1.0 / TIME_UNITS[time_unit])
            ax_dm.plot(t_out, dmg_out, lw=1.7, label=name, color=col)
            ax_dm.axhline(0, ls="--", lw=0.8, color='k', alpha=1)

    plot_species(0, A_name, A_S, A_mode, A_k, A_tau, A_m0)
    if species_count >= 2:
        plot_species(1, B_name, B_S, B_mode, B_k, B_tau, B_m0)
    if species_count >= 3:
        plot_species(2, C_name, C_S, C_mode, C_k, C_tau, C_m0)

    ax_m.grid(alpha=0.3); ax_m.set_xlabel(f"Time [{time_unit}]"); ax_m.set_ylabel(f"Mass m(t) [{mass_unit}]"); ax_m.legend()
    if ax_dm is not None:
        ax_dm.grid(alpha=0.3); ax_dm.set_xlabel(f"Time [{time_unit}]"); ax_dm.set_ylabel(f"dm/dt [{mass_unit}/{time_unit}]"); ax_dm.legend()
    plt.show()

# ---------------- Layout ----------------
row0 = HBox([species_count, time_unit, mass_unit, tmax], layout=Layout(gap="10px"))
rowA = A["row"]; rowB = B["row"]; rowC = C["row"]
opts = HBox([show_tau, show_thalf, show_3tau, show_dm], layout=Layout(gap="14px"))
ui = VBox([row0, rowA, rowB, rowC, opts])

out = interactive_output(
    render,
    {
        "species_count": species_count, "time_unit": time_unit, "mass_unit": mass_unit, "tmax": tmax,
        "A_name": A["name"], "A_S": A["S"], "A_mode": A["mode"], "A_k": A["k"], "A_tau": A["tau"], "A_m0": A["m0"],
        "B_name": B["name"], "B_S": B["S"], "B_mode": B["mode"], "B_k": B["k"], "B_tau": B["tau"], "B_m0": B["m0"],
        "C_name": C["name"], "C_S": C["S"], "C_mode": C["mode"], "C_k": C["k"], "C_tau": C["tau"], "C_m0": C["m0"],
        "show_tau": show_tau, "show_thalf": show_thalf, "show_3tau": show_3tau, "show_dm": show_dm,
    }
)
display(ui, out)


VBox(children=(HBox(children=(Dropdown(description='Species:', options=(('One species', 1), ('Two species', 2)…

Output()

# Urban Column Model (1D)

In [4]:
# @title

# %pip install ipywidgets
#%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import (
    VBox, HBox, Layout, Dropdown, FloatSlider, Checkbox, IntSlider, HTML, interactive_output, BoundedFloatText
)

# ---------- Units ----------
TIME_UNITS  = {"s":1.0, "min":60.0, "h":3600.0, "day":86400.0}
LENGTH_UNITS= {"m":1.0, "km":1000.0}
MASS_UNITS  = {"kg":1.0, "g":1e-3, "mg":1e-6, "µg":1e-9}

def to_base_time(v, tu):      return v * TIME_UNITS[tu]
def to_base_len(v, lu):       return v * LENGTH_UNITS[lu]
def to_base_speed(v, lu, tu): return v * LENGTH_UNITS[lu] / TIME_UNITS[tu]
def to_base_k(v, tu):         return v / TIME_UNITS[tu]              # 1/s
def to_base_E(v, mu, lu, tu): return v * MASS_UNITS[mu] / (LENGTH_UNITS[lu]**2) / TIME_UNITS[tu]  # kg m^-2 s^-1
def from_base_conc(x, mu, lu):return x / MASS_UNITS[mu] * (LENGTH_UNITS[lu]**3)                   # mass/len^3

# ---------- Equations ----------
def X_inside_constE(x_m, x0_m, X0, Ei_kg_m2_s, h_m, k_s, U_m_s):
    """Solution on a single segment with constant E=Ei over [x0, ...]."""
    if k_s <= 0 or U_m_s <= 0:

        return X0 + (Ei_kg_m2_s/h_m)/max(U_m_s,1e-12) * (x_m - x0_m)

    a = np.exp(-k_s*(x_m - x0_m)/U_m_s)
    return X0*a + (Ei_kg_m2_s/(h_m*k_s))*(1.0 - a)

def solve_piecewise_inside(x_m, L_m, breaks_m, E_list_kg_m2_s, h_m, k_s, U_m_s):
    """
    x_m: full x-grid (0..xmax)
    L_m: city length
    breaks_m: sorted internal breakpoints within (0,L)
    E_list: list of Ei for segments (len = n_segments)
    """
    mask_in = (x_m >= 0.0) & (x_m <= L_m)
    xi = x_m[mask_in]
    if xi.size==0:
        return np.zeros_like(x_m), 0.0

    # Build segment edges inside city
    edges = [0.0] + list(breaks_m[(breaks_m>0) & (breaks_m<L_m)]) + [L_m]
    # Remove duplicates and sort
    edges = np.unique(np.clip(edges, 0.0, L_m))
    nseg  = len(edges)-1
    Ei    = E_list_kg_m2_s[:nseg]  # truncate if more provided

    X_in = np.zeros_like(xi)
    X0, x0 = 0.0, edges[0]
    ptr = 0
    for s in range(nseg):
        xL = edges[s+1]
        seg_mask = (xi >= x0) & (xi <= xL)
        X_in[seg_mask] = X_inside_constE(xi[seg_mask], x0, X0, Ei[s], h_m, k_s, U_m_s)
        # carry value to next segment
        X0 = X_inside_constE(np.array([xL]), x0, X0, Ei[s], h_m, k_s, U_m_s)[0]
        x0 = xL
    # Return full-length array and X(L)
    X_full = np.zeros_like(x_m)
    X_full[mask_in] = X_in
    return X_full, X0  # X0 = X(L)

def X_downwind(x_m, L_m, X_L, k_s, U_m_s):
    mask_dn = x_m > L_m
    X = np.zeros_like(x_m)
    if k_s <= 0 or U_m_s <= 0:
        X[mask_dn] = X_L
        return X
    X[mask_dn] = X_L * np.exp(-k_s*(x_m[mask_dn] - L_m)/U_m_s)
    return X

# ---------- Global controls ----------
len_unit   = Dropdown(options=list(LENGTH_UNITS.keys()), value="km", description="Length unit:")
time_unit  = Dropdown(options=list(TIME_UNITS.keys()),   value="h",  description="Time unit:")
mass_unit  = Dropdown(options=list(MASS_UNITS.keys()),   value="µg", description="Mass unit:")

U   = FloatSlider(value=5.0,  min=0.1, max=40.0, step=0.1, description="U")
h   = FloatSlider(value=0.5,  min=0.05,max=2.0,  step=0.05, description="h")
L   = FloatSlider(value=20.0, min=0.5, max=200.0,step=0.5, description="L")
k   = FloatSlider(value=0.10, min=0.0, max=2.0,  step=0.001, description="k")
xmax= FloatSlider(value=100.0,min=10.0,max=6000.0,step=1.0, description="x_max")

# Emissions: base (uniform) for non-zone mode
E_uniform = FloatSlider(value=50.0, min=0.0, max=5e3, step=10.0, description="E (uniform)")

# Zones option (5)
use_zones = Checkbox(value=False, description="Use emission zones")
n_zones   = IntSlider(value=1, min=1, max=5, step=1, description="Em. zones #")

# Zone breakpoints as fractions of L (monotone enforced)
bp1 = BoundedFloatText(value=0.3, min=0.0, max=1.0, step=0.01, description="b1/L")
bp2 = BoundedFloatText(value=0.6, min=0.0, max=1.0, step=0.01, description="b2/L")
bp3 = BoundedFloatText(value=0.8, min=0.0, max=1.0, step=0.01, description="b3/L")
bp4 = BoundedFloatText(value=0.9, min=0.0, max=1.0, step=0.01, description="b4/L")
bp_widgets = [bp1,bp2,
              bp3,bp4
              ]

# Per-zone Emissions sliders
Ez1 = FloatSlider(value=80.0, min=0.0, max=5e3, step=10.0, description="E1")
Ez2 = FloatSlider(value=50.0, min=0.0, max=5e3, step=10.0, description="E2")
Ez3 = FloatSlider(value=20.0, min=0.0, max=5e3, step=10.0, description="E3")
Ez4 = FloatSlider(value=50.0, min=0.0, max=5e3, step=10.0, description="E4")
Ez5 = FloatSlider(value=10.0, min=0.0, max=5e3, step=10.0, description="E5")
Ez_widgets = [Ez1, Ez2, Ez3,
              Ez4, Ez5
              ]

show_L     = Checkbox(value=True, description="Show x = L")
show_xe    = Checkbox(value=True, description="Show x = U/k")
shade_city = Checkbox(value=True, description="Shade [0,L]")

helper = HTML(value="", layout=Layout(width="100%"))

def refresh_labels(*_):
    U.description    = f"U [{len_unit.value}/{time_unit.value}]"
    h.description    = f"h [{len_unit.value}]"
    L.description    = f"L [{len_unit.value}]"
    k.description    = f"k [1/{time_unit.value}]"
    xmax.description = f"x_max [{len_unit.value}]"
    E_uniform.description = f"E (uniform) [{mass_unit.value}/{len_unit.value}²/{time_unit.value}]"
    for i, Ez in enumerate(Ez_widgets, 1):
        Ez.description = f"E{i} [{mass_unit.value}/{len_unit.value}²/{time_unit.value}]"
refresh_labels()
for w in [len_unit, time_unit, mass_unit]:
    w.observe(refresh_labels, names="value")

def render(len_unit, time_unit, mass_unit, U, h, L, k, xmax,
           use_zones, n_zones, b1, b2,
           b3, b4,
        E_uniform, E1,E2,E3,
        E4,E5,
           show_L, show_xe, shade_city):
    col = plt.rcParams['axes.prop_cycle'].by_key()['color'][0]

    # Convert inputs
    U_ms   = to_base_speed(U, len_unit, time_unit)
    h_m    = to_base_len(h, len_unit)
    L_m    = to_base_len(L, len_unit)
    k_s    = to_base_k(k, time_unit)
    xmax_m = to_base_len(xmax, len_unit)

    # x-grid
    x_m = np.linspace(0.0, xmax_m, 1400)
    x_out = x_m / LENGTH_UNITS[len_unit]

    # Build piecewise E(x)
    if use_zones:
        # Sorted, unique internal breaks within (0, L)
        fracs = np.array([b1,b2,b3,b4])[:max(n_zones-1,0)]
        fracs = np.clip(fracs, 0.0, 1.0)
        fracs = np.unique(np.sort(fracs))
        breaks_m = fracs * L_m

        Ez_vals = [E1,E2,E3,E4,E5][:n_zones]
        E_list_kg_m2_s = [to_base_E(v, mass_unit, len_unit, time_unit) for v in Ez_vals]

        X_in, X_L = solve_piecewise_inside(x_m, L_m, breaks_m, E_list_kg_m2_s, h_m, k_s, U_ms)
    else:
        # Single uniform E
        E0_kg_m2_s = to_base_E(E_uniform, mass_unit, len_unit, time_unit)
        X_in, X_L = solve_piecewise_inside(x_m, L_m, np.array([]), [E0_kg_m2_s], h_m, k_s, U_ms)

    X = X_in + X_downwind(x_m, L_m, X_L, k_s, U_ms)

    # U dX/dx directly from ODE
    UXp = np.zeros_like(x_m)
    inside = x_m <= L_m
    if use_zones:
        # Need E(x) on grid inside city
        # Construct E(x) array
        edges = np.concatenate(([0.0], np.unique(np.clip(np.array([b1,b2,b3,b4])[:max(n_zones-1,0)]*L_m, 0.0, L_m)), [L_m]))
        Ei_vals = [to_base_E(v, mass_unit, len_unit, time_unit) for v in [E1,E2,E3,E4,E5][:len(edges)-1]]
        Ex = np.zeros_like(x_m)
        for s in range(len(edges)-1):
            mask = (x_m >= edges[s]) & (x_m <= edges[s+1])
            Ex[mask] = Ei_vals[s]
        UXp[inside] = (Ex[inside]/h_m) - k_s*X[inside]
    else:
        E0_kg_m2_s = to_base_E(E_uniform, mass_unit, len_unit, time_unit)
        UXp[inside] = (E0_kg_m2_s/h_m) - k_s*X[inside]
    UXp[~inside] = -k_s*X[~inside]

    # Convert outputs
    X_out   = from_base_conc(X,   mass_unit, len_unit)
    UXp_out = from_base_conc(UXp, mass_unit, len_unit) * TIME_UNITS[time_unit]  # per chosen time

    # Characteristic e-fold distance
    xe_out = np.inf if (k_s<=0 or U_ms<=0) else (U_ms/k_s) / LENGTH_UNITS[len_unit]

    # Plot
    fig = plt.figure(constrained_layout=True, figsize=(9.5,4))
    gs = fig.add_gridspec(2,1, height_ratios=[2.0,1.2])
    axX  = fig.add_subplot(gs[0,0]); axUX = fig.add_subplot(gs[1,0], sharex=axX)

    axX.plot(x_out, X_out, lw=2, color=col, label="[X](x)")
    axX.axhline(0, color='k', lw=0.8, ls='--')

    if shade_city:
        axX.axvspan(0, L, color=col, alpha=0.08)
    if show_L:
        axX.axvline(L, ls="--", lw=1.2, color=col)
        axX.plot([L], [from_base_conc(X_L, mass_unit, len_unit)], marker="o", color=col)
    if show_xe and np.isfinite(xe_out) and xe_out <= xmax:
        axX.axvline(xe_out, ls=":", lw=1.0, color=col)
        axX.text(xe_out, axX.get_ylim()[1]*0.9, "x = U/k", color=col, va="top")

    # Visualize zone breaks if enabled
    if use_zones and n_zones>1:
        fr = np.clip(np.array([b1,b2,b3,b4])[:n_zones-1], 0.0, 1.0)
        for f in np.unique(np.sort(fr)):
            axX.axvline(f*L, ls="-.", lw=0.8, color=col, alpha=0.6)

    axX.set_ylabel(f"[X] [{mass_unit}/{len_unit}³]")
    axX.set_title("Urban column with piecewise-constant emissions inside the city")
    axX.grid(alpha=0.3); axX.legend()

    axUX.plot(x_out, UXp_out, lw=1.7, color=col, label="U d[X]/dx = E(x)/h - k[X]")
    axUX.axhline(0, color='k', lw=0.8, ls='--')
    if shade_city:
        axUX.axvspan(0, L, color=col, alpha=0.08)
    if show_L:
        axUX.axvline(L, ls="--", lw=1.0, color=col)
    if show_xe and np.isfinite(xe_out) and xe_out <= xmax:
        axUX.axvline(xe_out, ls=":", lw=1.0, color=col)
    if use_zones and n_zones>1:
        fr = np.clip(np.array([b1,b2,b3,b4])[:n_zones-1], 0.0, 1.0)
        for f in np.unique(np.sort(fr)):
            axUX.axvline(f*L, ls="-.", lw=0.8, color=col, alpha=0.6)

    axUX.set_xlabel(f"x [{len_unit}]")
    axUX.set_ylabel(f"U d[X]/dx [{mass_unit}/{len_unit}³ per {time_unit}]")
    axUX.grid(alpha=0.3); axUX.legend()


    helper.value = (
        f"<b>Derived:</b> e-fold distance U/k = {xe_out:.3g} {len_unit}; "
        f"zone count = {n_zones if use_zones else 1}"
    )

# ---- widget layout ----
ui_units = HBox([len_unit, time_unit, mass_unit], layout=Layout(gap="10px"))
ui_phys1 = HBox([U, h, L], layout=Layout(gap="10px"))
ui_phys2 = HBox([k, xmax], layout=Layout(gap="10px"))

ui_zctl  = HBox([use_zones, n_zones], layout=Layout(gap="20px"))
ui_breaks= HBox(bp_widgets, layout=Layout(gap="10px"))
ui_Es    = HBox(Ez_widgets, layout=Layout(gap="10px"))

def zones_visibility(*_):
    ui_breaks.layout.display = "flex" if use_zones.value and n_zones.value>1 else "none"
    ui_Es.layout.display     = "flex" if use_zones.value else "none"
    E_uniform.layout.display = "none" if use_zones.value else "flex"
zones_visibility()
use_zones.observe(zones_visibility, names="value")
n_zones.observe(zones_visibility, names="value")

ui = VBox([
    ui_units,
    ui_phys1, ui_phys2,
    E_uniform,
    ui_zctl, ui_breaks, ui_Es,
    HBox([show_L, show_xe, shade_city], layout=Layout(gap="14px")),
    helper
])

out = interactive_output(
    render,
    {
        "len_unit": len_unit, "time_unit": time_unit, "mass_unit": mass_unit,
        "U": U, "h": h, "L": L, "k": k, "xmax": xmax,
        "use_zones": use_zones, "n_zones": n_zones,
        "b1": bp1, "b2": bp2,
        "b3": bp3, "b4": bp4,
        "E_uniform": E_uniform,
        "E1": Ez1, "E2": Ez2, "E3": Ez3,
        "E4": Ez4, "E5": Ez5,
        "show_L": show_L, "show_xe": show_xe, "shade_city": shade_city
    }
)

display(ui, out)


VBox(children=(HBox(children=(Dropdown(description='Length unit:', index=1, options=('m', 'km'), value='km'), …

Output()