Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 9, 2020
1 parent 838be2a commit 31cb953
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 71 deletions.
40 changes: 29 additions & 11 deletions scripts/question_answering/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,30 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
main_eval['best_f1_thresh'] = f1_thresh


def get_revised_results(preds, na_probs, thresh):
results = copy.deepcopy(preds)
def revise_unanswerable(preds, na_probs, na_prob_thresh):
"""
Revise the predictions results and return a null string for unanswerable question
whose unanswerable probability above the threshold.
Parameters
----------
preds: dict
A dictionary of full prediction of spans
na_probs: dict
A dictionary of unanswerable probabilities
na_prob_thresh: float
threshold of the unanswerable probability
Returns
-------
revised: dict
A dictionary of revised prediction
"""
revised = copy.deepcopy(preds)
for q_id in na_probs.keys():
if na_probs[q_id] > thresh:
results[q_id] = ""
return results
if na_probs[q_id] > na_prob_thresh:
revised[q_id] = ""
return revised


def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False):
Expand All @@ -197,18 +215,18 @@ def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False):
preds
predictions dictionary
na_probs
probabilities dict of unanswerable
probabilities dictionary of unanswerable
na_prob_thresh
threshold of unanswerable
revise
Wether to get the final predictions with impossible answers replaced
with null string ''
Returns
-------
out_eval
A dictionary of output results
(preds_out)
A dictionary of final predictions
out_eval
A dictionary of output results
(preds_out)
A dictionary of final predictions
"""
if isinstance(data_file, str):
with open(data_file) as f:
Expand Down Expand Up @@ -243,7 +261,7 @@ def squad_eval(data_file, preds, na_probs, na_prob_thresh=0.0, revise=False):
if revise:
thresh = (out_eval['best_exact_thresh'] +
out_eval['best_f1_thresh']) * 0.5
preds_out = get_revised_results(preds, na_probs, thresh)
preds_out = revise_unanswerable(preds, na_probs, thresh)
return out_eval, preds_out
else:
return out_eval, preds
117 changes: 57 additions & 60 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from squad_utils import SquadFeature, get_squad_examples, convert_squad_example_to_feature
from gluonnlp.models import get_backbone
from gluonnlp.utils.misc import repeat, grouper, set_seed, init_comm, \
parse_ctx, logging_config, count_parameters
logging_config, count_parameters, parse_ctx
from gluonnlp.initializer import TruncNorm
from gluonnlp.data.sampler import SplitSampler
from gluonnlp.utils.parameter import grad_global_norm, clip_grad_global_norm
Expand Down Expand Up @@ -305,6 +305,50 @@ def get_train(self, features, skip_unreliable=True):
return train_dataset, num_token_answer_mismatch, num_unreliable


def get_squad_features(args, tokenizer, segment):
"""
Get processed data features of SQuADExampls
Parameters
----------
args : argparse.Namespace
tokenizer:
Tokenizer instance
segment: str
train or dev
Returns
-------
data_features
The list of processed data features
"""
data_cache_path = os.path.join(CACHE_PATH,
'dev_{}_squad_{}.ndjson'.format(args.model_name,
args.version))
is_training = (segment == 'train')
if os.path.exists(data_cache_path) and not args.overwrite_cache:
data_features = []
with open(data_cache_path, 'r') as f:
for line in f:
data_features.append(SquadFeature.from_json(line))
logging.info('Found cached data features, load from {}'.format(data_cache_path))
else:
data_examples = get_squad_examples(args.data_dir, segment=segment, version=args.version)
start = time.time()
num_process = min(cpu_count(), 8)
logging.info('Tokenize Data:')
with Pool(num_process) as pool:
data_features = pool.map(functools.partial(convert_squad_example_to_feature,
tokenizer=tokenizer,
is_training=is_training), data_examples)
logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start))
with open(data_cache_path, 'w') as f:
for feature in data_features:
f.write(feature.to_json() + '\n')

return data_features


def get_network(model_name,
ctx_l,
dropout=0.1,
Expand Down Expand Up @@ -399,41 +443,14 @@ def untune_params(model, untunable_depth, not_included=[]):
def train(args):
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
args.comm_backend, args.gpus)

cfg, tokenizer, qa_net, use_segmentation = \
get_network(args.model_name, ctx_l,
args.classifier_dropout,
args.param_checkpoint,
args.backbone_path)
# Load the data
train_examples = get_squad_examples(args.data_dir, segment='train', version=args.version)
logging.info('Load data from {}, Version={}'.format(args.data_dir, args.version))
num_process = min(cpu_count(), 8)
train_cache_path = os.path.join(
CACHE_PATH, 'train_{}_squad_{}.ndjson'.format(
args.model_name, args.version))
if os.path.exists(train_cache_path) and not args.overwrite_cache:
train_features = []
with open(train_cache_path, 'r') as f:
for line in f:
train_features.append(SquadFeature.from_json(line))
logging.info('Found cached training features, load from {}'.format(train_cache_path))

else:
start = time.time()
logging.info('Tokenize Training Data:')
with Pool(num_process) as pool:
train_features = pool.map(
functools.partial(
convert_squad_example_to_feature,
tokenizer=tokenizer,
is_training=True),
train_examples)
logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start))
with open(train_cache_path, 'w') as f:
for feature in train_features:
f.write(feature.to_json() + '\n')

logging.info('Prepare training data')
train_features = get_squad_features(args, tokenizer, segment='train')
dataset_processor = SquadDatasetProcessor(tokenizer=tokenizer,
doc_stride=args.doc_stride,
max_seq_length=args.max_seq_length,
Expand All @@ -459,8 +476,7 @@ def train(args):
train_dataset,
batchify_fn=dataset_processor.BatchifyFunction,
batch_size=args.batch_size,
num_workers=4,
shuffle=True
num_workers=0,
sampler=sampler)
# Froze parameters
if 'electra' in args.model_name:
Expand Down Expand Up @@ -612,8 +628,8 @@ def train(args):
param_dict.zero_grad()

# saving
if local_rank == 0 and is_master_node and (
step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
if local_rank == 0 and (step_num + 1) % save_interval == 0 or (
step_num + 1) >= num_train_steps:
version_prefix = 'squad' + args.version
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
version_prefix,
Expand Down Expand Up @@ -650,9 +666,10 @@ def train(args):
num_samples_per_update = 0

if (step_num + 1) >= num_train_steps:
toc = time.time()
logging.info(
'Finish training step: {} within {} hours'.format(
step_num + 1, toc - global_tic))
step_num + 1, (toc - global_tic) / 3600))
break

return params_saved
Expand Down Expand Up @@ -790,31 +807,13 @@ def predict_extended(original_feature,

def evaluate(args, last=True):
ctx_l = parse_ctx(args.gpus)
logging.info('Srarting inference without horovod')

cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout)
# Prepare dev set
dev_cache_path = os.path.join(CACHE_PATH,
'dev_{}_squad_{}.ndjson'.format(args.model_name,
args.version))
if os.path.exists(dev_cache_path) and not args.overwrite_cache:
dev_features = []
with open(dev_cache_path, 'r') as f:
for line in f:
dev_features.append(SquadFeature.from_json(line))
logging.info('Found cached dev features, load from {}'.format(dev_cache_path))
else:
dev_examples = get_squad_examples(args.data_dir, segment='dev', version=args.version)
start = time.time()
num_process = min(cpu_count(), 8)
logging.info('Tokenize Dev Data:')
with Pool(num_process) as pool:
dev_features = pool.map(functools.partial(convert_squad_example_to_feature,
tokenizer=tokenizer,
is_training=False), dev_examples)
logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start))
with open(dev_cache_path, 'w') as f:
for feature in dev_features:
f.write(feature.to_json() + '\n')

logging.info('Prepare dev data')
dev_features = get_squad_features(args, tokenizer, segment='dev')
dev_data_path = os.path.join(args.data_dir, 'dev-v{}.json'.format(args.version))
dataset_processor = SquadDatasetProcessor(tokenizer=tokenizer,
doc_stride=args.doc_stride,
Expand All @@ -831,8 +830,6 @@ def eval_validation(ckpt_name, best_eval):
"""
Model inference during validation or final evaluation.
"""
ctx_l = parse_ctx(args.gpus)
# We process all the chunk features and also
dev_dataloader = mx.gluon.data.DataLoader(
dev_all_chunk_features,
batchify_fn=dataset_processor.BatchifyFunction,
Expand Down

0 comments on commit 31cb953

Please sign in to comment.