Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attribute error on import: module 'ml_dtypes' has no attribute 'float8_e4mb11' #17693

Closed
priyakasimbeg opened this issue Sep 21, 2023 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@priyakasimbeg
Copy link

Description

Jax requires ml_dtypes >= 0.2.0. Now, ml_dytpes 0.3.0 is installed by default which raises the following error:

Traceback (most recent call last):
  File "/home/runner/work/algorithmic-efficiency/algorithmic-efficiency/tests/reference_algorithm_tests.py", line 32, in <module>
    import flax
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/flax/__init__.py", line 18, in <module>
    from .configurations import (
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/flax/configurations.py", line 25, in <module>
    from jax import config as jax_config
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/_src/config.py", line 24, in <module>
    from jax._src import lib
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 92, in <module>
    import jaxlib.xla_client as xla_client
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jaxlib/xla_client.py", line 225, in <module>
    float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'

What jax/jaxlib version are you using?

jax 0.4.10, jaxlib=0.4.10+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

No response

NVIDIA GPU info

No response

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 21, 2023

Hi - thanks for the report. float8_e4m3b11 was deprecated in version 0.2.0 and removed in version 0.3.0. You should have been seeing a DeprecationWarning for the last several months when using version 0.2.0. You can use float8_e4m3b11fnuz instead. See the CHANGELOG for details.

@hawkinsp
Copy link
Collaborator

Note this error came from JAX itself, so the fix is either to upgrade JAX or downgrade ml_dtypes.

@peregilk
Copy link

I can confirm the same error when installing jax-metal 0.0.4 in a new venv with Python 3.10.9. I can also confirm that pip install ml_dtypes==0.2.0 is fixing the issue. I think jax-metal 0.0.4 is only compatible with jax v0.4.11.

@tdgfrost
Copy link

Identical experience to @peregilk, and downgrading ml_dtypes to 0.2.0 fixed the issue (thanks!). I'm also reliant on jax-metal - I'll leave a mention on the Apple developer forum, as their dependencies should reflect this change.

@shuhand0
Copy link
Collaborator

Thanks for reporting it. Pls downgrade ml-dtypes==0.2.0 for now. We will add it to setup instructions and fix the dependency in the next release.

@apivovarov
Copy link
Contributor

Sorry for off topic. What b11 means in the type name?

I found that

  • E4M3FNUZ: 1 bit for the sign, 4 bits for the exponents, 3 bits for the mantissa, only nan values and no infinite values (FN), no negative zero (UZ)
    But what is
  • e4m3b11fnuz - ???

@jakevdp
Copy link
Collaborator

jakevdp commented May 8, 2024

What b11 means in the type name?

It's described in the README of https://github.com/jax-ml/ml_dtypes – it refers to the exponent bias. For example, e4m3b11fnuz has an exponent bias of 11, whereas e4m3fnuz has an exponent bias of 8. The bias shifts the representable range of values by a factor of 2 ** bias:

>>> import jax.numpy as jnp
>>> info=jnp.finfo(jnp.float8_e4m3b11fnuz); (info.min, info.max)
(-30, 30)
>>> info=jnp.finfo(jnp.float8_e4m3fnuz); (info.min, info.max)
(-240, 240)

@jakevdp
Copy link
Collaborator

jakevdp commented May 8, 2024

I'm going to close this, because I think this has been resolved. If you run into this error, you can fix it by either upgrading JAX to a more recent version (v0.4.12 or newer should work), or downgrading ml_dtypes to v0.3.0 or older.

@jakevdp jakevdp closed this as completed May 8, 2024
@jakevdp jakevdp self-assigned this May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants