In [3]:
# Preprocess training / val dataset for CNNDM 

import csv
import json
import random
import hashlib


def group_entities(entities: List[Any]) -> List[Any]:
    adj_list: Dict[Tuple[str, str], Set] = {
        (entity["surface"], entity["type"]): set() for entity in entities
    }

    # Constructs a graph where a node represents an entity and an edge represents if the connected entities are the same.
    for outer_js in entities:
        outer = (outer_js["surface"], outer_js["type"])
        for inner_js in entities:
            inner = (inner_js["surface"], inner_js["type"])
            # Substring matching to determine if an edge should be added between the nodes.
            if (
                outer != inner
                and outer[1] == inner[1]
                and (outer[0] in inner[0] or inner[0] in outer[0])
            ):
                adj_list[outer].add(inner)

    # BFS to get groups of entities
    queue = []
    clusters = []
    visited = {node: False for node in adj_list}
    for node in list(adj_list.keys()):
        if not visited[node]:
            cluster = set()
            queue.append(node)
            while queue:
                node = queue.pop(0)
                visited[node] = True
                cluster.add(node)
                for neighbor in adj_list[node]:
                    if not visited[neighbor]:
                        queue.append(neighbor)

            clusters.append(tuple(cluster))
    return list(set(clusters))


def group_entities_wrapper(entities, sent_mapping):    
    entity_clusters = group_entities(entities)
    final_entities = []
    for cluster in entity_clusters:
        cluster = list(cluster)
        entity = {
             "sentences": set(),
             "surface": [c[0] for c in cluster],
             "type": cluster[0][1],
             "max_surface": max([c[0] for c in cluster], key=len)
        }
        for org_entity in entities:
            key = (org_entity["surface"], org_entity["type"])
            if key in cluster:
                st = org_entity["startCharOffset"]
                en = org_entity["endCharOffset"]
                entity["sentences"].add(sent_mapping[org_entity["startCharOffset"]])
        entity["sentences"] = list(entity["sentences"])
        final_entities.append(entity) 
    return final_entities


def extract_entity_sentences(fname):
    all_info = [json.loads(line.strip()) for line in open(fname)]

    extract = 0
    final_entities_count, init_entities_count = 0, 0 
    abs_final_entities_count, abs_init_entities_count = 0, 0 
    for task_info in all_info:

        sent_mapping = []
        for ind, sent in enumerate(task_info["article_text_sentences"]):
            if ind == 0:
                sent_mapping.extend([ind] * len(sent))
            else:
                sent_mapping.extend([ind] * (len(sent) + 1))
            
        assert len(sent_mapping) >= len(task_info["article_text"])

        entities = task_info["article_text_entities"] 
        final_entities = group_entities_wrapper(entities, sent_mapping)
        task_info["article_text_entities"] = final_entities

        init_entities_count += len(entities)
        final_entities_count += len(final_entities)
        if extract % 5000 == 0:
            print("Entity groups:", extract)
        extract += 1

    print(init_entities_count, final_entities_count, final_entities_count * 100.0 / init_entities_count)
    return all_info


def convert_mention(mention, output, comb_text):
    start = output['subtoken_map'][mention[0]]
    end = output['subtoken_map'][mention[1]] + 1
    nmention = (start, end)
    mtext = ''.join(' '.join(comb_text[mention[0]:mention[1]+1]).split(" ##"))
    return (nmention, mtext)


def get_spanbert_clusters(fname):
    docs = set()
    spanbert_clusters = {}
    abc = 0
    for line in open(fname):
        coref_out = json.loads(line.strip())

        comb_text = [word for sentence in coref_out['sentences'] for word in sentence]
        sent_numbers = coref_out["sentence_map"]

        clusters = []
        assert 'predicted_clusters' in coref_out
        for cluster in coref_out['predicted_clusters']:
            mapped_text, mapped_sents = set(), set()
            for mention in cluster:
                _, text = convert_mention(mention, coref_out, comb_text)
                assert sent_numbers[mention[0]] == sent_numbers[mention[1]]
                mapped_text.add(text)
                mapped_sents.add(sent_numbers[mention[0]])

            clusters.append((mapped_sents, mapped_text))

        assert coref_out["doc_id"] not in docs
        docs.add(coref_out["doc_id"])
        spanbert_clusters[coref_out["doc_id"]] = clusters
        
        if abc % 10000 == 0:
            print(abc)
        abc += 1
    return spanbert_clusters
    
    
def hashhex(s):
    """Returns a heximal formated SHA1 hash of the input string."""
    h = hashlib.sha1()
    h.update(s.encode('utf-8'))
    return h.hexdigest()


In [1]:
def generate_data(ner_article, coref_article, out_dir, out_dir_url, split):
    
    all_info = extract_entity_sentences(ner_article)
    
    all_article_clusters = get_spanbert_clusters(coref_article)
    print("Coreferenced documents: ", len(all_article_clusters))
    print("Input file size: ", len(all_info))
    
    with open(out_dir_url, 'w') as fpfiles:
        done = set()
        for ind, task_info in enumerate(all_info):
            
            doc_id = int(task_info["doc_id"])
            entities = task_info['article_text_entities']
            doc_sents = task_info["article_text_sentences"]
            article_clusters = all_article_clusters[doc_id] if doc_id in all_article_clusters else []
            
            if all([len(sent) < 400 for sent in doc_sents]):
                for eind, entity in enumerate(entities): 
                    if entity["max_surface"] != "CNN":
                        # Sentences in the source document containing the entity and its coreferences.
                        article_sent_ids = entity["sentences"][:]
                        for cluster in article_clusters:
                            if any([ename in name for name in cluster[1] for ename in entity["surface"]]):
                                article_sent_ids.extend(list(cluster[0]))
                        article_sent_ids = sorted(list(set(article_sent_ids)))[:3]
                    
                        story_fname = str(ind) + "-" + str(eind) + "." + split + ".story"
                        story_fname_hexname = hashhex(story_fname)
                        fpfiles.write(story_fname + "\n")
                        assert story_fname_hexname not in done
                        done.add(story_fname_hexname)

                        with open(out_dir + "/" + story_fname_hexname + ".story", "w") as fp:
                            fp.write(" | ".join([name for name in entity["surface"]]) + " =>\n\n")

                            for sent in doc_sents:
                                fp.write(sent.strip() + "\n\n")

                            for sent_id in article_sent_ids:
                                fp.write("@highlight\n\n")
                                fp.write(doc_sents[sent_id].strip() + "\n\n")
                              
                
                if ind % 5000 == 0:
                    print(ind)
                    print(entity["surface"])
                    print([(doc_sents[sent_id], len(doc_sents[sent_id])) for sent_id in article_sent_ids])
                    print()
                

if __name__ == "__main__":
    generate_data("ner/val.jsonl", "coref/output/val_article.jsonl", 
                  "bertsum/lead3_sents/raw_data",
                  "bertsum/lead3_sents/urls/mapping_valid.txt", 
                  "valid")
    
    generate_data("ner/val.jsonl", "coref/output/val_article.jsonl", 
                  "bertsum/lead3_sents/raw_data",
                  "bertsum/lead3_sents/urls/mapping_test.txt", 
                  "test")
    
    generate_data("ner/train.jsonl", "coref/output/train_article.jsonl", 
                  "bertsum/lead3_sents/raw_data",
                  "bertsum/lead3_sents/urls/mapping_train.txt", 
                  "train")
    