Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct way to transform and init a hk.Module with non-default parameter? #218

Closed
IanQS opened this issue Sep 19, 2021 · 6 comments
Closed

Comments

@IanQS
Copy link

IanQS commented Sep 19, 2021

Hey all!

I'm trying to run a linear regression example and I've got the following

import jax.numpy as jnp
from sklearn.datasets import load_boston
import haiku as hk
import optax
import jax


X, y = load_boston(return_X_y=True)
train_X = jnp.asarray(X.tolist())
train_y = jnp.asarray(y.tolist())
    
class Model(hk.Module):
    def __init__(self, input_dims):
        super().__init__()
        self.input_dims = input_dims
    
    def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
        l1 = hk.Linear(self.input_dims)
        return l1(X)
    
model = hk.transform(lambda x: Model()(x))  # <-- where I would specify the model shape if at all? 

So I'm running into an issue where I'm not able to specify the model shape. If I do not specify it as in the above, I get the error of

__init__() missing 1 required positional argument: 'input_dims'

but if I do specify the shape via

model = hk.transform(lambda x: Model(train_X.shape[1])(x))

I get Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.


What is the recommended way of addressing this? I'm reading through hk.transform but I'm not sure. Looking at the code examples, there are __init__ functions without default args so I know it's possible.

@tomhennigan
Copy link
Collaborator

Hi @IanQS , the second issue does not seem to be related to passing the input shape, but perhaps is from later in your code? In particular it seems to be related to use of the init function returned from Haiku (e.g. model.init in your code). My guess is that you are passing that as an argument to a JAX transformed function somewhere? If you are able to reproduce this issue in a Colab notebook I'd be happy to debug it further.

With regard to your question about how to pass values in, there are a few options. Your train_x.shape[1] approach LGTM (we refer to this as "closing over" the value, and this is the most common pattern I see at DM). I think there are basically two common patterns in our code:

(Option 1) For readonly hyperparameters I think many people prefer to declare them using flags or as constants in the file, and then close over their value:

INPUT_DIMS = ...

def f(x):
  m = Model(INPUT_DIMS)
  return m(x)

f = hk.transform(f)

params = f.init(rng, x)
out = f.apply(params, rng, x)

I think our most complete example for this type of thing is the resnet+imagenet example, which includes both closing over hyperparameters as flags, as well as passing a boolean value into the transform for is_training.

(Option 2): You can pass any arguments you like into the function you transform with Haiku:

def f(x, input_dims):
  m = Model(input_dims)
  return m(x)

f = hk.transform(f)

params = f.init(rng, x, input_dims)
out = f.apply(params, rng, x, input_dims)

@IanQS
Copy link
Author

IanQS commented Sep 19, 2021

Thanks Tom!

I'm a fan of Option 2 (hoping to use this in areas where the shape isn't specified until run-time)

Is there a function signature somewhere for f.init? It wasn't clear to me that this is where I would pass in the shape of input_dims. I'm guessing it's along the lines of

PRNGKey = jnp.ndarray

def init(rng: PRNGKey, input_data: jnp.ndarray, *args):
...

is that correct? It's not clear to me why I would pass in x as in your example

def f(x, input_dims):
  m = Model(input_dims)
  return m(x)

f = hk.transform(f)

params = f.init(rng, x, input_dims)           # <---- HERE
out = f.apply(params, rng, x, input_dims)

My guess is that you are passing that as an argument to a JAX transformed function somewhere

Oh, is that not what we are supposed to do? It looks like what is happening here: VAE example

I've attached my colab link here. Thank you for your help!

@tomhennigan
Copy link
Collaborator

Is there a function signature somewhere for f.init?

Given a function f(*a, **k) -> out, hk.transform(f) gives you back a pair of functions: f.init(rng, *a, **k) -> params and f.apply(params, rng, *a, **k) -> out.

If you wanted to get a deeper understanding of what happens inside transform then take a look here: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html

By the way, there is another common option, which is to transform a method on a regular object:

class Model:
  def __init__(self, input_dims):
    self.input_dims = input_dims
    self.init, self.apply = hk.transform(self._forward)

  def _forward(self, x):
    m = hk.Linear(self.input_dims)
    return m(x)

m = Model(input_dims=10)
params = m.init(rng, x)
out = m.apply(params, rng, x)

I've attached my colab link here. Thank you for your help!

So the key issue here is that jax.grad requires all arguments to be JAX Arrays. You can reproduce your error with the following minimal code:

>>> jax.grad(lambda x: x)(lambda: None)
...
TypeError: Argument '<function <lambda> at 0x7f2116434290>' of type <class 'function'> is not a valid JAX type.

I've modified your colab to use the pattern described above (have a regular python object holding the haiku and jax transformed methods) and the loss seems to go in the right direction: https://colab.research.google.com/gist/tomhennigan/456984830510eded8f1675476bf1ff8f/haiku_ianqs_dbg.ipynb

@IanQS
Copy link
Author

IanQS commented Sep 19, 2021

So the key issue here is that jax.grad requires all arguments to be JAX Arrays. You can reproduce your error with the following minimal code:

Ahh! Thank you! I'll have to keep that in mind... I'm sure it'll end up biting me again at some point


Follow up question for you: I've seen 3 ways of defining models in Haiku:

  1. function definitions

  2. objects inheriting hk.Module

  3. transform a method on a regular object (what you helped transform the colab notebook into)

Out of the 3, do you see any weaknesses between them? I know that for method 1), you can only define simple hk.Sequential models (linear pipeline). Is there a fundamental difference between 2 and 3?

@tomhennigan
Copy link
Collaborator

I know that for method 1), you can only define simple hk.Sequential models (linear pipeline).

I don't think this is true, you can have an arbitrarily complex driving of modules inside your transformed function. There is no requirement in Haiku for there to be a single top level module inside the transform, or for modules to be called in a particular order. For example you might want to add a residual inside your sequential stack:

def f(x):
  x = jnp.tanh(hk.Linear(..)(x))
  shortcut = x
  x = jnp.tanh(hk.Linear(..)(x))
  x += shortcut
  x = hk.Linear(..)(x)
  return x

It might be worth taking a look at this visualized to get a better intuition for what is happening: https://colab.research.google.com/gist/tomhennigan/cbb2297aac0093088936da4f53c97577/example-of-non-trivial-network-without-top-level-class.ipynb

Out of the 3, do you see any weaknesses between them?

I think the differences are just cosmetic. I work primarily with ML researchers and the advice I give them is to use the approach that will enable them to explore their ideas as quickly as possible. Usually this means having very few layers of abstraction and indirection. This may sound like bad advice from a software engineering point of view, since it would have a tendency to lead to unmaintainable or hard to test/debug code, however most research code ends up being thrown away, and for the code that is not thrown away, it is often straightforward to make the code more robust later on.

I suspect which approach will be best for you depends on what you will use the code for (e.g. how many configuration parameters are you require, how you want to drive those config params [e.g. flags or arguments] and whether you want to write unit tests covering them).

I'm trying to learn idiomatic haiku as I'm trying to compare and contrast it with flax to allow me to better choose between one or the other for some things I'm building

I don't think you can go wrong with either library. One useful thing about Flax is that it has a larger GitHub community and there are likely more examples or projects you can fork. It has been a while since I looked at Flax, but IIRC they also had abstractions for training inside the library which may also be useful (in Haiku these are out of scope, we just offer NN modules and examples of how to write training loops).

Haiku's core design (e.g. which modules are in the core library, default values for initializers etc) follow on from Sonnet which DM used for a number of years. Haiku itself has been used by DM research for 18mo+ and it has held up well for us from small projects all the way up to AlphaFold and friends, so if you decided to use it you can be confident that it should work well (if your needs are similar to ours).

Either way you go it should be straightforward to move code between either library if you change your mind later. When porting code I use the following checklist:

  • Are parameter initializers the same between these libraries (often they are subtly different).
  • Are other defaults (e.g. epsilons, betas) the same.
  • Is the implementation of the module equivalent (e.g. is this the "fast" implementation or the numerically stable one).
  • How will I migrate a checkpoint from old to new.

@IanQS
Copy link
Author

IanQS commented Sep 20, 2021

I don't think this is true, you can have an arbitrarily complex driving of modules inside your transformed function. There is no requirement in Haiku for there to be a single top level module inside the transform, or for modules to be called in a particular order. For example you might want to add a residual inside your sequential stack:

Ahh, sorry, I was conflating hk.Sequential with defining the computation graph in a function. Yeah, that makes sense as defining the graph in a function is probably another way of expressing it compared to a module. It probably all thunks down to the same thing

decided to use it you can be confident that it should work well (if your needs are similar to ours).

Thanks! I'm looking at it for research, and also in terms of spinning up a startup on it. I'm mostly evaluating which library makes the most sense: has the most clear documentation; is better tested; has fewer / more explicitly mentioned sharp edges and so forth.

@IanQS IanQS closed this as completed Sep 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants