## Playing with ROUGE

In [109]:
from tensor2tensor.utils import rouge
import numpy as np
import os

Define ROUGE recall scoring functions since t2t only supplies F1 scoring. We will borrow from `tensor2tensor.utils.rouge` to define these

In [110]:
def rouge_l_sentence_level_recall(eval_sentences, ref_sentences):
  """Computes ROUGE-L (sentence level) of two collections of sentences.
  Source: https://www.microsoft.com/en-us/research/publication/
  rouge-a-package-for-automatic-evaluation-of-summaries/
  Calculated according to:
  R_lcs = LCS(X,Y)/m
  where:
  X = reference summary
  Y = Candidate summary
  m = length of reference summary
  n = length of candidate summary
  Args:
    eval_sentences: The sentences that have been picked by the summarizer
    ref_sentences: The sentences from the reference set
  Returns:
    A float: recall_lcs
  """

  recall_scores = []
  for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
    m = len(ref_sentence)
    lcs = rouge._len_lcs(eval_sentence, ref_sentence)
    R_lcs = lcs / m
    recall_scores.append(R_lcs)
  return np.mean(recall_scores, dtype=np.float32)

def rouge_n_recall(eval_sentences, ref_sentences, n=2):
  """Computes ROUGE-N recall score of two text collections of sentences.
  Source: https://www.microsoft.com/en-us/research/publication/
  rouge-a-package-for-automatic-evaluation-of-summaries/
  Args:
    eval_sentences: The sentences that have been picked by the summarizer
    ref_sentences: The sentences from the reference set
    n: Size of ngram.  Defaults to 2.
  Returns:
    recall score for ROUGE-N
  """

  recall_scores = []
  for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
    eval_ngrams = rouge._get_ngrams(n, eval_sentence)
    ref_ngrams = rouge._get_ngrams(n, ref_sentence)
    ref_count = len(ref_ngrams)

    # Gets the overlapping ngrams between evaluated and reference
    overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if ref_count == 0:
      recall = 0.0
    else:
      recall = overlapping_count / ref_count

    recall_scores.append(recall)

  # return overlapping_count / reference_count
  return np.mean(recall_scores, dtype=np.float32)

### Prepare input

The general structure here is that each sentence should be a list of words. The method is set up to intake multiple sentences, so the overall input structure is a list of lists.

In [111]:
# sample headlines for troubleshooting purposes
# generated = [["Hello", "World"], ["My", "name", "is", "Mark"]]
# reference = [["Hello", "people"], ["Your", "name", "might", "be", "Mark"]]

In [112]:
# actual reading in from decoder/baseline and reference
method = "baseline" # "decoder" or "baseline"

if method == "baseline":
    gen_path = os.path.join("postprocess","baseline.txt")
elif method == "decoder":
    gen_path = os.path.join("decoder","decoder_test_sf.universal_transformer.universal_transformer_base.gavrilov.beam4.alpha0.6.decodes")
else:
    raise("unknown method")
    
generated = []
with open(gen_path) as gen_file:
    line = gen_file.readline()
    while line:
        split_line = line.split()
        if not split_line:
            split_line = [""]
        generated.append(split_line)
        line = gen_file.readline()

ref_path = os.path.join("postprocess","headlines.txt")
reference = []
with open(ref_path) as ref_file:
    line = ref_file.readline()
    while line: 
        reference.append(line.split())
        line = ref_file.readline()

### ROUGE-1 and ROUGE-2 F1 score

In [113]:
rouge.rouge_n(generated, reference, n=1)*100

7.7437616884708405

In [114]:
rouge.rouge_n(generated, reference, n=2)*100

1.6718987375497818

### ROUGE-1 and ROUGE-2 recall score

In [115]:
rouge_n_recall(generated, reference, n=1)*100

8.389227092266083

In [116]:
rouge_n_recall(generated, reference, n=2)*100

1.6853120177984238

### ROUGE-L F1 score

In [117]:
rouge.rouge_l_sentence_level(generated, reference)*100

7.121086120605469

### ROUGE-L recall score

In [118]:
rouge_l_sentence_level_recall(generated, reference)*100

8.160179108381271