Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix xlmr
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 28, 2020
1 parent 2070b86 commit ff7aae8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def list_pretrained_roberta():

def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
root: str = get_model_zoo_home_dir(),
load_backbone=True) \
load_backbone: bool = True) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
Expand Down
17 changes: 13 additions & 4 deletions src/gluonnlp/models/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def list_pretrained_xlmr():


def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
root: str = get_model_zoo_home_dir()) \
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
-> Tuple[CN, SentencepieceTokenizer, str]:
"""Get the pretrained XLM-R weights
Expand All @@ -99,6 +100,8 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
The name of the xlmr model.
root
The downloading root
load_backbone
Whether to load the weights of the backbone network
Returns
-------
Expand All @@ -115,14 +118,20 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
sp_model_path = PRETRAINED_URL[model_name]['sentencepiece.model']
params_path = PRETRAINED_URL[model_name]['params']
local_paths = dict()
for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path), \
('params', params_path)]:
for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path)]:
local_paths[k] = download(url=get_repo_model_zoo_url() + path,
path=os.path.join(root, path),
sha1_hash=FILE_STATS[path])
if load_backbone:
local_params_path = download(url=get_repo_model_zoo_url() + params_path,
path=os.path.join(root, params_path),
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None

tokenizer = SentencepieceTokenizer(local_paths['sentencepiece.model'])
cfg = XLMRModel.get_cfg().clone_merge(local_paths['cfg'])
return cfg, tokenizer, local_paths['params']
return cfg, tokenizer, local_params_path


BACKBONE_REGISTRY.register('xlmr', [XLMRModel,
Expand Down

0 comments on commit ff7aae8

Please sign in to comment.