Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.5"
__version__ = "2.3.6"


# fundamental supporting modules
Expand Down Expand Up @@ -75,8 +75,7 @@
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share, Delay

Expand Down Expand Up @@ -207,7 +206,6 @@
dyn.__dict__['TwoEndConn'] = TwoEndConn
dyn.__dict__['CondNeuGroup'] = CondNeuGroup
dyn.__dict__['Channel'] = Channel
dyn.__dict__['NoSharedArg'] = NoSharedArg
dyn.__dict__['LoopOverTime'] = LoopOverTime
dyn.__dict__['DSRunner'] = DSRunner

Expand Down
61 changes: 9 additions & 52 deletions brainpy/_src/dyn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

__all__ = [
'LoopOverTime',
'NoSharedArg',
]


Expand Down Expand Up @@ -207,12 +206,13 @@ def __call__(
if isinstance(duration_or_xs, float):
shared = tools.DotDict()
if self.t0 is not None:
shared['t'] = jnp.arange(self.t0.value, duration_or_xs, self.dt)
shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value
if self.i0 is not None:
shared['i'] = jnp.arange(self.i0.value, shared['t'].shape[0])
shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value
xs = None
if self.no_state:
raise ValueError('Under the `no_state=True` setting, input cannot be a duration.')
length = shared['t'].shape

else:
inp_err_msg = ('\n'
Expand Down Expand Up @@ -278,8 +278,8 @@ def __call__(

else:
shared = tools.DotDict()
shared['t'] = jnp.arange(self.t0.value, self.dt * length[0], self.dt)
shared['i'] = jnp.arange(self.i0.value, length[0])
shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value
shared['i'] = jnp.arange(0, length[0]) + self.i0.value

assert not self.no_state
results = bm.for_loop(functools.partial(self._run, self.shared_arg),
Expand All @@ -295,6 +295,10 @@ def __call__(

def reset_state(self, batch_size=None):
self.target.reset_state(batch_size)
if self.i0 is not None:
self.i0.value = jnp.asarray(0)
if self.t0 is not None:
self.t0.value = jnp.asarray(0.)

def _run(self, static_sh, dyn_sh, x):
share.save(**static_sh, **dyn_sh)
Expand All @@ -304,50 +308,3 @@ def _run(self, static_sh, dyn_sh, x):
self.target.clear_input()
return outs


class NoSharedArg(DynSysToBPObj):
"""Transform an instance of :py:class:`~.DynamicalSystem` into a callable
:py:class:`~.BrainPyObject` :math:`y=f(x)`.

.. note::

This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`.

If some children nodes need shared arguments, like :py:class:`~.Dropout` or
:py:class:`~.LIF` models, using ``NoSharedArg`` will cause errors.

Examples
--------

>>> import brainpy as bp
>>> import brainpy.math as bm
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
>>> bm.relu,
>>> bp.layers.Dense(10, 2))
>>> l = bp.NoSharedArg(l)
>>> l(bm.random.random(256, 100))

Parameters
----------
target: DynamicalSystem
The target to transform.
name: str
The transformed object name.
"""

def __init__(self, target: DynamicalSystem, name: str = None):
super().__init__(target=target, name=name)
if isinstance(target, Sequential) and target.no_shared_arg:
raise ValueError(f'It is a {Sequential.__name__} object with `no_shared_arg=True`, '
f'which has already able to be called with `f(x)`. ')

def __call__(self, *args, **kwargs):
return self.target(tools.DotDict(), *args, **kwargs)

def reset(self, batch_size=None):
"""Reset function which reset the whole variables in the model.
"""
self.target.reset(batch_size)

def reset_state(self, batch_size=None):
self.target.reset_state(batch_size)
5 changes: 4 additions & 1 deletion brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def clone(self):
return self.__class__()


def set_environment(
def set(
mode: modes.Mode = None,
dt: float = None,
x64: bool = None,
Expand Down Expand Up @@ -381,6 +381,9 @@ def set_environment(
set_complex(complex_)


set_environment = set


class environment(_DecoratorContextManager):
r"""Context-manager that sets a computing environment for brain dynamics computation.

Expand Down
Loading