Skip to content

Commit

Permalink
updates (#550)
Browse files Browse the repository at this point in the history
* [running] fix multiprocessing bugs

* fix tests

* [doc] update doc

* update

* [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation

* [math] `clear_buffer_memory` support to clear array and compilation both

* [dyn] compatible old version of `.reset_state()` function

* [setup] update installation info
  • Loading branch information
chaoming0625 committed Nov 28, 2023
1 parent bd04b90 commit 6c599a7
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 47 deletions.
52 changes: 29 additions & 23 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import collections
import inspect
import warnings
import numbers
import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Any

import numpy as np
Expand All @@ -13,7 +13,7 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool
from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape

__all__ = [
Expand All @@ -27,9 +27,9 @@
'Dynamic', 'Projection',
]


IonChaDyn = None
SLICE_VARS = 'slice_vars'
the_top_layer_reset_state = True


def not_implemented(fun):
Expand Down Expand Up @@ -138,16 +138,12 @@ def update(self, *args, **kwargs):
"""
raise NotImplementedError('Must implement "update" function by subclass self.')

def reset(self, *args, include_self: bool = False, **kwargs):
def reset(self, *args, **kwargs):
"""Reset function which reset the whole variables in the model (including its children models).
``reset()`` function is a collective behavior which resets all states in this model.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
Args::
include_self: bool. Reset states including the node self. Please turn on this if the node has
implemented its ".reset_state()" function.
"""
from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)
Expand All @@ -162,19 +158,6 @@ def reset_state(self, *args, **kwargs):
"""
pass

# raise APIChangedError(
# '''
# From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
#
# 1. If you are resetting all states in a network by calling "net.reset_state()", please use
# "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
# in a local node (excluded its children nodes).
#
# 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
#
# '''
# )

def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
pass
Expand Down Expand Up @@ -344,14 +327,37 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)

def _compatible_reset_state(self, *args, **kwargs):
global the_top_layer_reset_state
the_top_layer_reset_state = False
try:
self.reset(*args, **kwargs)
finally:
the_top_layer_reset_state = True
warnings.warn(
'''
From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details.
1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
"bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).
2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
''',
DeprecationWarning
)

def _get_update_fun(self):
return object.__getattribute__(self, 'update')

def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
else:
return super().__getattribute__(item)
if item == 'reset_state':
if the_top_layer_reset_state:
return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function
return super().__getattribute__(item)

def __repr__(self):
return f'{self.name}(mode={self.mode})'
Expand Down
48 changes: 41 additions & 7 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from typing import Any, Callable, TypeVar, cast

import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge

Expand Down Expand Up @@ -682,7 +683,11 @@ def set_host_device_count(n):
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)


def clear_buffer_memory(platform=None):
def clear_buffer_memory(
platform: str = None,
array: bool = True,
compilation: bool = False
):
"""Clear all on-device buffers.
This function will be very useful when you call models in a Python loop,
Expand All @@ -697,18 +702,47 @@ def clear_buffer_memory(platform=None):
----------
platform: str
The device to clear its memory.
array: bool
Clear all buffer array.
compilation: bool
Clear compilation cache.
"""
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
buf.delete()
if array:
for buf in xla_bridge.get_backend(platform=platform).live_buffers():
buf.delete()
if compilation:
jax.clear_caches()


def disable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
def disable_gpu_memory_preallocation(release_memory: bool = True):
"""Disable pre-allocating the GPU memory.
This disables the preallocation behavior. JAX will instead allocate GPU memory as needed,
potentially decreasing the overall memory usage. However, this behavior is more prone to
GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory
may OOM with preallocation disabled.
Args:
release_memory: bool. Whether we release memory during the computation.
"""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
if release_memory:
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


def enable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR')
os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)


def gpu_memory_preallocation(percent: float):
"""GPU memory allocation.
If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
"""
assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)

13 changes: 0 additions & 13 deletions brainpy/_src/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,19 +519,6 @@ def __subclasscheck__(self, subclass):
return all([issubclass(subclass, cls) for cls in self.__bases__])


class UnionType2(MixIn):
"""Union type for multiple types.
>>> import brainpy as bp
>>>
>>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay])
"""

@classmethod
def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
return _MetaUnionType('UnionType', types, {})


if sys.version_info.minor > 8:
class _JointGenericAlias(_UnionGenericAlias, _root=True):
def __subclasscheck__(self, subclass):
Expand Down
1 change: 1 addition & 0 deletions brainpy/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
clear_buffer_memory as clear_buffer_memory,
enable_gpu_memory_preallocation as enable_gpu_memory_preallocation,
disable_gpu_memory_preallocation as disable_gpu_memory_preallocation,
gpu_memory_preallocation as gpu_memory_preallocation,
ditype as ditype,
dftype as dftype,
)
2 changes: 1 addition & 1 deletion docs/tutorial_advanced/operator_custom_with_numba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"collapsed": true
},
"source": [
"# Operator Customization with Numba"
"# CPU Operator Customization with Numba"
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion docs/tutorial_advanced/operator_custom_with_taichi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Operator Customization with Taichi"
"# CPU and GPU Operator Customization with Taichi"
]
},
{
"cell_type": "markdown",
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
13 changes: 11 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
# installation packages
packages = find_packages(exclude=['lib*', 'docs', 'tests'])


# setup
setup(
name='brainpy',
Expand All @@ -51,13 +50,23 @@
author_email='chao.brain@qq.com',
packages=packages,
python_requires='>=3.8',
install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'],
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
"Documentation": "https://brainpy.readthedocs.io/",
"Source Code": "https://github.com/brainpy/BrainPy",
},
dependency_links=[
'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
],
extras_require={
'cpu': ['jaxlib>=0.4.13', 'brainpylib'],
'cuda': ['jax[cuda]', 'brainpylib-cu11x'],
'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'],
'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'],
'tpu': ['jax[tpu]'],
},
keywords=('computational neuroscience, '
'brain-inspired computation, '
'dynamical systems, '
Expand Down

0 comments on commit 6c599a7

Please sign in to comment.