Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior with eqx.Module #100

Closed
knyazer opened this issue Aug 13, 2023 · 4 comments
Closed

Strange behavior with eqx.Module #100

knyazer opened this issue Aug 13, 2023 · 4 comments

Comments

@knyazer
Copy link
Contributor

knyazer commented Aug 13, 2023

I encountered a strange behavior when trying to install_import_hooks for a class inherited from eqx.Module: there appears a warning that states about typechecker inability to wrap the init method, which to me seems strange, since eqx.Module is (promised to be) a dataclass.

Warning message:

/home/user/.local/lib/python3.10/site-packages/jaxtyping/_decorator.py:139: InstrumentationWarning: @typechecked only supports instrumenting functions wrapped with @classmethod, @staticmethod or @property -- not typechecking source.A.__init__
  init = jaxtyped(typechecker(kls.__init__))

Minimal example:

##### main.py #####
from jaxtyping import install_import_hook
with install_import_hook("source", typechecker="typeguard.typechecked"):
    from source import *

##### source.py #####
import equinox as eqx

class A(eqx.Module):
    x: int
    def __init__(self, y: int):
        self.x = y

I reckon there is a high chance I am simply not able to use the method correctly, and if so, I would appreciate any general feedback, or a short example of how to do it correctly. The problem seems to be directly related to issue 46.

@patrick-kidger
Copy link
Owner

What version of typeguard do you have installed?

@knyazer
Copy link
Contributor Author

knyazer commented Aug 14, 2023

pip3 freeze | grep -E "typeguard|equinox|jax|jaxtyping"

gives

equinox==0.10.11
jax==0.4.14
jaxlib==0.4.14+cuda12.cudnn89
jaxopt==0.7
jaxtyping==0.2.20
typeguard==4.1.0

I believe 4.1.0 version of typeguard is the latest pypi version available. I use Python 3.10.4, OS Ubuntu 22.04.

@patrick-kidger
Copy link
Owner

Thanks! Okay, so I think this is a typeguard bug. If you write A("hi"), then you'll see the type error being raised correctly, despite the warning. E.g. A("hi") will raise an error.

For now I'd suggest either ignoring the warning, or switching to typeguard version 2.*, or using beartype instead.

@knyazer
Copy link
Contributor Author

knyazer commented Aug 14, 2023

Thanks for debugging the problem for me. Yes, even in my original code this warning seems to be meaningless, since both jax-types and python-types mismatches are raised correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants