You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for creating a lot of amazing libraries on JAX ecosystem!
As an exploration, I am trying out equinox and wanna annotate a loss function by jaxtyping
While the runtime type check manages to warn the wrong tensor shape of my input, it wrongly flags the argument that causes the error.
Here is a snippet of my code to highlight the issue (refer to my colab notebook below for the complete version):
in second thought, it is not a bug
if x has a size (50, in_size), then batch is matched to 50.
it's just meaning that x and y has unmatching batch size, so its correct to say either x is wrong or y is wrong
Regarding your decorators, by the way: it's slightly better to decorate it as
@jax.jit@jax.grad@jaxtyped@typecheckeddeffoo(...)
So that the overhead of runtime type checking only happens once, when the function is jit-compiled. (At the moment you're adding the type-checking after the JIT so it will happen every time you call the function.)
Thanks for creating a lot of amazing libraries on JAX ecosystem!
As an exploration, I am trying out
equinox
and wanna annotate a loss function byjaxtyping
While the runtime type check manages to warn the wrong tensor shape of my input, it wrongly flags the argument that causes the error.
Here is a snippet of my code to highlight the issue (refer to my colab notebook below for the complete version):
The error message should raise that argument
x
is wrong, but here is the message I received:Package
JAX version: 0.3.17
eqx version: 0.7.1
jaxtyping: 0.2.0
typeguard as runtime type checking
Simple colab notebook for reproducing the bug
https://colab.research.google.com/drive/10rSs6IhNmU7lvhPxxlU2ext8JkgBLULI?usp=sharing
(As a side note, I am not sure if I am in good practice to annotate the model by its class, feel free to comment)
The text was updated successfully, but these errors were encountered: