You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This might be a question for jax, but I think it probably comes up in Haiku.
Supposing I have the code within some hk.Module:
out=input# the code in each layer is identical, only the parameters differforlayerinself.layers:
out=layer(out)
returnout
And, assume that each layer is an instance of the same derived hk.Module class that uses hk.get_parameter inside its __call__ method.
Given the situation that the code is identical in each layer, one could express it as a jax.lax.fori_loop, but it is quite awkward.
Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?
# parameters previously defined by hk.get_parameter in the above, merged across layersall_layer_params= ...
deflayer_fn(i, input):
# the code in any layer of above self.layerslayer_params=jax.lax.dynamic_slice(all_layer_params, i)
...
returnjax.lax.fori_loop(0, num_layers, layer_fn, input)
Is there a way to do this idiomatically in Haiku, to take advantage of the internal hk.get_parameter calls?
Thanks in advance!
The text was updated successfully, but these errors were encountered:
Hey @hrbigelow, both versions should work and in theory should be equally efficient, however we've seen a few cases (in particular with transformer models) where if you use structured control flow the XLA compiler does a better job at optimizing (in particular reducing peak memory usage) and (sometimes) overlapping communication with compute.
The implementation of layer stack is kind of complex (it handles quite a few edge cases) but it basically boils down to using jax.lax.scan for the per-layer init and apply functions correctly.
Thanks Tom. Actually I'm looking at the examples for hk.experimental.layer_stack. Just making sure I understand, it doesn't seem possible to somehow use layer_stack the same way you would use an ordinary hk.Module that has calls to hk.get_parameter, is that right?
Instead, if you wanted to build an hk.Module method that used layer_stack, you'd need to somehow obtain the pure function f to pass to stack(f)(...).
This might be a question for jax, but I think it probably comes up in Haiku.
Supposing I have the code within some
hk.Module
:And, assume that each
layer
is an instance of the same derivedhk.Module
class that useshk.get_parameter
inside its__call__
method.Given the situation that the code is identical in each layer, one could express it as a
jax.lax.fori_loop
, but it is quite awkward.Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?
Is there a way to do this idiomatically in Haiku, to take advantage of the internal
hk.get_parameter
calls?Thanks in advance!
The text was updated successfully, but these errors were encountered: