diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4dc23b352..f59927666 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -79,6 +79,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install + pip install jax==0.4.30 + pip install jaxlib==0.4.30 - name: Test with pytest run: | cd brainpy diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 7a0fa57af..514706419 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -4,6 +4,7 @@ import brainpy.math as bm + class Test_Activation(parameterized.TestCase): @parameterized.product( diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f523622..af38a355f 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- +import platform import jax.numpy as jnp +import pytest from absl.testing import absltest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +if platform.system() == 'Darwin': + pytest.skip('skip Mac OS', allow_module_level=True) + class TestConv(parameterized.TestCase): def test_Conv2D_img(self): diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py index 342093ea2..54a3c9be2 100644 --- a/brainpy/_src/math/op_register/ad_support.py +++ b/brainpy/_src/math/op_register/ad_support.py @@ -41,9 +41,14 @@ def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): r = tuple(rule(t, *primals, **params)) tangents_out.append(r) assert tree_util.tree_structure(r) == tree - return val_out, functools.reduce(_add_tangents, + try: + return val_out, functools.reduce(_add_tangents, tangents_out, - tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) + tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out)) + except: + return val_out, functools.reduce(_add_tangents, + tangents_out, + tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) def _add_tangents(xs, ys): diff --git a/brainpy/_src/math/pre_syn_post.py b/brainpy/_src/math/pre_syn_post.py index bc9785692..06976a35b 100644 --- a/brainpy/_src/math/pre_syn_post.py +++ b/brainpy/_src/math/pre_syn_post.py @@ -56,7 +56,7 @@ def pre2post_event_sum(events, for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values + post_val[post_ids[j]] += values When ``values`` is a vector (with the length of ``len(post_ids)``), this function is equivalent to @@ -70,7 +70,7 @@ def pre2post_event_sum(events, for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values[j] + post_val[post_ids[j]] += values[j] Parameters