Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax callback change break tqdm updates #19

Closed
mdmould opened this issue May 14, 2024 · 1 comment
Closed

jax callback change break tqdm updates #19

mdmould opened this issue May 14, 2024 · 1 comment

Comments

@mdmould
Copy link
Contributor

mdmould commented May 14, 2024

As of jax 0.4.27, arguments to jax.debug.callback are now jax.Array rather than np.ndarray (see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-27-may-7-2024).

This break the tqdm updates with error TypeError: unsupported type for timedelta seconds component: jaxlib.xla_extension.ArrayImpl.

It looks like the fix would be to use jax.tree.map(np.asarray, args) on the callback args.

@zombie-einstein
Copy link
Collaborator

Closing as now fixed in latest release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants