Skip to content

Commit

Permalink
disable_ jit support in brainpy.math.scan (#606)
Browse files Browse the repository at this point in the history
* [math] support disable jit in `brainpy.math.scan`

* [math] support brainpy array in `cond`, `ifelse`, `scan` transformations

* fix tests
  • Loading branch information
chaoming0625 committed Jan 30, 2024
1 parent 16cf74a commit bde7f8a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
55 changes: 38 additions & 17 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

import functools
import numbers
from typing import Union, Sequence, Any, Dict, Callable, Optional
Expand All @@ -12,7 +13,7 @@

from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.ndarray import (Array, _as_jax_array_)
from .base import BrainPyObject, ObjectTransform
from .naming import (
get_unique_name,
Expand Down Expand Up @@ -421,11 +422,27 @@ def call(pred, x=None):
return ControlObject(call, dyn_vars, repr_fun={'true_fun': true_fun, 'false_fun': false_fun})


@functools.cache
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))

return new_f


def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))

return new_f


def _check_f(f):
if callable(f):
return f
return _warp(f)
else:
return (lambda *args, **kwargs: f)
return _warp_data(f)


def _check_sequence(a):
Expand Down Expand Up @@ -557,7 +574,7 @@ def _if_else_return2(conditions, branches):
return branches[-1]


def all_equal(iterator):
def _all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
Expand Down Expand Up @@ -671,7 +688,7 @@ def ifelse(
else:
rets = [jax.eval_shape(branch, *operands) for branch in branches]
trees = [jax.tree_util.tree_structure(ret) for ret in rets]
if not all_equal(trees):
if not _all_equal(trees):
msg = 'All returns in branches should have the same tree structure. But we got:\n'
for tree in trees:
msg += f'- {tree}\n'
Expand Down Expand Up @@ -914,12 +931,14 @@ def fun2scan(carry, x):
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

if remat:
fun2scan = jax.checkpoint(fun2scan)

def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
Expand Down Expand Up @@ -991,19 +1010,21 @@ def scan(
bar = tqdm(total=num_total)

dyn_vars = get_stack_cache(body_fun)
if dyn_vars is None:
with new_transform('scan'):
with VariableStack() as dyn_vars:
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
if current_transform_number() > 1:
rets = transform(init, operands)
else:
rets = jax.eval_shape(transform, init, operands)
cache_stack(body_fun, dyn_vars) # cache
if current_transform_number():
return rets[0][1], rets[1]
del rets
if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('scan'):
with VariableStack() as dyn_vars:
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
if current_transform_number() > 1:
rets = transform(init, operands)
else:
rets = jax.eval_shape(transform, init, operands)
cache_stack(body_fun, dyn_vars) # cache
if current_transform_number():
return rets[0][1], rets[1]
del rets

dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
(dyn_vals, carry), out_vals = transform(init, operands)
for key in dyn_vars.keys():
Expand Down
28 changes: 28 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,34 @@ def f_outer(carray, x):
expected = bm.expand_dims(expected, axis=-1)
self.assertTrue(bm.allclose(outs, expected))

def test_disable_jit(self):
def cumsum(res, el):
res = res + el
print(res)
return res, res # ("carryover", "accumulated")

a = bm.array([1, 2, 3, 5, 7, 11, 13, 17]).value
result_init = 0
with jax.disable_jit():
final, result = jax.lax.scan(cumsum, result_init, a)

b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = bm.scan(cumsum, result_init, b)

bm.clear_buffer_memory()

def test_array_aware_of_bp_array(self):
def cumsum(res, el):
res = bm.asarray(res + el)
return res, res # ("carryover", "accumulated")

b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
result_init = 0
with jax.disable_jit():
final, result = bm.scan(cumsum, result_init, b)


class TestCond(unittest.TestCase):
def test1(self):
Expand Down

0 comments on commit bde7f8a

Please sign in to comment.