Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s] run_eval.py QOL improvements and cleanup #6746

Merged
merged 5 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions examples/seq2seq/run_eval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import argparse
import json
import time
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict, List

import torch
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


logger = getLogger(__name__)

try:
from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, use_task_specific_params

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -23,44 +29,47 @@ def chunks(lst, n):


def generate_summaries_or_translations(
examples: list,
examples: List[str],
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**gen_kwargs,
) -> None:
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.

# update config with summarization specific params
start_time = time.time()
# update config with task specific params
use_task_specific_params(model, task)

for batch in tqdm(list(chunks(examples, batch_size))):
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
examples_chunk = [model.config.prefix + text for text in examples_chunk]
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id,
**gen_kwargs,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
fout.close()
runtime = time.time() - start_time
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))


def run_generate():
Expand All @@ -70,7 +79,13 @@ def run_generate():
parser.add_argument("save_path", type=str, help="where to save summaries")

parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
Expand All @@ -79,7 +94,7 @@ def run_generate():
type=int,
default=None,
required=False,
help="decoder_start_token_id (otherwise will look at config)",
help="Defaults to using config",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
Expand All @@ -90,7 +105,9 @@ def run_generate():
if args.n_obs > 0:
examples = examples[: args.n_obs]
Path(args.save_path).parent.mkdir(exist_ok=True)
generate_summaries_or_translations(
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
runtime_metrics = generate_summaries_or_translations(
examples,
args.save_path,
args.model_name,
Expand All @@ -107,9 +124,10 @@ def run_generate():
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
scores.update(runtime_metrics)
print(scores)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w+"))
json.dump(scores, open(args.score_path, "w"))
return scores


Expand Down
15 changes: 13 additions & 2 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,24 @@ def _test_distiller_cli(self, updates, check_contents=True):


@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model):
def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles)
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--task",
task,
]
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
Expand Down