Skip to content

Conversation

@runame
Copy link
Contributor

@runame runame commented Sep 20, 2023

Resolves #436.

@pomonam Could you test this setup on your university's slurm cluster?

@runame runame requested a review from a team as a code owner September 20, 2023 10:44
@github-actions
Copy link

github-actions bot commented Sep 20, 2023

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@runame
Copy link
Contributor Author

runame commented Sep 20, 2023

The changes to the Dockerfile seem to break the tests. Locally building a docker container from the modified Dockerfile worked just fine. @priyakasimbeg Any idea what's going on?

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Sep 20, 2023

The changes to the Dockerfile seem to break the tests. Locally building a docker container from the modified Dockerfile worked just fine. @priyakasimbeg Any idea what's going on?

@runame that is strange, these tests don't use Docker. It seems like the failure is coming from the flax import:

Traceback (most recent call last):
  File "/home/runner/work/algorithmic-efficiency/algorithmic-efficiency/tests/reference_algorithm_tests.py", line 32, in <module>
    import flax
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/flax/__init__.py", line 18, in <module>
    from .configurations import (
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/flax/configurations.py", line 25, in <module>
    from jax import config as jax_config
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/_src/config.py", line 24, in <module>
    from jax._src import lib
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 92, in <module>
    import jaxlib.xla_client as xla_client
  File "/opt/hostedtoolcache/Python/3.9.18/x64/lib/python3.9/site-packages/jaxlib/xla_client.py", line 225, in <module>
    float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'

These small tests are run on GH runners w CPU. I think maybe some "sub" package has not been pinned. I'm also seeing this in my other PRs. Let's not block this PR on those tests though.
We should just merge this once Juhan confirms the above procedure works for him too.

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Sep 21, 2023

Regarding failing tests; I filed jax-ml/jax#17693 w jax. I think we should just fix ml_dtypes to 0.2.0 in the meantime while they fix their dependencies.

Update: I pinned the ml_dtypes package in #514. If you sync this branch to dev, the tests should pass.

@runame
Copy link
Contributor Author

runame commented Sep 21, 2023

Update: I pinned the ml_dtypes package in #514. If you sync this branch to dev, the tests should pass.

Cool, thanks! The tests pass now.

@priyakasimbeg priyakasimbeg merged commit 20376e0 into mlcommons:dev Sep 28, 2023
@github-actions github-actions bot locked and limited conversation to collaborators Sep 28, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants