Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
tiny update on run_squad
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 14, 2020
1 parent 4defc7a commit dc55fc9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
3 changes: 2 additions & 1 deletion scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,9 +808,10 @@ def evaluate(args, last=True):
args.comm_backend, args.gpus)
# only evaluate once
if rank != 0:
logging.info('Skipping node {}'.format(rank))
return
ctx_l = parse_ctx(args.gpus)
logging.info('Srarting inference without horovod on the first node')
logging.info('Srarting inference without horovod on the first node on device {}'.format(str(ctx_l)))

cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout)
Expand Down
10 changes: 6 additions & 4 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab',
'params': 'fairseq_roberta_large/model-6b043b91.params',
'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params'
},
'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params',
}
}

FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'roberta.txt'))
Expand Down Expand Up @@ -524,11 +524,13 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
"""
assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format(
model_name, list_pretrained_roberta())
cfg_path = PRETRAINED_URL[model_name]['cfg']
cfg_path = PRETRAINED_URL[model_name
]['cfg']
merges_path = PRETRAINED_URL[model_name]['merges']
vocab_path = PRETRAINED_URL[model_name]['vocab']
params_path = PRETRAINED_URL[model_name]['params']
mlm_params_path = PRETRAINED_URL[model_name]['mlm_params']

local_paths = dict()
for k, path in [('cfg', cfg_path), ('vocab', vocab_path),
('merges', merges_path)]:
Expand All @@ -541,7 +543,7 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None
if load_mlm:
if load_mlm and mlm_params_path is not None:
local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
path=os.path.join(root, mlm_params_path),
sha1_hash=FILE_STATS[mlm_params_path])
Expand Down
8 changes: 4 additions & 4 deletions src/gluonnlp/models/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@
'cfg': 'fairseq_xlmr_base/model-b893d178.yml',
'sentencepiece.model': 'fairseq_xlmr_base/sentencepiece-18e17bae.model',
'params': 'fairseq_xlmr_base/model-3fa134e9.params',
'mlm_params': 'model_mlm-86e37954.params'
'mlm_params': 'model_mlm-86e37954.params',
},
'fairseq_xlmr_large': {
'cfg': 'fairseq_xlmr_large/model-01fc59fb.yml',
'sentencepiece.model': 'fairseq_xlmr_large/sentencepiece-18e17bae.model',
'params': 'fairseq_xlmr_large/model-b62b074c.params',
'mlm_params': 'model_mlm-887506c2.params'
'mlm_params': 'model_mlm-887506c2.params',

}
}

FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'xlmr.txt'))
xlmr_cfg_reg = Registry('roberta_cfg')
xlmr_cfg_reg = Registry('xlmr_cfg')


@xlmr_cfg_reg.register()
Expand Down Expand Up @@ -139,7 +139,7 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base',
sha1_hash=FILE_STATS[params_path])
else:
local_params_path = None
if load_mlm:
if load_mlm and mlm_params_path is not None:
local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path,
path=os.path.join(root, mlm_params_path),
sha1_hash=FILE_STATS[mlm_params_path])
Expand Down

0 comments on commit dc55fc9

Please sign in to comment.