In [11]:
from sumtool.storage import get_summaries
from datasets import load_dataset

In [13]:
def load_data(dataset: str, split: str, model_summaries: str):
    data = load_dataset(dataset)[split]

    summaries_by_id = get_summaries(dataset, model_summaries)
    source_docs_by_id = {doc["id"]: doc["document"] for doc in data}

    data = []
    for doc_id, summary in summaries_by_id.items():
        data.append({
            "document": source_docs_by_id[doc_id],
            "summary": summary["summary"],
            "id": doc_id
        })

    return data

data = load_data("xsum", "test", "facebook-bart-large-xsum")
len(data)

Using custom data configuration default
Reusing dataset xsum (/Users/anton164/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)
100%|██████████| 3/3 [00:00<00:00, 130.31it/s]


11334

In [None]:
%%capture
import stanza
from stanza import Document
nlp = stanza.Pipeline('en')

In [35]:
start_idx = 4500
end_idx = start_idx + 500
subset = data[start_idx:end_idx]
len(subset)

500

In [36]:
from preprocessing.make_entity_perturbations import make_perturbations
from tqdm import tqdm

test_data = []

for row in tqdm(subset):
    try:
        source = row["document"]
        gen_summary = row["summary"]

        src_doc = nlp(source)
        src_doc.build_ents()

        tgt_doc = nlp(gen_summary)
        tgt_doc.build_ents()

        neg_examples, changed_list = make_perturbations(target_text=tgt_doc._text,
                                                        target_ents=tgt_doc.ents,
                                                        source_ents=src_doc.ents,
                                                        is_training_mode=False,
                                                        max_perturbation_per_example=10)
        
        test_data.append({
            "source_text": source,
            "positive_examples": [gen_summary],
            "negative_examples": neg_examples,
            "changed_list": changed_list
        }) 
    except: 
        print("Failed to process", row["id"])

100%|██████████| 500/500 [39:28<00:00,  4.74s/it]  


In [37]:
import json
with open(f"data/inference/test.bart.{start_idx}-{end_idx}.jsonl", "w") as f:
    for row in test_data:
        f.write(json.dumps(row) + "\n")