-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible bug when using torch.nn.Module
and @jaxtyping
+ @typechecker
#71
Comments
So For your downstream use-case with
That'll give some insight on how we might fix this. |
They both work. |
Great. Can you try install from the |
I had just come across what is probably the same issue in different clothes when trying to subclass tensorflow_probability's |
Closing as I think this issue is resolved. |
I could be missing something, but I think there may be something problematic happening when decorating functions with
@jaxtyping
and@typechecker
.Minimal code example showing that when using the decorators
@jaxtyping
and@typechecker
wheretypechecker
isbeartype
,forward()
is no longer atypes.MethodType
. This causes bugs when trying to train atorch.compile(model)
because it needs to assertmodel.forward
is of typetypes.MethodType
.The text was updated successfully, but these errors were encountered: