Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions brainpy/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
from . import surrogate
from .surrogate.compt import *

# JAX transformations for Variable and class objects
# Variable and Objects for object-oriented JAX transformations
from .object_transform import *


# environment settings
from .modes import *
from .environment import *
Expand Down
9 changes: 5 additions & 4 deletions brainpy/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@


import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple as TupleType

import numpy as np
from jax import numpy as jnp
from jax.tree_util import register_pytree_node


from brainpy.errors import MathError

__all__ = [
Expand Down Expand Up @@ -997,7 +998,7 @@ def __init__(
f'but the batch axis is set to be {batch_axis}.')

@property
def shape_nb(self) -> Tuple[int, ...]:
def shape_nb(self) -> TupleType[int, ...]:
"""Shape without batch axis."""
shape = list(self.value.shape)
if self.batch_axis is not None:
Expand Down Expand Up @@ -1562,7 +1563,6 @@ class BatchVariable(Variable):
pass



class VariableView(Variable):
"""A view of a Variable instance.

Expand Down Expand Up @@ -1742,7 +1742,7 @@ def _jaxarray_unflatten(aux_data, flat_contents):


register_pytree_node(Array,
lambda t: ((t.value,), (t._transform_context, )),
lambda t: ((t.value,), (t._transform_context,)),
_jaxarray_unflatten)

register_pytree_node(Variable,
Expand All @@ -1756,3 +1756,4 @@ def _jaxarray_unflatten(aux_data, flat_contents):
register_pytree_node(Parameter,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Parameter(*flat_contents))

6 changes: 6 additions & 0 deletions brainpy/math/object_transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@
+ controls.__all__
+ jit.__all__
+ function.__all__
+ base_object.__all__
+ base_transform.__all__
+ collector.__all__
)

from .autograd import *
from .controls import *
from .jit import *
from .function import *
from .base_object import *
from .base_transform import *
from .collector import *
131 changes: 119 additions & 12 deletions brainpy/math/object_transform/base_object.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
# -*- coding: utf-8 -*-

import os
import logging
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union

import jax
import numpy as np
from jax._src.tree_util import _registry
from jax.tree_util import register_pytree_node
from jax.tree_util import register_pytree_node_class
from jax.util import safe_zip

from brainpy import errors
from .collector import Collector, ArrayCollector
from ..ndarray import Variable, VariableView, TrainVar
from ..ndarray import (Array,
Variable,
VariableView,
TrainVar)

StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])


__all__ = [
'check_name_uniqueness',
'get_unique_name',
'clear_name_cache',
# naming
'check_name_uniqueness', 'get_unique_name', 'clear_name_cache',

# objects
'BrainPyObject', 'Base', 'FunAsObject',

# variables
'numerical_seq', 'object_seq',
'numerical_dict', 'object_dict',
]

logger = logging.getLogger('brainpy.brainpy_object')

_name2id = dict()
_typed_names = {}
Expand Down Expand Up @@ -59,7 +72,7 @@ def clear_name_cache(ignore_warn=False):
_name2id.clear()
_typed_names.clear()
if not ignore_warn:
logger.warning(f'All named models and their ids are cleared.')
warnings.warn(f'All named models and their ids are cleared.', UserWarning)


class BrainPyObject(object):
Expand All @@ -78,6 +91,11 @@ class BrainPyObject(object):
_excluded_vars = ()

def __init__(self, name=None):
super().__init__()
cls = self.__class__
if cls not in _registry:
register_pytree_node_class(cls)

# check whether the object has a unique name.
self._name = None
self._name = self.unique_name(name=name)
Expand All @@ -91,15 +109,17 @@ def __init__(self, name=None):
# which cannot be accessed by self.xxx
self.implicit_nodes = Collector()

def __setattr__(self, key, value) -> None:
"""Overwrite __setattr__ method for non-changeable Variable setting.
def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for change Variable values.

.. versionadded:: 2.3.1

Parameters
----------
key: str
The attribute.
value: Any
The value.
"""
if key in self.__dict__:
val = self.__dict__[key]
Expand All @@ -109,19 +129,24 @@ def __setattr__(self, key, value) -> None:
super().__setattr__(key, value)

def tree_flatten(self):
"""
"""Flattens the object as a PyTree.

The flattening order is determined by attributes added order.

.. versionadded:: 2.3.1

Returns
-------

res: tuple
A tuple of dynamical values and static values.
"""
dts = (BrainPyObject,) + tuple(dynamical_types)
dynamic_names = []
dynamic_values = []
static_names = []
static_values = []
for k, v in self.__dict__.items():
if isinstance(v, (ArrayCollector, BrainPyObject, Variable)):
if isinstance(v, dts):
dynamic_names.append(k)
dynamic_values.append(v)
else:
Expand Down Expand Up @@ -531,3 +556,85 @@ def __repr__(self) -> str:
node_string = ", \n".join(nodes)
return (f'{name}(nodes=[{node_string}],\n' +
" " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')


class numerical_seq(list):
"""A list to represent a dynamically changed numerical
sequence in which its element can be changed during JIT compilation.

.. note::
The element must be numerical, like ``bool``, ``int``, ``float``,
``jax.Array``, ``numpy.ndarray``, ``brainpy.math.Array``.
"""
def append(self, element):
if not isinstance(element, (bool, int, float, jax.Array, Array, np.ndarray)):
raise TypeError(f'Each element should be a numerical value.')

def extend(self, iterable) -> None:
for element in iterable:
self.append(element)


register_pytree_node(numerical_seq,
lambda x: (tuple(x), ()),
lambda _, values: numerical_seq(values))


class object_seq(list):
"""A list to represent a sequence of :py:class:`~.BrainPyObject`.

.. note::
The element must be :py:class:`~.BrainPyObject`.
"""
def append(self, element):
if not isinstance(element, BrainPyObject):
raise TypeError(f'Only support {BrainPyObject.__name__}')

def extend(self, iterable) -> None:
for element in iterable:
self.append(element)


register_pytree_node(object_seq,
lambda x: (tuple(x), ()),
lambda _, values: object_seq(values))


class numerical_dict(dict):
"""A dict to represent a dynamically changed numerical
dictionary in which its element can be changed during JIT compilation.

.. note::
Each key must be a string, and each value must be numerical, including
``bool``, ``int``, ``float``, ``jax.Array``, ``numpy.ndarray``,
``brainpy.math.Array``.
"""
def update(self, *args, **kwargs) -> 'numerical_dict':
super().update(*args, **kwargs)
return self


register_pytree_node(numerical_dict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: numerical_dict(safe_zip(keys, values)))


class object_dict(dict):
"""A dict to represent a dictionary of :py:class:`~.BrainPyObject`.

.. note::
Each key must be a string, and each value must be :py:class:`~.BrainPyObject`.
"""
def update(self, *args, **kwargs) -> 'object_dict':
super().update(*args, **kwargs)
return self


register_pytree_node(object_dict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: object_dict(safe_zip(keys, values)))

dynamical_types = [Variable,
numerical_seq, numerical_dict,
object_seq, object_dict]

Loading