Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 28, 2020
1 parent 86702fe commit 2b7f7a3
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 44 deletions.
39 changes: 20 additions & 19 deletions scripts/conversion_toolkits/convert_fairseq_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def convert_vocab(args, fairseq_model):
fairseq_vocab = fairseq_model.task.dictionary
# bos_word attr missing in fairseq_vocab
fairseq_vocab.bos_word = fairseq_vocab[fairseq_vocab.bos_index]

assert os.path.exists(fairseq_dict_path), \
'{} not found'.format(fairseq_dict_path)
from mxnet.gluon.utils import download
Expand All @@ -63,7 +63,7 @@ def convert_vocab(args, fairseq_model):
inter_vocab = list(inter_vocab.items())
inter_vocab = sorted(inter_vocab, key=lambda x : x[1])
tokens = [e[0] for e in inter_vocab]

tail = [fairseq_vocab[-4],
fairseq_vocab[-3],
fairseq_vocab[-2],
Expand All @@ -84,15 +84,15 @@ def convert_vocab(args, fairseq_model):
gluon_vocab.save(vocab_save_path)
os.remove(temp_vocab_file)
os.remove(temp_merges_file)

gluon_tokenizer = HuggingFaceByteBPETokenizer(
merges_save_path,
vocab_save_path
)

if args.test:
test_vocab(fairseq_model, gluon_tokenizer)

vocab_size = len(fairseq_vocab)
print('| converted dictionary: {} types'.format(vocab_size))
return vocab_size
Expand All @@ -103,14 +103,14 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
gluon_vocab = gluon_tokenizer.vocab
assert len(fairseq_vocab) == \
len(gluon_vocab)

# assert all_tokens
# roberta with gpt2 bytebpe bpe does not provide all tokens directly
if check_all_tokens:
for i in range(len(fairseq_vocab)):
assert fairseq_vocab[i] == gluon_vocab.all_tokens[i], \
'{}, {}, {}'.format(i, fairseq_vocab[i], gluon_vocab.all_tokens[i])

# assert special tokens
for special_tokens in ['unk', 'pad', 'eos', 'bos']:
assert getattr(fairseq_vocab, special_tokens + '_index') == \
Expand All @@ -121,7 +121,7 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
assert fairseq_vocab[-1] == \
gluon_vocab.all_tokens[-1] == \
'<mask>'

sentence = "Hello, y'all! How are you Ⅷ 😁 😁 😁 ?" + \
'GluonNLP is great!!!!!!' + \
"GluonNLP-Amazon-Haibin-Leonard-Sheng-Shuai-Xingjian...../:!@# 'abc'"
Expand All @@ -131,7 +131,7 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
# Notice: we may append bos and eos
# manuually after tokenizing sentences
assert fs_tokens.numpy().tolist()[1:-1] == gl_tokens

# assert decode
fs_sentence = fairseq_model.decode(fs_tokens)
gl_sentence = gluon_tokenizer.decode(gl_tokens)
Expand Down Expand Up @@ -170,7 +170,8 @@ def convert_params(fairseq_model,
fairseq_prefix = 'model.decoder.'
gluon_model = gluon_model_cls.from_cfg(
gluon_cfg,
return_all_hiddens=True,
use_pooler=False,
output_all_encodings=True,
prefix=gluon_prefix
)
gluon_model.initialize(ctx=ctx)
Expand All @@ -196,7 +197,7 @@ def convert_params(fairseq_model,
np.concatenate([fs_q_weight, fs_k_weight, fs_v_weight], axis=0))
gl_qkv_bias.set_data(
np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0))

for k, v in [
('self_attn.out_proj.weight', 'proj_weight'),
('self_attn.out_proj.bias', 'proj_bias'),
Expand Down Expand Up @@ -230,20 +231,20 @@ def convert_params(fairseq_model,
gl_name = gluon_prefix + v
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

# position embed weight
padding_idx = fairseq_model.task.dictionary.pad_index
fs_pos_embed_name = fairseq_prefix + 'sentence_encoder.embed_positions.weight'
gl_pos_embed_name = gluon_prefix + 'pos_embed_embed_weight'
gluon_params[gl_pos_embed_name].set_data(
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:,:])

# assert untie=False
assert np.array_equal(
fairseq_params[fairseq_prefix + 'sentence_encoder.embed_tokens.weight'].cpu().numpy(),
fairseq_params[fairseq_prefix + 'lm_head.weight'].cpu().numpy()
)

return gluon_model

def test_model(fairseq_model, gluon_model, gpu):
Expand Down Expand Up @@ -272,16 +273,16 @@ def test_model(fairseq_model, gluon_model, gpu):
fs_input_ids = torch.from_numpy(input_ids).cuda(gpu)
if gpu is not None:
fs_input_ids = fs_input_ids.cuda(gpu)

fairseq_model.model.eval()
gl_x, gl_all_hiddens = \

gl_all_hiddens, gl_x = \
gluon_model(gl_input_ids, gl_valid_length)

fs_x, fs_extra = \
fairseq_model.model.cuda(gpu)(fs_input_ids, return_all_hiddens=True)
fs_all_hiddens = fs_extra['inner_states']

num_layers = fairseq_model.args.encoder_layers
for i in range(num_layers + 1):
gl_hidden = gl_all_hiddens[i].asnumpy()
Expand Down Expand Up @@ -317,7 +318,7 @@ def rename(save_dir):
new_name = '{file_prefix}-{short_hash}.{file_sufix}'.format(
file_prefix=file_prefix,
short_hash=long_hash[:8],
file_sufix=file_sufix)
file_sufix=file_sufix)
new_path = os.path.join(save_dir, new_name)
shutil.move(old_path, new_path)
file_size = os.path.getsize(new_path)
Expand Down
89 changes: 64 additions & 25 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self,
use_mlm=True,
untie_weight=False,
encoder_normalize_before=True,
return_all_hiddens=False,
output_all_encodings=False,
prefix=None,
params=None):
"""
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(self,
untie_weight
Whether to untie weights between embeddings and classifiers
encoder_normalize_before
return_all_hiddens
output_all_encodings
prefix
params
"""
Expand All @@ -182,7 +182,7 @@ def __init__(self,
self.use_mlm = use_mlm
self.untie_weight = untie_weight
self.encoder_normalize_before = encoder_normalize_before
self.return_all_hiddens = return_all_hiddens
self.output_all_encodings = output_all_encodings
with self.name_scope():
self.tokens_embed = nn.Embedding(
input_dim=self.vocab_size,
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(self,
bias_initializer=bias_initializer,
activation=self.activation,
dtype=self.dtype,
return_all_hiddens=self.return_all_hiddens
output_all_encodings=self.output_all_encodings
)
self.encoder.hybridize()

Expand All @@ -237,25 +237,63 @@ def __init__(self,
bias_initializer=bias_initializer
)
self.lm_head.hybridize()
# TODO support use_pooler

def hybrid_forward(self, F, tokens, valid_length):
x = self.tokens_embed(tokens)
outputs = []
embedding = self.get_initial_embedding(F, tokens)

inner_states = self.encoder(x, valid_length)
if self.output_all_encodings:
contextual_embeddings = inner_states[-1]
else:
contextual_embeddings = inner_states
outputs.append(contextual_embeddings)

if self.use_pooler:
pooled_out = self.apply_pooling(contextual_embeddings)
outputs.append(pooled_out)

if self.use_mlm:
mlm_output = self.lm_head(contextual_embeddings)
outputs.append(mlm_output)
return tuple(outputs) if len(outputs) > 1 else outputs[0]

def get_initial_embedding(self, F, inputs):
"""Get the initial token embeddings that considers the token type and positional embeddings
Parameters
----------
F
inputs
Shape (batch_size, seq_length)
Returns
-------
embedding
The initial embedding that will be fed into the encoder
"""
embedding = self.tokens_embed(inputs)
if self.pos_embed_type:
positional_embedding = self.pos_embed(F.npx.arange_like(x, axis=1))
positional_embedding = self.pos_embed(F.npx.arange_like(inputs, axis=1))
positional_embedding = F.np.expand_dims(positional_embedding, axis=0)
x = x + positional_embedding
embedding = embedding + positional_embedding
if self.embed_ln:
x = self.embed_ln(x)
x = self.embed_dropout(x)
inner_states = self.encoder(x, valid_length)
x = inner_states[-1]
if self.use_mlm:
x = self.lm_head(x)
if self.return_all_hiddens:
return x, inner_states
else:
return x
embedding = self.embed_ln(embedding)
embedding = self.embed_dropout(embedding)

def apply_pooling(self, sequence):
"""Generate the representation given the inputs.
This is used for pre-training or fine-tuning a mobile bert model.
Get the first token of the whole sequence which is [CLS]
sequence:
Shape (batch_size, sequence_length, units)
return:
Shape (batch_size, units)
"""
outputs = sequence[:, 0, :]
return outputs

@staticmethod
def get_cfg(key=None):
Expand All @@ -271,7 +309,7 @@ def from_cfg(cls,
use_mlm=True,
untie_weight=False,
encoder_normalize_before=True,
return_all_hiddens=False,
output_all_encodings=False,
prefix=None,
params=None):
cfg = RobertaModel.get_cfg().clone_merge(cfg)
Expand All @@ -298,7 +336,7 @@ def from_cfg(cls,
use_mlm=use_mlm,
untie_weight=untie_weight,
encoder_normalize_before=encoder_normalize_before,
return_all_hiddens=return_all_hiddens,
output_all_encodings=output_all_encodings,
prefix=prefix,
params=params)

Expand All @@ -316,7 +354,7 @@ def __init__(self,
bias_initializer='zeros',
activation='gelu',
dtype='float32',
return_all_hiddens=False,
output_all_encodings=False,
prefix='encoder_',
params=None):
super(RobertaEncoder, self).__init__(prefix=prefix, params=params)
Expand All @@ -329,7 +367,7 @@ def __init__(self,
self.layer_norm_eps = layer_norm_eps
self.activation = activation
self.dtype = dtype
self.return_all_hiddens = return_all_hiddens
self.output_all_encodings = output_all_encodings
with self.name_scope():
self.all_layers = nn.HybridSequential(prefix='layers_')
with self.all_layers.name_scope():
Expand Down Expand Up @@ -358,8 +396,8 @@ def hybrid_forward(self, F, x, valid_length):
layer = self.all_layers[layer_idx]
x, _ = layer(x, atten_mask)
inner_states.append(x)
if not self.return_all_hiddens:
inner_states = [x]
if not self.output_all_encodings:
inner_states = x
return inner_states

@use_np
Expand Down Expand Up @@ -419,7 +457,8 @@ def list_pretrained_roberta():

def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True) \
load_backbone: bool = True,
load_mlm: bool = False) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
Expand Down

0 comments on commit 2b7f7a3

Please sign in to comment.