You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TLDR: when you disable jit with config.update('jax_disable_jit', True), apply_fn from nvt_nose_hoover throws the following error: AttributeError: 'numpy.ndarray' object has no attribute 'at'. At some point along the way, the jax array gets converted to a numpy array. More specifically, the call to dataclasses.astuple(state) in update_chain_mass_fn (called by chain_fns.update_mass(chain, _kT) in apply_fn) converts the jax numpy arrays to normal numpy arrays.
Full story:
I have been experimenting with nvt_nose_hoover recently. However, perhaps due to its chaining method, it takes a long time to compile. To accelerate my debugging process, I wanted to turn off jit globally (as an aside, it can be a bit confusing to the user that even without any usage of jit themselves, some things are being jitted by default. This makes using normal debugging tools like pdb confusing), so I set config.update('jax_disable_jit', True). This broke the code. I did a bit of digging and figured out why -- dataclasses.astuple converts jax arrays to normal numpy arrays. I suspect this is an issue with other integrators as well.
A simple way to reproduce this result is by running a nose-hoover test case from the rigid body tests with jit disabled. Perhaps it might be wise to run all tests with jit disabled to identify where else this could be a problem.
The text was updated successfully, but these errors were encountered:
TLDR: when you disable
jit
withconfig.update('jax_disable_jit', True)
,apply_fn
fromnvt_nose_hoover
throws the following error:AttributeError: 'numpy.ndarray' object has no attribute 'at'
. At some point along the way, the jax array gets converted to a numpy array. More specifically, the call todataclasses.astuple(state)
inupdate_chain_mass_fn
(called bychain_fns.update_mass(chain, _kT)
inapply_fn
) converts the jax numpy arrays to normal numpy arrays.Full story:
I have been experimenting with
nvt_nose_hoover
recently. However, perhaps due to its chaining method, it takes a long time to compile. To accelerate my debugging process, I wanted to turn offjit
globally (as an aside, it can be a bit confusing to the user that even without any usage ofjit
themselves, some things are being jitted by default. This makes using normal debugging tools likepdb
confusing), so I setconfig.update('jax_disable_jit', True)
. This broke the code. I did a bit of digging and figured out why --dataclasses.astuple
converts jax arrays to normal numpy arrays. I suspect this is an issue with other integrators as well.A simple way to reproduce this result is by running a nose-hoover test case from the rigid body tests with
jit
disabled. Perhaps it might be wise to run all tests withjit
disabled to identify where else this could be a problem.The text was updated successfully, but these errors were encountered: