Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
re-upload roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 13, 2020
1 parent 5811d40 commit 8ed8a72
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 48 deletions.
12 changes: 1 addition & 11 deletions scripts/conversion_toolkits/convert_fairseq_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,7 @@ def test_model(fairseq_model, gluon_model, gpu):
1E-3,
1E-3
)

gl_mlm_scores = gl_mlm_scores.asnumpy()
fs_mlm_scores = fs_mlm_scores.transpose(0, 1)
fs_mlm_scores = fs_mlm_scores.detach().cpu().numpy()
for j in range(batch_size):
assert_allclose(
gl_mlm_scores[j, :valid_length[j], :],
fs_mlm_scores[j, :valid_length[j], :],
1E-3,
1E-3
)
#TODO(zheyuye), checking the masking scores

def rename(save_dir):
"""Rename converted files with hash"""
Expand Down
35 changes: 22 additions & 13 deletions scripts/conversion_toolkits/convert_fairseq_xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import mxnet as mx

from gluonnlp.utils.misc import logging_config
from gluonnlp.models.xlmr import XLMRModel as gluon_XLMRModel
from gluonnlp.models.xlmr import XLMRModel, XLMRForMLM
from gluonnlp.third_party import sentencepiece_model_pb2
from fairseq.models.roberta import XLMRModel as fairseq_XLMRModel
from convert_fairseq_roberta import rename, test_model, test_vocab, convert_config, convert_params
Expand Down Expand Up @@ -88,23 +88,32 @@ def convert_fairseq_model(args):
vocab_size = convert_vocab(args, fairseq_xlmr)

gluon_cfg = convert_config(fairseq_xlmr.args, vocab_size,
gluon_XLMRModel.get_cfg().clone())
XLMRModel.get_cfg().clone())
with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of:
of.write(gluon_cfg.dump())

ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu()
gluon_xlmr = convert_params(fairseq_xlmr,
gluon_cfg,
gluon_XLMRModel,
ctx,
gluon_prefix='xlmr_')
if args.test:
test_model(fairseq_xlmr, gluon_xlmr, args.gpu)
for is_mlm in [False, True]:
gluon_xlmr = convert_params(fairseq_roberta,
gluon_cfg,
ctx,
is_mlm=is_mlm,
gluon_prefix='roberta_')

if is_mlm:
if args.test:
test_model(fairseq_roberta, gluon_xlmr, args.gpu)

gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True)
logging.info('Convert the RoBERTa MLM model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model_mlm.params')))
else:
gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
logging.info('Convert the RoBERTa backbone model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model.params')))

gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
logging.info('Convert the XLM-R model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model.params')))
logging.info('Conversion finished!')
logging.info('Statistics:')
rename(args.save_dir)
Expand Down
22 changes: 10 additions & 12 deletions src/gluonnlp/models/model_zoo_checksums/roberta.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401
fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
fairseq_roberta_base/model-98b4532f.params 98b4532fe59e6fd755422057fde4601b3eb8fbf0 498792661
fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
fairseq_roberta_large/model-e3f578dc.params e3f578dc669cf36fa5b6730b0bbee77c980276d7 1421659773
fairseq_roberta_large_mnli/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
fairseq_roberta_large_mnli/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
fairseq_roberta_large_mnli/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
fairseq_roberta_large_mnli/model-5288bb09.params 5288bb09db89b7900e85c9d673686f748f0abd56 1421659773
fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401
fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
fairseq_roberta_base/model-09a1520a.params 09a1520adf652468c07e43a6ed28908418fa58a7 496222787
fairseq_roberta_base/model_mlm-29889e2b.params 29889e2b4ef20676fda117bb7b754e1693d0df25 498794868
fairseq_roberta_large/model-6b043b91.params 6b043b91a6a781a12ea643d0644d32300db38ec8 1417251819
fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318
fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402
fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231
fairseq_roberta_large/model_mlm-119f38e1.params 119f38e1249bd28bea7dd2e90c09b8f4b879fa19 1421664140
14 changes: 5 additions & 9 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
}
"""

__all__ = ['RobertaModel', 'list_pretrained_roberta', 'get_pretrained_roberta']
__all__ = ['RobertaModel', 'RobertaForMLM', 'list_pretrained_roberta', 'get_pretrained_roberta']

import os
from typing import Tuple
Expand All @@ -54,20 +54,16 @@
'cfg': 'fairseq_roberta_base/model-565d1db7.yml',
'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab',
'params': 'fairseq_roberta_base/model-98b4532f.params'
'params': 'fairseq_roberta_base/model-09a1520a.params'
'mlm_params': 'google_uncased_mobilebert/model_mlm-29889e2b.params',
},
'fairseq_roberta_large': {
'cfg': 'fairseq_roberta_large/model-6e66dc4a.yml',
'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab',
'params': 'fairseq_roberta_large/model-e3f578dc.params'
'params': 'fairseq_roberta_large/model-6b043b91.params',
'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params'
},
'fairseq_roberta_large_mnli': {
'cfg': 'fairseq_roberta_large_mnli/model-6e66dc4a.yml',
'merges': 'fairseq_roberta_large_mnli/gpt2-396d4d8e.merges',
'vocab': 'fairseq_roberta_large_mnli/gpt2-f1335494.vocab',
'params': 'fairseq_roberta_large_mnli/model-5288bb09.params'
}
}

FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'roberta.txt'))
Expand Down
8 changes: 5 additions & 3 deletions src/gluonnlp/models/xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
}
"""

__all__ = ['XLMRModel', 'list_pretrained_xlmr', 'get_pretrained_xlmr']
__all__ = ['XLMRModel', 'XLMRForMLM', 'list_pretrained_xlmr', 'get_pretrained_xlmr']

from typing import Tuple
import os
from mxnet import use_np
from .roberta import RobertaModel, roberta_base, roberta_large
from .roberta import RobertaModel, RobertaForMLM roberta_base, roberta_large
from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir
from ..utils.config import CfgNode as CN
from ..utils.registry import Registry
Expand Down Expand Up @@ -82,7 +82,9 @@ def get_cfg(key=None):
return xlmr_cfg_reg.create(key)
else:
return xlmr_base()

@use_np
class XLMRForMLM(RobertaForMLM):
super().__init__()

def list_pretrained_xlmr():
return sorted(list(PRETRAINED_URL.keys()))
Expand Down

0 comments on commit 8ed8a72

Please sign in to comment.