Skip to content

Commit

Permalink
Revert "[s2s] rougeLSum expects \n between sentences (huggingface#7410)"
Browse files Browse the repository at this point in the history
This reverts commit cb09fa6.
  • Loading branch information
fabiocapsouza committed Nov 15, 2020
1 parent d216f1d commit a781224
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 176 deletions.
1 change: 0 additions & 1 deletion examples/requirements.txt
Expand Up @@ -11,7 +11,6 @@ git-python==1.0.3
faiss-cpu
streamlit
elasticsearch
nltk
pandas
datasets
fire
Expand Down
17 changes: 0 additions & 17 deletions examples/seq2seq/rouge_cli.py

This file was deleted.

3 changes: 1 addition & 2 deletions examples/seq2seq/run_eval_search.py
Expand Up @@ -7,14 +7,13 @@
from collections import OrderedDict

from run_eval import datetime_now, run_generate
from utils import ROUGE_KEYS


# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"summarization": ROUGE_KEYS,
"summarization": ["rouge1", "rouge2", "rougeL"],
}


Expand Down
21 changes: 0 additions & 21 deletions examples/seq2seq/sentence_splitter.py

This file was deleted.

80 changes: 0 additions & 80 deletions examples/seq2seq/test_calculate_rouge.py

This file was deleted.

4 changes: 2 additions & 2 deletions examples/seq2seq/test_seq2seq_examples.py
Expand Up @@ -20,7 +20,7 @@
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
from utils import label_smoothed_nll_loss, lmap, load_json


logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_run_eval_search(model):
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(ROUGE_KEYS)
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
Expand Down
64 changes: 11 additions & 53 deletions examples/seq2seq/utils.py
Expand Up @@ -18,7 +18,6 @@
from torch import nn
from torch.utils.data import Dataset, Sampler

from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer
from transformers.file_utils import cached_property

Expand Down Expand Up @@ -379,63 +378,19 @@ def get_git_info():
return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict


def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
Args:
pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores)

if bootstrap_aggregation:
result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
aggregator.add_scores(scores)

else:
return aggregator._scores # here we return defaultdict(list)
result = aggregator.aggregate()
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}


# Utilities for freezing parameters and checking whether they are frozen
Expand Down Expand Up @@ -468,6 +423,9 @@ def assert_not_all_frozen(model):
assert any(model_grads), f"none of {npars} weights require grad"


# CLI Parsing utils


def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Expand Down

0 comments on commit a781224

Please sign in to comment.