You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Actually, I'm not sure if we should call this a 'bug', since this is an internal API. Is there an issue you're seeing in a public API?
This is one of the main reasons partial_eval_jaxpr_stateful exists. The two may be merged someday, but for the moment this is essentially expected behavior for trace_to_jaxpr_nounits.
Can you say more about what problem you're trying to solve?
Description
prints
I expected jaxpr_unkown to contain the effects.
possible workaround: wrap unknown args using a dynamic=False jaxpr trace. I also had to change how pjit_p is handled.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.12.2 (main, Feb 6 2024, 20:19:44) [GCC 13.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
The text was updated successfully, but these errors were encountered: