-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do I alter the parameter cotangents in a custom derivative? #98
Comments
Hi Neil, I wonder if it would be easier to apply this to the result of In general Haiku is designed to make it easy to define neural networks and pass state around, but I think when working with (1) JAX transforms or (2) trying to integrate Haiku with another library it is usually easier and safer to work with the pure functions that Haiku gives you (the init and apply functions from transform). The main benefit of this approach is that you and your users are free to swap Haiku out for any other NN library in the future if you prefer (since the only coupling is in the signature of the pure function and perhaps the structure of the params dictionary). Additionally it is usually much easier to reason about pure functions, but (subjectively) it is more difficult to describe a neural network using pure functions and combinators. |
Hi Tom, thanks a lot for looking into this.
I see what you're getting at, but the problem is that this "module" (and its custom VJP) is buried deep within the network. I can't just call transform on it alone and then filter and merge results. I need the underlying module with its custom VJP to produce the correct cotangents. The fundamental problem is that such a VJP has to be written with explicit parameters.
I agree, and that is essentially what my proposal is aiming to do. Whereas
Yes, I see you mean from the JAX side of things. I'm not making any proposal about changing that side. My problem is happening on the Haiku part of my code (within |
We do have an experimental feature called I've knocked up an example here: I wonder if this will be sufficient for your use case (making relevant parameters explicit inside a haiku transform). |
I'm pretty sure that's exactly what I want!! Would you mind if I left the issue open for a few more days until I have my code running and I'm sure this works? |
Of course, feel free to keep this open as long as is useful for you. |
Yesterday, I started learning Haiku in order to port my codebase over, but I'm running into some showstoppers and I'm wondering if anyone could offer some helpful pointers. My main issue right now is how to port over a custom gradient that has this form in my code:
(This ability to store one VJP in the residuals of another custom VJP was something that I added to JAX google/jax#3705.)
The problem with porting this over to Haiku is that the forward pass is not an explicit function of the weights, and so the backward pass doesn't have the opportunity to pass cotangents to the weights.
I'm new to Haiku, but I wonder if it would be possible to do something like this:
Basically,
get_relevant_parameters
would be something like:Alternatively, I could just pass in
current_frame().params
tof
, but t would be annoying to have passNone
as corresponding cotangents for all the parameters in the model inbwd
.I'm going to keep working on this, but I thought I'd file the issue early in the very likely case I'm missing something. Thanks a lot, and great project by the way!
The text was updated successfully, but these errors were encountered: