<a href="https://colab.research.google.com/github/lhchem/Tools/blob/master/Diffusion_reaction_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install numpy matplotlib

In [None]:
#!/usr/bin/env python3
"""
Author: HM Lai
"""

# ────────── 0. imports ─────────────────────────────────────────
import json, pathlib, shutil, time
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt
from numba import njit
from tqdm import tqdm

# ────────── 1. default parameters ─────────────────────────────
## Parameters of the simulations, user-defined
DEFAULTS = dict(
    # geometry
    R_t=5e-3, x_ratio=50, length=5.0,

    # solvent (stage-1 / stage-2)
    eta_stage1=1.0e-3,  T_stage1=298.15,
    eta_stage2=3.0e-3,  T_stage2=298.15,

    # bath stirring (simulated by vastly increased diffusion)
    stirred_bath=True,  stir_factor=1000.0,

    # molecular weights (in Daltons or g mol^-1)
    MW_B=181, MW_C=1_673, MW_Ab=150_000, MW_sAb=50_000,

    # hindrance (discount factor of diffusivity when the species enters the arena)
    h_B=0.9, h_C=0.8, h_Ab=0.5, h_sAb=0.75,

    # initial concentrations (all in mM)
    B0_tissue=250.0, B0_bath=250.0,
    Ab0_bath=0.0005,
    sAb0_bath=0.0015,

    # immobilised affinity target concentration (uniformly distributed)
    Ag_in_mM=0.005, #(mM)

    # stage-2 bath composition (all in mM)
    C0_bath=250.0,  B_bath_stage2=0.0, Ab_bath_stage2=0.0,

    # kinetics (primary affinity reagent binding)
    k_on0_M=1e8, #(M^-1 s^-1)
    k_off0=1e-7, #(s^-1)

    # kinetics (secondary reporting agent binding)
    k_on0_s_M=1e6, #(M^-1 s^-1)
    k_off0_s=1e-5, #(s^-1)

    # protein binding modulation curve (shared between Ab-Ag and sAb-Ab binding)
    k_amp=1.0, # constant for increased protein dissociation
    tau_on=100.0, # exponential constant for decreased protein association
    tau_off=2.0, # exponential constant for increased protein dissociation

    # B + C ⇌ BC, host-guest reaction as the switch
    k_on_BC_M=30000, #(M^-1 s^-1)
    k_off_BC=1e-2, #(s^-1)

    n_sAb=3, #stoichiometry for secondary-to-primary binding

    # timing (h)
    t_stage1_h=72, #simulated incubation time in 1st stage (hours)
    t_stage2_h=48, #simulated incubation time in 2nd stage (hours)

    # numerics (suggest not to change for stability)
    Nr=401,
    save_each=200,
    tag=None,
)

# ────────── 2. scenarios ──────────────────────────────────────
SIMS = [dict(tag="demo"),
        dict(tag="no_reaction_suppression",eta_stage2=1.0e-3,B0_tissue=0,B0_bath=0,C0_bath=0,t_stage1_h=120),
        #dict(tag="input_your_scenario_here",parameter-to-change-from-default=20),
       ]

# ────────── 3. utilities ──────────────────────────────────────
k_B = 1.380649e-23  # Boltzmann constant, J K⁻¹


def MW_to_D(MW, T, eta):
    Rh_nm = 0.066 * MW ** (1/3)
    return k_B * T / (6.0 * np.pi * eta * Rh_nm * 1e-9)


def make_diffuser(D_arr, dr, dt):
    @njit(cache=True, fastmath=True)
    def step(C):
        dF = (np.roll(C, -1) - C) / dr
        dB = (C - np.roll(C, 1)) / dr
        Df = 0.5 * (D_arr + np.roll(D_arr, -1))
        Db = 0.5 * (D_arr + np.roll(D_arr, 1))
        Cn = C + dt * (Df * dF - Db * dB) / dr
        Cn[0], Cn[-1] = Cn[1], Cn[-2]  # no-flux BC
        return Cn
    return step


def make_reactor(n_sAb):
    @njit(cache=True, fastmath=True)
    def react(idx, B, C, BC, Ag, sAb,
              Ab_st, AgAb_st,
              kon_ag, koff_ag,
              kon_s,  koff_s,
              kon_BC, koff_BC, dt):

        nmax = n_sAb

        for i in idx:
            # R1: Ag + Ab_k ⇌ AgAb_k
            for k in range(nmax + 1):
                d = dt * (kon_ag[i] * Ab_st[k, i] * Ag[i]
                          - koff_ag[i] * AgAb_st[k, i])
                Ab_st[k, i]  -= d
                AgAb_st[k, i] += d
                Ag[i]        -= d

            # R2/R3: secondary attachment
            for k in range(nmax):
                d1 = dt * (kon_s[i] * Ab_st[k, i] * sAb[i]
                           - koff_s[i] * Ab_st[k + 1, i])
                Ab_st[k, i]     -= d1
                Ab_st[k + 1, i] += d1
                sAb[i]          -= d1

                d2 = dt * (kon_s[i] * AgAb_st[k, i] * sAb[i]
                           - koff_s[i] * AgAb_st[k + 1, i])
                AgAb_st[k, i]     -= d2
                AgAb_st[k + 1, i] += d2
                sAb[i]            -= d2

            # R4: host-guest reaction
            dBC = dt * (kon_BC * B[i] * C[i] - koff_BC * BC[i])
            B[i]  -= dBC
            C[i]  -= dBC
            BC[i] += dBC

            # positivity clamp
            if Ag[i]  < 0: Ag[i]  = 0.
            if sAb[i] < 0: sAb[i] = 0.
            if B[i]   < 0: B[i]   = 0.
            if C[i]   < 0: C[i]   = 0.
            if BC[i]  < 0: BC[i]  = 0.
            for k in range(nmax + 1):
                if Ab_st[k, i]   < 0: Ab_st[k, i]   = 0.
                if AgAb_st[k, i] < 0: AgAb_st[k, i] = 0.

    return react


# ────────── 4. simulation core ────────────────────────────────
def run_one(P, out_dir: pathlib.Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    nSto = P["n_sAb"]
    reactor = make_reactor(nSto)

    # timeline
    t1    = P["t_stage1_h"] * 3600.0
    t_end = t1 + P["t_stage2_h"] * 3600.0

    # radial grid
    R_t = P["R_t"]
    R_b = R_t * np.sqrt(1.0 + P["x_ratio"])
    r   = np.linspace(0.0, R_b, P["Nr"])
    dr  = r[1] - r[0]

    m_t = r <= R_t
    m_b = ~m_t
    idx_t = np.where(m_t)[0]
    bath_mult = P["stir_factor"] if P["stirred_bath"] else 1.0

    # diffusion ------------------------------------------------
    def build_D(T, eta):
        DB = MW_to_D(P["MW_B"], T, eta)
        DC = MW_to_D(P["MW_C"], T, eta)
        Ds = MW_to_D(P["MW_sAb"], T, eta)

        D_B  = np.where(m_t, P["h_B"]  * DB, DB * bath_mult)
        D_C  = np.where(m_t, P["h_C"]  * DC, DC * bath_mult)
        D_BC = D_C
        D_s  = np.where(m_t, P["h_sAb"] * Ds, Ds * bath_mult)

        D_Ab = np.empty((nSto + 1, r.size))
        for k in range(nSto + 1):
            Dk = MW_to_D(P["MW_Ab"] + k * P["MW_sAb"], T, eta)
            D_Ab[k] = np.where(m_t, P["h_Ab"] * Dk, Dk * bath_mult)
        return D_B, D_C, D_BC, D_s, D_Ab

    D_B, D_C, D_BC, D_s, D_Ab = build_D(P["T_stage1"], P["eta_stage1"])

    dt = 0.4 * dr**2 / (2.0 * max(D_B.max(), D_C.max(),
                                  D_s.max(), D_Ab.max()))
    Nt = int(np.ceil(t_end / dt))

    diff_B  = make_diffuser(D_B,  dr, dt)
    diff_C  = make_diffuser(D_C,  dr, dt)
    diff_BC = make_diffuser(D_BC, dr, dt)
    diff_s  = make_diffuser(D_s,  dr, dt)
    diff_Ab = [make_diffuser(D_Ab[k], dr, dt) for k in range(nSto + 1)]

    # ───── initial fields ─────────────────────────────────────
    B   = np.where(m_t, P["B0_tissue"], P["B0_bath"])
    C   = np.zeros_like(r)
    BC  = np.zeros_like(r)
    sAb = np.where(m_t, 0.0, P["sAb0_bath"])

    # uniform immobilised antigen ------------------------------
    Ag = np.where(m_t, P["Ag_in_mM"], 0.0)

    # antibodies
    Ab_st   = np.zeros((nSto + 1, r.size))
    AgAb_st = np.zeros_like(Ab_st)
    Ab_st[0, m_b] = P["Ab0_bath"]

    # storage --------------------------------------------------
    frames = dict(time=[], B=[], C=[], BC=[],
                  sAb_free=[], sAb_bound=[],
                  Ag=[], Ab=[], AgAb=[])

    def snap(t_now):
        kvec = np.arange(1, nSto + 1).reshape(nSto, 1)
        sAb_bound = (kvec * Ab_st[1:]).sum(axis=0) + \
                    (kvec * AgAb_st[1:]).sum(axis=0)
        frames["time"].append(t_now)
        frames["B"].append(B.copy())
        frames["C"].append(C.copy())
        frames["BC"].append(BC.copy())
        frames["sAb_free"].append(sAb.copy())
        frames["sAb_bound"].append(sAb_bound.copy())
        frames["Ag"].append(Ag.copy())
        frames["Ab"].append(Ab_st.copy())
        frames["AgAb"].append(AgAb_st.copy())

    snap(0.0)

    # kinetics -------------------------------------------------
    kon0_ag = P["k_on0_M"]   / 1e3
    kon0_s  = P["k_on0_s_M"] / 1e3
    kon_BC  = P["k_on_BC_M"] / 1e3

    koff0_ag = P["k_off0"]
    koff0_s  = P["k_off0_s"]
    koff_BC  = P["k_off_BC"]

    switched = False
    nan_flag = False

    # ───── time loop ─────────────────────────────────────────
    for n in tqdm(range(1, Nt + 1), desc=P["tag"], unit="step",
                  miniters=max(Nt // 100, 1)):
        t = n * dt

        # stage change ----------------------------------------
        if (not switched) and (t >= t1):
            D_B, D_C, D_BC, D_s, D_Ab = build_D(P["T_stage2"],
                                                P["eta_stage2"])
            diff_B  = make_diffuser(D_B,  dr, dt)
            diff_C  = make_diffuser(D_C,  dr, dt)
            diff_BC = make_diffuser(D_BC, dr, dt)
            diff_s  = make_diffuser(D_s,  dr, dt)
            diff_Ab = [make_diffuser(D_Ab[k], dr, dt)
                       for k in range(nSto + 1)]

            B[m_b]   = P["B_bath_stage2"]
            C[m_b]   = P["C0_bath"]
            sAb[m_b] = 0.0
            Ab_st[:, m_b] = P["Ab_bath_stage2"] / (nSto + 1)
            switched = True

        # diffusion -------------------------------------------
        Bn   = diff_B(B)
        Cn   = diff_C(C)
        BCn  = diff_BC(BC)
        sAbn = diff_s(sAb)
        Agn  = Ag.copy()              # immobile
        AgAbn = AgAb_st.copy()
        Abn  = np.empty_like(Ab_st)
        for k in range(nSto + 1):
            Abn[k] = diff_Ab[k](Ab_st[k])

        # quick NaN diagnostics -------------------------------
        if (not nan_flag) and (
            np.isnan(Bn).any() or np.isnan(Cn).any() or
            np.isnan(sAbn).any() or np.isnan(Abn).any()):
            for name, arr in (("B", Bn), ("C", Cn), ("sAb", sAbn)):
                if np.isnan(arr).any():
                    i_bad = np.where(np.isnan(arr))[0][0]
                    print(f"⚠ NaN in {name} at step {n}, "
                          f"t={t/3600:.2f} h, r={r[i_bad]*1e3:.3f} mm")
                    break
            nan_flag = True

        # reaction rates --------------------------------------
        mod_on  = np.exp(-P["tau_on"]  * Bn)
        mod_off = 1.0 - np.exp(-P["tau_off"] * Bn)

        kon_ag  = kon0_ag * mod_on
        kon_s   = kon0_s  * mod_on

        koff_ag = koff0_ag + P["k_amp"] * mod_off
        koff_s  = koff0_s  + P["k_amp"] * mod_off

        # sub-cycling Δt --------------------------------------
        max_koff = max(np.max(koff_ag[idx_t]), np.max(koff_s[idx_t]), koff_BC)
        dt_r = min(0.1 / max_koff, dt)
        n_sub = max(1, int(np.ceil(dt / dt_r)))
        dt_r  = dt / n_sub

        # reactions -------------------------------------------
        for _ in range(n_sub):
            reactor(idx_t, Bn, Cn, BCn, Agn, sAbn,
                    Abn, AgAbn,
                    kon_ag, koff_ag,
                    kon_s,  koff_s,
                    kon_BC, koff_BC, dt_r)

        # commit ----------------------------------------------
        B, C, BC, sAb = Bn, Cn, BCn, sAbn
        Ab_st, AgAb_st, Ag = Abn, AgAbn, Agn

        if n % P["save_each"] == 0:
            snap(t)

    snap(t_end)

    # ───── save ----------------------------------------------
    for k in frames:
        frames[k] = np.asarray(frames[k])
    np.savez_compressed(out_dir / "frames.npz", **frames, r=r)
    with open(out_dir / "meta.json", "w") as fh:
        json.dump(P, fh, indent=2)

    # quick-look plots ---------------------------------------
    t_arr = frames["time"] / 3600.0
    Ab_tot   = frames["Ab"].sum(axis=1)
    AgAb_tot = frames["AgAb"].sum(axis=1)

    to_plot = [("B",          frames["B"],         "viridis"),
               ("C",          frames["C"],         "cividis"),
               ("BC",         frames["BC"],        "YlOrRd"),
               ("Ab",         Ab_tot,              "plasma"),
               ("Ag",         frames["Ag"],        "YlGn"),
               ("AgAb",       AgAb_tot,            "magma"),
               ("sAb-free",   frames["sAb_free"],  "Blues"),
               ("sAb-bound",  frames["sAb_bound"], "PuRd")]

    ncol = 4
    nrow = int(np.ceil(len(to_plot) / ncol))
    fig, axs = plt.subplots(nrow, ncol,
                            figsize=(4*ncol, 3.5*nrow),
                            sharex=True, sharey=True)
    axs = axs.ravel()
    for ax, (lbl, data, cmap) in zip(axs, to_plot):
        im = ax.imshow(data,
                       extent=(0, R_b*1e3, t_arr[-1], 0),
                       aspect="auto", cmap=cmap)
        ax.axvline(R_t*1e3, ls="--", color="w")
        ax.set_title(lbl + " (mM)")
        ax.set_xlabel("r (mm)")
        ax.set_ylabel("time (h)")
        fig.colorbar(im, ax=ax, shrink=0.8)
    for ax in axs[len(to_plot):]:
        ax.axis("off")
    fig.suptitle(P["tag"])
    fig.tight_layout()
    fig.savefig(out_dir / "profiles.png", dpi=150)
    plt.close(fig)


# ────────── 5. batch driver ────────────────────────────────────
def main():
    root = pathlib.Path("results")
    root.mkdir(exist_ok=True)
    for i, cfg in enumerate(SIMS, 1):
        P = deepcopy(DEFAULTS)
        P.update(cfg)
        if P["tag"] is None:
            P["tag"] = f"sim-{i:03d}"
        out = root / P["tag"]
        if out.exists():
            shutil.rmtree(out)
        print(f"\n=== running {P['tag']} ===")
        t0 = time.time()
        run_one(P, out)
        print(f"finished in {time.time()-t0:.1f}s → {out}")
    print("\nAll simulations complete ✔")


if __name__ == "__main__":
    main()