Skip to content

Commit

Permalink
fix comment1
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 6533601 commit 1b5fa7b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 49 deletions.
55 changes: 25 additions & 30 deletions scripts/conversion_toolkits/convert_fairseq_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
cfg.MODEL.shared_embed = fairseq_cfg.share_all_embeddings
cfg.MODEL.scale_embed = not fairseq_cfg.no_scale_embedding
cfg.MODEL.tie_weights = fairseq_cfg.share_decoder_input_output_embed
cfg.MODEL.layernorm_embedding = fairseq_cfg.layernorm_embedding
cfg.MODEL.data_norm = fairseq_cfg.layernorm_embedding
cfg.MODEL.pooler_activation = fairseq_cfg.pooler_activation_fn
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.dropout = fairseq_cfg.dropout
Expand Down Expand Up @@ -111,26 +111,6 @@ def convert_attention(num_layers,
gl_qkv_bias.set_data(
np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0))

def convert_embeddings(fairseq_prefix, gluon_prefix):
for k, v in [
('.embed_tokens.weight', '_embed_layer.weight'),
('.layernorm_embedding.weight', '_embed_ln.gamma'),
('.layernorm_embedding.bias', '_embed_ln.beta'),
]:
fs_name = fairseq_prefix + k
gl_name = gluon_prefix + v
all_keys.remove(gl_name)
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

# position embed weight
padding_idx = fairseq_model.task.dictionary.pad_index
fs_pos_embed_name = fairseq_prefix + '.embed_positions.weight'
gl_pos_embed_name = gluon_prefix + '_pos_embed_layer._embed.weight'
all_keys.remove(gl_pos_embed_name)
gluon_params[gl_pos_embed_name].set_data(
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:, :])

def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
# convert feed forward layer in encoder
for layer_id in range(num_layers):
Expand All @@ -150,11 +130,33 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting embedding params')
padding_idx = fairseq_model.task.dictionary.pad_index
for fs_name, gl_name in [
('model.encoder.embed_tokens.weight', 'src_embed_layer.weight'),
('model.encoder.embed_positions.weight', 'src_pos_embed_layer._embed.weight'),
('model.encoder.layernorm_embedding.weight', 'encoder.ln_data.gamma'),
('model.encoder.layernorm_embedding.bias', 'encoder.ln_data.beta'),
('model.decoder.embed_tokens.weight', 'tgt_embed_layer.weight'),
('model.decoder.embed_positions.weight', 'tgt_pos_embed_layer._embed.weight'),
('model.decoder.layernorm_embedding.weight', 'decoder.ln_data.gamma'),
('model.decoder.layernorm_embedding.bias', 'decoder.ln_data.beta'),
# final projection in decoder
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
]:
all_keys.remove(gl_name)
if 'embed_positions' in fs_name:
# position embed weight
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy()[padding_idx + 1:, :])
else:
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting encoder params')
encoder_num_layers = gluon_cfg.MODEL.ENCODER.num_layers
convert_attention(encoder_num_layers, 'model.encoder', 'encoder')
convert_ffn(encoder_num_layers, 'model.encoder', 'encoder')
convert_embeddings('model.encoder', 'src')
for layer_id in range(encoder_num_layers):
for k, v in [
('self_attn.out_proj.weight', 'attention_proj.weight'),
Expand All @@ -170,6 +172,7 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

print('converting decoder params')
decoder_num_layers = gluon_cfg.MODEL.DECODER.num_layers
convert_attention(decoder_num_layers, 'model.decoder', 'decoder',
gluon_attn_prefix='attn_in_qkv')
Expand Down Expand Up @@ -201,14 +204,6 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())

convert_embeddings('model.decoder', 'tgt')
# final projection in decoder
for fs_name, gl_name in [
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
]:
all_keys.remove(gl_name)
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())
assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint'

# check parameters sharing if share_decoder_input_output_embed is true
Expand Down
4 changes: 2 additions & 2 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bart_base():
cfg.MODEL.dropout = 0.0
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.pooler_activation = 'tanh'
cfg.MODEL.layernorm_embedding = True
cfg.MODEL.data_norm = True
cfg.MODEL.layout = 'NT'
cfg.MODEL.dtype = 'float32'

Expand Down Expand Up @@ -285,13 +285,13 @@ def from_cfg(cls, cfg,
pos_embed_type=cfg.MODEL.pos_embed_type,
shared_embed=cfg.MODEL.shared_embed,
tie_weights=cfg.MODEL.tie_weights,
data_norm=cfg.MODEL.data_norm,
use_pooler=use_pooler,
attention_dropout=cfg.MODEL.attention_dropout,
activation_dropout=cfg.MODEL.activation_dropout,
dropout=cfg.MODEL.dropout,
pooler_activation=cfg.MODEL.pooler_activation,
layer_norm_eps=cfg.MODEL.layer_norm_eps,
layernorm_embedding=cfg.MODEL.layernorm_embedding,
enc_num_layers=cfg.MODEL.ENCODER.num_layers,
enc_units=cfg.MODEL.ENCODER.units,
enc_num_heads=cfg.MODEL.ENCODER.num_heads,
Expand Down
19 changes: 2 additions & 17 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,7 @@ def __init__(self, units: int = 512,
num_heads=num_heads,
attention_dropout=self._attention_dropout,
dtype=dtype,
layout=attention_layout,
layout='NTK')
layout=attention_layout)
self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
Expand Down Expand Up @@ -914,7 +913,6 @@ def __init__(self, src_vocab_size: int,
max_tgt_length: Optional[int] = None,
scale_embed: bool = True,
pos_embed_type="sinusoidal",
layernorm_embedding: bool = False,
shared_embed: bool = True,
tie_weights: bool = True,
activation_dropout: float = 0.0,
Expand Down Expand Up @@ -959,8 +957,6 @@ def __init__(self, src_vocab_size: int,
Whether to multiply the src and dst embeddings by sqrt(units)
pos_embed_type
Type of the positional embedding
layernorm_embedding
Wether to layer normalize the embedding
shared_embed
Whether to share the embedding of the src and tgt language
tie_weights
Expand Down Expand Up @@ -1027,7 +1023,6 @@ def __init__(self, src_vocab_size: int,
self._tgt_vocab_size = tgt_vocab_size
self.tie_weights = tie_weights
self.pos_embed_type = pos_embed_type
self.layernorm_embedding = layernorm_embedding
self.scaled_embed = scale_embed
self.enc_units = enc_units
self.dec_units = dec_units
Expand Down Expand Up @@ -1063,11 +1058,6 @@ def __init__(self, src_vocab_size: int,
max_length=max_tgt_length,
dtype=self._dtype,
method=pos_embed_type)
if layernorm_embedding:
self.src_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
in_channels=enc_units)
self.tgt_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
in_channels=dec_units)
self.encoder = TransformerEncoder(num_layers=enc_num_layers,
recurrent=enc_recurrent,
units=enc_units,
Expand Down Expand Up @@ -1164,8 +1154,6 @@ def encode(self, F, src_data, src_valid_length):
else:
src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer(
F.npx.arange_like(src_data, axis=0)), axis=1)
if self.layernorm_embedding:
src_data = self.src_embed_ln(src_data)

enc_out = self.encoder(src_data, src_valid_length)
return enc_out
Expand Down Expand Up @@ -1209,8 +1197,7 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
else:
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
F.npx.arange_like(tgt_data, axis=0)), axis=1)
if self.layernorm_embedding:
tgt_data = self.tgt_embed_ln(tgt_data)

dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
return dec_out

Expand Down Expand Up @@ -1403,8 +1390,6 @@ def hybrid_forward(self, F, step_data, states):
step_data = step_data * np.sqrt(self.model.dec_units)
if self.model.pos_embed_type is not None:
step_data = step_data + self.model.tgt_pos_embed_layer(position)
if self.model.layernorm_embedding:
step_data = self.tgt_embed_ln(step_data)
out, new_states =\
self.model.decoder.incremental_decode(F, step_data, dec_states,
mem_data, mem_valid_length)
Expand Down

0 comments on commit 1b5fa7b

Please sign in to comment.