In [61]:
import math
import jax
from jax import jit, vmap, value_and_grad,lax
import jax.numpy as jnp
import optax
import numpy as np
import jaxley as jx
import jaxley.optimize.transforms as jt
import jax.scipy as jsp
import matplotlib.pyplot as plt
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

In [62]:
SAFE_Z   = 1e-4
EXP_CLIP = 60.0

def safe_exp(x): return jnp.exp(jnp.clip(x, -EXP_CLIP, EXP_CLIP))

def vtrap(dx, k):
    z = dx / k
    taylor = 1.0 - z/2.0 + (z*z)/12.0 - (z**4)/720.0
    return jnp.where(jnp.abs(z) < SAFE_Z, k * taylor, dx / jnp.expm1(z))

def alpha_m(v): return 0.1  * vtrap(25.0 - v, 10.0)
def beta_m(v):  return 4.0  * safe_exp(-v / 18.0)
def alpha_h(v): return 0.07 * safe_exp(-v / 20.0)
def beta_h(v):  return 1.0  / (safe_exp((30.0 - v)/10.0) + 1.0)
def alpha_n(v): return 0.01 * vtrap(10.0 - v, 10.0)
def beta_n(v):  return 0.125 * safe_exp(-v / 80.0)

def x_inf(alpha_fn, beta_fn, v):
    a, b = alpha_fn(v), beta_fn(v)
    return a / (a + b + 1e-12)

def gate_update(x, a, b, dt):
    inv_tau = a + b
    tau     = 1.0 / jnp.maximum(inv_tau, 1e-12)
    x_inf   = a * tau
    return jnp.clip(x_inf - (x_inf - x) * jnp.exp(-dt / tau), 0.0, 1.0)

# Cable coupling (axial G matrix)
def build_tridiagonal_gax(n_comp, g_link):
    """Line cable; uniform axial conductance g_link (S/cm^2 equiv)."""
    G = jnp.zeros((n_comp, n_comp))
    idx = jnp.arange(n_comp - 1)
    G = G.at[idx, idx+1].set(-g_link)
    G = G.at[idx+1, idx].set(-g_link)
    G = G.at[jnp.arange(n_comp), jnp.arange(n_comp)].set(-jnp.sum(G, axis=1))
    return G  # symmetric PSD

# One cable step (implicit V, RL gates)
@jit
def hh_step_cable(state, I_inj_vec, params, dt, Gax):
    """
    state: (v, m, h, n), each (N,)
    I_inj_vec: (N,)
    params: dict with per-comp arrays (N,) or scalars broadcastable
    Gax: (N,N) axial conductance matrix
    """
    v, m, h, n = state

    # rates at V_t
    am, bm = alpha_m(v), beta_m(v)
    ah, bh = alpha_h(v), beta_h(v)
    an, bn = alpha_n(v), beta_n(v)

    # Rush–Larsen gates → t+1
    m_new = gate_update(m, am, bm, dt)
    h_new = gate_update(h, ah, bh, dt)
    n_new = gate_update(n, an, bn, dt)

    gNa, gK, gL = params["HH_gNa"], params["HH_gK"], params["HH_gL"]
    ENa, EK, EL = params["HH_ENa"], params["HH_EK"], params["HH_EL"]
    Cm          = params["C_m"]

    # ionic conductances at t+1 gates (canonical m^3 h)
    gNa_eff = gNa * (m_new**3) * h_new
    gK_eff  = gK  * (n_new**4)
    gion    = gNa_eff + gK_eff + gL            # (N,)
    Irev    = gNa_eff*ENa + gK_eff*EK + gL*EL  # (N,)

    # implicit Euler: (C/dt I + diag(gion) + Gax) v_{t+1} = (C/dt) v_t + Irev + I_inj
    A = jnp.diag(Cm / dt + gion) + Gax
    b = (Cm/dt) * v + Irev + I_inj_vec
    v_new = jsp.linalg.solve(A, b)


    return (v_new, m_new, h_new, n_new), v_new

def integrate_cable(state0, currents, params, dt, Gax):
    """currents: (T,N)"""
    def step(carry, I_t):
        return hh_step_cable(carry, I_t, params, dt, Gax)
    final_state, v_trace = lax.scan(step, state0, currents)
    return final_state, v_trace  # (T,N)

In [63]:
class MultiCompHH:
    def __init__(self, n_comp=5, default_params=None, g_link=0.2):
        self.n = n_comp
        base = {
            "HH_gNa": 120.0*jnp.ones(n_comp),
            "HH_gK":   36.0*jnp.ones(n_comp),
            "HH_gL":    0.3*jnp.ones(n_comp),
            "HH_ENa":  50.0*jnp.ones(n_comp),
            "HH_EK":  -77.0*jnp.ones(n_comp),
            "HH_EL":  -54.387*jnp.ones(n_comp),
            "C_m":      1.0*jnp.ones(n_comp),
        }
        if default_params is not None:
            base.update(default_params)
        self.params = base
        self.g_link = g_link
        self.Gax    = build_tridiagonal_gax(n_comp, g_link)

    def simulate(self, params_physical, currents, dt=0.025, t_equil_ms=25.0):
        """
        currents: (T,N), inject at soma via currents[:,0], others 0
        returns (T,N)
        """
        v0 = params_physical["HH_EL"]  # (N,)
        m0 = x_inf(alpha_m, beta_m, v0)
        h0 = x_inf(alpha_h, beta_h, v0)
        n0 = x_inf(alpha_n, beta_n, v0)
        state0 = (v0, m0, h0, n0)

        _, V = integrate_cable(state0, currents, params_physical, dt, self.Gax)
        return V


In [64]:

N      = 5
dt     = 0.025   # ms
t_max  = 50.0    # ms
T      = int(np.ceil(t_max / dt))

# Model instance (match your old value)
model  = MultiCompHH(n_comp=N, g_link=0.2)
model.n_comp = model.n   # alias so existing code works


# Choose params to visualize (true or perturbed)
phys_true    = model.params
phys_learner = {k: v.copy() for k, v in phys_true.items()}
phys_learner["HH_gNa"] *= 0.8
phys_learner["HH_gK"]  *= 1.1
phys_learner["HH_EL"]  += 2.0

phys_params_for_viz = phys_learner   # or phys_true


In [65]:
def make_step_injection(T, N, dt, inj_idx, delay=5.0, dur=20.0, amp=8.0):
    t = jnp.arange(T) * dt
    mask = ((t >= delay) & (t < delay + dur)).astype(jnp.float64)  # (T,)
    I = jnp.zeros((T, N), dtype=jnp.float64)
    return I.at[:, inj_idx].set(amp * mask)  # (T, N)

# step parameters (match your previous runs)
delay, dur, amp = 5.0, 20.0, 8.0

# build one current per injection site i=0..4
inj_currents = [make_step_injection(T, N, dt, i, delay, dur, amp) for i in range(N)]


In [66]:
def simulate_all_injections(model, phys_params, inj_currents, dt):
    volts_by_inj = []
    for I_i in inj_currents:
        V = model.simulate(phys_params, I_i, dt=dt)  # (T, C) JAX array
        volts_by_inj.append(np.asarray(V))
    return volts_by_inj

In [67]:
volts_by_inj = simulate_all_injections(model, phys_params_for_viz, inj_currents, dt)
print("[info] volts_by_inj:", len(volts_by_inj), volts_by_inj[0].shape)

[info] volts_by_inj: 5 (2000, 5)


In [68]:
# 1) Build jitted simulators that CAPTURE model & dt (never pass the model into jit)
if not hasattr(model, "n_comp"):
    model.n_comp = getattr(model, "n", len(model.params["HH_EL"]))

def make_simulators(model, dt):
    import jax, jax.numpy as jnp
    @jax.jit
    def sim_full(p, I):      # (T,C)
        return model.simulate(p, I, dt=dt)
    @jax.jit
    def sim_soma(p, I):      # (T,)
        return model.simulate(p, I, dt=dt)[:, 0]
    return sim_full, sim_soma

sim_full, sim_soma = make_simulators(model, dt)

# 2) All-columns Jacobian using JVP + vmap (no linearize)
import jax, jax.numpy as jnp
import numpy as np

def dVfull_dgvec_jvp(phys_params, key, I):
    """
    Return J(t,c,j) = ∂V_c(t)/∂(key[j]) for all j, using a batched JVP.
    Shapes:
      phys_params: pytree of JAX arrays (same as model.params)
      I: (T,C) JAX array
    Output: (T, C, Np)
    """
    f = lambda p: sim_full(p, I)                     # close over sim_full & I
    zero_tang = jax.tree_map(jnp.zeros_like, phys_params)

    def push(e):                                     # e: (Np,)
        # create tangent tree with basis 'e' in the requested key
        tang = {**zero_tang, key: e}
        _, ydot = jax.jvp(f, (phys_params,), (tang,))  # (T,C)
        return ydot

    Np = phys_params[key].shape[0]
    E  = jnp.eye(Np)                                 # (Np,Np)
    Ycols = jax.vmap(push)(E)                        # (Np, T, C)
    return np.asarray(jnp.transpose(Ycols, (1, 2, 0)))  # → (T,C,Np)

def _model_n(model):
    return getattr(model, "n_comp", getattr(model, "n", len(model.params["HH_EL"])))

def compute_Jdict_for_injection(phys_params, inj_I):
    """
    For ONE injection current inj_I: return {param_key: J(t,c,j)} for all
    1D per-compartment parameters of length n.
    """
    n = _model_n(model)
    I = jnp.asarray(inj_I)                           # ensure JAX array
    Jdict = {}
    for key, val in phys_params.items():
        if hasattr(val, "ndim") and val.ndim == 1 and val.shape[0] == n:
            Jdict[key] = dVfull_dgvec_jvp(phys_params, key, I)  # (T,C,n)
    return Jdict


In [69]:
def detect_spikes(Vsoma, dt, thr=0.0, min_isi_ms=2.0):
    up = (Vsoma[1:] >= thr) & (Vsoma[:-1] < thr)
    idx = np.where(up)[0] + 1
    if idx.size == 0:
        return idx
    min_gap = int(round(min_isi_ms / dt))
    keep = [idx[0]]
    for k in idx[1:]:
        if k - keep[-1] >= min_gap:
            keep.append(k)
    return np.asarray(keep, dtype=int)


In [70]:
def bap_metrics(V_tc, spike_idx, dt, pre_ms=1.0, post_ms=10.0, baseline="pre_mean"):
    """
    V_tc: (T, C) for ONE injection.
    Returns dict with arrays (S, C): 'amp' (peak depol above baseline), 'ttp' (ms), 'base'.
    """
    T, C = V_tc.shape
    S = len(spike_idx)
    pre_s  = max(1, int(round(pre_ms / dt)))
    post_s = max(1, int(round(post_ms / dt)))

    amps = np.zeros((S, C))
    ttp  = np.zeros((S, C))
    base = np.zeros((S, C))
    for s, si in enumerate(spike_idx):
        lo = max(0, si - pre_s)
        hi = min(T, si + post_s + 1)
        if baseline == "pre_mean":
            base_s = V_tc[lo:si].mean(axis=0) if si > lo else V_tc[si]
        elif baseline == "at_spike":
            base_s = V_tc[si]
        else:
            raise ValueError("baseline in {'pre_mean','at_spike'}")
        post_seg = V_tc[si:hi] - base_s                # (len, C)
        k = np.argmax(post_seg, axis=0)                # per-compartment
        amps[s] = post_seg[k, np.arange(C)]
        ttp[s]  = k * dt
        base[s] = base_s
    return {"amp": amps, "ttp": ttp, "base": base}


In [71]:
def grad_windows_all_params(Jdict_tcj, spike_idx, dt, pre_ms=1.0, post_ms=10.0):
    """
    Jdict_tcj: {key: J(t,c,j)} for ONE injection.
    Returns: {key: {'pre_mean': (S,C,J), 'post_mean': (S,C,J)}}.
    """
    S = len(spike_idx)
    out = {}
    pre_s  = max(1, int(round(pre_ms / dt)))
    post_s = max(1, int(round(post_ms / dt)))

    for key, J in Jdict_tcj.items():
        T, C, Jn = J.shape
        pre_mean  = np.zeros((S, C, Jn))
        post_mean = np.zeros((S, C, Jn))
        for s, si in enumerate(spike_idx):
            lo = max(0, si - pre_s)
            hi = min(T, si + post_s)
            pre_mean[s]  = J[lo:si].mean(axis=0) if si > lo else J[si:si+1].mean(axis=0)
            post_mean[s] = J[si:hi].mean(axis=0)
        out[key] = {"pre_mean": pre_mean, "post_mean": post_mean}
    return out


In [72]:
i = 0
V_i   = volts_by_inj[i]
Vsoma = V_i[:, 0]

spk = detect_spikes(Vsoma, dt, thr=0.0, min_isi_ms=2.0)
print(f"[inj {i}] spikes:", spk, " times (ms):", np.round(spk*dt, 3))

bap = bap_metrics(V_i, spk, dt, pre_ms=1.0, post_ms=10.0)
print("bAP amp:", bap["amp"].shape, " ttp:", bap["ttp"].shape)

Jdict = compute_Jdict_for_injection(phys_params_for_viz, inj_currents[i])
for k, J in Jdict.items():
    print(f"{k}: J shape {J.shape}")  # expect (T, C, n)


gw = grad_windows_all_params(Jdict, spk, dt, pre_ms=1.0, post_ms=10.0)
for k, d in gw.items():
    print(k, "pre_mean:", d["pre_mean"].shape, " post_mean:", d["post_mean"].shape)


[inj 0] spikes: []  times (ms): []
bAP amp: (0, 5)  ttp: (0, 5)
HH_gNa: J shape (2000, 5, 5)
HH_gK: J shape (2000, 5, 5)
HH_gL: J shape (2000, 5, 5)
HH_ENa: J shape (2000, 5, 5)
HH_EK: J shape (2000, 5, 5)
HH_EL: J shape (2000, 5, 5)
C_m: J shape (2000, 5, 5)
HH_gNa pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
HH_gK pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
HH_gL pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
HH_ENa pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
HH_EK pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
HH_EL pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
C_m pre_mean: (0, 5, 5)  post_mean: (0, 5, 5)
