From b59940caf512b4ef6cb50ca7584ff0d9784c901a Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 11 Aug 2022 21:51:43 +0800 Subject: [PATCH] update apis and tests --- brainpy/base/base.py | 1 - brainpy/base/tests/test_collector.py | 2 + brainpy/dyn/neurons/biological_models.py | 80 +++++++++++++++--------- brainpy/visualization/base.py | 5 ++ brainpy/visualization/plots.py | 10 +++ 5 files changed, 66 insertions(+), 32 deletions(-) diff --git a/brainpy/base/base.py b/brainpy/base/base.py index f55e25ec1..ef1da758b 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -124,7 +124,6 @@ def vars(self, method='absolute', level=-1, include_self=True): v = getattr(node, k) if isinstance(v, math.Variable): if k not in node._excluded_vars: - # if not k.startswith('_') and not k.endswith('_'): gather[f'{node_path}.{k}' if node_path else k] = v gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()}) return gather diff --git a/brainpy/base/tests/test_collector.py b/brainpy/base/tests/test_collector.py index fb4dfb4f8..041a305ba 100644 --- a/brainpy/base/tests/test_collector.py +++ b/brainpy/base/tests/test_collector.py @@ -273,6 +273,8 @@ def test_net_vars_2(): def test_hidden_variables(): class BPClass(bp.base.Base): + _excluded_vars = ('_rng_', ) + def __init__(self): super(BPClass, self).__init__() diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index 23c737c49..b0ef1852f 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional import brainpy.math as bm from brainpy.dyn.base import NeuGroup @@ -204,9 +204,9 @@ def __init__( V_th: Union[float, Tensor, Initializer, Callable] = 20., C: Union[float, Tensor, Initializer, Callable] = 1.0, V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.), - m_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.5), - h_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.6), - n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32), + m_initializer: Optional[Union[Initializer, Callable, Tensor]] = None, + h_initializer: Optional[Union[Initializer, Callable, Tensor]] = None, + n_initializer: Optional[Union[Initializer, Callable, Tensor]] = None, noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, @@ -233,9 +233,9 @@ def __init__( self.noise = init_noise(noise, self.varshape, num_vars=4) # initializers - check_initializer(m_initializer, 'm_initializer', allow_none=False) - check_initializer(h_initializer, 'h_initializer', allow_none=False) - check_initializer(n_initializer, 'n_initializer', allow_none=False) + check_initializer(m_initializer, 'm_initializer', allow_none=True) + check_initializer(h_initializer, 'h_initializer', allow_none=True) + check_initializer(n_initializer, 'n_initializer', allow_none=True) check_initializer(V_initializer, 'V_initializer', allow_none=False) self._m_initializer = m_initializer self._h_initializer = h_initializer @@ -243,10 +243,19 @@ def __init__( self._V_initializer = V_initializer # variables - self.m = variable(self._m_initializer, mode, self.varshape) - self.h = variable(self._h_initializer, mode, self.varshape) - self.n = variable(self._n_initializer, mode, self.varshape) self.V = variable(self._V_initializer, mode, self.varshape) + if self._m_initializer is None: + self.m = bm.Variable(self.m_inf(self.V.value)) + else: + self.m = variable(self._m_initializer, mode, self.varshape) + if self._h_initializer is None: + self.h = bm.Variable(self.h_inf(self.V.value)) + else: + self.h = variable(self._h_initializer, mode, self.varshape) + if self._n_initializer is None: + self.n = bm.Variable(self.n_inf(self.V.value)) + else: + self.n = variable(self._n_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) @@ -256,32 +265,41 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + # m channel + m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) + m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m + + # h channel + h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.) + h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10)) + h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h + + # n channel + n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) + n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n + def reset_state(self, batch_size=None): - self.m.value = variable(self._m_initializer, batch_size, self.varshape) - self.h.value = variable(self._h_initializer, batch_size, self.varshape) - self.n.value = variable(self._n_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) + if self._m_initializer is None: + self.m.value = self.m_inf(self.V.value) + else: + self.m.value = variable(self._m_initializer, batch_size, self.varshape) + if self._h_initializer is None: + self.h.value = self.h_inf(self.V.value) + else: + self.h.value = variable(self._h_initializer, batch_size, self.varshape) + if self._n_initializer is None: + self.n.value = self.n_inf(self.V.value) + else: + self.n.value = variable(self._n_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) - def dm(self, m, t, V): - alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - beta = 4.0 * bm.exp(-(V + 65) / 18) - dmdt = alpha * (1 - m) - beta * m - return dmdt - - def dh(self, h, t, V): - alpha = 0.07 * bm.exp(-(V + 65) / 20.) - beta = 1 / (1 + bm.exp(-(V + 35) / 10)) - dhdt = alpha * (1 - h) - beta * h - return dhdt - - def dn(self, n, t, V): - alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - beta = 0.125 * bm.exp(-(V + 65) / 80) - dndt = alpha * (1 - n) - beta * n - return dndt - def dV(self, V, t, m, h, n, I_ext): I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) I_K = (self.gK * n ** 4.0) * (V - self.EK) diff --git a/brainpy/visualization/base.py b/brainpy/visualization/base.py index 1e48874c9..36a67ea7c 100644 --- a/brainpy/visualization/base.py +++ b/brainpy/visualization/base.py @@ -93,6 +93,11 @@ def animate_2D(values, frame_delay=frame_delay, frame_step=frame_step, title_size=title_size, figsize=figsize, gif_dpi=gif_dpi, video_fps=video_fps, save_path=save_path, show=show) + @staticmethod + def remove_axis(ax, *pos): + from .plots import remove_axis + return remove_axis(ax, *pos) + @staticmethod def plot_style1(fontsize=22, axes_edgecolor='black', diff --git a/brainpy/visualization/plots.py b/brainpy/visualization/plots.py index 141d961e8..4045579b4 100644 --- a/brainpy/visualization/plots.py +++ b/brainpy/visualization/plots.py @@ -17,6 +17,7 @@ 'raster_plot', 'animate_2D', 'animate_1D', + 'remove_axis', ] @@ -504,3 +505,12 @@ def frame(t): else: anim_result.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000) return fig + + +def remove_axis(ax, *pos): + for p in pos: + if p not in ['left', 'right', 'top', 'bottom']: + raise ValueError + ax.spine[p].set_visible(False) + +