In [1]:
import os
import json
from datasets import load_from_disk, load_dataset, load_metric
import pickle
from rouge_score import rouge_scorer
from tqdm import tqdm
import numpy as np
from nltk.corpus import stopwords
from nltk import word_tokenize
from collections import Counter
import string
import random

In [2]:
from transformers import AutoTokenizer
from dataset import CitationTextGenerationRAGDataset

In [3]:
stops = set(stopwords.words('english')).union(set(string.punctuation))
rouge = load_metric("rouge")

In [4]:
def clean_stop(text):
    cleaned = []
    for word in text.split():
        if word not in stops:
            cleaned.append(word)
    return " ".join(cleaned)

In [5]:
def remove_stop(words):
    cleaned = []
    for word in words:
        word = word.lower()
        if word not in stops:
            cleaned.append(word)
    return cleaned

In [6]:
def remove_citation_marks(text, citations):
    for citation in citations.split("#"):
        text = text.replace("\n"," ").replace(citation,"")
    return text

In [7]:
def highlight_words(cited_text, highlights):
    processed_words = []
    for word in word_tokenize(cited_text):
        if word.lower() in highlights:
            word = "<b><i>"+word+"</i></b>"
        processed_words.append(word)
    return " ".join(processed_words) + "<br>"

In [8]:
passages_path = os.path.join("cited_text_embeddings_sentence_better", "cited_papers")
sentence_cited_dataset = load_from_disk(passages_path)
passages_path = "cited_text_embeddings_citation_mark/cited_papers"
paragraph_cited_dataset = load_from_disk(passages_path)

In [9]:
exception = '173188413_1_0_2@52113465'

In [10]:
prediction_files = {
    "fid_abstract": "FiD_CTS_RAG_span_generation_retrieved_abstract.jsonl",
    #"led_abstract": "LED_abstract_citation_span_generation.jsonl",
    "fid_context": "FiD_CTS_RAG_span_generation_sentence_pre_retrieval.jsonl",
    #"led_context":"LED_sentence_CTS_citation_span_generation.jsonl",
    "fid_oracle":"FiD_CTS_RAG_span_generation_sentence_oracle.jsonl",
    #"led_oracle":"LED_oracle_sentence_CTS_citation_span_generation.jsonl",
    "fid_keyword":"FiD_CTS_RAG_span_generation_sentence_keyword.jsonl",
    #"led_keyword": "keyword_sentence_CTS_citation_span_generation.jsonl",
}

In [11]:
all_outputs = {}
for config, file in prediction_files.items():
    outputs = []
    with open(file) as f:
        for line in f:
            outputs.append(json.loads(line))
    all_outputs[config] = outputs

In [12]:
integrated_predictions = {}
for config, outputs in all_outputs.items():
    for example in outputs:
        this_example = integrated_predictions.get(example["id"],{})
        if "led_" in config:
            pred = example["prediction"][0].replace("\n"," ")
        else:
            pred = example["prediction"].replace("\n"," ")
        this_example[config] = pred
        integrated_predictions[example["id"]] = this_example

In [13]:
integrated_predictions["202766392_0_0_5@19099243"]

{'fid_abstract': 'Bao et al. (2018) proposed a generative model to generate a natural language sentence describing a',
 'fid_context': 'Bao et al. (2018) proposed a generative model to generate a natural language sentence describing a table.',
 'fid_oracle': 'Bao et al. (2018) proposed a table-aware decoder to copy from the input.',
 'fid_keyword': 'Bao et al. (2018) proposed a copy mechanism that can copy from both the cells and attributes.'}

In [15]:
abstract_ids = {example['id']: example['retrieved_doc_ids'] for example in all_outputs["fid_abstract"]}

In [16]:
context_ids = {example['id']: example['retrieved_doc_ids'] for example in all_outputs["fid_context"]}

In [17]:
oracle_ids = {example['id']: example['retrieved_doc_ids'] for example in all_outputs["fid_oracle"]}

In [18]:
keyword_ids = {example['id']: example['retrieved_doc_ids'] for example in all_outputs["fid_keyword"]}

In [19]:
contexts = {}
for example in all_outputs["fid_oracle"]:
    context = example['source'].split("[E_Dominant]")[-1].split("[E_Reference]")[-1]
    context = context.split("[Dominant]")[0].replace("\n"," ").strip()
    contexts[example['id']] = context

In [20]:
targets = {}
for example in all_outputs["fid_oracle"]:
    target = example['target']
    targets[example['id']] = target

In [21]:
citations = {}
for example in all_outputs["fid_oracle"]:
    citations[example['id']] = example['citations']

In [22]:
integrated_outputs = {}
for ID, preds in integrated_predictions.items():
    preds["target"] = targets[ID]
    integrated_outputs[ID] = {
        "predictions": preds,
        "context": contexts[ID],
        "abstract_ids": abstract_ids[ID],
        "context_ids": context_ids[ID],
        "oracle_ids": oracle_ids[ID],
        "keyword_ids": keyword_ids[ID],
        "citations": citations[ID],
    }

In [23]:
configs = [
    "target",
    "fid_abstract",
    #"led_abstract",
    "fid_context",
    #"led_context",
    "fid_oracle",
    #"led_oracle",
    "fid_keyword",
    #"led_keyword"
]

In [24]:
base_dir = "human_evaluation_html/"

In [None]:
del integrated_outputs[exception]

In [None]:
N = 50

In [None]:
all_sampled_ids = random.sample(list(integrated_outputs.keys()),N)
all_sampled_ids = sorted(all_sampled_ids)

In [None]:
sampled_ids = all_sampled_ids[:25]

In [None]:
order_ids = {}
for ID in tqdm(sampled_ids):
    output = integrated_outputs[ID]
    concat_preds = []
    for k,v in output["predictions"].items():
        concat_preds.append(v)
    citation_marks = output["citations"]
    pred = " ".join(concat_preds)
    highlights = set(remove_stop(word_tokenize(remove_citation_marks(pred, citation_marks))))
    
    retrieved_sentence_ids = set([])
    for name in ["context_ids", "oracle_ids", "keyword_ids"]:
        for idx in output[name]:
            if idx >= 0:
                retrieved_sentence_ids.add(idx)
    retrieved_sentence_ids = sorted(list(retrieved_sentence_ids))
    sentence_CTS = []
    for si in retrieved_sentence_ids:
        cited_text = sentence_cited_dataset[si]["text"]
        author = sentence_cited_dataset[si]["title"]
        highlighted_sentence = highlight_words(cited_text, highlights)
        sentence_CTS.append(author + " ## " + highlighted_sentence)
    output["sentence_CTS"] = sentence_CTS
    
    retrieved_paragraph_ids = set([])
    for idx in output["abstract_ids"]:
        if idx >= 0:
            retrieved_paragraph_ids.add(idx)
    retrieved_paragraph_ids = sorted(list(retrieved_paragraph_ids))
    abstracts = []
    for pi in retrieved_paragraph_ids:
        cited_text = paragraph_cited_dataset[pi]["text"]
        author = paragraph_cited_dataset[pi]["title"]
        highlighted_abstract = highlight_words(cited_text, highlights)
        abstracts.append(author + " ## " +highlighted_abstract)
    output["abstracts"] = abstracts
    config_indices = [i for i in range(len(configs))]
    random.shuffle(config_indices)
    shuffle_order = "".join([str(i) for i in config_indices])
    order_ids[ID] = config_indices
    with open(base_dir + ID+".html","w") as f:
        f.write("Span ID: "+ID+"<br>")
        f.write("<h2>Context (Up to 2 sentences before the target citation):</h2>")
        f.write(output["context"])
        f.write("<h2>Randomized System Outputs:</h2>")
        for i, idx in enumerate(config_indices):
            config = configs[idx]
            f.write(str(i+1)+". "+output["predictions"][config]+"<br>")
        f.write("<h2>Cited Abstracts:</h2>")
        for abstract in abstracts:
            f.write(abstract)
        f.write("<h2>Retrieved Body Sentences:</h2>")
        for cts in sentence_CTS:
            f.write(cts)
with open("group2_configs.json","w") as f:
    json.dump(order_ids,f)

In [None]:
with open("YourName.csv","w") as f:
    f.write(",".join(["id","system#","fluency","coherence","relevance","overall"])+"\n")
    for ID in sampled_ids:
        for i in range(len(configs)):
            f.write(",".join([ID,str(i+1),"","","",""])+"\n")
        f.write(",".join(["","","","","",""])+"\n")

In [25]:
with open("group1_configs.json") as f:
    group_config = json.load(f)

In [26]:
with open("group2_configs.json") as f:
    group_config = json.load(f)

In [None]:
for ID, config_indices in tqdm(group_config.items()):
    output = integrated_outputs[ID]
    concat_preds = []
    for k,v in output["predictions"].items():
        concat_preds.append(v)
    citation_marks = output["citations"]
    pred = " ".join(concat_preds)
    highlights = set(remove_stop(word_tokenize(remove_citation_marks(pred, citation_marks))))
    
    retrieved_sentence_ids = set([])
    for name in ["context_ids", "oracle_ids", "keyword_ids"]:
        for idx in output[name]:
            if idx >= 0:
                retrieved_sentence_ids.add(idx)
    retrieved_sentence_ids = sorted(list(retrieved_sentence_ids))
    sentence_CTS = []
    for si in retrieved_sentence_ids:
        cited_text = sentence_cited_dataset[si]["text"]
        author = sentence_cited_dataset[si]["title"]
        highlighted_sentence = highlight_words(cited_text, highlights)
        sentence_CTS.append(author + " ## " + highlighted_sentence)
    output["sentence_CTS"] = sentence_CTS
    
    retrieved_paragraph_ids = set([])
    for idx in output["abstract_ids"]:
        if idx >= 0:
            retrieved_paragraph_ids.add(idx)
    retrieved_paragraph_ids = sorted(list(retrieved_paragraph_ids))
    abstracts = []
    for pi in retrieved_paragraph_ids:
        cited_text = paragraph_cited_dataset[pi]["text"]
        author = paragraph_cited_dataset[pi]["title"]
        highlighted_abstract = highlight_words(cited_text, highlights)
        abstracts.append(author + " ## " +highlighted_abstract)
    output["abstracts"] = abstracts

    shuffle_order = "".join([str(i) for i in config_indices])
    with open(base_dir + ID+".html","w") as f:
        f.write("Span ID: "+ID+"$"+"<br>")
        f.write("<h2>Context (Up to 2 sentences before the target citation):</h2>")
        f.write(output["context"])
        f.write("<h2>Randomized System Outputs:</h2>")
        for i, idx in enumerate(config_indices):
            config = configs[idx]
            if config == "target":
                f.write("<b><i>"+str(i+1)+". "+output["predictions"][config]+"</i></b><br>")
            else:
                f.write(str(i+1)+". "+output["predictions"][config]+"<br>")
        f.write("<h2>Cited Abstracts:</h2>")
        for abstract in abstracts:
            f.write(abstract)
        f.write("<h2>Retrieved Body Sentences:</h2>")
        for cts in sentence_CTS:
            f.write(cts)

In [None]:
with open(base_dir + "YourName_target.csv","w") as f:
    f.write(",".join(["id","system#","fluency","coherence","relevance","overall"])+"\n")
    for ID, config_indices in group_config.items():
        for i in range(len(configs)):
            f.write(",".join([ID+"$",str(i+1),"","","",""])+"\n")
        f.write(",".join(["","","","","",""])+"\n")

In [32]:
with open("group1_configs.json") as f:
    group_config = json.load(f)
with open("group2_configs.json") as f:
    group2_config = json.load(f)
group_config.update(group2_config)

In [33]:
group_config

{'102350797_0_0_1@3144258': [1, 3, 4, 0, 2],
 '128350532_2_0_2@13335042': [3, 4, 0, 2, 1],
 '166228482_2_0_3@8140780': [4, 3, 0, 1, 2],
 '173188413_2_0_0@44130060': [0, 2, 4, 1, 3],
 '173990267_1_0_4@5068376': [0, 4, 3, 1, 2],
 '174798410_0_0_0@21700944': [3, 0, 4, 2, 1],
 '174799580_0_0_4@52068673': [0, 4, 2, 1, 3],
 '184482991_3_0_0@3936688': [4, 1, 2, 0, 3],
 '184483889_2_0_0@4570064@1733167': [0, 4, 3, 2, 1],
 '189761997_0_0_3@7497218': [4, 2, 1, 0, 3],
 '189927896_2_0_1@9192723': [4, 3, 2, 1, 0],
 '195218693_1_0_2@59600034': [2, 3, 4, 0, 1],
 '195504787_0_0_1@21730715': [0, 1, 2, 3, 4],
 '196172757_1_0_2@3792324': [0, 1, 4, 3, 2],
 '196172757_2_0_7@53216389': [1, 4, 0, 3, 2],
 '196180835_0_0_1@3513372@15418780': [2, 3, 0, 1, 4],
 '196182403_1_0_1@67855531': [4, 0, 1, 3, 2],
 '196189186_1_0_4@1238927': [1, 0, 3, 2, 4],
 '196197006_3_0_0@52290656@29151507@59600051': [2, 4, 0, 3, 1],
 '196208296_0_0_1@16538528': [4, 2, 0, 3, 1],
 '197465409_0_0_1@7663461': [0, 3, 2, 1, 4],
 '19818482

In [35]:
for ID, config_indices in tqdm(group_config.items()):
    output = integrated_outputs[ID]
    concat_preds = []
    for k,v in output["predictions"].items():
        concat_preds.append(v)
    citation_marks = output["citations"]
    pred = " ".join(concat_preds)
    highlights = set(remove_stop(word_tokenize(remove_citation_marks(pred, citation_marks))))
    
    retrieved_sentence_ids = set([])
    for name in ["context_ids", "oracle_ids", "keyword_ids"]:
        for idx in output[name]:
            if idx >= 0:
                retrieved_sentence_ids.add(idx)
    retrieved_sentence_ids = sorted(list(retrieved_sentence_ids))
    sentence_CTS = []
    for si in retrieved_sentence_ids:
        cited_text = sentence_cited_dataset[si]["text"]
        author = sentence_cited_dataset[si]["title"]
        highlighted_sentence = highlight_words(cited_text, highlights)
        sentence_CTS.append(author + " ## " + highlighted_sentence)
    output["sentence_CTS"] = sentence_CTS
    
    retrieved_paragraph_ids = set([])
    for idx in output["abstract_ids"]:
        if idx >= 0:
            retrieved_paragraph_ids.add(idx)
    retrieved_paragraph_ids = sorted(list(retrieved_paragraph_ids))
    abstracts = []
    for pi in retrieved_paragraph_ids:
        cited_text = paragraph_cited_dataset[pi]["text"]
        author = paragraph_cited_dataset[pi]["title"]
        highlighted_abstract = highlight_words(cited_text, highlights)
        abstracts.append(author + " ## " +highlighted_abstract)
    output["abstracts"] = abstracts

    shuffle_order = "".join([str(i) for i in config_indices])
    with open(base_dir + ID+".html","w") as f:
        f.write("Span ID: "+ID+"$"+"<br>")
        f.write("<h2>Context (Up to 2 sentences before the target citation):</h2>")
        f.write(output["context"])
        f.write("<h2>Randomized System Outputs:</h2>")
        for i, idx in enumerate(config_indices):
            config = configs[idx]
            f.write(str(i+1)+". "+config+": "+output["predictions"][config]+"<br>")
        f.write("<h2>Cited Abstracts:</h2>")
        for abstract in abstracts:
            f.write(abstract)
        f.write("<h2>Retrieved Body Sentences:</h2>")
        for cts in sentence_CTS:
            f.write(cts)

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 30.61it/s]
