Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
layer-wise decay
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 2, 2020
1 parent 07186d5 commit 4bc3a96
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 40 deletions.
41 changes: 1 addition & 40 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,45 +380,6 @@ def untune_params(model, untunable_depth, not_included=[]):
continue
value.grad_req = 'null'


def apply_layerwise_decay(model, layerwise_decay, not_included=[]):
"""Apply the layer-wise gradient decay
.. math::
lr = lr * layerwise_decay^(max_depth - layer_depth)
Parameters:
----------
model
qa_net
layerwise_decay: int
layer-wise decay power
not_included: list of str
A list or parameter names that not included in the layer-wise decay
"""
# consider the task specific finetuning layer as the last layer, following with pooler
# In addition, the embedding parameters have the smaller learning rate based on this setting.
all_layers = model.backbone.encoder.all_encoder_layers
max_depth = len(all_layers)
if 'pool' in model.collect_params().keys():
max_depth += 1
for key, value in model.collect_params().items():
if 'scores' in key:
value.lr_mult = layerwise_decay**(0)
if 'pool' in key:
value.lr_mult = layerwise_decay**(1)
if 'embed' in key:
value.lr_mult = layerwise_decay**(max_depth + 1)

for (layer_depth, layer) in enumerate(all_layers):
layer_params = layer.collect_params()
for key, value in layer_params.items():
for pn in not_included:
if pn in key:
continue
value.lr_mult = layerwise_decay**(max_depth - layer_depth)


def train(args):
ctx_l = parse_ctx(args.gpus)
cfg, tokenizer, qa_net, use_segmentation = \
Expand Down Expand Up @@ -486,7 +447,7 @@ def train(args):
if args.untunable_depth > 0:
untune_params(qa_net, args.untunable_depth)
if args.layerwise_decay > 0:
apply_layerwise_decay(qa_net, args.layerwise_decay)
qa_net.backbone.apply_layerwise_decay(args.layerwise_decay)

# Do not apply weight decay to all the LayerNorm and bias
for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items():
Expand Down
33 changes: 33 additions & 0 deletions src/gluonnlp/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def __init__(self,
self.pos_embed_type = pos_embed_type
self.num_token_types = num_token_types
self.vocab_size = vocab_size
self.num_layers = num_layers
self.num_heads = num_heads
self.embed_size = embed_size
self.units = units
self.max_length = max_length
Expand Down Expand Up @@ -372,6 +374,37 @@ def get_initial_embedding(self, F, inputs, token_types=None):
embedding = self.embed_dropout(embedding)
return embedding


def apply_layerwise_decay(self, layerwise_decay, not_included=[]):
"""Apply the layer-wise gradient decay
.. math::
lr = lr * layerwise_decay^(max_depth - layer_depth)
Parameters:
----------
layerwise_decay: int
layer-wise decay power
not_included: list of str
A list or parameter names that not included in the layer-wise decay
"""

# consider the task specific finetuning layer as the last layer, following with pooler
# In addition, the embedding parameters have the smaller learning rate based on this setting.
max_depth = self.num_layers
for key, value in self.collect_params().items():
if 'embed' in key:
value.lr_mult = layerwise_decay**(max_depth + 1)

for (layer_depth, layer) in enumerate(self.encoder.all_encoder_layers):
layer_params = layer.collect_params()
for key, value in layer_params.items():
for pn in not_included:
if pn in key:
continue
value.lr_mult = layerwise_decay**(max_depth - layer_depth)


@staticmethod
def get_cfg(key=None):
if key is None:
Expand Down

0 comments on commit 4bc3a96

Please sign in to comment.