Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion brainpy/math/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from jax.util import safe_map

from brainpy import errors
from brainpy.math.jaxarray import JaxArray
from brainpy.base.naming import get_unique_name
from brainpy.math.jaxarray import JaxArray, add_context, del_context


__all__ = [
'grad', # gradient of scalar function
Expand All @@ -28,20 +30,26 @@

def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars,
argnums, return_value, has_aux):
name = get_unique_name('_brainpy_object_oriented_grad_')

# outputs
def call_func(*args, **kwargs):
old_grad_vs = [v.value for v in grad_vars]
old_dyn_vs = [v.value for v in dyn_vars]
try:
add_context(name)
grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs,
old_dyn_vs,
*args,
**kwargs)
del_context(name)
except UnexpectedTracerError as e:
del_context(name)
for v, d in zip(grad_vars, old_grad_vs): v._value = d
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e
except Exception as e:
del_context(name)
for v, d in zip(grad_vars, old_grad_vs): v._value = d
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
raise e
Expand Down
84 changes: 50 additions & 34 deletions brainpy/math/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from jax.core import UnexpectedTracerError

from brainpy import errors
from brainpy.base.naming import get_unique_name
from brainpy.math.jaxarray import (JaxArray, Variable,
turn_on_global_jit,
turn_off_global_jit)
add_context,
del_context)
from brainpy.math.numpy_ops import as_device_array

__all__ = [
Expand Down Expand Up @@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
out_vars=out_vars,
has_return=has_return)

name = get_unique_name('_brainpy_object_oriented_make_loop_')

# functions
if has_return:
def call(xs=None, length=None):
init_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, (out_values, results) = lax.scan(
f=fun2scan, init=init_values, xs=xs, length=length)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
for v, d in zip(dyn_vars, dyn_values): v._value = d
Expand All @@ -178,15 +181,15 @@ def call(xs=None, length=None):
def call(xs):
init_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, out_values = lax.scan(f=fun2scan, init=init_values, xs=xs)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
Expand Down Expand Up @@ -255,20 +258,22 @@ def _cond_fun(op):
for v, d in zip(dyn_vars, dyn_values): v._value = d
return as_device_array(cond_fun(static_values))

name = get_unique_name('_brainpy_object_oriented_make_while_')

def call(x=None):
dyn_init = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, _ = lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_init, x))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
Expand Down Expand Up @@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
if not isinstance(v, JaxArray):
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')

name = get_unique_name('_brainpy_object_oriented_make_cond_')

if len(dyn_vars) > 0:
def _true_fun(op):
dyn_vals, static_vals = op
Expand All @@ -348,25 +355,25 @@ def _false_fun(op):
def call(pred, x=None):
old_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, res = lax.cond(pred, _true_fun, _false_fun, (old_values, x))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
return res

else:
def call(pred, x=None):
turn_on_global_jit()
add_context(name)
res = lax.cond(pred, true_fun, false_fun, x)
turn_off_global_jit()
del_context(name)
return res

return call
Expand Down Expand Up @@ -445,6 +452,8 @@ def cond(
if not isinstance(v, Variable):
raise ValueError(f'Only support {Variable.__name__}, but got {type(v)}')

name = get_unique_name('_brainpy_object_oriented_cond_')

# calling the model
if len(dyn_vars) > 0:
def _true_fun(op):
Expand All @@ -463,25 +472,25 @@ def _false_fun(op):

old_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, res = lax.cond(pred=pred,
true_fun=_true_fun,
false_fun=_false_fun,
operand=(old_values, operands))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
else:
turn_on_global_jit()
add_context(name)
res = lax.cond(pred, true_fun, false_fun, operands)
turn_off_global_jit()
del_context(name)
return res


Expand Down Expand Up @@ -591,7 +600,11 @@ def ifelse(
if show_code: print(codes)
exec(compile(codes.strip(), '', 'exec'), code_scope)
f = code_scope['f']
return f(operands)
name = get_unique_name('_brainpy_object_oriented_ifelse_')
add_context(name)
r = f(operands)
del_context(name)
return r


def for_loop(body_fun: Callable,
Expand Down Expand Up @@ -694,22 +707,24 @@ def fun2scan(dyn_vals, x):
results = body_fun(*x)
return [v.value for v in dyn_vars], results

name = get_unique_name('_brainpy_object_oriented_for_loop_')

# functions
init_vals = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_vals, out_vals = lax.scan(f=fun2scan,
init=init_vals,
xs=operands,
reverse=reverse,
unroll=unroll)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_vals): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_vals): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_vals): v._value = d
Expand Down Expand Up @@ -797,19 +812,20 @@ def _cond_fun(op):
r = cond_fun(*static_vals)
return r if isinstance(r, JaxArray) else r

name = get_unique_name('_brainpy_object_oriented_while_loop_')
dyn_init = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, out = lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_init, operands))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
Expand Down
Loading