Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples/s2s] add test set predictions #10085

Merged
merged 3 commits into from
Feb 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 62 additions & 3 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,22 @@ class DataTrainingArguments:
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
source_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."})
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -336,8 +349,13 @@ def main():
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
elif training_args.do_eval:
column_names = datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return

# For translation we set the codes of our source and target languages (only useful for mBART, the others will
# ignore those attributes).
Expand Down Expand Up @@ -440,6 +458,19 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_predict:
max_target_length = data_args.val_max_target_length
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)

# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
Expand Down Expand Up @@ -523,7 +554,7 @@ def compute_metrics(eval_preds):
if training_args.do_eval:
logger.info("*** Evaluate ***")

results = trainer.evaluate()
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)

output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
Expand All @@ -533,6 +564,34 @@ def compute_metrics(eval_preds):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")

if training_args.do_predict:
logger.info("*** Test ***")

test_results = trainer.predict(
test_dataset,
metric_key_prefix="test",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics

output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_test_result_file, "w") as writer:
logger.info("***** Test results *****")
for key, value in sorted(test_metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")

if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
test_preds = [pred.strip() for pred in test_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))

return results


Expand Down