New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct way to transform and init a hk.Module
with non-default parameter?
#218
Comments
Hi @IanQS , the second issue does not seem to be related to passing the input shape, but perhaps is from later in your code? In particular it seems to be related to use of the With regard to your question about how to pass values in, there are a few options. Your (Option 1) For readonly hyperparameters I think many people prefer to declare them using flags or as constants in the file, and then close over their value: INPUT_DIMS = ...
def f(x):
m = Model(INPUT_DIMS)
return m(x)
f = hk.transform(f)
params = f.init(rng, x)
out = f.apply(params, rng, x) I think our most complete example for this type of thing is the resnet+imagenet example, which includes both closing over hyperparameters as flags, as well as passing a boolean value into the transform for (Option 2): You can pass any arguments you like into the function you transform with Haiku: def f(x, input_dims):
m = Model(input_dims)
return m(x)
f = hk.transform(f)
params = f.init(rng, x, input_dims)
out = f.apply(params, rng, x, input_dims) |
Thanks Tom! I'm a fan of Option 2 (hoping to use this in areas where the shape isn't specified until run-time) Is there a function signature somewhere for
is that correct? It's not clear to me why I would pass in
Oh, is that not what we are supposed to do? It looks like what is happening here: VAE example I've attached my colab link here. Thank you for your help! |
Given a function If you wanted to get a deeper understanding of what happens inside transform then take a look here: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html By the way, there is another common option, which is to transform a method on a regular object: class Model:
def __init__(self, input_dims):
self.input_dims = input_dims
self.init, self.apply = hk.transform(self._forward)
def _forward(self, x):
m = hk.Linear(self.input_dims)
return m(x)
m = Model(input_dims=10)
params = m.init(rng, x)
out = m.apply(params, rng, x)
So the key issue here is that >>> jax.grad(lambda x: x)(lambda: None)
...
TypeError: Argument '<function <lambda> at 0x7f2116434290>' of type <class 'function'> is not a valid JAX type. I've modified your colab to use the pattern described above (have a regular python object holding the haiku and jax transformed methods) and the loss seems to go in the right direction: https://colab.research.google.com/gist/tomhennigan/456984830510eded8f1675476bf1ff8f/haiku_ianqs_dbg.ipynb |
Ahh! Thank you! I'll have to keep that in mind... I'm sure it'll end up biting me again at some point Follow up question for you: I've seen 3 ways of defining models in Haiku:
Out of the 3, do you see any weaknesses between them? I know that for method 1), you can only define simple |
I don't think this is true, you can have an arbitrarily complex driving of modules inside your transformed function. There is no requirement in Haiku for there to be a single top level module inside the transform, or for modules to be called in a particular order. For example you might want to add a residual inside your sequential stack:
It might be worth taking a look at this visualized to get a better intuition for what is happening: https://colab.research.google.com/gist/tomhennigan/cbb2297aac0093088936da4f53c97577/example-of-non-trivial-network-without-top-level-class.ipynb
I think the differences are just cosmetic. I work primarily with ML researchers and the advice I give them is to use the approach that will enable them to explore their ideas as quickly as possible. Usually this means having very few layers of abstraction and indirection. This may sound like bad advice from a software engineering point of view, since it would have a tendency to lead to unmaintainable or hard to test/debug code, however most research code ends up being thrown away, and for the code that is not thrown away, it is often straightforward to make the code more robust later on. I suspect which approach will be best for you depends on what you will use the code for (e.g. how many configuration parameters are you require, how you want to drive those config params [e.g. flags or arguments] and whether you want to write unit tests covering them).
I don't think you can go wrong with either library. One useful thing about Flax is that it has a larger GitHub community and there are likely more examples or projects you can fork. It has been a while since I looked at Flax, but IIRC they also had abstractions for training inside the library which may also be useful (in Haiku these are out of scope, we just offer NN modules and examples of how to write training loops). Haiku's core design (e.g. which modules are in the core library, default values for initializers etc) follow on from Sonnet which DM used for a number of years. Haiku itself has been used by DM research for 18mo+ and it has held up well for us from small projects all the way up to AlphaFold and friends, so if you decided to use it you can be confident that it should work well (if your needs are similar to ours). Either way you go it should be straightforward to move code between either library if you change your mind later. When porting code I use the following checklist:
|
Ahh, sorry, I was conflating
Thanks! I'm looking at it for research, and also in terms of spinning up a startup on it. I'm mostly evaluating which library makes the most sense: has the most clear documentation; is better tested; has fewer / more explicitly mentioned sharp edges and so forth. |
Hey all!
I'm trying to run a linear regression example and I've got the following
So I'm running into an issue where I'm not able to specify the model shape. If I do not specify it as in the above, I get the error of
__init__() missing 1 required positional argument: 'input_dims'
but if I do specify the shape via
I get
Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.
What is the recommended way of addressing this? I'm reading through hk.transform but I'm not sure. Looking at the code examples, there are
__init__
functions without default args so I know it's possible.The text was updated successfully, but these errors were encountered: