Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I alter the parameter cotangents in a custom derivative? #98

Closed
NeilGirdhar opened this issue Jan 31, 2021 · 5 comments
Closed

Comments

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Jan 31, 2021

Yesterday, I started learning Haiku in order to port my codebase over, but I'm running into some showstoppers and I'm wondering if anyone could offer some helpful pointers. My main issue right now is how to port over a custom gradient that has this form in my code:

@custom_vjp
def f(..., weights): ...

def fwd(..., weights):
    internal_vjp = vjp(g, weights)
    return f(...), internal_vjp

def bwd(residuals, y_bar):
    internal_vjp = residuals
    weights_bar = internal_vjp(y_bar)
    # In fact I split y_bar into a variety of pieces, and then vmap internal_vjp over those different pieces,
    # and finally I assemble the weight cotangents.  This allows me to funnel different cotangents from
    # y_bar to different parameters.
    return ..., weights_bar

f.defvjp(fwd, bwd)

(This ability to store one VJP in the residuals of another custom VJP was something that I added to JAX google/jax#3705.)

The problem with porting this over to Haiku is that the forward pass is not an explicit function of the weights, and so the backward pass doesn't have the opportunity to pass cotangents to the weights.

I'm new to Haiku, but I wonder if it would be possible to do something like this:

def f(...):
    return internal_f(..., hk.get_relevant_parameters(g))  # g is the internal function that implements f.

@custom_vjp
def internal_f(..., weights: hk.Params):
    t = hk.transform(g)
    return t.apply(weights)

def fwd(..., weights: hk.Params):
    t = hk.transform(g)
    primal, internal_vjp = vjp(t.apply, weights)
    return primal, internal_vjp

def bwd(residuals, y_bar):
    internal_vjp = residuals
    weights_bar = internal_vjp(y_bar)
    # Fortunately, I can assemble weights_bar thanks to the filter and merge functions in hk.data_structure.
    return ..., weights_bar

internal_f.defvjp(fwd, bwd)

Basically, get_relevant_parameters would be something like:

def get_relevant_parameters(f: Callable[..., Any], *args: Any, **kwargs: Any) -> Params:
  parameters = transform(f).init(0, *args, **kwargs)  # RNG value is irrelevant for the parameters returned by init.
  # Return parameters from the current frame that are needed by f.
  return {k: {l: current_frame().params[k][l] for l, _ in bundle.items()}
          for k, bundle in parameters.items()}

Alternatively, I could just pass in current_frame().params to f, but t would be annoying to have pass None as corresponding cotangents for all the parameters in the model in bwd.

I'm going to keep working on this, but I thought I'd file the issue early in the very likely case I'm missing something. Thanks a lot, and great project by the way!

@tomhennigan
Copy link
Collaborator

Hi Neil, I wonder if it would be easier to apply this to the result of transform. For example hk.transform(f).apply is an explicit function of your parameters (signature is f.apply(params, optional_rng, *args, **kwargs)). If you needed to additionally filter the parameters dictionary to a subset of parameters you can use our hk.data_structures.{merge,partition,..} utilities, we have an example here.

In general Haiku is designed to make it easy to define neural networks and pass state around, but I think when working with (1) JAX transforms or (2) trying to integrate Haiku with another library it is usually easier and safer to work with the pure functions that Haiku gives you (the init and apply functions from transform).

The main benefit of this approach is that you and your users are free to swap Haiku out for any other NN library in the future if you prefer (since the only coupling is in the signature of the pure function and perhaps the structure of the params dictionary). Additionally it is usually much easier to reason about pure functions, but (subjectively) it is more difficult to describe a neural network using pure functions and combinators.

@NeilGirdhar
Copy link
Contributor Author

Hi Tom, thanks a lot for looking into this.

I wonder if it would be easier to apply this to the result of transform…

I see what you're getting at, but the problem is that this "module" (and its custom VJP) is buried deep within the network. I can't just call transform on it alone and then filter and merge results. I need the underlying module with its custom VJP to produce the correct cotangents. The fundamental problem is that such a VJP has to be written with explicit parameters.

trying to integrate Haiku with another library it is usually easier and safer to work with the pure functions…

I agree, and that is essentially what my proposal is aiming to do. Whereas transform allows you to call Haiku code (wherein the state is implicit) from JAX (wherein the state is explicit), get_relevant_parameters allows you to call JAX code from Haiku.

users are free to swap Haiku out for any other NN library in the future if you prefer (since the only coupling is in the signature of the pure function)…

Yes, I see you mean from the JAX side of things. I'm not making any proposal about changing that side. My problem is happening on the Haiku part of my code (within transform, where parameters are implicit). I'm not sure if I'm explaining this well. Please let me know what you think?

@tomhennigan
Copy link
Collaborator

We do have an experimental feature called lift. This enables you to make the parameters explicit for a specific function call inside a transform. It is not very well documented but a few advanced users are making heavy use of it internally (e.g. to make it easier to scan over the application of modules).

I've knocked up an example here:

https://colab.research.google.com/gist/tomhennigan/6f1237b5fb268a3d6d2391329ba2d051/example-of-using-hk-experimental-lift.ipynb

I wonder if this will be sufficient for your use case (making relevant parameters explicit inside a haiku transform).

@NeilGirdhar
Copy link
Contributor Author

I'm pretty sure that's exactly what I want!!

Would you mind if I left the issue open for a few more days until I have my code running and I'm sure this works?

@tomhennigan
Copy link
Collaborator

Of course, feel free to keep this open as long as is useful for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants