-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ImportError: cannot import name 'Array' from 'jaxtyping' #70
Comments
You need to have JAX installed as well. jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc. |
Ran into this as well for pytorch. For me the solution as described on https://docs.kidger.site/jaxtyping/api/array/#array was to use from torch import Tensor
from jaxtyping import Float32
def f(x: Float32[Tensor, "dim1 dim2"]) -> Float32[Tensor, "dim1 dim2"]:
return x |
@MilesCranmer thanks. @patrick-kidger the |
Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct. |
When authoring ML runtime agnostic tooling, such as a dataset, what is the correct array type to use? I cannot assume I have neither torch, jax nor tensorflow. I currently assume at least numpy and do the following, but it might not work for other use cases: from typing import Union, TYPE_CHECKING
from jaxtyping import Float, Bool
if TYPE_CHECKING:
from torch import Tensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
# TODO: tensorflow
Array = Union[Tensor, ndarray, JaxArray]
else:
from numpy import ndarray as Array |
Probably something like this: from typing import Union, TYPE_CHECKING
if TYPE_CHECKING:
from torch import TorchTensor
from numpy import ndarray
from jaxtyping import Array as JaxArray
from tensorflow import TfTensor
Array = Union[TorchTensor, ndarray, JaxArray, TfTensor]
else:
arrays = []
try:
from torch import Tensor as TorchTensor
except Exception:
pass
else:
arrays.append(TorchTensor)
try:
from numpy import ndarray
except Exception:
pass
else:
arrays.append(ndarray)
try:
from jaxtyping import Array as JaxArray
except Exception:
pass
else:
arrays.append(JaxArray)
try:
from tensorflow import Tensor as TfTensor
except Exception:
pass
else:
arrays.append(TfTensor)
Array = Union[tuple(arrays)] |
Neat! I'd go for And it doesn't exactly roll of the tongue. Any chance this could be added to |
Actually, the more general As for adding the above to jaxtyping. jaxtyping tries to essentially be backend-agnostic. In particular, I don't think I'd want to hardcode that it'll look for specifically torch+numpy+tensorflow+jax and nothing else. As such I think something like this is out-of-scope for jaxtyping, I'm afraid. |
returns
The text was updated successfully, but these errors were encountered: