Skip to content

Conversation

init-22
Copy link
Contributor

@init-22 init-22 commented Sep 21, 2025

CIFAR workload error resolution: #889

@init-22 init-22 requested a review from a team as a code owner September 21, 2025 08:46
Copy link

github-actions bot commented Sep 21, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@davidtweedle
Copy link
Contributor

Hi Isaac,
Thank you for doing this. When I tried it in my docker container it said

File "/algorithmic-efficiency/algoperf/workloads/cifar/cifar_jax/workload.py", line 223, in _normalize_eval_metrics
    return jax.tree_map(lambda x: x / num_examples, total_metrics)
           ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).

When I change this line in workload.py to
jax.tree.map(....)
everything works fine. If it is just that I am using the wrong jax version, then you can just ignore this problem.
Anyways, thanks again - David

@init-22
Copy link
Contributor Author

init-22 commented Sep 25, 2025

Oh thanks for pointing this out, I've updated the code which uses jax.tree.map now, can you please try it again?

@davidtweedle
Copy link
Contributor

davidtweedle commented Sep 25, 2025 via email

@davidtweedle
Copy link
Contributor

davidtweedle commented Sep 25, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants