From 47b61deaa2af970df2d3ede1bfecb10d7ce111cf Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 22:56:11 +0800 Subject: [PATCH 1/2] Support JAX transformation contexts for all JaxArray --- brainpy/math/autograd.py | 10 +- brainpy/math/controls.py | 84 ++++++----- brainpy/math/jaxarray.py | 131 ++++++++++++------ brainpy/math/jit.py | 22 +-- .../math/tests/test_transformation_context.py | 47 +++++++ 5 files changed, 204 insertions(+), 90 deletions(-) create mode 100644 brainpy/math/tests/test_transformation_context.py diff --git a/brainpy/math/autograd.py b/brainpy/math/autograd.py index 02c028da7..200514e72 100644 --- a/brainpy/math/autograd.py +++ b/brainpy/math/autograd.py @@ -16,7 +16,9 @@ from jax.util import safe_map from brainpy import errors -from brainpy.math.jaxarray import JaxArray +from brainpy.base.naming import get_unique_name +from brainpy.math.jaxarray import JaxArray, add_context, del_context + __all__ = [ 'grad', # gradient of scalar function @@ -28,20 +30,26 @@ def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars, argnums, return_value, has_aux): + name = get_unique_name('_brainpy_object_oriented_grad_') + # outputs def call_func(*args, **kwargs): old_grad_vs = [v.value for v in grad_vars] old_dyn_vs = [v.value for v in dyn_vars] try: + add_context(name) grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs, old_dyn_vs, *args, **kwargs) + del_context(name) except UnexpectedTracerError as e: + del_context(name) for v, d in zip(grad_vars, old_grad_vs): v._value = d for v, d in zip(dyn_vars, old_dyn_vs): v._value = d raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e except Exception as e: + del_context(name) for v, d in zip(grad_vars, old_grad_vs): v._value = d for v, d in zip(dyn_vars, old_dyn_vs): v._value = d raise e diff --git a/brainpy/math/controls.py b/brainpy/math/controls.py index ab1d1b923..efbee4b5d 100644 --- a/brainpy/math/controls.py +++ b/brainpy/math/controls.py @@ -13,9 +13,10 @@ from jax.core import UnexpectedTracerError from brainpy import errors +from brainpy.base.naming import get_unique_name from brainpy.math.jaxarray import (JaxArray, Variable, - turn_on_global_jit, - turn_off_global_jit) + add_context, + del_context) from brainpy.math.numpy_ops import as_device_array __all__ = [ @@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False): out_vars=out_vars, has_return=has_return) + name = get_unique_name('_brainpy_object_oriented_make_loop_') + # functions if has_return: def call(xs=None, length=None): init_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, (out_values, results) = lax.scan( f=fun2scan, init=init_values, xs=xs, length=length) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -178,15 +181,15 @@ def call(xs=None, length=None): def call(xs): init_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, out_values = lax.scan(f=fun2scan, init=init_values, xs=xs) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -255,20 +258,22 @@ def _cond_fun(op): for v, d in zip(dyn_vars, dyn_values): v._value = d return as_device_array(cond_fun(static_values)) + name = get_unique_name('_brainpy_object_oriented_make_while_') + def call(x=None): dyn_init = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, _ = lax.while_loop(cond_fun=_cond_fun, body_fun=_body_fun, init_val=(dyn_init, x)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None): if not isinstance(v, JaxArray): raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}') + name = get_unique_name('_brainpy_object_oriented_make_cond_') + if len(dyn_vars) > 0: def _true_fun(op): dyn_vals, static_vals = op @@ -348,15 +355,15 @@ def _false_fun(op): def call(pred, x=None): old_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, res = lax.cond(pred, _true_fun, _false_fun, (old_values, x)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d @@ -364,9 +371,9 @@ def call(pred, x=None): else: def call(pred, x=None): - turn_on_global_jit() + add_context(name) res = lax.cond(pred, true_fun, false_fun, x) - turn_off_global_jit() + del_context(name) return res return call @@ -445,6 +452,8 @@ def cond( if not isinstance(v, Variable): raise ValueError(f'Only support {Variable.__name__}, but got {type(v)}') + name = get_unique_name('_brainpy_object_oriented_cond_') + # calling the model if len(dyn_vars) > 0: def _true_fun(op): @@ -463,25 +472,25 @@ def _false_fun(op): old_values = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, res = lax.cond(pred=pred, true_fun=_true_fun, false_fun=_false_fun, operand=(old_values, operands)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, old_values): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d else: - turn_on_global_jit() + add_context(name) res = lax.cond(pred, true_fun, false_fun, operands) - turn_off_global_jit() + del_context(name) return res @@ -591,7 +600,11 @@ def ifelse( if show_code: print(codes) exec(compile(codes.strip(), '', 'exec'), code_scope) f = code_scope['f'] - return f(operands) + name = get_unique_name('_brainpy_object_oriented_ifelse_') + add_context(name) + r = f(operands) + del_context(name) + return r def for_loop(body_fun: Callable, @@ -694,22 +707,24 @@ def fun2scan(dyn_vals, x): results = body_fun(*x) return [v.value for v in dyn_vars], results + name = get_unique_name('_brainpy_object_oriented_for_loop_') + # functions init_vals = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_vals, out_vals = lax.scan(f=fun2scan, init=init_vals, xs=operands, reverse=reverse, unroll=unroll) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_vals): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, init_vals): v._value = d raise e for v, d in zip(dyn_vars, dyn_vals): v._value = d @@ -797,19 +812,20 @@ def _cond_fun(op): r = cond_fun(*static_vals) return r if isinstance(r, JaxArray) else r + name = get_unique_name('_brainpy_object_oriented_while_loop_') dyn_init = [v.value for v in dyn_vars] try: - turn_on_global_jit() + add_context(name) dyn_values, out = lax.while_loop(cond_fun=_cond_fun, body_fun=_body_fun, init_val=(dyn_init, operands)) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise errors.JaxTracerError(variables=dyn_vars) from e except Exception as e: - turn_off_global_jit() + del_context(name) for v, d in zip(dyn_vars, dyn_init): v._value = d raise e for v, d in zip(dyn_vars, dyn_values): v._value = d diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 02b79d381..640376c8a 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -33,20 +33,53 @@ msg = ('JaxArray created outside of the jit function ' 'cannot be updated in JIT mode. You should ' 'mark it as brainpy.math.Variable instead.') -_global_jit_mode = False +_jax_transformation_context_ = [] -def turn_on_global_jit(): - """Turn on the global JIT mode to declare - all instantiated JaxArray cannot be updated.""" - global _global_jit_mode - _global_jit_mode = True +def add_context(name): + _jax_transformation_context_.append(name) -def turn_off_global_jit(): - """Turn off the global JIT mode.""" - global _global_jit_mode - _global_jit_mode = False + +def del_context(name=None): + try: + context = _jax_transformation_context_.pop(-1) + if name is not None: + if context != name: + raise MathError('Transformation context is different!') + # warnings.warn(, UserWarning) + except IndexError: + raise MathError('No transformation context!') + # warnings.warn('No transformation context!', UserWarning) + + +def get_context(): + if len(_jax_transformation_context_) > 0: + return _jax_transformation_context_[-1] + else: + return None + + +def check_context(arr_context): + if arr_context is None: + if len(_jax_transformation_context_) > 0: + raise MathError(f'JaxArray created outside of the transformation functions ' + f'({_jax_transformation_context_[-1]}) cannot be updated. ' + f'You should mark it as a Variable instead.') + return True + else: + return False + else: + if len(_jax_transformation_context_) > 0: + if arr_context != _jax_transformation_context_[-1]: + raise MathError(f'JaxArray context "{arr_context}" differs from the JAX ' + f'transformation context "{_jax_transformation_context_[-1]}"' + '\n\n' + 'JaxArray created outside of the transformation functions ' + 'cannot be updated. You should mark it as a Variable instead.') + return True + else: + return False def _check_input_array(array): @@ -61,7 +94,7 @@ def _check_input_array(array): class JaxArray(object): """Multiple-dimensional array in JAX backend. """ - __slots__ = ("_value", "_outside_global_jit") + __slots__ = ("_value", "_transform_context") def __init__(self, value, dtype=None): # array value @@ -73,7 +106,7 @@ def __init__(self, value, dtype=None): value = jnp.asarray(value, dtype=dtype) self._value = value # jit mode - self._outside_global_jit = False if _global_jit_mode else True + self._transform_context = get_context() @property def value(self): @@ -86,7 +119,7 @@ def value(self, value): def update(self, value): """Update the value of this JaxArray. """ - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) if isinstance(value, JaxArray): value = value.value @@ -189,7 +222,7 @@ def __getitem__(self, index): return self.value[index] def __setitem__(self, index, value): - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) # value is JaxArray @@ -260,7 +293,7 @@ def __radd__(self, oc): def __iadd__(self, oc): # a += b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value += _check_input_array(oc) return self @@ -273,7 +306,7 @@ def __rsub__(self, oc): def __isub__(self, oc): # a -= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value - _check_input_array(oc) return self @@ -286,7 +319,7 @@ def __rmul__(self, oc): def __imul__(self, oc): # a *= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value * _check_input_array(oc) return self @@ -302,7 +335,7 @@ def __rtruediv__(self, oc): def __itruediv__(self, oc): # a /= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value / _check_input_array(oc) return self @@ -315,7 +348,7 @@ def __rfloordiv__(self, oc): def __ifloordiv__(self, oc): # a //= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value // _check_input_array(oc) return self @@ -334,7 +367,7 @@ def __rmod__(self, oc): def __imod__(self, oc): # a %= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value % _check_input_array(oc) return self @@ -347,7 +380,7 @@ def __rpow__(self, oc): def __ipow__(self, oc): # a **= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value ** _check_input_array(oc) return self @@ -360,7 +393,7 @@ def __rmatmul__(self, oc): def __imatmul__(self, oc): # a @= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value @ _check_input_array(oc) return self @@ -373,7 +406,7 @@ def __rand__(self, oc): def __iand__(self, oc): # a &= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value & _check_input_array(oc) return self @@ -386,7 +419,7 @@ def __ror__(self, oc): def __ior__(self, oc): # a |= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value | _check_input_array(oc) return self @@ -399,7 +432,7 @@ def __rxor__(self, oc): def __ixor__(self, oc): # a ^= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value ^ _check_input_array(oc) return self @@ -412,7 +445,7 @@ def __rlshift__(self, oc): def __ilshift__(self, oc): # a <<= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value << _check_input_array(oc) return self @@ -425,7 +458,7 @@ def __rrshift__(self, oc): def __irshift__(self, oc): # a >>= b - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self._value >> _check_input_array(oc) return self @@ -547,7 +580,7 @@ def dot(self, b): def fill(self, value): """Fill the array with a scalar value.""" - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = jnp.ones_like(self.value) * value @@ -675,7 +708,7 @@ def sort(self, axis=-1, kind='quicksort', order=None): but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. """ - if self._outside_global_jit and _global_jit_mode: + if check_context(self._transform_context): raise MathError(msg) self._value = self.value.sort(axis=axis, kind=kind, order=order) @@ -1513,23 +1546,6 @@ def __init__(self, value_or_size, dtype=None, batch_axis: int = None): super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) -register_pytree_node(JaxArray, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: JaxArray(*flat_contents)) - -register_pytree_node(Variable, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: Variable(*flat_contents)) - -register_pytree_node(TrainVar, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: TrainVar(*flat_contents)) - -register_pytree_node(Parameter, - lambda t: ((t.value,), None), - lambda aux_data, flat_contents: Parameter(*flat_contents)) - - class VariableView(Variable): """A view of a Variable instance. @@ -1559,6 +1575,7 @@ class VariableView(Variable): Moreover, it's worthy to note that ``VariableView`` is not a PyTree. """ + def __init__(self, value: Variable, index): self.index = index if not isinstance(value, Variable): @@ -1700,3 +1717,25 @@ def value(self, value): f"while we got {value.dtype}.") self._value[self.index] = value.value if isinstance(value, JaxArray) else value + +def _jaxarray_unflatten(aux_data, flat_contents): + r = JaxArray(*flat_contents) + r._transform_context = aux_data[0] + return r + + +register_pytree_node(JaxArray, + lambda t: ((t.value,), (t._transform_context, )), + _jaxarray_unflatten) + +register_pytree_node(Variable, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: Variable(*flat_contents)) + +register_pytree_node(TrainVar, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: TrainVar(*flat_contents)) + +register_pytree_node(Parameter, + lambda t: ((t.value,), None), + lambda aux_data, flat_contents: Parameter(*flat_contents)) diff --git a/brainpy/math/jit.py b/brainpy/math/jit.py index 9e22d7dd0..01836de2c 100644 --- a/brainpy/math/jit.py +++ b/brainpy/math/jit.py @@ -15,12 +15,13 @@ try: from jax.errors import UnexpectedTracerError, ConcretizationTypeError except ImportError: - from jax.core import UnexpectedTracerError + from jax.core import UnexpectedTracerError, ConcretizationTypeError from brainpy import errors from brainpy.base.base import Base +from brainpy.base.naming import get_unique_name from brainpy.base.collector import TensorCollector -from brainpy.math.jaxarray import JaxArray, turn_on_global_jit, turn_off_global_jit +from brainpy.math.jaxarray import JaxArray, add_context, del_context from brainpy.tools.codes import change_func_name __all__ = [ @@ -38,22 +39,24 @@ def jitted_func(variable_data, *args, **kwargs): changes = vars.dict() return out, changes + name = get_unique_name('_brainpy_object_oriented_jit_') + def call(*args, **kwargs): variable_data = vars.dict() try: - turn_on_global_jit() + add_context(name) out, changes = jitted_func(variable_data, *args, **kwargs) - turn_off_global_jit() + del_context(name) except UnexpectedTracerError as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise errors.JaxTracerError(variables=vars) from e except ConcretizationTypeError as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise errors.ConcretizationTypeError() from e except Exception as e: - turn_off_global_jit() + del_context(name) for key, v in vars.items(): v._value = variable_data[key] raise e for key, v in vars.items(): v._value = changes[key] @@ -64,11 +67,12 @@ def call(*args, **kwargs): def _make_jit_without_vars(func, static_argnames=None, device=None, f_name=None): jit_f = jax.jit(func, static_argnames=static_argnames, device=device) + name = get_unique_name('_jax_functional_jit_') def call(*args, **kwargs): - turn_on_global_jit() + add_context(name) r = jit_f(*args, **kwargs) - turn_off_global_jit() + del_context(name) return r return change_func_name(name=f_name, f=call) if f_name else call diff --git a/brainpy/math/tests/test_transformation_context.py b/brainpy/math/tests/test_transformation_context.py new file mode 100644 index 000000000..26ca0c862 --- /dev/null +++ b/brainpy/math/tests/test_transformation_context.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + + +import unittest + +import brainpy as bp +import brainpy.math as bm + + +class TestJIT(unittest.TestCase): + def test1(self): + @bm.jit + def f1(a): + a[:] = 1. + return a + + a = bm.zeros(10) + with self.assertRaises(bp.errors.MathError): + print(f1(a)) + + def test2(self): + @bm.jit + def f1(a): + b = a + 1 + + @bm.jit + def f2(x): + x.value = 1. + return x + + return f2(b) + + with self.assertRaises(bp.errors.MathError): + print(f1(bm.ones(2))) + + def test3(self): + @bm.jit + def f1(a): + return a + 1 + + @bm.jit + def f2(b): + b[:] = 1. + return b + + with self.assertRaises(bp.errors.MathError): + print(f2(f1(bm.ones(2)))) From a7e505302cb3b2b7e70b63e563f445f77d41a2f7 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 16 Oct 2022 23:09:43 +0800 Subject: [PATCH 2/2] Update JaxArray transformation error message --- brainpy/math/jaxarray.py | 7 ++++--- brainpy/math/tests/test_transformation_context.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 640376c8a..3ab28adcd 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -65,7 +65,7 @@ def check_context(arr_context): if len(_jax_transformation_context_) > 0: raise MathError(f'JaxArray created outside of the transformation functions ' f'({_jax_transformation_context_[-1]}) cannot be updated. ' - f'You should mark it as a Variable instead.') + f'You should mark it as a brainpy.math.Variable instead.') return True else: return False @@ -75,8 +75,9 @@ def check_context(arr_context): raise MathError(f'JaxArray context "{arr_context}" differs from the JAX ' f'transformation context "{_jax_transformation_context_[-1]}"' '\n\n' - 'JaxArray created outside of the transformation functions ' - 'cannot be updated. You should mark it as a Variable instead.') + 'JaxArray created in one transformation function ' + 'cannot be updated another transformation function. ' + 'You should mark it as a brainpy.math.Variable instead.') return True else: return False diff --git a/brainpy/math/tests/test_transformation_context.py b/brainpy/math/tests/test_transformation_context.py index 26ca0c862..2732afa83 100644 --- a/brainpy/math/tests/test_transformation_context.py +++ b/brainpy/math/tests/test_transformation_context.py @@ -45,3 +45,12 @@ def f2(b): with self.assertRaises(bp.errors.MathError): print(f2(f1(bm.ones(2)))) + + def test4(self): + @bm.jit + def f2(a): + b = bm.ones(1) + b += 10 + return a + b + + print(f2(bm.ones(1)))