Skip to content

[math & dyn] add brainpy.math.exprel, and change the code in the corresponding HH neuron models to improve numerical computation accuracy #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 10, 2023
24 changes: 14 additions & 10 deletions brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def __init__(
self.reset_state(self.mode)

# m channel
m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
# m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10)
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
Expand All @@ -360,7 +361,8 @@ def __init__(
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h

# n channel
n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
# n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10)
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
Expand All @@ -383,8 +385,9 @@ def reset_state(self, batch_size=None, **kwargs):

def dV(self, V, t, m, h, n, I):
I = self.sum_inputs(V, init=I)
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
Expand Down Expand Up @@ -516,8 +519,9 @@ class HH(HHLTC):
"""

def dV(self, V, t, m, h, n, I):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
I_leak = self.gL * (V - self.EL)
dVdt = (- I_Na - I_K - I_leak + I) / self.C
return dVdt
Expand Down Expand Up @@ -680,9 +684,7 @@ def update(self, x=None):
t = share.load('t')
dt = share.load('dt')
x = 0. if x is None else x

V, W = self.integral(self.V, self.W, t, x, dt)

spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.W.value = W
Expand Down Expand Up @@ -930,7 +932,8 @@ def reset_state(self, batch_size=None):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
# alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
alpha = 1. / bm.exprel(-0.1 * (V + 35))
beta = 4. * bm.exp(-(V + 60.) / 18.)
return alpha / (alpha + beta)

Expand All @@ -941,7 +944,8 @@ def dh(self, h, t, V):
return self.phi * dhdt

def dn(self, n, t, V):
alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
# alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
alpha = 1. / bm.exprel(-0.1 * (V + 34))
beta = 0.125 * bm.exp(-(V + 44) / 80)
dndt = alpha * (1 - n) - beta * n
return self.phi * dndt
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
Expand All @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(
self.g = self.syn.g

# delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
Expand All @@ -328,7 +328,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# delays
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(
mode=mode)

# delay
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)
self.pre.register_local_delay("spike", self.name, delay_step)

# synaptic dynamics
self.syn = syn
Expand All @@ -317,7 +317,7 @@ def __init__(

def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = self.pre.get_local_delay("spike", self.name)
if stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
if self.stp is not None:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _compatible_reset_state(self, *args, **kwargs):
the_top_layer_reset_state = True
warnings.warn(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details.
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details.

1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
"bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
Expand Down
6 changes: 1 addition & 5 deletions brainpy/_src/integrators/ode/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@
.. [2] Hochbruck, M., & Ostermann, A. (2010). Exponential integrators. Acta Numerica, 19, 209-286.
"""

import logging

from functools import wraps
from brainpy import errors
from brainpy._src import math as bm
Expand Down Expand Up @@ -360,9 +358,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
phi = bm.where(linear == 0.,
bm.ones_like(linear),
(bm.exp(dt * linear) - 1) / (dt * linear))
phi = bm.exprel(dt * linear)
return args[0] + dt * phi * derivative

return [(integral, vars, pars), ]
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,7 @@ def integral(*args, **kwargs):
assert len(args) > 0
dt = kwargs.pop('dt', self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
linear = bm.as_jax(linear)
phi = jnp.where(linear == 0., jnp.ones_like(linear), (jnp.exp(dt * linear) - 1) / (dt * linear))
phi = bm.as_jax(bm.exprel(dt * linear))
return args[0] + dt * phi * derivative

return [(integral, vars, pars), ]
Expand Down
13 changes: 7 additions & 6 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Array(object):

"""

__slots__ = ('_value', '_keep_sharding')
__slots__ = ('_value', )

def __init__(self, value, dtype: Any = None):
# array value
Expand Down Expand Up @@ -132,7 +132,7 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, Array) else value
self._value = value

def update(self, value):
"""Update the value of this Array.
Expand Down Expand Up @@ -1549,11 +1549,12 @@ def value(self):
Returns:
The stored data.
"""
v = self._value
# keep sharding constraints
if self._keep_sharding and hasattr(self._value, 'sharding') and (self._value.sharding is not None):
return jax.lax.with_sharding_constraint(self._value, self._value.sharding)
if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None):
return jax.lax.with_sharding_constraint(v, v.sharding)
# return the value
return self._value
return v

@value.setter
def value(self, value):
Expand All @@ -1574,6 +1575,6 @@ def value(self, value):
if value.dtype != self_value.dtype:
raise MathError(f"The dtype of the original data is {self_value.dtype}, "
f"while we got {value.dtype}.")
self._value = value.value if isinstance(value, Array) else value
self._value = value


39 changes: 38 additions & 1 deletion brainpy/_src/math/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from jax.tree_util import tree_map

from brainpy import check, tools
from .compat_numpy import fill_diagonal
from .environment import get_dt, get_int
from .ndarray import Array
from .compat_numpy import fill_diagonal
from .interoperability import as_jax

__all__ = [
'shared_args_over_time',
'remove_diag',
'clip_by_norm',
'exprel',
]


Expand Down Expand Up @@ -82,3 +84,38 @@ def f(l):
return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm)

return tree_map(f, t)


def _exprel(x, threshold):
def true_f(x):
x2 = x * x
return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120.

def false_f(x):
return (jnp.exp(x) - 1) / x

# return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x)
return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x)


def exprel(x, threshold: float = None):
"""Relative error exponential, ``(exp(x) - 1)/x``.

When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can
suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of
precision that occurs when ``x`` is near zero.

Args:
x: ndarray. Input array. ``x`` must contain real numbers.
threshold: float.

Returns:
``(exp(x) - 1)/x``, computed element-wise.
"""
x = as_jax(x)
if threshold is None:
if hasattr(x, 'dtype') and x.dtype == jnp.float64:
threshold = 1e-8
else:
threshold = 1e-5
return _exprel(x, threshold)
21 changes: 21 additions & 0 deletions brainpy/_src/math/tests/test_others.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

import brainpy.math as bm
from scipy.special import exprel

from unittest import TestCase


class Test_exprel(TestCase):
def test1(self):
for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
# self.assertEqual(exprel(x))

def test2(self):
bm.enable_x64()
for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]:
print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}')
# self.assertEqual(exprel(x))



1 change: 1 addition & 0 deletions brainpy/math/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
shared_args_over_time as shared_args_over_time,
remove_diag as remove_diag,
clip_by_norm as clip_by_norm,
exprel as exprel,
)

from brainpy._src.math.object_transform.naming import (
Expand Down