Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: cannot import name 'Array' from 'jaxtyping' #70

Open
danbider opened this issue Mar 20, 2023 · 8 comments
Open

ImportError: cannot import name 'Array' from 'jaxtyping' #70

danbider opened this issue Mar 20, 2023 · 8 comments
Labels
question User queries

Comments

@danbider
Copy link

import jaxtyping
print(jaxtyping.__version__) # returns 0.2.14
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped

returns

ImportError: cannot import name 'Array' from 'jaxtyping' (/home/jovyan/conda/lib/python3.8/site-packages/jaxtyping/__init__.py)
@patrick-kidger
Copy link
Owner

You need to have JAX installed as well.

jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.

@MilesCranmer
Copy link

MilesCranmer commented Jun 20, 2023

Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use torch.Tensor in place of jaxtyping.Array, like so:

from torch import Tensor
from jaxtyping import Float32

def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
    return x

@danbider
Copy link
Author

@MilesCranmer thanks. @patrick-kidger the jax requirement was relaxed? I don't see it anymore in pyproject.toml.
If so i'll modify my code according to the syntax suggested by Miles

@patrick-kidger
Copy link
Owner

Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.

@patrick-kidger patrick-kidger added the question User queries label Aug 21, 2023
@pbsds
Copy link

pbsds commented Oct 16, 2023

When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases:

from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
    from torch import Tensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    # TODO: tensorflow
    Array = Union[Tensor, ndarray, JaxArray]
else:
    from numpy import ndarray as Array

@patrick-kidger
Copy link
Owner

Probably something like this:

from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
    from torch import TorchTensor
    from numpy import ndarray
    from jaxtyping import Array as JaxArray
    from tensorflow import TfTensor
    Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
    arrays = []
    try:
        from torch import Tensor as TorchTensor
    except Exception:
        pass
    else:
        arrays.append(TorchTensor)
    try:
        from numpy import ndarray
    except Exception:
        pass
    else:
        arrays.append(ndarray)
    try:
        from jaxtyping import Array as JaxArray
    except Exception:
        pass
    else:
        arrays.append(JaxArray)
    try:
        from tensorflow import Tensor as TfTensor
    except Exception:
        pass
    else:
        arrays.append(TfTensor)
    Array = Union[tuple(arrays)]

@pbsds
Copy link

pbsds commented Oct 16, 2023

Neat! I'd go for except (ModuleNotFoundError, ImportError): 😉

And it doesn't exactly roll of the tongue. Any chance this could be added to jaxtyping?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 16, 2023

Actually, the more general Exception is deliberate. There are cases when try-importing a module can result in other issues too, c.f. https://github.com/google/jaxtyping/blob/7a84b27da9e57c425ce4e6333121c3cdf2e07302/jaxtyping/_array_types.py#L33-L39

As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants