Skip to content

Commit

Permalink
[s2s] rougeLSum expects \n between sentences (#7410)
Browse files Browse the repository at this point in the history
Co-authored-by: Swetha Mandava <smandava@nvidia.com>
  • Loading branch information
sshleifer and Swetha Mandava committed Sep 27, 2020
1 parent eab5f59 commit 7296fea
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/requirements.txt
Expand Up @@ -11,6 +11,7 @@ git-python==1.0.3
faiss-cpu
streamlit
elasticsearch
nltk
pandas
datasets
fire
Expand Down
17 changes: 17 additions & 0 deletions examples/seq2seq/rouge_cli.py
@@ -0,0 +1,17 @@
import fire

from utils import calculate_rouge, save_json


def calculate_rouge_path(pred_path, tgt_path, save_path=None, **kwargs):
"""Kwargs will be passed to calculate_rouge"""
pred_lns = [x.strip() for x in open(pred_path).readlines()]
tgt_lns = [x.strip() for x in open(tgt_path).readlines()][: len(pred_lns)]
metrics = calculate_rouge(pred_lns, tgt_lns, **kwargs)
if save_path is not None:
save_json(metrics, save_path)
return metrics # these print nicely


if __name__ == "__main__":
fire.Fire(calculate_rouge_path)
3 changes: 2 additions & 1 deletion examples/seq2seq/run_eval_search.py
Expand Up @@ -7,13 +7,14 @@
from collections import OrderedDict

from run_eval import datetime_now, run_generate
from utils import ROUGE_KEYS


# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"],
"summarization": ROUGE_KEYS,
}


Expand Down
21 changes: 21 additions & 0 deletions examples/seq2seq/sentence_splitter.py
@@ -0,0 +1,21 @@
import re


try:
import nltk

NLTK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
NLTK_AVAILABLE = False

if NLTK_AVAILABLE:
try:
nltk.download("punkt", quiet=True)
except FileExistsError: # multiprocessing race condition
pass


def add_newline_to_end_of_each_sentence(x: str) -> str:
re.sub("<n>", "", x) # remove pegasus newline char
assert NLTK_AVAILABLE, "nltk must be installed to separate newlines betwee sentences. (pip install nltk)"
return "\n".join(nltk.sent_tokenize(x))
80 changes: 80 additions & 0 deletions examples/seq2seq/test_calculate_rouge.py
@@ -0,0 +1,80 @@
from collections import defaultdict
from pathlib import Path

import pandas as pd

from rouge_cli import calculate_rouge_path
from utils import calculate_rouge


PRED = [
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
]

TGT = [
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
]


def test_disaggregated_scores_are_determinstic():
no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"])
assert isinstance(no_aggregation, defaultdict)
no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"])
assert (
pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean()
== pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean()
)


def test_newline_cnn_improvement():
k = "rougeLsum"
score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k]
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k]
assert score > score_no_sep


def test_newline_irrelevant_for_other_metrics():
k = ["rouge1", "rouge2", "rougeL"]
score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k)
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k)
assert score_sep == score_no_sep


def test_single_sent_scores_dont_depend_on_newline_sep():
pred = [
"Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.",
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .',
]
tgt = [
"Margot Frank, died in 1945, a month earlier than previously thought.",
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
]
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)


def test_pegasus_newline():

pred = [
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
]
tgt = [
""" Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says ."""
]

prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"]
new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"]
assert new_score > prev_score


def test_rouge_cli():
data_dir = Path("examples/seq2seq/test_data/wmt_en_ro")
metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target"))
assert isinstance(metrics, dict)
metrics_default_dict = calculate_rouge_path(
data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False
)
assert isinstance(metrics_default_dict, defaultdict)
4 changes: 2 additions & 2 deletions examples/seq2seq/test_seq2seq_examples.py
Expand Up @@ -20,7 +20,7 @@
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from utils import label_smoothed_nll_loss, lmap, load_json
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json


logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_run_eval_search(model):
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
expected_strings.extend(ROUGE_KEYS)
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
Expand Down
64 changes: 53 additions & 11 deletions examples/seq2seq/utils.py
Expand Up @@ -18,6 +18,7 @@
from torch import nn
from torch.utils.data import Dataset, Sampler

from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer
from transformers.file_utils import cached_property

Expand Down Expand Up @@ -378,19 +379,63 @@ def get_git_info():
return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]


def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict


def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
Args:
pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores)

result = aggregator.aggregate()
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
if bootstrap_aggregation:
result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}

else:
return aggregator._scores # here we return defaultdict(list)


# Utilities for freezing parameters and checking whether they are frozen
Expand Down Expand Up @@ -423,9 +468,6 @@ def assert_not_all_frozen(model):
assert any(model_grads), f"none of {npars} weights require grad"


# CLI Parsing utils


def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Expand Down

0 comments on commit 7296fea

Please sign in to comment.