In [4]:
"""Contains code to generate training data for GSUM"""

import csv
import json
import random
import argparse
from fuzzywuzzy import fuzz
from typing import Any, Dict, List, Set, Tuple


def group_entities(entities: List[Any]) -> List[Any]:
    adj_list: Dict[Tuple[str, str], Set] = {
        (entity["surface"], entity["type"]): set() for entity in entities
    }

    # Constructs a graph where a node represents an entity and an edge represents if the connected entities are the same.
    for outer_js in entities:
        outer = (outer_js["surface"], outer_js["type"])
        for inner_js in entities:
            inner = (inner_js["surface"], inner_js["type"])
            # Substring matching to determine if an edge should be added between the nodes.
            if (
                outer != inner
                and outer[1] == inner[1]
                and (outer[0] in inner[0] or inner[0] in outer[0])
            ):
                adj_list[outer].add(inner)

    # BFS to get groups of entities
    queue = []
    clusters = []
    visited = {node: False for node in adj_list}
    for node in list(adj_list.keys()):
        if not visited[node]:
            cluster = set()
            queue.append(node)
            while queue:
                node = queue.pop(0)
                visited[node] = True
                cluster.add(node)
                for neighbor in adj_list[node]:
                    if not visited[neighbor]:
                        queue.append(neighbor)

            clusters.append(tuple(cluster))
    return list(set(clusters))


def group_entities_wrapper(entities, sent_mapping):    
    entity_clusters = group_entities(entities)
    final_entities = []
    for cluster in entity_clusters:
        cluster = list(cluster)
        entity = {
             "sentences": set(),
             "surface": [c[0] for c in cluster],
             "type": cluster[0][1],
             "max_surface": max([c[0] for c in cluster], key=len)
        }
        for org_entity in entities:
            key = (org_entity["surface"], org_entity["type"])
            if key in cluster:
                st = org_entity["startCharOffset"]
                en = org_entity["endCharOffset"]
                entity["sentences"].add(sent_mapping[org_entity["startCharOffset"]])
        entity["sentences"] = list(entity["sentences"])
        final_entities.append(entity) 
    return final_entities


def extract_entity_sentences(fname):
    all_info = [json.loads(line.strip()) for line in open(fname)]

    extract = 0
    final_entities_count, init_entities_count = 0, 0 
    abs_final_entities_count, abs_init_entities_count = 0, 0 
    for task_info in all_info:

        sent_mapping = []
        for ind, sent in enumerate(task_info["article_text_sentences"]):
            if ind == 0:
                sent_mapping.extend([ind] * len(sent))
            else:
                sent_mapping.extend([ind] * (len(sent) + 1))
            
        assert len(sent_mapping) >= len(task_info["article_text"])

        entities = task_info["article_text_entities"] 
        final_entities = group_entities_wrapper(entities, sent_mapping)
        task_info["article_text_entities"] = final_entities

        init_entities_count += len(entities)
        final_entities_count += len(final_entities)
        if extract % 5000 == 0:
            print("Entity groups:", extract)
        extract += 1

    print(init_entities_count, final_entities_count, final_entities_count * 100.0 / init_entities_count)
    return all_info


def convert_mention(mention, output, comb_text):
    start = output['subtoken_map'][mention[0]]
    end = output['subtoken_map'][mention[1]] + 1
    nmention = (start, end)
    mtext = ''.join(' '.join(comb_text[mention[0]:mention[1]+1]).split(" ##"))
    return (nmention, mtext)


def get_spanbert_clusters(fname):
    docs = set()
    spanbert_clusters = {}
    for line in open(fname):
        coref_out = json.loads(line.strip())

        comb_text = [word for sentence in coref_out['sentences'] for word in sentence]
        sent_numbers = coref_out["sentence_map"]

        clusters = []
        assert 'predicted_clusters' in coref_out
        for cluster in coref_out['predicted_clusters']:
            mapped_text, mapped_sents = set(), set()
            for mention in cluster:
                _, text = convert_mention(mention, coref_out, comb_text)
                assert sent_numbers[mention[0]] == sent_numbers[mention[1]]
                mapped_text.add(text)
                mapped_sents.add(sent_numbers[mention[0]])

            clusters.append((mapped_sents, mapped_text))

        assert coref_out["doc_id"] not in docs
        docs.add(coref_out["doc_id"])
        spanbert_clusters[coref_out["doc_id"]] = clusters
        
    return spanbert_clusters
    


In [1]:
"""Generate CNNDM / NYT data for GSUM experiments that involved training one entity at a time."""


def generate_data(ner_article, ner_abstract_fname, out_dir, coref_article, coref_abstract, guidance):
    
    all_info = extract_entity_sentences(ner_article)
    random.shuffle(all_info)
    
    all_article_clusters = get_spanbert_clusters(coref_article)
    all_abstract_clusters = get_spanbert_clusters(coref_abstract)
    print("Coreferenced documents: ", len(all_article_clusters), len(all_abstract_clusters))
    print("Input file size: ", len(all_info))
    
    ind = 0
    with open(out_dir + ".source", 'w') as fps, open(out_dir + ".target", 'w') as  fpd, open(out_dir + ".z", 'w') as fpz:
        done = set()
        abstract_num = 0
        for ind, task_info in enumerate(all_info):
            doc_id = int(task_info["doc_id"])
            assert doc_id not in done
            done.add(doc_id)
            
            entities = task_info['article_text_entities']
            doc_sents = task_info['article_text_sentences']
            abstract_sents = task_info['abstract_text_sentences']
            article_clusters = all_article_clusters[doc_id] if doc_id in all_article_clusters else []
            abstract_clusters = all_abstract_clusters[doc_id] if doc_id in all_abstract_clusters else []

            for eind, entity in enumerate(entities): 
                
                if entity["max_surface"] != "CNN":

                    # Sentences in the reference containing the entity and its coreferences.
                    abstract_sent_ids = set()
                    for cluster in abstract_clusters:
                        if any([ename in name for name in cluster[1] for ename in entity["surface"]]):
                            abstract_sent_ids.update(list(cluster[0]))
                    for sent_id, sent in enumerate(abstract_sents):
                        if any([surface in sent and len(surface) > 5 for surface in entity["surface"]]):
                                abstract_sent_ids.add(sent_id)               
                    abstract_sent_ids = sorted(list(abstract_sent_ids))
                    
                    if len(abstract_sent_ids) > 0:
                        abstract_num += 1

                    # Sentences in the source document containing the entity and its coreferences.
                    article_sent_ids = entity["sentences"][:]
                    for cluster in article_clusters:
                        if any([ename in name for name in cluster[1] for ename in entity["surface"]]):
                            article_sent_ids.extend(list(cluster[0]))
                    article_sent_ids = sorted(list(set(article_sent_ids)))

                    # Get upto 3 sentences from the reference that mention the entity. Else take the lead3 of the document.
                    final_sents = []
                    for idx in abstract_sent_ids:
                        if len(final_sents) < 3:
                            final_sents.append(abstract_sents[idx])

                    for idx in article_sent_ids:
                        if len(final_sents) < 3 and fuzz.ratio(" ".join(final_sents), doc_sents[idx]) < 60:
                            final_sents.append(doc_sents[idx])
                        
                    assert 0 < len(final_sents) <= 3
                    final_sents_text = " ".join(final_sents)
                    if len(final_sents) > 0 and " . " not in final_sents_text:
                        fps.write(" ".join(doc_sents) + "\n")
                        if guidance == "entity":
                            fpz.write(" | ".join(entity["surface"]) + "\n")
                        elif guidance == "lead3":
                            fpz.write(" ".join([doc_sents[idx] for idx in article_sent_ids[:3]]) + "\n")
                        else:
                            fpz.write(" ".join([doc_sents[idx] for idx in article_sent_ids]) + "\n")
                        fpd.write(" ".join(final_sents) + "\n")         
                
            if ind % 5000 == 0:
                print(ind)
                print(entity["surface"])
                print(final_sents)
                if len(abstract_sent_ids) > 0:
                    print(doc_sents[abstract_sent_ids[0]])
                print()
                
        print(abstract_num)


if __name__ == "__main__":
    GUIDANCE = "entity"
    TYPE = "one_entity_name_guidance"
    
    generate_data("cnndm/ner/train.jsonl", 
                  "cnndm/ner/train.jsonl",
                  "cnndm/gsum/" + TYPE + "/raw/train", 
                  "cnndm/coref/output/train_article.jsonl", 
                  "cnndm/coref/output/train_abstract.jsonl", GUIDANCE)
    
    generate_data("cnndm/ner/val.jsonl", 
                  "cnndm/ner/val.jsonl",
                  "cnndm/gsum/" + TYPE + "/raw/val", 
                  "cnndm/coref/output/val_article.jsonl", 
                  "cnndm/coref/output/val_abstract.jsonl", GUIDANCE)
    
    generate_data("cnndm/ner/val.jsonl", 
                  "cnndm/ner/val.jsonl",
                  "cnndm/gsum/" + TYPE + "/raw/test", 
                  "cnndm/coref/output/val_article.jsonl", 
                  "cnndm/coref/output/val_abstract.jsonl", GUIDANCE)