Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Rename copy_initial_weights to something more intuitive, and replace copy with detach where appropriate. #54

Open
egrefen opened this issue May 14, 2020 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@egrefen
Copy link
Contributor

egrefen commented May 14, 2020

As pointed out in #30, the kwarg copy_initial_weights is hard to understand.

First, we should investigate whether it's not sufficient just to detach when branching from the outer loop model when unrolling. Second, we should come up with a kwarg and docs which are more intuitive. Third, we should illustrate its use in a tutorial example.

@egrefen egrefen added the enhancement New feature or request label May 14, 2020
@egrefen egrefen self-assigned this May 14, 2020
@egrefen egrefen added this to In progress in higher development May 14, 2020
@renesax14
Copy link
Contributor

renesax14 commented Jun 9, 2020

@egrefen I want to help improving the naming and documentation for copy_initial_weights. I've spent several hours clarifying with precise notation and articulating precisely what my multiple confusions are and what I think your/egrefen's intended semantics is. To arrive at this I read #30 and several meta-learning papers in depth, including MAML of course.

I will start with something very simple like MAML, where the meta-parameters of the meta-learner are only the initialization of the base model (in the original paper the weights of a 4 layer CNN with 32 filters).

Let T be the number of inner loop steps and eta_i be the hyper parameter for the inner loop step and eta_o the outer step, both constants.

To clarify the semantics I will start by describing what I believe copy_initial_weights = False should do.

Notation:
w^<t_o,t_i> := weights at outer time step t_o and inner step t_i
to disambiguate train and test data-set from the meta-train-set and meta-test-set I will use S_t for support set and Q_t for query set for task t (which is usually a class label) for each episode. I will only sample 1 task in my description and suppress t leading to S and Q as the notation. Note, though that usually an episode samples a batch of tasks however.

If we want MAML the correct loop invariants we want to hold in higher are as follows:

For the outer loop step/update (assuming it's SGD and not Adam) we want:

w^<t_o+1,0> := w^<t_o,0> - \eta_o * Grad_{w^<t_o,0>}( L( w^<t_o,T>, Q) )

i.e. after each outer time step we want the gradient to be with respect to the current initialization w^<t_o,0>. We note that this update is the done with an in place operation in normal pytorch.

For the inner loop step (assuming SGD) we want:

w^<t_o,t_i+1> := w^<t_o,t_i> - \eta_i * Grad_{w^<t_o,t_i>}( L( w^<t_o,t_i>, S)

note that in the code we are likely to have the initialization params of the base model in a variable w. Higher should allow Grad_{param} to be differentiable and track gradients correctly (e.g. not over track gradients so to accumulate the inner gradients in the outer gradients).

At the end of the outer step we want:

w.grad = Grad_{w^<t_o,0>}( L( w^<t_o,T>, Q)

and nothing else. The gradients Grad_{w^<t_o,t_i>}( L( w^<t_o,t_i>, S) should not be collected there accidentally.

At the end of the inner step we want:

w.grad = Grad_{w^<t_o,t_i>}( L( w^<t_o,t_i>, Q)

assuming w always keeps track of the base model parameters.

Now let's get back to the docs to figure out the intended semantics. The docs say:

If this (copy_initial_weights) is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

I think it's important to precisely define "initial weights". The first time I read that I thought it meant w^<0,0>. If that were the meaning then the update steps would be:

w^<t_o+1,0> := w^<t_o,0> - \eta_o * Grad_{w^<0,0>}( L( w^<t_o,T>, Q)
w^<t_o,t_i+1> := w^<t_o,t_i> - \eta_i * Grad_{w^<0,0>}( L( w^<t_o,t_i>, S)

which cannot possible be correct. So I will assume that by "initial weight" we mean the weight before an update to the outer loop. No other definition makes sense to me. So the initial weight is w or to be precise; w^<t_o,0>. This means the initial weight is tracking the changing value of the initialization.

@renesax14
Copy link
Contributor

renesax14 commented Jun 10, 2020

Now with the notation being clear in my head I think I can provide actionable feedback (I hope).

Before I start giving more feedback I think it’s important to clarify what exactly “initial” means. I believe it’s meant to be the weights after an update from the outer optimizer has been done (note that the outer optimization process is not unrolled or differentiated through). Thus, the loop invariant for the word “initial" should mean w^<t_o,0> according to the notation I introduced. Note that initial does not necessarily mean “initialization of the base model”, which is another source of confusion since that can be it’s “initial weights”. I believe that’s true since (in principle) there are meta-learners that do not train the initialization (or the base models “initial weights”, note the potential source of confusion). So the intended meaning of initial I believe means

"the value of the meta-parameters initially before inner loop adaptation.”

To give concrete feedback on the current wording:

if True, the weights of the patched module are copied to form the initial weights of the patched module,

According to my previous clarification I believe it means that the patched module (which really means the differentiable module or the module as a functional object) get’s the weight w^<t_o,0> before inner loop adaptation. To further clarify I believe the word “copy” probably means “deep copy”. So it’s a separate set of weights. Not sure why this is useful to be done but at least the terms are clear. If this is done my suspicion is that the outer optimizer will not see the original weights as part of the forward computation (after unrolling the gradient path of the inner loop) and thus the outer gradients would zero as I’ve outlined here: #58 and seen in my own code.

In the true case I’d be curious to know what would happen to weights of a parametrized optimizer (like in the case of meta-lstm optimizer paper by Ravi and Hugo) if those are part of the meta-parameters.

Next:

and thus are not part of the gradient tape when unrolling the patched module.

I hope this doesn’t seem pedantic but since at some point I was confused that there one could accidentally unroll both the outer and inner loop, I believe it's a useful reminder to say that only the inner loop is usually unrolled. Thus, “unrolling” means that we make the inner optimization part of the forward pass (and thus differentiate through it) when using the outer optimizer. The outer optimizer should disallow further chaining because that’s the standard way normal pytorch optimizers work.

I’ve never heard the term “gradient tape” (and I’ve read a handful of learning to learn papers) so that remains ambiguous but I believe it means the “unrolled path through the gradient operation in the inner loop optimization”.

If this is set to False, the actual module weights will be the initial weights of the patched module. This is useful when doing MAML, for example.

I guess what this means is that the weights of the differentiable module (i.e. patched module) will be w^<t_o,0> and not w^<t_o,0>.deep_copy(). Perhaps a better wording would be

The patched differentiable module would have the original weights w^<t_o,0> (i.e. the weights right after the in place outer loop update) before the inner loop adaptation, rather than a deep copy. This is useful if one wants to train the the initialization of the base model.

Now after going through the wording in more detail I fail to appreciate what is the use case for copy_initial_weights. I always want to train the weights of a potentially parametrized inner optimizer. It’s very common to want to train the initialization, so for now I can’t see why one wouldn’t want to do this.

For a suggestion to renaming the “copy_weights_before_inner_loop”, I suggest to drop the word initial since that is confusing with initialization (e.g. it could also mean initial weights before inner loop adaptation of the parametrized optimizer). Or at least clarify what “initial” means, since I believe it means w^<t_o,0> i.e. the weights before the first inner loop step and after the in place outer loop update.

I realize this is hard because the learning to learn and meta-learning has optimizer of optimizer of optimizer of gradient of gradient ad nauseam. That’s a joke but that’s why this is difficult to explain cleanly. I hope this helps.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
Development

No branches or pull requests

2 participants