In [1]:
import jax.numpy as np
import tree_math
import dataclasses

from jaxsnn.base import funcutils
from jaxsnn.base import implicit
from jaxsnn.tree_solver import ArrayLike, tree_solve, tree_matmul

In [None]:
@tree_math.struct
class HodgekinHuxleyState:
    v: ArrayLike
    m: ArrayLike
    h: ArrayLike
    n: ArrayLike

@tree_math.struct
class HodgekinHuxleyParameters:
    g_K: ArrayLike
    E_K: ArrayLike
    g_Na: ArrayLike
    E_Na: ArrayLike
    g_L: ArrayLike
    E_L: ArrayLike
    C_m_inv: ArrayLike

def alpha_n(v : ArrayLike):
    x = 0.001 * (10 - v)
    y = np.exp((10 - v)/10) - 1
    return x/y

def beta_n(v : ArrayLike):
    x = 0.125 * np.exp(- v / 80)
    return x

def alpha_h(v : ArrayLike):
    return 0.07 * np.exp(- v / 20)

def beta_h(v : ArrayLike):
    y = np.exp((30 - v)/10) + 1
    return 1/y

def alpha_m(v : ArrayLike):
    x = 0.1 * (25 - v)
    y = np.exp((25 - v)/10) - 1
    return x/y

def beta_m(v : ArrayLike):
    return 4 * np.exp(- v/18)

def I_Na(s : HodgekinHuxleyState, p : HodgekinHuxleyParameters):
    return p.g_Na * s.m**3 * s.h * (s.v - p.E_Na)

def I_K(s : HodgekinHuxleyState, p: HodgekinHuxleyParameters):
    return p.g_K * s.n**4 * (s.v - p.E_K)

def I_L(s : HodgekinHuxleyState, p: HodgekinHuxleyParameters):
    return p.g_L * (s.v - p.E_L)

def channel_dynamics(alpha, beta):
    def dynamics(x : ArrayLike, v : ArrayLike):
        return alpha(v) * (1 - x) + beta(v) * x
    return dynamics

def hodgekin_huxley_dynamics(p : HodgekinHuxleyParameters):
    m_dynamics = channel_dynamics(alpha_m, beta_m)
    h_dynamics = channel_dynamics(alpha_h, beta_h)
    n_dynamics = channel_dynamics(alpha_n, beta_n)

    def dynamics(s : HodgekinHuxleyState):
        return HodgekinHuxleyState(
            -p.C_m_inv * (I_K(s, p) + I_Na(s, p) + I_L(s, p)),
            m_dynamics(s.m, s.v),
            n_dynamics(s.n, s.v),
            n_dynamics(s.h, s.v)
        )
    return dynamics