Skip to content

TPU Colab fails with AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class' #3239

@andsteing

Description

@andsteing

On a fresh Colab TPU runtime, we have:

>>> import flax
AttributeError                            Traceback (most recent call last)
[<ipython-input-2-f5b294e0faf0>](https://localhost:8080/#) in <cell line: 2>()
      1 # Verify we can import everything.
----> 2 import flax
      3 from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,
      4                            orbax_utils, prefetch_iterator, train_state, common_utils)
      5 from flax.metrics import tensorboard

2 frames
[/usr/local/lib/python3.10/dist-packages/flax/core/frozen_dict.py](https://localhost:8080/#) in <module>
     48 
     49 
---> 50 @jax.tree_util.register_pytree_with_keys_class
     51 class FrozenDict(Mapping[K, V]):
     52   """An immutable variant of the Python dict."""

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'

Note the versions:

!pip freeze | egrep 'jax|flax'
flax==0.7.0
jax==0.3.25
jaxlib==0.3.25

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions