Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 28, 2020
1 parent 8645115 commit eead164
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
1 change: 1 addition & 0 deletions scripts/conversion_toolkits/convert_fairseq_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def convert_params(fairseq_model,
fairseq_prefix = 'model.decoder.'
gluon_model = gluon_model_cls.from_cfg(
gluon_cfg,
use_mlm=True,
use_pooler=False,
output_all_encodings=True,
prefix=gluon_prefix
Expand Down
8 changes: 6 additions & 2 deletions scripts/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class ModelForQABasic(HybridBlock):
"""
def __init__(self, backbone, weight_initializer=None, bias_initializer=None,
prefix=None, params=None):
use_segmentation=True, prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
with self.name_scope():
self.backbone = backbone
self.use_segmentation = use_segmentation
self.qa_outputs = nn.Dense(units=2, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
Expand Down Expand Up @@ -53,7 +54,10 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask):
The log-softmax scores that the position is the end position.
"""
# Get contextual embedding with the shape (batch_size, sequence_length, C)
contextual_embedding = self.backbone(tokens, token_types, valid_length)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
scores = self.qa_outputs(contextual_embedding)
start_scores = scores[:, :, 0]
end_scores = scores[:, :, 1]
Expand Down
24 changes: 13 additions & 11 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def process_sample(self, feature: SquadFeature):
doc_stride=self._doc_stride,
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3)
for chunk in chunks:
data = np.array([self.cls_id] + truncated_query_ids +
[self.sep_id] +
data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] +
feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +
[self.sep_id], dtype=np.int32)
valid_length = len(data)
Expand Down Expand Up @@ -315,6 +314,7 @@ def get_network(model_name,
cfg
tokenizer
qa_net
use_segmentation
"""
# Create the network
use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Expand All @@ -326,7 +326,8 @@ def get_network(model_name,

backbone_params_path = backbone_path if backbone_path else download_params_path
if checkpoint_path is None:
backbone.load_parameters(backbone_params_path, ignore_extra=True, ctx=ctx_l)
# TODO(zheyuye), be careful of allow_missing that used to pass the mlm parameters in roberta
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=True, ctx=ctx_l)
num_params, num_fixed_params = count_parameters(backbone.collect_params())
logging.info(
'Loading Backbone Model from {}, with total/fixd parameters={}/{}'.format(
Expand All @@ -344,7 +345,7 @@ def get_network(model_name,
qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True)
qa_net.hybridize()

return cfg, tokenizer, qa_net
return cfg, tokenizer, qa_net, use_segmentation


def untune_params(model, untunable_depth, not_included=[]):
Expand Down Expand Up @@ -414,10 +415,11 @@ def apply_layerwise_decay(model, layerwise_decay, not_included=[]):

def train(args):
ctx_l = parse_ctx(args.gpus)
cfg, tokenizer, qa_net = get_network(args.model_name, ctx_l,
args.classifier_dropout,
args.param_checkpoint,
args.backbone_path)
cfg, tokenizer, qa_net, use_segmentation = \
get_network(args.model_name, ctx_l,
args.classifier_dropout,
args.param_checkpoint,
args.backbone_path)
# Load the data
train_examples = get_squad_examples(args.data_dir, segment='train', version=args.version)
logging.info('Load data from {}, Version={}'.format(args.data_dir, args.version))
Expand Down Expand Up @@ -558,7 +560,7 @@ def train(args):
log_sample_num += len(tokens)
epoch_sample_num += len(tokens)
num_samples_per_update += len(tokens)
segment_ids = sample.segment_ids.as_in_ctx(ctx)
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
gt_start = sample.gt_start.as_in_ctx(ctx)
Expand Down Expand Up @@ -786,7 +788,7 @@ def predict_extended(original_feature,

def evaluate(args, last=True):
ctx_l = parse_ctx(args.gpus)
cfg, tokenizer, qa_net = get_network(
cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout)
# Prepare dev set
dev_cache_path = os.path.join(CACHE_PATH,
Expand Down Expand Up @@ -852,7 +854,7 @@ def eval_validation(ckpt_name, best_eval):
tokens = sample.data.as_in_ctx(ctx)
total_num += len(tokens)
log_num += len(tokens)
segment_ids = sample.segment_ids.as_in_ctx(ctx)
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
p_mask = sample.masks.as_in_ctx(ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
Expand Down
16 changes: 8 additions & 8 deletions src/gluonnlp/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self,
bias_initializer='zeros',
dtype='float32',
use_pooler=False,
use_mlm=True,
use_mlm=False,
untie_weight=False,
encoder_normalize_before=True,
output_all_encodings=False,
Expand Down Expand Up @@ -242,9 +242,10 @@ def hybrid_forward(self, F, tokens, valid_length):
outputs = []
embedding = self.get_initial_embedding(F, tokens)

inner_states = self.encoder(embedding, valid_length)
outputs.append(inner_states)
contextual_embeddings = inner_states[-1]
contextual_embeddings = self.encoder(embedding, valid_length)
outputs.append(contextual_embeddings)
if self.output_all_encodings:
contextual_embeddings = contextual_embeddings[-1]

if self.use_pooler:
pooled_out = self.apply_pooling(contextual_embeddings)
Expand Down Expand Up @@ -305,7 +306,7 @@ def get_cfg(key=None):
def from_cfg(cls,
cfg,
use_pooler=False,
use_mlm=True,
use_mlm=False,
untie_weight=False,
encoder_normalize_before=True,
output_all_encodings=False,
Expand Down Expand Up @@ -396,7 +397,7 @@ def hybrid_forward(self, F, x, valid_length):
x, _ = layer(x, atten_mask)
inner_states.append(x)
if not self.output_all_encodings:
inner_states = [x]
inner_states = x
return inner_states

@use_np
Expand Down Expand Up @@ -456,8 +457,7 @@ def list_pretrained_roberta():

def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
root: str = get_model_zoo_home_dir(),
load_backbone: bool = True,
load_mlm: bool = False) \
load_backbone: bool = True) \
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
"""Get the pretrained RoBERTa weights
Expand Down

0 comments on commit eead164

Please sign in to comment.