Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxtyping does not play nicely with inheritance #43

Closed
murphyk opened this issue Nov 13, 2022 · 2 comments
Closed

jaxtyping does not play nicely with inheritance #43

murphyk opened this issue Nov 13, 2022 · 2 comments

Comments

@murphyk
Copy link
Contributor

murphyk commented Nov 13, 2022

The code below fails, whether I use typeguard or beartype,
and gives the error

    class LinearGaussianConjugateSSM(LinearGaussianSSM):
TypeError: __init__() takes 2 positional arguments but 4 were given

However if I omit the initial @jaxtyped it works.

from jaxtyping import jaxtyped
#from beartype import beartype as typechecker
from typeguard import typechecked as typechecker


@jaxtyped ### OMIT
@typechecker
class LinearGaussianSSM():
    def __init__(
        self,
        state_dim: int,
        emission_dim: int,
        input_dim: int=0,
        has_dynamics_bias: bool=True,
        has_emissions_bias: bool=True
    ):
        self.state_dim = state_dim
        self.emission_dim = emission_dim
        self.input_dim = input_dim
        self.has_dynamics_bias = has_dynamics_bias
        self.has_emissions_bias = has_emissions_bias


  
@typechecker
class LinearGaussianConjugateSSM(LinearGaussianSSM):
    def __init__(self,
                 state_dim,
                 emission_dim,
                 input_dim=0,
                 has_dynamics_bias=True,
                 has_emissions_bias=True):
        super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
        has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)

@slinderman

@patrick-kidger
Copy link
Owner

Right, this is because jaxtyped returns a new object, different to the underlying type.

(It was only designed to work on functions.)

We could probably extend it to classes in the following way:

def jaxtyped(fn):
    if inspect.isclass(fn):
        init = jaxtyped(fn.__init__)
        fn.__init__ = init
        return fn
    else:
        ... # existing implementation

Give that a try and see if it works? If it does then I would be happy to accept the above as a PR.

@murphyk
Copy link
Contributor Author

murphyk commented Nov 13, 2022

Yes, that works!


from jaxtyping import  jaxtyped
#from beartype import beartype as typechecker
from typeguard import typechecked as typechecker
import inspect

def jaxtyped2(fn):
    if inspect.isclass(fn):
        init = jaxtyped(fn.__init__)
        fn.__init__ = init
        return fn
    else:
        jaxtyped(fn) # existing implementation
       

@jaxtyped2
@typechecker
class LinearGaussianSSM():
    def __init__(
        self,
        state_dim: int,
        emission_dim: int,
        input_dim: int=0,
        has_dynamics_bias: bool=True,
        has_emissions_bias: bool=True
    ):
        self.state_dim = state_dim
        self.emission_dim = emission_dim
        self.input_dim = input_dim
        self.has_dynamics_bias = has_dynamics_bias
        self.has_emissions_bias = has_emissions_bias



    
@typechecker
class LinearGaussianConjugateSSM(LinearGaussianSSM):
    def __init__(self,
                 state_dim,
                 emission_dim,
                 input_dim=0,
                 has_dynamics_bias=True,
                 has_emissions_bias=True):
        super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
        has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants