Skip to content

Commit

Permalink
[math] fix brainpy.math.scan (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 29, 2024
1 parent 7e8dd81 commit 16cf74a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
4 changes: 3 additions & 1 deletion brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,8 @@ def scan(
):
"""``scan`` control flow with :py:class:`~.Variable`.
Similar to ``jax.lax.scan``.
.. versionadded:: 2.4.7
All returns in body function will be gathered
Expand Down Expand Up @@ -999,7 +1001,7 @@ def scan(
rets = jax.eval_shape(transform, init, operands)
cache_stack(body_fun, dyn_vars) # cache
if current_transform_number():
return rets[1]
return rets[0][1], rets[1]
del rets

transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
Expand Down
28 changes: 20 additions & 8 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# -*- coding: utf-8 -*-
import sys
import tempfile
import unittest
from functools import partial

import jax
from jax import vmap

from absl.testing import parameterized
from jax._src import test_util as jtu
from jax import vmap

import brainpy as bp
import brainpy.math as bm
Expand Down Expand Up @@ -147,6 +144,25 @@ def f(carray, x):
expected = bm.expand_dims(expected, axis=-1)
self.assertTrue(bm.allclose(outs, expected))

def test2(self):
a = bm.Variable(1)

def f(carray, x):
carray += x
a.value += 1.
return carray, a

@bm.jit
def f_outer(carray, x):
carry, outs = bm.scan(f, carray, x, unroll=2)
return carry, outs

carry, outs = f_outer(bm.zeros(2), bm.arange(10))
self.assertTrue(bm.allclose(carry, 45.))
expected = bm.arange(1, 11).astype(outs.dtype)
expected = bm.expand_dims(expected, axis=-1)
self.assertTrue(bm.allclose(outs, expected))


class TestCond(unittest.TestCase):
def test1(self):
Expand Down Expand Up @@ -234,7 +250,6 @@ def F2(x):
self.assertTrue(bm.grad(F2)(9.0) == 18.)
self.assertTrue(bm.grad(F2)(11.0) == 1.)


def test_grad2(self):
def F3(x):
return bm.ifelse(conditions=(x >= 10, x >= 0),
Expand Down Expand Up @@ -519,6 +534,3 @@ def body(a):
file.seek(0)
out6 = file.read().strip()
self.assertTrue(out5 == out6)



1 change: 1 addition & 0 deletions docs/apis/brainpy.math.oo_transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Object-oriented Transformations
ifelse
for_loop
while_loop
scan
jit
cls_jit
to_object
Expand Down

0 comments on commit 16cf74a

Please sign in to comment.