Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: cannot import name 'partial' from 'jax.util' #60

Closed
chriscarmona opened this issue Oct 15, 2021 · 1 comment · Fixed by #61
Closed

ImportError: cannot import name 'partial' from 'jax.util' #60

chriscarmona opened this issue Oct 15, 2021 · 1 comment · Fixed by #61

Comments

@chriscarmona
Copy link

chriscarmona commented Oct 15, 2021

Hi,
I am getting n ImportError simply by importing the distrax module. I installed the latest version using pip on a clean virtual environment

pip install -U distrax

and then simply

>>> import distrax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/distrax/__init__.py", line 18, in <module>
    from distrax._src.bijectors.bijector import Bijector
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/distrax/_src/bijectors/bijector.py", line 26, in <module>
    tfb = tfp.bijectors
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 57, in __getattr__
    module = self._load()
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/lazy_loader.py", line 44, in _load
    module = importlib.import_module(self.__name__)
  File "/usr/local/Cellar/python@3.9/3.9.7_1/Frameworks/Python.framework/Versions/3.9/lib/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/__init__.py", line 45, in <module>
    from tensorflow_probability.substrates.jax import bijectors
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py", line 23, in <module>
    from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py", line 21, in <module>
    from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py", line 23, in <module>
    from tensorflow_probability.python.internal.backend.jax import compat
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py", line 21, in <module>
    from tensorflow_probability.python.internal.backend.jax import v1
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py", line 34, in <module>
    from tensorflow_probability.python.internal.backend.jax.random_generators import set_seed
  File "/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py", line 167, in <module>
    from jax.util import partial  # pylint: disable=g-import-not-at-top
ImportError: cannot import name 'partial' from 'jax.util' (/Users/chris/.virtualenvs/foo/lib/python3.9/site-packages/jax/util.py)

I guess there is conflict on the version of Jax being used by tensorflow_probability.

Best,
Chris

@chriscarmona
Copy link
Author

Thanks for the quick fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant