Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions brainpy/base/function.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-

from typing import Callable, Sequence, Dict, Union
from typing import Callable, Sequence, Dict, Union, TypeVar

from brainpy.base.base import BrainPyObject
from brainpy.types import ArrayType


Variable = TypeVar('Variable')


__all__ = [
'FunAsObject',
Expand All @@ -28,7 +31,7 @@ class FunAsObject(BrainPyObject):
def __init__(self,
f: Callable,
child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None,
dyn_vars: Union[ArrayType, Sequence[ArrayType], Dict[dict, ArrayType]] = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None,
name: str = None):
super(FunAsObject, self).__init__(name=name)
self._f = f
Expand Down
2 changes: 1 addition & 1 deletion brainpy/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def load(
gda_manager: Optional[Any] = None,
allow_partial_mpa_restoration: bool = False,
) -> PyTree:
"""Load last or best checkpoint from the given checkpoint path.
"""Load last or best checkpoint from the given checkpoint path.

Sorts the checkpoint files naturally, returning the highest-valued
file, e.g.:
Expand Down
45 changes: 41 additions & 4 deletions brainpy/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from typing import Any, Callable, TypeVar, cast

from jax import dtypes, config, numpy as jnp, devices
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge

from . import modes
Expand Down Expand Up @@ -329,6 +329,7 @@ def clone(self):
def set_environment(
mode: modes.Mode = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
Expand All @@ -342,6 +343,8 @@ def set_environment(
The computing mode.
dt: float
The numerical integration precision.
x64: bool
Enable x64 computation.
complex_: type
The complex data type.
float_
Expand All @@ -359,6 +362,10 @@ def set_environment(
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
set_mode(mode)

if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
set_x64(x64)

if float_ is not None:
assert isinstance(float_, type), '"float_" must a float.'
set_float(float_)
Expand Down Expand Up @@ -402,8 +409,9 @@ class environment(_DecoratorContextManager):

def __init__(
self,
dt: float = None,
mode: modes.Mode = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
Expand All @@ -412,6 +420,7 @@ def __init__(
super().__init__()
self.old_dt = get_dt()
self.old_mode = get_mode()
self.old_x64 = config.read("jax_enable_x64")
self.old_int = get_int()
self.old_bool = get_bool()
self.old_float = get_float()
Expand All @@ -421,6 +430,8 @@ def __init__(
assert isinstance(dt, float), '"dt" must a float.'
if mode is not None:
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
if float_ is not None:
assert isinstance(float_, type), '"float_" must a float.'
if int_ is not None:
Expand All @@ -431,6 +442,7 @@ def __init__(
assert isinstance(complex_, type), '"complex_" must a type.'
self.dt = dt
self.mode = mode
self.x64 = x64
self.complex_ = complex_
self.float_ = float_
self.int_ = int_
Expand All @@ -439,6 +451,7 @@ def __init__(
def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
if self.mode is not None: set_mode(self.mode)
if self.x64 is not None: set_x64(self.x64)
if self.float_ is not None: set_float(self.float_)
if self.int_ is not None: set_int(self.int_)
if self.complex_ is not None: set_complex(self.complex_)
Expand All @@ -448,6 +461,7 @@ def __enter__(self) -> 'environment':
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.dt is not None: set_dt(self.old_dt)
if self.mode is not None: set_mode(self.old_mode)
if self.x64 is not None: set_x64(self.old_x64)
if self.int_ is not None: set_int(self.old_int)
if self.float_ is not None: set_float(self.old_float)
if self.complex_ is not None: set_complex(self.old_complex)
Expand All @@ -456,6 +470,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def clone(self):
return self.__class__(dt=self.dt,
mode=self.mode,
x64=self.x64,
bool_=self.bool_,
complex_=self.complex_,
float_=self.float_,
Expand All @@ -468,6 +483,7 @@ class training_environment(environment):
This is a short-cut context setting for an environment with the training mode.
It is equivalent to::

>>> import brainpy.math as bm
>>> with bm.environment(mode=bm.training_mode):
>>> pass

Expand All @@ -476,11 +492,17 @@ class training_environment(environment):

def __init__(self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None):
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
mode=modes.TrainingMode())


Expand All @@ -490,6 +512,7 @@ class batching_environment(environment):
This is a short-cut context setting for an environment with the batching mode.
It is equivalent to::

>>> import brainpy.math as bm
>>> with bm.environment(mode=bm.batching_mode):
>>> pass

Expand All @@ -498,11 +521,17 @@ class batching_environment(environment):

def __init__(self,
dt: float = None,
x64: bool = None,
complex_: type = None,
float_: type = None,
int_: type = None,
bool_: type = None):
super().__init__(dt=dt, complex_=complex_, float_=float_, int_=int_, bool_=bool_,
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
mode=modes.BatchingMode())


Expand All @@ -520,6 +549,14 @@ def disable_x64():
set_complex(jnp.complex64)


def set_x64(enable: bool):
assert isinstance(enable, bool)
if enable:
enable_x64()
else:
disable_x64()


def set_platform(platform: str):
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
Expand Down
82 changes: 44 additions & 38 deletions brainpy/math/fft.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-

from typing import Optional
import jax.numpy.fft

from brainpy.math.ndarray import Array
from brainpy.math.numpy_ops import _remove_brainpy_array
from brainpy.math.numpy_ops import _as_jax_array_

__all__ = [
"fft", "fft2", "fftfreq", "fftn", "fftshift", "hfft",
Expand All @@ -12,89 +12,95 @@
]


def fft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm))
def fft(a,
n: Optional[int] = None,
axis: int = -1,
norm: Optional[str] = None):
a = _as_jax_array_(a)
return jax.numpy.fft.fft(a=a, n=n, axis=axis, norm=norm)


def fft2(a, s=None, axes=(-2, -1), norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.fft2(a=a, s=s, axes=axes, norm=norm)


def fftfreq(n, d=1.0):
return Array(jax.numpy.fft.fftfreq(n=n, d=d))
return jax.numpy.fft.fftfreq(n=n, d=d)


def fftn(a, s=None, axes=None, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.fftn(a=a, s=s, axes=axes, norm=norm)


def fftshift(x, axes=None):
x = _remove_brainpy_array(x)
return Array(jax.numpy.fft.fftshift(x=x, axes=axes))
x = _as_jax_array_(x)
return jax.numpy.fft.fftshift(x=x, axes=axes)


def hfft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.hfft(a=a, n=n, axis=axis, norm=norm)


def ifft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm))
def ifft(a,
n: Optional[int] = None,
axis: int = -1,
norm: Optional[str] = None):
a = _as_jax_array_(a)
return jax.numpy.fft.ifft(a=a, n=n, axis=axis, norm=norm)


def ifft2(a, s=None, axes=(-2, -1), norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.ifft2(a=a, s=s, axes=axes, norm=norm)


def ifftn(a, s=None, axes=None, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.ifftn(a=a, s=s, axes=axes, norm=norm)


def ifftshift(x, axes=None):
x = _remove_brainpy_array(x)
return Array(jax.numpy.fft.ifftshift(x=x, axes=axes))
x = _as_jax_array_(x)
return jax.numpy.fft.ifftshift(x=x, axes=axes)


def ihfft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.ihfft(a=a, n=n, axis=axis, norm=norm)


def irfft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.irfft(a=a, n=n, axis=axis, norm=norm)


def irfft2(a, s=None, axes=(-2, -1), norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm)


def irfftn(a, s=None, axes=None, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.irfftn(a=a, s=s, axes=axes, norm=norm)


def rfft(a, n=None, axis=-1, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm)


def rfft2(a, s=None, axes=(-2, -1), norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.rfft2(a=a, s=s, axes=axes, norm=norm)


def rfftfreq(n, d=1.0):
return Array(jax.numpy.fft.rfftfreq(n=n, d=d))
return jax.numpy.fft.rfftfreq(n=n, d=d)


def rfftn(a, s=None, axes=None, norm=None):
a = _remove_brainpy_array(a)
return Array(jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm))
a = _as_jax_array_(a)
return jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm)
Loading