In [1]:
import os
import json
from datasets import load_from_disk, load_dataset, load_metric
import pickle
from tqdm import tqdm
import numpy as np
from nltk.corpus import stopwords

In [3]:
from transformers import AutoTokenizer
from dataset import CitationTextGenerationRAGDataset

In [4]:
stops = set(stopwords.words('english'))
rouge = load_metric("rouge")

In [5]:
def clean_stop(text):
    cleaned = []
    for word in text.split():
        if word not in stops:
            cleaned.append(word)
    return " ".join(cleaned)

In [6]:
def remove_citation_marks(text, citations):
    text = text.replace("\n"," ")
    for citation in citations.split("#"):
        for token in citation.split():
            text = text.replace(token,"")
    return text

In [7]:
def ROUGE(predictions, references, use_stemmer=True):
    result = rouge.compute(predictions=predictions, references=references, use_stemmer=use_stemmer)
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}

In [8]:
def ROUGE_recall(predictions, references, use_stemmer=True):
    result = rouge.compute(predictions=predictions, references=references, use_stemmer=use_stemmer)
    return {k: round(v.mid.recall * 100, 4) for k, v in result.items()}

In [37]:
passages_path = os.path.join("cited_text_embeddings_sentence_better", "cited_papers") 
#passages_path = "cited_text_embeddings_citation_mark/cited_papers"
cited_dataset = load_from_disk(passages_path)

In [38]:
outputs = []
with open("LED_sentence_CTS_citation_span_generation.jsonl") as f:
    for line in f:
        outputs.append(json.loads(line))

In [40]:
keyword_ids = {}
# sorted_sentence_ROUGE_test
# abstract_test_doc_ids
# test_retrieved_sentence_CTS
with open("test_retrieved_sentence_CTS.jsonl") as f:
    for line in f:
        obj = json.loads(line)
        keyword_ids[obj["id"]] = obj["cited_indices"][:10]

In [41]:
use_stemmer = True
cleaned_predictions = []
cleaned_targets = []
predictions = []
targets = []
prediction_no_stop = []
target_no_stop = []
all_retrieved = []
for candidate in tqdm(outputs):
    prediction = remove_citation_marks(candidate["prediction"][0], candidate["citations"])
    #prediction = candidate["prediction"][0].replace("\n"," ")
    target = remove_citation_marks(candidate["target"], candidate["citations"])
    #target = candidate["gold_label"][0]
    cleaned_predictions.append(prediction)
    cleaned_targets.append(target)
    predictions.append(candidate["prediction"][0])
    targets.append(candidate["target"])
    all_retrieved_texts = []
    for idx in keyword_ids[candidate["id"]]:
        if idx >= 0:
            cited_text = cited_dataset[idx]["text"]
            all_retrieved_texts.append(cited_text)
    cleaned_target = clean_stop(candidate["target"])
    target_no_stop.append(cleaned_target)
    cleaned_prediction = clean_stop(candidate["prediction"][0])
    prediction_no_stop.append(cleaned_prediction)
    all_retrieved.append(" ".join(all_retrieved_texts))


100%|██████████████████████████████████████████████████████████████| 1206/1206 [00:06<00:00, 196.27it/s]


In [42]:
print(ROUGE(predictions, targets, use_stemmer=use_stemmer))
print(ROUGE(cleaned_predictions, cleaned_targets, use_stemmer=use_stemmer))
print(ROUGE_recall(all_retrieved, prediction_no_stop, use_stemmer=use_stemmer))

INFO:absl:Using default tokenizer.
INFO:absl:Using default tokenizer.


{'rouge1': 36.598, 'rouge2': 17.7807, 'rougeL': 31.2337, 'rougeLsum': 32.0303}


INFO:absl:Using default tokenizer.


{'rouge1': 24.4983, 'rouge2': 6.7446, 'rougeL': 19.4368, 'rougeLsum': 19.4503}
{'rouge1': 75.4517, 'rouge2': 39.171, 'rougeL': 62.6386, 'rougeLsum': 62.6028}
