Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hk.Module should be an abstract Callable #52

Closed
cgarciae opened this issue Jul 4, 2020 · 2 comments
Closed

hk.Module should be an abstract Callable #52

cgarciae opened this issue Jul 4, 2020 · 2 comments

Comments

@cgarciae
Copy link

cgarciae commented Jul 4, 2020

Hey, I use pyright / pylance for type checking and they are pretty unhappy that hk.Module doesn't define and abstract __call__ method, I get type errors all over the place when defining code that take arbitrary hk.Modules. Given most of Haiku is already typed this would be a nice addition.

@tomhennigan
Copy link
Collaborator

Thanks for the FR! In Haiku we don't special case __call__ (or other methods) meaning that module instances don't actually have to be callable. As an example for a VAE you may want a single module that defines def encode(self, x) and def decode(self, z) but not __call__.

It is common for modules to be callable of course, and we have considered adding def __call__(self, *a, **k) -> Any: raise NotImplementedError to the Module base class (purely for type hints), however our current thinking is that this is not actually more useful as a type hint than users using Callable[..., Any] and actually might be harmful (modules that aren't actually callable would pass the static analysis). Another downside IMHO is that we cannot define a calling convention other than *a, **k -> Any because users can (and do) do anything with their __call__ method.

Using Callable instead of hk.Module may have other benefits, for example in most places where you could pass a callable module you could also pass a function. As a concrete example in our Sequential module we take a list of callables and this means users can pass lambda x: x or a module instance or a JAX function etc.

We're keen where possible to encourage JAX code to be decoupled from Haiku, we feel that overall this is best for the ecosystem and users will not be locked into a particular way of using JAX by our libraries.

Concretely, if you're thinking about requiring hk.Modules in your type signatures to, I would suggest instead requiring Callable[..., Any] and even better define the calling convention you require too: Callable[[jnp.ndarray], jnp.ndarray].

@cgarciae
Copy link
Author

cgarciae commented Jul 7, 2020

I see. What you are saying makes sense in that context.

On my side I am creating a framework on top of Haiku so I have to assume that users will create a generic callable Module, I guess I need to create a type intersection like this: python/typing#213 (comment)

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants