Skip to content

Conversation

@chaoming0625
Copy link
Member

This PR supports JAX transformation contexts for JaxArray. This context will be used for debugging. Notablely, if some JaxArray are updated in the JAX transformations, it will cause errors:

>>> @bm.jit
>>> def f1(a):
>>>    a[:] = 1.
>>>    return a
>>> 
>>> a = bm.zeros(10)
>>> f1(a)
brainpy.errors.MathError: JaxArray created outside of the transformation functions (_jax_functional_jit_0) cannot be updated. You should mark it as a Variable instead.
>>> @bm.jit
>>> def f1(a):
>>>   b = a + 1
>>> 
>>>   @bm.jit
>>>   def f2(x):
>>>     x.value = 1.
>>>     return x
>>> 
>>> return f2(b)
>>> 
>>> f1(bm.ones(2))
brainpy.errors.MathError: JaxArray context "_jax_functional_jit_0" differs from the JAX transformation context "_jax_functional_jit_1"
JaxArray created outside of the transformation functions cannot be updated. You should mark it as a Variable instead.
>>> @bm.jit
>>> def f1(a):
>>>   return a + 1
>>> 
>>> @bm.jit
>>> def f2(b):
>>>   b[:] = 1.
>>>   return b
>>>
>>> f2(f1(bm.ones(2)))
brainpy.errors.MathError: JaxArray context "_jax_functional_jit_0" differs from the JAX transformation context "_jax_functional_jit_1"

JaxArray created outside of the transformation functions cannot be updated. You should mark it as a Variable instead.

@chaoming0625 chaoming0625 requested a review from ztqakita October 16, 2022 15:06
Copy link
Collaborator

@ztqakita ztqakita left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good! It has passed the tests.

@chaoming0625 chaoming0625 merged commit 7693f14 into brainpy:master Oct 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants