 # Text Generation Evaluation - Example

In [1]:
!pip install transformers
!pip install pytorch_transformers
!pip install datasets
!pip install bert_score

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.19.3-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 8.8 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 58.2 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 6.7 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 60.8 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstallin

In [None]:
import torch

import pandas as pd
import numpy as np

import datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Generate Translations Using a Transformer Model

1. Load a german to english transformer model
    * Model: https://huggingface.co/google/bert2bert_L-24_wmt_de_en
    * Model has already been finetuned on an english to german translation dataset
2. Load a test dataset:
    * Dataset: https://huggingface.co/datasets/wmt16
    * wmt16 dataset (this is the same dataset as the model has been finetuned on)
    * Available on huggingface (via `datasets` library)
    * Contains sentences in multiple languages
3. Tokenize german sentence
4. Generate english sentence

In [None]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/bert2bert_L-24_wmt_de_en", pad_token="<pad>", eos_token="</s>", bos_token="<s>")
model = AutoModelForSeq2SeqLM.from_pretrained("google/bert2bert_L-24_wmt_de_en")
model.to(device)

Downloading:   0%|          | 0.00/3.36k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/263k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading:   0%|          | 0.00/2.88G [00:00<?, ?B/s]

EncoderDecoderModel(
  (encoder): BertGenerationEncoder(
    (embeddings): BertGenerationEmbeddings(
      (word_embeddings): Embedding(31950, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    

In [None]:
# load german to english dataset
wmt16 = datasets.load_dataset("wmt16", "de-en")

Downloading builder script:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.38k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/8.73k [00:00<?, ?B/s]

Downloading and preparing dataset wmt16/de-en (download: 1.57 GiB, generated: 1.28 GiB, post-processed: Unknown size, total: 2.85 GiB) to /root/.cache/huggingface/datasets/wmt16/de-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/658M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/919M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/75.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split:   0%|          | 0/4548885 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2169 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2999 [00:00<?, ? examples/s]

Dataset wmt16 downloaded and prepared to /root/.cache/huggingface/datasets/wmt16/de-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
n_examples = 10
input = [wmt16["test"][i]["translation"]["de"] for i in range(n_examples)]
references = [wmt16["test"][i]["translation"]["en"] for i in range(n_examples)]

In [None]:
input

['Obama empfängt Netanyahu',
 'Das Verhältnis zwischen Obama und Netanyahu ist nicht gerade freundschaftlich.',
 'Die beiden wollten über die Umsetzung der internationalen Vereinbarung sowie über Teherans destabilisierende Maßnahmen im Nahen Osten sprechen.',
 'Bei der Begegnung soll es aber auch um den Konflikt mit den Palästinensern und die diskutierte Zwei-Staaten-Lösung gehen.',
 'Das Verhältnis zwischen Obama und Netanyahu ist seit Jahren gespannt.',
 'Washington kritisiert den andauernden Siedlungsbau Israels und wirft Netanyahu mangelnden Willen beim Friedensprozess vor.',
 'Durch den von Obama beworbenen Deal um das iranische Atomprogramm hat sich die Beziehung der beiden weiter verschlechtert.',
 'Im März hatte Netanyahu auf Einladung der Republikaner vor dem US-Kongress eine umstrittene Rede gehalten, die teils als Affront gegen Obama gewertet wurde.',
 'Die Rede war mit Obama nicht abgesprochen, ein Treffen hatte dieser mit Hinweis auf die seinerzeit bevorstehende Wahl in Is

In [None]:
references

['Obama receives Netanyahu',
 'The relationship between Obama and Netanyahu is not exactly friendly.',
 "The two wanted to talk about the implementation of the international agreement and about Teheran's destabilising activities in the Middle East.",
 'The meeting was also planned to cover the conflict with the Palestinians and the disputed two state solution.',
 'Relations between Obama and Netanyahu have been strained for years.',
 'Washington criticises the continuous building of settlements in Israel and accuses Netanyahu of a lack of initiative in the peace process.',
 "The relationship between the two has further deteriorated because of the deal that Obama negotiated on Iran's atomic programme, .",
 'In March, at the invitation of the Republicans, Netanyahu made a controversial speech to the US Congress, which was partly seen as an affront to Obama.',
 'The speech had not been agreed with Obama, who had rejected a meeting with reference to the election that was at that time impen

In [None]:
# tokenize german sentences
input_ids = tokenizer(
    input,
    return_tensors="pt",
    add_special_tokens=False,
    padding=True,
    truncation=True).input_ids.to(device)
# generate english translation
output_ids = model.generate(input_ids)
# decode translation
translations = [
    tokenizer.decode(out_ids, skip_special_tokens=True) for out_ids in output_ids
]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
translations

['amama received netanyahu',
 'the relationship between obama and netanyahu is not friendly.',
 'the two wanted to speak about the implementation of the international agreement and about teherans destabilizing measures in the near east.',
 'At the meeting, however, there will also be the conflict with the Palatinians and the two-state-unit discussed.',
 'the relationship between obama and netanyahu has been tense for years.',
 "Tonton criticizes Israel's ongoing settlement-building and accuses netanyahu of lack of will in the peace process.",
 "The two men's relationship has deteriorated further due to obama's proposed deal with Iran's nuclear program.",
 'At the marz, netanyahu had, at the invitation of the republicans, made a controversial speech in front of the us-congress, which was partly interpreted as an affront to obama.',
 'The hotel was very clean and the staff were friendly and helpful. The room was clean and comfortable.',
 "in an emergency call professor confesses that he 

## 2. Use Metric to evaluate generated tranlation

* Metric: Bert Score using the huggingface wrapper: https://huggingface.co/metrics/bertscore
* https://huggingface.co/docs/datasets/how_to_metrics

In [None]:
bert_score = datasets.load_metric("bertscore")

Downloading builder script:   0%|          | 0.00/2.92k [00:00<?, ?B/s]

In [None]:
print(bert_score.inputs_description)


BERTScore Metrics with the hashcode from a source against one or more references.

Args:
    predictions (list of str): Prediction/candidate sentences.
    references (list of str or list of list of str): Reference sentences.
    lang (str): Language of the sentences; required (e.g. 'en').
    model_type (str): Bert specification, default using the suggested
        model for the target language; has to specify at least one of
        `model_type` or `lang`.
    num_layers (int): The layer of representation to use,
        default using the number of layers tuned on WMT16 correlation data.
    verbose (bool): Turn on intermediate status update.
    idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict.
    device (str): On which the contextual embedding model will be allocated on.
        If this argument is None, the model lives on cuda:0 if cuda is available.
    nthreads (int): Number of threads.
    batch_size (int): Bert score processing batch size,
        at

In [None]:
score = bert_score.compute(
    predictions=translations,
    references=references,
    lang="en"
)

Downloading:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

In [None]:
score

{'f1': [0.9031886458396912,
  0.951582133769989,
  0.955406129360199,
  0.904647946357727,
  0.9517983198165894,
  0.938241720199585,
  0.9381374716758728,
  0.9342036843299866,
  0.8394179940223694,
  0.9090743660926819],
 'hashcode': 'roberta-large_L17_no-idf_version=0.3.11(hug_trans=4.19.2)',
 'precision': [0.8884567618370056,
  0.948771595954895,
  0.9553632140159607,
  0.8928921818733215,
  0.9421710968017578,
  0.931533932685852,
  0.9393641948699951,
  0.9220834374427795,
  0.8543077707290649,
  0.9077369570732117],
 'recall': [0.9184173941612244,
  0.9544093608856201,
  0.9554489850997925,
  0.9167172908782959,
  0.9616243839263916,
  0.9450467824935913,
  0.9369138479232788,
  0.9466466903686523,
  0.8250383138656616,
  0.9104157090187073]}

In [None]:
np.mean(score["precision"]), np.mean(score["recall"]), np.mean(score["f1"])

(0.9182681143283844, 0.9270678758621216, 0.9225698411464691)

## 3. Evaluate metric on a benchmark

### Benchmark: WMT18
* WMT18 was also used in the BERTScore paper
* It is a metric evaluation dataset (Ma et al., 2018), which contains predictions of 149 translation systems across 14 language pairs, gold references, and two types of human judgment scores:
    * Segment-level human judgments assign a score to each reference-candidate pair.
    * System-level human judgments associate each system with a single score based on all pairs in the test set.
* Download links:
    * http://ufallab.ms.mff.cuni.cz/\~bojar/wmt18-metrics-task-package.tgz
    * http://ufallab.ms.mff.cuni.cz/~bojar/wmt18/wmt18-metrics-task-nohybrids.tgz
    * s. https://github.com/Tiiiger/bert_score/blob/master/reproduce/download_wmt18.sh

### Evaluation (Segment level):

* https://github.com/Tiiiger/bert_score/blob/master/reproduce/get_wmt18_seg_results.py

Segment level annotations contain for a sentence and two translation systems, which system produced the better translation

1. For each of the two system pairs:
    1. Compute bert score for better system
    2. Compute bert score for worse system
2. Check how often bert_score better system > bert_score worse system using  Kendall rank correlation (https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient)






In [None]:
# download wmt18 dataset https://github.com/Tiiiger/bert_score/blob/master/reproduce/download_wmt18.sh
!mkdir -p wmt18
!cd wmt18
!wget http://ufallab.ms.mff.cuni.cz/\~bojar/wmt18-metrics-task-package.tgz
!tar -axvf wmt18-metrics-task-package.tgz

!wget http://ufallab.ms.mff.cuni.cz/~bojar/wmt18/wmt18-metrics-task-nohybrids.tgz
!tar -axvf wmt18-metrics-task-nohybrids.tgz

!mv wmt18-metrics-task-nohybrids wmt18-metrics-task-package/input
!mkdir -p wmt18
!mv wmt18-metrics-task-package wmt18

!rm -f *.tgz

--2022-06-02 17:40:24--  http://ufallab.ms.mff.cuni.cz/~bojar/wmt18-metrics-task-package.tgz
Resolving ufallab.ms.mff.cuni.cz (ufallab.ms.mff.cuni.cz)... 195.113.18.181
Connecting to ufallab.ms.mff.cuni.cz (ufallab.ms.mff.cuni.cz)|195.113.18.181|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 234960189 (224M) [application/x-gzip]
Saving to: ‘wmt18-metrics-task-package.tgz’


2022-06-02 17:40:48 (9.80 MB/s) - ‘wmt18-metrics-task-package.tgz’ saved [234960189/234960189]

wmt18-metrics-task-package/
wmt18-metrics-task-package/submissions-as-received/
wmt18-metrics-task-package/submissions-as-received/ITER.seg.score.gz
wmt18-metrics-task-package/submissions-as-received/BLEND.sys.score.gz
wmt18-metrics-task-package/submissions-as-received/CharacTER.testsuites.segment.result.gz
wmt18-metrics-task-package/submissions-as-received/CharacTER.newstest2018.nohybrid.system.result.gz
wmt18-metrics-task-package/submissions-as-received/BLEND.seg.score.gz
wmt18-metrics-task-pac

In [None]:
!ls wmt18/wmt18-metrics-task-package

creating-hybrids     input		README	 source-system-outputs
final-metric-scores  manual-evaluation	results  submissions-as-received


In [None]:
# segment level results
wmt18 = pd.read_csv("wmt18/wmt18-metrics-task-package/manual-evaluation/RR-seglevel.csv", sep=' ')
wmt18

Unnamed: 0,LP,SID,BETTER,WORSE
cs-en,newstest2018,593,CUNI-Transformer.5560,online-B.0
cs-en,newstest2018,344,CUNI-Transformer.5560,online-B.0
cs-en,newstest2018,345,online-G.0,CUNI-Transformer.5560
cs-en,newstest2018,346,uedin.5561,online-A.0
cs-en,newstest2018,342,online-B.0,CUNI-Transformer.5560
...,...,...,...,...
en-zh,newstest2018,1369,UMD.5680,online-F.0
en-zh,newstest2018,1369,Alibaba-ensemble-system-with-reranking.5738,online-F.0
en-zh,newstest2018,1369,Alibaba-ensemble-system-with-reranking.5738,Tencent-ensemble-system.5760
en-zh,newstest2018,1369,Alibaba-ensemble-system-with-reranking.5738,Alibaba-General-System.5743


In [None]:
# use czech to english translations only
wmt18_cs_en = wmt18.loc[wmt18.index == "cs-en", :]

In [None]:
wmt18_cs_en

Unnamed: 0,LP,SID,BETTER,WORSE
cs-en,newstest2018,593,CUNI-Transformer.5560,online-B.0
cs-en,newstest2018,344,CUNI-Transformer.5560,online-B.0
cs-en,newstest2018,345,online-G.0,CUNI-Transformer.5560
cs-en,newstest2018,346,uedin.5561,online-A.0
cs-en,newstest2018,342,online-B.0,CUNI-Transformer.5560
...,...,...,...,...
cs-en,newstest2018,479,online-A.0,online-G.0
cs-en,newstest2018,1368,uedin.5561,online-B.0
cs-en,newstest2018,1368,CUNI-Transformer.5560,online-B.0
cs-en,newstest2018,1368,online-G.0,online-B.0


In [None]:
# output sentences produced by the different translation systems
!ls wmt18/wmt18-metrics-task-package/input/wmt18-metrics-task-nohybrids/system-outputs/newstest2018/cs-en

newstest2018.CUNI-Transformer.5560.cs-en  newstest2018.online-G.0.cs-en
newstest2018.online-A.0.cs-en		  newstest2018.uedin.5561.cs-en
newstest2018.online-B.0.cs-en


In [None]:
systems = list(set(wmt18_cs_en["BETTER"]).union(set(wmt18_cs_en["WORSE"])))
systems

['online-B.0',
 'CUNI-Transformer.5560',
 'online-G.0',
 'uedin.5561',
 'online-A.0']

In [None]:
# read in output for czech to english translation for each system
translations = {}
for system in systems:
    with open(f"wmt18/wmt18-metrics-task-package/input/wmt18-metrics-task-nohybrids/system-outputs/newstest2018/cs-en/newstest2018.{system}.cs-en") as f:
        translations[system] = f.read().split("\n")

In [None]:
translations[systems[0]][:10]

['The Civil Rights Movement issued a travel alert for Missouri',
 '"The NAACP Travel Statement for Missouri, with effect from August 28, 2017, urges African-American travelers, visitors and residents of Missouri to pay increased attention when traveling across the country due to a series of controversial racially-motivated incidents that are currently occurring across the nation," the statement said.',
 "The NAACP has stated that this step has been taken on the basis of Missouri's current laws that make it more difficult to defend racially motivated acts before a court, as well as the conduct of law enforcement bodies that over-target minorities.",
 '"There is a violation of civil rights.',
 'People are stopped by policemen just because of their skin color, they are attacked or killed, "said President of NAACP for Missouri Rod Capel for Kansas City Star.',
 '"We have accumulated the largest number of complaints so far."',
 'Among the incidents cited by the organization were racially mo

In [None]:
# refernce sentences in the different languages
!ls wmt18/wmt18-metrics-task-package/input/wmt18-metrics-task-nohybrids/references

newstest2018-csen-ref.en  newstest2018-zhen-ref.en
newstest2018-deen-ref.en  out_of_domain-entr-ref.tr
newstest2018-encs-ref.cs  out_of_domain-tren-ref.en
newstest2018-ende-ref.de  prepositions_encs-csen-ref.en
newstest2018-enet-ref.et  prepositions_encs-encs-ref.cs
newstest2018-enfi-ref.fi  prepositions_ende-deen-ref.en
newstest2018-enru-ref.ru  prepositions_ende-ende-ref.de
newstest2018-entr-ref.tr  pronoun_evaluation-ende-ref.de
newstest2018-enzh-ref.zh  some_syntax_phenomena-csen-ref.en
newstest2018-eten-ref.en  some_syntax_phenomena-encs-ref.cs
newstest2018-fien-ref.en  tur_morph_pt1-tren-ref.en
newstest2018-ruen-ref.en  wsd-deen-ref.en
newstest2018-tren-ref.en


In [None]:
# read in english reference sentences
with open("wmt18/wmt18-metrics-task-package/input/wmt18-metrics-task-nohybrids/references/newstest2018-csen-ref.en") as f:
    references = f.read().split("\n")

In [None]:
references[:10]

 "The National Association for the Advancement of Colored People has put out an alert for people of color traveling to Missouri because of the state's discriminatory policies and racist attacks.",
 '"The NAACP Travel Advisory for the state of Missouri, effective through August 28th, 2017, calls for African American travelers, visitors and Missourians to pay special attention and exercise extreme caution when traveling throughout the state given the series of questionable, race-based incidents occurring statewide recently, and noted therein," the group\'s statement reads.',
 "A recent Missouri law making it harder for people to win discrimination lawsuits, as well as the state's law enforcement disproportionately targeting minorities prompted the group to issue the travel alert, the NAACP said.",
 '"You have violations of civil rights that are happening to people.',
 'They\'re being pulled over because of their skin color, they\'re being beaten up or killed," the president of the Missou

In [None]:
ref = []
translations_better = []
translations_worse = []
for index, row in wmt18_cs_en.iterrows():
    translations_better += [translations[row["BETTER"]][row["SID"]-1]]
    translations_worse += [translations[row["WORSE"]][row["SID"]-1]]
    ref += [references[row["SID"]-1]]

In [None]:
# compute bert scores
scores_better = bert_score.compute(
    predictions=translations_better,
    references=ref,
    lang="en"
)
scores_worse = bert_score.compute(
    predictions=translations_worse,
    references=ref,
    lang="en"
)

In [None]:
# comput kendall score (based on f1)
# https://github.com/Tiiiger/bert_score/blob/master/reproduce/get_wmt18_seg_results.py#L53

def kendell_score(scores_better, scores_worse):
    total = len(scores_better)
    correct = np.sum(np.array(scores_better) > np.array(scores_worse))
    incorrect = total - correct
    return (correct - incorrect)/total

ks = kendell_score(scores_better["f1"], scores_worse["f1"])

In [None]:
ks

0.4140900195694716

In [None]:
# model used
scores_better["hashcode"]

'roberta-large_L17_no-idf_version=0.3.11(hug_trans=4.19.2)'

* Kendall correlation of reproduction: 0.414
* Reported: 0.404 (I think, s. BERTScore paper table 17 (F_RoBERTa-LARGE)


## 4. Create custom metric using the huggingface wrapper

* Huggingface allows loading metrics using a custom loading scrip: https://huggingface.co/docs/datasets/master/en/how_to_metrics#custom-metric-loading-script
* Template: https://github.com/huggingface/datasets/blob/master/templates/new_metric_script.py

In [2]:
# official repo
!git clone https://github.com/yg211/acl20-ref-free-eval.git

Cloning into 'acl20-ref-free-eval'...
remote: Enumerating objects: 371, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 371 (delta 19), reused 60 (delta 9), pack-reused 294[K
Receiving objects: 100% (371/371), 195.25 KiB | 11.48 MiB/s, done.
Resolving deltas: 100% (135/135), done.


In [7]:
!mkdir supert
!cp -r acl20-ref-free-eval/sentence_transformers supert
!cp -r acl20-ref-free-eval/data .

In [4]:
%%writefile supert/supert.py
"""SUPERT: Towards New Frontiers in Unsupervised Evaluation Metrics for Multi-Document Summarization"""

import functools

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

from tqdm.auto import tqdm

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize 
from gensim.parsing.preprocessing import strip_tags

import datasets


_CITATION = """\
@misc{https://doi.org/10.48550/arxiv.2005.03724,
  doi = {10.48550/ARXIV.2005.03724},
  url = {https://arxiv.org/abs/2005.03724},
  author = {Gao, Yang and Zhao, Wei and Eger, Steffen},
  keywords = {Computation and Language (cs.CL), Information Retrieval (cs.IR), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {SUPERT: Towards New Frontiers in Unsupervised Evaluation Metrics for Multi-Document Summarization},
  publisher = {arXiv},
  year = {2020},
  copyright = {arXiv.org perpetual, non-exclusive license}
}
"""


_DESCRIPTION = """\
Unsupervised multi-document summarization evaluation metric.
"""


_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
    predictions: nested list, where each entry is a list of strings containing at least
        one summary of the respective source document(s)
    source_documents: nested_list, where each entry is a list of strings containing at
        least one source document of the respective summaries
    model_type: sbert model to use (default: "bert-large-nli-stsb-mean-tokens")
    top_n: number of first n sentence which are used as pseudo reference (default: 15)
Returns:
    nested list of the supert scores for each summary
Examples:
    >>> supert = datasets.load_metric("supert")
    >>> source_documents = [
        [
            "Long source document about a topic.",
            "Another long source document about the same topic."
        ],
        [
            "Long source document about another topic.",
            "Another document about the second topic."
        ]
    ]
    >>> predictions = [
        ["A summary of the documents about the first topic."],
        [
            "A summary of the documents about the second topic."
            "Another summary of the documents about the second topic."
        ]
    ]
    >>> results = supert.compute(source_documents=source_documents, predictions=predictions)
    >>> print(results)
    [[0.3677717150290458], [0.6020039781738254, 0.6592496765919262]]
"""


LANGUAGE = "english"
    

def get_ref_sents(source_docs, top_n):
    ref_sents = []
    for doc in source_docs:
        ref_sents.append(doc[:top_n])
    return ref_sents

def get_token_vecs(model, sents, remove_stopwords=True):
    vecs, tokens = model.encode(sents, token_vecs=True)
    vecs = functools.reduce(lambda a, b: a+b.tolist(), vecs, [])
    tokens = functools.reduce(lambda a, b: a+b, tokens)
    assert len(vecs) == len(tokens)
    if remove_stopwords:
        clean_vecs = []
        clean_tokens = []
        mystopwords = list(set(stopwords.words(LANGUAGE)))
        mystopwords.extend(["[cls]","[sep]"])
        for i, t in enumerate(tokens):
            if t.lower() not in mystopwords: 
                clean_vecs.append(vecs[i])
                clean_tokens.append(t)
        assert len(clean_vecs) == len(clean_tokens)
        return np.array(clean_vecs)
    return np.array(vecs)

def get_sbert_score(ref_token_vecs, summ_token_vecs):
    f1_list = []
    for i, rvecs in enumerate(ref_token_vecs):
        r_f1_list = []
        for j, svecs in enumerate(summ_token_vecs):
            sim_matrix = cosine_similarity(rvecs, svecs)
            recall = np.mean(np.max(sim_matrix, axis=1))
            precision = np.mean(np.max(sim_matrix, axis=0))
            f1 = 2. * recall * precision / (recall + precision)
            r_f1_list.append(f1)
        f1_list.append(r_f1_list)
    f1_list = np.array(f1_list)
    scores = []
    for i in range(len(summ_token_vecs)):
        scores.append(np.mean(f1_list[:,i]))
    return scores


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Supert(datasets.Metric):
    """SUPERT: Unsupervised multi-document summarization evaluation metric"""

    def _info(self):
        return datasets.MetricInfo(
            # This is the description that will appear on the metrics page.
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features({
                "predictions": datasets.Sequence(datasets.Value("string")),
                "source_documents": datasets.Sequence(datasets.Value("string"))
            }),
            # Homepage of the metric for documentation
            homepage="https://github.com/yg211/acl20-ref-free-eval",
            # Additional links to the codebase or references
            codebase_urls=["https://github.com/yg211/acl20-ref-free-eval"],
            reference_urls=["https://arxiv.org/abs/2005.03724"]
        )

    def _download_and_prepare(self, dl_manager):
        try:
            nltk.data.find("tokenizers/punkt")
        except LookupError:
            nltk.download("punkt")
        try:
            nltk.data.find("corpora/stopwords")
        except LookupError:
            nltk.download("stopwords")

    def _compute(self, predictions, source_documents, model_type="bert-large-nli-stsb-mean-tokens", top_n=15):
        """Returns the scores"""
        from sentence_transformers import SentenceTransformer
        
        assert len(predictions) == len(source_documents), "predictions and source documents need to be nested list of same length"

        predictions = [list(map(lambda x: sent_tokenize(x, LANGUAGE), s)) for s in predictions]
        source_documents = [list(map(lambda x: sent_tokenize(x, LANGUAGE), s)) for s in source_documents]

        model = SentenceTransformer(model_type)
                   
        scores = []
        for i, source_docs in enumerate(tqdm(source_documents)):
            summaries = predictions[i]
            ref_sents = get_ref_sents(source_docs, top_n)
            ref_vecs = []
            for ref in ref_sents:
                ref_vecs.append(get_token_vecs(model, ref))
            summ_vecs = []
            for summ in summaries:
                summ_vecs.append(get_token_vecs(model, summ))
            scores.append(get_sbert_score(ref_vecs, summ_vecs))
        return scores

Writing supert/supert.py


In [5]:
from datasets import load_metric
# load own implementation
metric = load_metric("supert/supert.py")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [9]:
# read example data from repo
%cd acl20-ref-free-eval
from utils.data_reader import CorpusReader

# read docs and summaries
reader = CorpusReader('data/topic_1')
source_docs = reader()
summaries = reader.readSummaries()

/content/acl20-ref-free-eval


In [13]:
source_documents = [" ".join(s[1]) for s in source_docs]
source_documents

['Ten people were killed and more than 100 injured Wednesday when two commuter trains slammed into each other after one of them hit a car left on the lines by a suicidal man, police said. A murder investigation was launched as the death toll rose following the horrific early morning crash involving three trains and a sports utility vehicle in the heavily-populated Los Angeles suburb of Glendale. Rescuers were picking through the mangled wreckage of the trains that collided with an explosive crash just after 6:00 am (1400 GMT) after one of them derailed as it ploughed into the abandoned vehicle. "We have 10 fatalities so far and well over 100 injured," Glendale Police Chief Randy Adams told reporters, adding that the probe into the accident had turned into a homicide investigation. "This whole incident was started by a deranged individual who was suicidal," he said announcing that the depressed driver, identified as Juan Manuel Alvarez, 26, was unhurt and was in custody. The driver alle

In [14]:
summaries

['Juan Miguel Alvarez, charged with murder with special circumstances in the deaths of 11 Metrolink passengers, slashed and stabbed himself after seeing the horrific train crash he caused, sources close to the investigation said Thursday. Alvarez, 25, despondent over the breakup of his marriage, had planned to kill himself when he drove his green Jeep Grand Cherokee in the path of a Metrolink train, officials said, but he bolted from the vehicle at the last moment. As he watched, southbound Train No.\n',
 'Until Wednesday, Juan Manuel Alvarez was living the average life of an obscure and troubled man.It had to be the worst multicasualty incident Ive been to, Los Angeles Fire Capt. Rick Godinez said.David Morrison, 47, an attorney, was heading to downtown Los Angeles on his regular morning commute. With his tire, apparently caught between the tracks, Mr. Alvarez jumped out of the Jeep and ran. But, Mrs. Alvarez was tracked to a modest home in the north section of Compton. In filing the 

In [18]:
metric.compute(predictions=[summaries], source_documents=[source_documents])

100%|██████████| 1.24G/1.24G [01:05<00:00, 19.1MB/s]


  0%|          | 0/1 [00:00<?, ?it/s]

[[0.47249419506571255,
  0.38462301527605525,
  0.49472848573043066,
  0.5013696327685596,
  0.5108337603817136]]

In [19]:
# original implementation
from ref_free_metrics.supert import Supert

# compute the Supert scores
supert = Supert(source_docs, ref_metric='top15') 
scores = supert(summaries)

In [20]:
scores

[0.4724941877785051,
 0.3846230242320096,
 0.4947284814499466,
 0.5013696401682224,
 0.5108337581248206]