Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
revise
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 29, 2020
1 parent b460bbe commit 1450f5c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
5 changes: 3 additions & 2 deletions scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def parse_args():
help='If set, both training and dev samples are generated on-the-fly '
'from raw texts instead of pre-processed npz files. ')
parser.add_argument("--short_seq_prob", type=float, default=0.05,
help="The probability of sampling sequences shorter than the max_seq_length.")
help='The probability of sampling sequences '
'shorter than the max_seq_length.')
parser.add_argument("--cached_file_path", default=None,
help="Directory for saving preprocessed features")
help='Directory for saving preprocessed features')
parser.add_argument('--circle_length', type=int, default=2,
help='Number of files to be read for a single GPU at the same time.')
parser.add_argument('--repeat', type=int, default=8,
Expand Down
25 changes: 16 additions & 9 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@


def parse_args():
parser = argparse.ArgumentParser(description='Question Answering example.'
' We fine-tune the pretrained model on SQuAD dataset.')
parser = argparse.ArgumentParser(
description='Question Answering example. '
'We fine-tune the pretrained model on SQuAD dataset.')
parser.add_argument('--model_name', type=str, default='google_albert_base_v2',
help='Name of the pretrained model.')
parser.add_argument('--do_train', action='store_true',
Expand Down Expand Up @@ -80,7 +81,8 @@ def parse_args():
help='warmup steps. Note that either warmup_steps or warmup_ratio is set.')
parser.add_argument('--wd', type=float, default=0.01, help='weight decay')
parser.add_argument('--layerwise_decay', type=float, default=-1, help='Layer-wise lr decay')
parser.add_argument('--untunable_depth', type=float, default=-1, help='Depth of untunable parameters')
parser.add_argument('--untunable_depth', type=float, default=-1,
help='Depth of untunable parameters')
parser.add_argument('--classifier_dropout', type=float, default=0.1,
help='dropout of classifier')
# Data pre/post processing
Expand Down Expand Up @@ -108,15 +110,16 @@ def parse_args():
'to a start position')
parser.add_argument('--n_best_size', type=int, default=20, help='Top N results written to file')
parser.add_argument('--max_answer_length', type=int, default=30,
help='The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.'
' default is 30')
help='The maximum length of an answer that can be generated. This is '
'needed because the start and end predictions are not conditioned '
'on one another. default is 30')
parser.add_argument('--param_checkpoint', type=str, default=None,
help='The parameter checkpoint for evaluating the model')
parser.add_argument('--backbone_path', type=str, default=None,
help='The parameter checkpoint of backbone model')
parser.add_argument('--all_evaluate', action='store_true',
help='Whether to evaluate all intermediate checkpoints instead of only last one')
help='Whether to evaluate all intermediate checkpoints '
'instead of only last one')
parser.add_argument('--max_saved_ckpt', type=int, default=10,
help='The maximum number of saved checkpoints')
args = parser.parse_args()
Expand Down Expand Up @@ -326,7 +329,11 @@ def get_network(model_name,
backbone_params_path = backbone_path if backbone_path else download_params_path
if checkpoint_path is None:
# TODO(zheyuye), be careful of allow_missing that used to pass the mlm parameters in roberta
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=True, ctx=ctx_l)
backbone.load_parameters(
backbone_params_path,
ignore_extra=True,
allow_missing=True,
ctx=ctx_l)
num_params, num_fixed_params = count_parameters(backbone.collect_params())
logging.info(
'Loading Backbone Model from {}, with total/fixd parameters={}/{}'.format(
Expand Down Expand Up @@ -710,7 +717,7 @@ def predict_extended(original_feature,
# TODO investigate the impact
token_max_context_score[i, j] = min(j - chunk_start,
chunk_start + chunk_length - 1 - j) \
+ 0.01 * chunk_length
+ 0.01 * chunk_length
token_max_chunk_id = token_max_context_score.argmax(axis=0)

for chunk_id, (result, chunk_feature) in enumerate(zip(results, chunked_features)):
Expand Down

0 comments on commit 1450f5c

Please sign in to comment.