Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug when using torch.nn.Module and @jaxtyping + @typechecker #71

Closed
PaulScemama opened this issue Mar 27, 2023 · 5 comments
Closed

Comments

@PaulScemama
Copy link

I could be missing something, but I think there may be something problematic happening when decorating functions with @jaxtyping and @typechecker.

Minimal code example showing that when using the decorators @jaxtyping and @typechecker where typechecker is beartype, forward() is no longer a types.MethodType. This causes bugs when trying to train a torch.compile(model) because it needs to assert model.forward is of type types.MethodType.

import torch
import types
import torch.nn as nn
from jaxtyping import jaxtyped, Float
from beartype import beartype as typechecker

class model_without_jaxtyping(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x
    
model_without_jaxtyping_instance = model_without_jaxtyping()
print(type(model_without_jaxtyping_instance.forward) == types.MethodType)
# 'True'

class model_with_jaxtyping(nn.Module):
    def __init__(self):
        super().__init__()

    @jaxtyped
    @typechecker
    def forward(self, x: Float[torch.Tensor, "..."]):
        return x
    
model_with_jaxtyping_instance = model_with_jaxtyping()
print(type(model_with_jaxtyping_instance.forward) == types.MethodType)
# 'False'
@patrick-kidger
Copy link
Owner

So jaxtyped actually happens to return a class instance, rather than a function. That's the immediate reason for what we're observing here.

For your downstream use-case with torch.compile, can you check whether:

  • Whether

    @typechecker
    def forward(...)

    works?

  • And whether

    def wrapper(fn):
        @functools.wraps(fn)
        def wrapped(*args, **kwargs):
            return jaxtyped(typechecker(fn))(*args, **kwargs)
        return wrapped
    
    class Model(nn.Module):
        @wrapper
        def forward(...): ...

    works?

That'll give some insight on how we might fix this.

@PaulScemama
Copy link
Author

They both work.

@patrick-kidger
Copy link
Owner

Great. Can you try install from the no-more-jaxtyped branch (c.f. #72) and see if this works for your downstream use-case? If so then I'll merge that PR.

@st--
Copy link

st-- commented Apr 17, 2023

I had just come across what is probably the same issue in different clothes when trying to subclass tensorflow_probability's Distribution type - that also got resolved by #72!

@patrick-kidger
Copy link
Owner

Closing as I think this issue is resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants