Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong error message when using jaxtyping with equinox #17

Closed
riven314 opened this issue Sep 13, 2022 · 2 comments
Closed

Wrong error message when using jaxtyping with equinox #17

riven314 opened this issue Sep 13, 2022 · 2 comments

Comments

@riven314
Copy link

riven314 commented Sep 13, 2022

Thanks for creating a lot of amazing libraries on JAX ecosystem!

As an exploration, I am trying out equinox and wanna annotate a loss function by jaxtyping
While the runtime type check manages to warn the wrong tensor shape of my input, it wrongly flags the argument that causes the error.

Here is a snippet of my code to highlight the issue (refer to my colab notebook below for the complete version):

@jaxtyped
@typechecked
@jax.jit
@jax.grad
def loss_fn(
    model: Linear,
    x: Float[Array, "batch in_dim"],
    y: Float[Array, "batch out_dim"]
) -> Linear:
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
wrong_x = jax.numpy.zeros((50, in_size))
grads = loss_fn(model, x=wrong_x, y=y)

The error message should raise that argument x is wrong, but here is the message I received:

TypeError                                 Traceback (most recent call last)
[<ipython-input-9-6d78a92b3a76>](https://localhost:8080/#) in <module>
      1 wrong_x = jax.numpy.zeros((50, in_size))
----> 2 grads = loss_fn(model, x=wrong_x, y=y)

2 frames
[/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py](https://localhost:8080/#) in check_argument_types(memo)
    873                 check_type(description, value, expected_type, memo)
    874             except TypeError as exc:  # suppress unnecessarily long tracebacks
--> 875                 raise TypeError(*exc.args) from None
    876 
    877     return True

TypeError: type of argument "y" must be jaxtyping.array_types.Float[ndarray, 'batch out_dim']; got jaxlib.xla_extension.DeviceArray instead

Package
JAX version: 0.3.17
eqx version: 0.7.1
jaxtyping: 0.2.0
typeguard as runtime type checking

Simple colab notebook for reproducing the bug
https://colab.research.google.com/drive/10rSs6IhNmU7lvhPxxlU2ext8JkgBLULI?usp=sharing

(As a side note, I am not sure if I am in good practice to annotate the model by its class, feel free to comment)

@riven314 riven314 changed the title Wrong error message when using with equinox Wrong error message when using jaxtyping with equinox Sep 13, 2022
@riven314
Copy link
Author

in second thought, it is not a bug
if x has a size (50, in_size), then batch is matched to 50.
it's just meaning that x and y has unmatching batch size, so its correct to say either x is wrong or y is wrong

@patrick-kidger
Copy link
Owner

Yep, exactly!

Regarding your decorators, by the way: it's slightly better to decorate it as

@jax.jit
@jax.grad
@jaxtyped
@typechecked
def foo(...)

So that the overhead of runtime type checking only happens once, when the function is jit-compiled. (At the moment you're adding the type-checking after the JIT so it will happen every time you call the function.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants