Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 28, 2020
1 parent ee1f0e3 commit 2070b86
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self,
return_all_hiddens=False,
prefix=None,
params=None):
"""
"""
Parameters
----------
Expand All @@ -155,7 +155,7 @@ def __init__(self,
dtype
use_pooler
Whether to use classification head
use_mlm
use_mlm
Whether to use lm head, if False, forward return hidden states only
untie_weight
Whether to untie weights between embeddings and classifiers
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self,
method=pos_embed_type,
prefix='pos_embed_'
)

self.encoder = RobertaEncoder(
units=self.units,
hidden_size=self.hidden_size,
Expand All @@ -223,7 +223,7 @@ def __init__(self,
return_all_hiddens=self.return_all_hiddens
)
self.encoder.hybridize()

if self.use_mlm:
embed_weight = None if untie_weight else \
self.tokens_embed.collect_params('.*weight')
Expand Down Expand Up @@ -303,7 +303,7 @@ def from_cfg(cls,
params=params)

@use_np
class RobertaEncoder(HybridBlock):
class RobertaEncoder(HybridBlock):
def __init__(self,
units=768,
hidden_size=3072,
Expand Down Expand Up @@ -418,7 +418,8 @@ def list_pretrained_roberta():


def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
root: str = get_model_zoo_home_dir()) \
root: str = get_model_zoo_home_dir(),
load_backbone=True) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
Expand All @@ -428,6 +429,8 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
The name of the RoBERTa model.
root
The downloading root
load_backbone
Whether to load the weights of the backbone network
Returns
-------
Expand All @@ -446,13 +449,20 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
params_path = PRETRAINED_URL[model_name]['params']
local_paths = dict()
for k, path in [('cfg', cfg_path), ('vocab', vocab_path),
('merges', merges_path), ('params', params_path)]:
('merges', merges_path)]:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
path=os.path.join(root, path),
sha1_hash=FILE_STATS[path])
if load_backbone:
local_params_path = download(url=get_repo_model_zoo_url() + params_path,
path=os.path.join(root, params_path),
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None

tokenizer = HuggingFaceByteBPETokenizer(local_paths['merges'], local_paths['vocab'])
cfg = RobertaModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_paths['params']
return cfg, tokenizer, local_params_path


BACKBONE_REGISTRY.register('roberta', [RobertaModel,
Expand Down

0 comments on commit 2070b86

Please sign in to comment.