Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The implementation of layerwise learning rate decay #51

Closed
importpandas opened this issue Apr 30, 2020 · 2 comments
Closed

The implementation of layerwise learning rate decay #51

importpandas opened this issue Apr 30, 2020 · 2 comments

Comments

@importpandas
Copy link

for layer in range(n_layers):
key_to_depths["encoder/layer_" + str(layer) + "/"] = layer + 1
return {
key: learning_rate * (layer_decay ** (n_layers + 2 - depth))
for key, depth in key_to_depths.items()
}

According to the code here, assume that n_layers=24, then key_to_depths["encoder/layer_23/"] = 24 which is the depth for last encoder layer, but the learning rate for last layer is:
learning_rate * (layer_decay ** (24+ 2 - 24)) = learning_rate * (layer_decay ** (2)).

That's what confused me. Why the learning rate for last layer is learning_rate * (layer_decay ** (2)) rather than learning_rate? Do I ignore anything?

@clarkkev
Copy link
Collaborator

clarkkev commented May 8, 2020

For the layerwise learning rate decay we count task-specific layer added on top of the pre-trained transformer as additional layer of the model, so the learning rate for the last layer of ELECTRA should be learning_rate * 0.8. But you've still found a bug, where instead it is learning_rate * 0.8^2.

The bug happened because there used to be a pooler layer in ELECTRA before we removed the next-sentence-prediction task. In that case the learning rates per layer were

  • task-specific softmax: learning_rate
  • pooler: learning_rate * 0.8
  • transformer layer 24: learning_rate * 0.8^2
  • transformer layer 23: learning_rate * 0.8^3
  • ...
    However, when we removed the pooling layer, we didn't fix the learning rates correspondingly. I guess in practice this didn't hurt performance much, so I'm leaving it as-is to keep result reproducible for now.

@importpandas
Copy link
Author

I got it, thanks for your explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants