>>> import flax
AttributeError Traceback (most recent call last)
[<ipython-input-2-f5b294e0faf0>](https://localhost:8080/#) in <cell line: 2>()
1 # Verify we can import everything.
----> 2 import flax
3 from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,
4 orbax_utils, prefetch_iterator, train_state, common_utils)
5 from flax.metrics import tensorboard
2 frames
[/usr/local/lib/python3.10/dist-packages/flax/core/frozen_dict.py](https://localhost:8080/#) in <module>
48
49
---> 50 @jax.tree_util.register_pytree_with_keys_class
51 class FrozenDict(Mapping[K, V]):
52 """An immutable variant of the Python dict."""
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
!pip freeze | egrep 'jax|flax'
flax==0.7.0
jax==0.3.25
jaxlib==0.3.25
On a fresh Colab TPU runtime, we have:
Note the versions: