Skip to content

Commit

Permalink
Merge pull request #536 from chaoming0625/master
Browse files Browse the repository at this point in the history
[doc] update documentations
  • Loading branch information
chaoming0625 committed Nov 5, 2023
2 parents 1d35e2e + 2970399 commit 687744d
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 243 deletions.
48 changes: 24 additions & 24 deletions brainpy/_src/dyn/channels/sodium_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(
mode: bm.Mode = None,
):
super().__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
keep_size=keep_size,
name=name,
mode=mode)

# parameters
self.E = parameter(E, self.varshape, allow_none=False)
Expand Down Expand Up @@ -174,13 +174,13 @@ def __init__(
mode: bm.Mode = None,
):
super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=3 ** ((T - 36) / 10),
g_max=g_max,
E=E,
mode=mode)
keep_size=keep_size,
name=name,
method=method,
phi=3 ** ((T - 36) / 10),
g_max=g_max,
E=E,
mode=mode)
self.T = parameter(T, self.varshape, allow_none=False)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

Expand Down Expand Up @@ -261,13 +261,13 @@ def __init__(
mode: bm.Mode = None,
):
super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
Expand Down Expand Up @@ -348,13 +348,13 @@ def __init__(
mode: bm.Mode = None,
):
super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
mode=mode)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
Expand All @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(
self.g = self.syn.g

# delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)

def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
Expand All @@ -328,7 +328,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# delays
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(
mode=mode)

# delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
self.delay_step = self.pre.register_delay("spike", delay_step, self.pre.spike)

# synaptic dynamics
self.syn = syn
Expand All @@ -317,7 +317,7 @@ def __init__(

def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
pre_spike = self.pre.get_delay_data("spike", self.delay_step)
if stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
if self.stp is not None:
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ Installation

.. code-block:: bash
pip install -U "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U brainpy brainpylib-cu11x # only on linux
.. tab-item:: GPU (CUDA-12x)

.. code-block:: bash
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U brainpy brainpylib-cu12x # only on linux
Expand Down

0 comments on commit 687744d

Please sign in to comment.