- This notebook serves as the second step: load processed files, then convert them as a task-specific input format
- This is for Span Prediction model (PURE)

In [1]:
import os
import re
import json
from collections import defaultdict, OrderedDict
import statistics

from glob import glob
import ssplit
from tokenization_bert import BasicTokenizer

## Main part - Create two types of datasets for Span Detection Task

In [4]:
# Switch if you want to include all 15 types of entities
processed_file_path = "../data/brat/processed_files"
# processed_file_path = "../data/brat/processed_files_full_entities"

# Processed files for Title+abs and Results section
title_abs_path = os.path.join(processed_file_path, "title_abs")
results_path = os.path.join(processed_file_path, "results")

In [5]:
# Load processed files
ann_files = {"Title-abs":[], "Results":[]}
txt_files = {"Title-abs":[], "Results":[]}

for file_name in os.listdir(title_abs_path):
    if file_name.endswith(".ann"):
        ann_files["Title-abs"].append(file_name)
    elif file_name.endswith(".txt"):
        txt_files["Title-abs"].append(file_name)
        
for file_name in os.listdir(results_path):
    if file_name.endswith(".ann"):
        ann_files["Results"].append(file_name)
    elif file_name.endswith(".txt"):
        txt_files["Results"].append(file_name)

for key in ann_files.keys():
    ann_files[key].sort()
    txt_files[key].sort()

print(len(ann_files["Title-abs"]))
print(len(txt_files["Title-abs"]))

165
165


In [19]:
BASIC_TOKENIZER = BasicTokenizer(do_lower_case=False)

def generate_sentence_boundaries(doc):
    offsets = []
    for start_offset, end_offset in ssplit.regex_sentence_boundary_gen(doc):
        # Skip empty lines
        if doc[start_offset:end_offset].strip():
            while doc[start_offset] == " ":
                start_offset += 1
            while doc[end_offset - 1] == " ":
                end_offset -= 1
            assert start_offset < end_offset
            offsets.append((start_offset, end_offset))
    return offsets

def norm_path(*paths):
    return os.path.relpath(os.path.normpath(os.path.join(os.getcwd(), *paths)))

def make_dirs(*paths):
    os.makedirs(norm_path(*paths), exist_ok=True)

def read_file(filename):
    with open(norm_path(filename), "rb") as f:
        return f.read()

def read_text(filename, encoding=None):
    with open(norm_path(filename), "r", encoding=encoding) as f:
        return f.read()

def write_text(text, filename, encoding="UTF-8"):
    make_dirs(os.path.dirname(filename))
    with open(norm_path(filename), "w", encoding=encoding) as f:
        f.write(text)

def read_lines(filename, encoding=None):
    with open(norm_path(filename), "r", encoding=encoding) as f:
        for line in f:
            yield line.rstrip("\r\n\v")

def write_lines(lines, filename, linesep="\n", encoding="UTF-8"):
    make_dirs(os.path.dirname(filename))
    with open(norm_path(filename), "w", encoding=encoding) as f:
        for line in lines:
            f.write(line)
            f.write(linesep)

def read_json(filename, encoding=None):
    return json.loads(read_text(filename, encoding=encoding))


def write_json(obj, filename, indent=2, encoding="UTF-8"):
    write_text(
        json.dumps(obj, indent=indent, ensure_ascii=False), filename, encoding=encoding
    )
    
def extend_offset(offset, doc, reverse=False):
    if reverse:
        while offset < len(doc) and re.match(r"[^\W_]", doc[offset]):
            offset += 1
    else:
        while offset > 0 and re.match(r"[^\W_]", doc[offset - 1]):
            offset -= 1
    return offset

def parse_standoff_file(standoff_file, text_file, encoding=None):
    assert os.path.exists(standoff_file), "Standoff file not found: " + standoff_file
    assert os.path.exists(text_file), "Text file not found: " + text_file
    
    num_disjoint_spans = 0

    entities = OrderedDict()
    relations = OrderedDict()
    events = OrderedDict()
    modalities = OrderedDict()
    attributes = OrderedDict()

    # Using reference for double-check
    reference = read_text(text_file, encoding=encoding)

    for line in read_lines(standoff_file, encoding=encoding):
        # Trim trailing whitespaces
        line = line.strip()

        if line.startswith(
                "T"
        ):  # Entities (T), Triggers (TR) (are also included in this case)
            entity_id, entity_annotation, entity_reference = line.split("\t")

            entity_id = entity_id.strip()
            entity_annotation = entity_annotation.strip()
            entity_reference = entity_reference.strip()

            annotation_elements = entity_annotation.split(";")
            entity_type, *first_offset_pair = annotation_elements[0].split()

            offset_pairs = [first_offset_pair] + [
                offset_pair.split() for offset_pair in annotation_elements[1:]
            ]

            if len(offset_pairs) > 1:
                print(
                    "## Discontinuous entity found (will be excluded for this task): {} in {}".format(
                        entity_id, standoff_file
                    )
                )
                num_disjoint_spans += 1
                continue

            start_offsets, end_offsets = list(
                zip(
                    *[
                        (int(start_offset), int(end_offset))
                        for start_offset, end_offset in offset_pairs
                    ]
                )
            )

            start_offset, end_offset = min(start_offsets), max(end_offsets)

            actual_reference = reference[start_offset:end_offset]
            
            for values in entities.values():
                if (start_offset, end_offset) == (values['start'], values['end']):
                    print(standoff_file)
                    print(values)
                    print({
                "id": entity_id,
                "type": entity_type,
                "start": start_offset,
                "end": end_offset,
                "ref": actual_reference,
            })
            
            entities[entity_id] = {
                "id": entity_id,
                "type": entity_type,
                "start": start_offset,
                "end": end_offset,
                "ref": actual_reference,
            }
            
        elif line.startswith("E"):  # Relations
            event_id, event_annotation = line.split("\t")

            event_id = event_id.strip()
            event_annotation = event_annotation.strip()

            trigger, *args = event_annotation.split()

            trigger_type, trigger_id = trigger.split(":")

            args = [
                {"role": arg_role, "id": arg_id}
                for arg_role, arg_id in (arg.split(":") for arg in args)
            ]
            
            for values in events.values():
                if args == (values['args']):
                    print(values)
                    print({
                "id": event_id,
                "trigger_type": trigger_type,
                "trigger_id": trigger_id,
                "args": args,
            })

            events[event_id] = {
                "id": event_id,
                "trigger_type": trigger_type,
                "trigger_id": trigger_id,
                "args": args,
            }
            
        elif line.startswith("A"):
            modal_id, modal_type, reference_id = line.split()
            modalities[modal_id] = {
                "id": modal_id,
                "type": modal_type,
                "reference_ids": reference_id,
            }
        elif line.startswith("N"):
            attribute_id, attribute_value = line.split("\t", 1)
            attributes[attribute_id] = {
                "id": attribute_id.strip(),
                "value": attribute_value.strip(),
            }
        else:
            print(
                "## Unexpected annotation found: {} in {}".format(line, standoff_file)
            )

    return reference, entities, relations, events, modalities, attributes

In [None]:
defined_types = defaultdict(set)

for fn in glob(os.path.join(title_abs_path, "**/*.txt"), recursive=True):
    basename, _ = os.path.splitext(fn)
    ann_file = basename + ".ann"
    if os.path.exists(ann_file):
        _, entities, _, events, _, _ = parse_standoff_file(
            ann_file, fn, encoding="UTF-8"
        )
        for entity in entities.values():
            if not entity["type"].isupper():
                defined_types["entity_types"].add(entity["type"]) 
        for event in events.values():
            defined_types["trigger_types"].add(event["trigger_type"])

In [21]:
def preprocess(datapath):

    all_documents = []

    num_sents = []
    num_tokens = []
    num_entities = []
    num_triggers = []
    num_relations = []
    num_modalities = []

    inter_sent_relations = defaultdict(int)
    modality_dict = defaultdict(int)

    for fn in glob(os.path.join(datapath, "**/*.txt"), recursive=True):
        print(">> Processing: " + fn)

        doc_output = {}

        sent_cnt = 0
        tokens_cnt = 0

        basename, _ = os.path.splitext(fn)
        doc_key = basename.split("/")[-1]
        doc_output['doc_key'] = doc_key

        ann_file = basename + ".ann"
        if os.path.exists(ann_file):
            _, entities, _, relations, modalities, attributes = parse_standoff_file(
                ann_file, fn, encoding="UTF-8"
            )
        else:
            print("Ann file missing: " + ann_file)

        num_relations.append(len(relations))
        num_modalities.append(len(modalities))
            
        original_doc = read_text(fn, encoding="utf-8")

        cursor = 0
        offset_map = {}
        sentence_boundaries = []

        # Split into sentences and ensure that there is no broken entities
        for sentence_idx, (start_offset, end_offset) in enumerate(
                generate_sentence_boundaries(original_doc)
        ):
            sentence_boundaries.append({"start": start_offset, "end": end_offset})
            for offset in range(start_offset, end_offset + 1):
                offset_map[offset] = {"offset": cursor, "line": sentence_idx}
                cursor += 1  # This will include the newline at the end of sentence

        # Correct broken sentence boundaries
        for entity in entities.values():
            entity_start, entity_end = entity["start"], entity["end"]

            while original_doc[entity_start] == " ":
                entity_start += 1
            while original_doc[entity_end - 1] == " ":
                entity_end -= 1

            left_line_idx = offset_map[entity_start]["line"]
            right_line_idx = offset_map[entity_end]["line"]
            if left_line_idx != right_line_idx:
                sentence_boundaries[min(left_line_idx, right_line_idx)]["broken"] = max(
                    sentence_boundaries[min(left_line_idx, right_line_idx)].get(
                        "broken", -1
                    ),
                    left_line_idx,
                    right_line_idx,
                )

        # Merge broken sentences into a sentence
        sentence_idx = 0
        normalised_sentences = []
        normalised_sentence_boundaries = []
        
        while sentence_idx < len(sentence_boundaries):
            start_offset = sentence_boundaries[sentence_idx]["start"]
            end_offset = sentence_boundaries[sentence_idx]["end"]

            while (
                    sentence_idx < len(sentence_boundaries)
                    and "broken" in sentence_boundaries[sentence_idx]
            ):
                broken_sentence_idx = sentence_boundaries[sentence_idx]["broken"]
                end_offset = sentence_boundaries[broken_sentence_idx]["end"]
                sentence_idx = broken_sentence_idx

            normalised_sentences.append(original_doc[start_offset:end_offset])
            normalised_sentence_boundaries.append({"start": start_offset, "end": end_offset})
            sentence_idx += 1
            
        sentences = []
        doc_tokens = []
        for sentence in normalised_sentences:
            sent_cnt += 1
            tokens = []
            for token in sentence.split():
                for subtoken in BASIC_TOKENIZER.tokenize(token):
                    tokens_cnt += 1
                    tokens.append(subtoken)
            sentences.append(" ".join(tokens))
            doc_tokens.append(tokens)
        
        num_sents.append(sent_cnt)
        num_tokens.append(tokens_cnt)
        
        normalized_doc = "\n".join(sentences)
        doc_output["sentences"] = doc_tokens

        print(">> Building offset map...")

        # Build offset map
        offset_map = {}
        inverse_offset_map = {}

        original_doc_pos = 0
        normalized_doc_pos = 0

        _original_doc = re.sub(r"\s", " ", original_doc) # Address special blank characters
        _normalized_doc = normalized_doc.replace("\r", " ").replace("\n", " ")

        while original_doc_pos < len(_original_doc) and normalized_doc_pos < len(_normalized_doc):
            original_doc_char = _original_doc[original_doc_pos]
            normalized_doc_char = _normalized_doc[normalized_doc_pos]

            if original_doc_char == normalized_doc_char:
                offset_map[original_doc_pos] = normalized_doc_pos
                inverse_offset_map[normalized_doc_pos] = original_doc_pos
                original_doc_pos += 1
                normalized_doc_pos += 1
            else:
                if original_doc_char == " ":
                    offset_map[original_doc_pos] = normalized_doc_pos
                    original_doc_pos += 1
                elif normalized_doc_char == " ":
                    inverse_offset_map[normalized_doc_pos] = original_doc_pos
                    normalized_doc_pos += 1                               
                    
        # Build Char2Token map dict
        char2token_map = {}
        original_doc_pos = 0
        doc_token_idx = 0

        doc_tokens_flattened = [token for sent_tokens in doc_tokens for token in sent_tokens]
        token_cache = ""
        token_reference = doc_tokens_flattened[doc_token_idx]
        while original_doc_pos < len(_original_doc):
            original_doc_char = _original_doc[original_doc_pos]
            if original_doc_char == " ":
                original_doc_pos += 1
                continue
            char2token_map[original_doc_pos] = doc_token_idx
            token_cache += original_doc_char
            if token_cache == token_reference:
                doc_token_idx += 1
                if doc_token_idx == len(doc_tokens_flattened):
                    break
                token_reference = doc_tokens_flattened[doc_token_idx]
                token_cache = ""
            original_doc_pos += 1

        if offset_map:
            offset_map[max(offset_map) + 1] = max(offset_map.values()) + 1

        if inverse_offset_map:
            inverse_offset_map[max(inverse_offset_map) + 1] = (
                    max(inverse_offset_map.values()) + 1
            )

        assert max(offset_map.values()) == len(_normalized_doc) \
        and max(inverse_offset_map) == len(_normalized_doc)  # To ensure the code above is right

        doc_key = fn.split("/")[-1].rstrip('.txt')
        doc_fn = f"./data/Reconcile_Final6/processed_files/corpora/{doc_key}"

        write_json(offset_map, doc_fn + ".map")
        write_text(normalized_doc, doc_fn + ".txt")
        write_json(char2token_map, doc_fn + "_char2token.map")
        
        sentence_boundaries_token = []
        for span in normalised_sentence_boundaries:
            char_start = span['start']
            char_end = span['end']
            while char_start not in char2token_map:
                char_start += 1
            while char_end not in char2token_map:
                char_end -= 1
                
            token_span = {'start': char2token_map[char_start], 'end': char2token_map[char_end]+1}
            sentence_boundaries_token.append(token_span)

        # Allocate entities based on token spans
        sent_level_entities = [[] for _ in range(len(sentence_boundaries_token))]
        sent_level_triggers = [[] for _ in range(len(sentence_boundaries_token))]

        NUM_ENT = 0
        NUM_TRG = 0
        for ann_id, annotation in entities.items():
            entity_start = char2token_map[annotation['start']]
            entity_end = char2token_map[annotation['end']-1]
            
            new_annotation = [entity_start, entity_end, annotation['type']]
            
            for sent_idx, boundary in enumerate(sentence_boundaries_token):
                if boundary['start'] <= entity_start < boundary['end']:
                    if not annotation['type'].isupper():
                        NUM_ENT += 1
                        sent_level_entities[sent_idx].append(new_annotation)
                    else:
                        NUM_TRG += 1
                        sent_level_triggers[sent_idx].append(new_annotation)
                    break
                    
        num_entities.append(NUM_ENT)
        num_triggers.append(NUM_TRG)
        
        # Update certainty attribute for relations
        for ann_id, annotation in modalities.items():
            relations[annotation['reference_ids']]['modality'] = annotation['type']
        
        # Allocate relations based on token span of first arg
        sent_level_relations = [[] for _ in range(len(sentence_boundaries_token))]
        sent_level_triplets = [[] for _ in range(len(sentence_boundaries_token))]
        
        NUM_REL = 0
        for ann_id, annotation in relations.items():
            trg_type = annotation['trigger_type']
            trg_id = annotation['trigger_id']
            modality = annotation['modality']
            
            arguments = annotation['args']
            assert len(arguments) == 2, print(annotation)
            
            arg1, arg2 = arguments
            if arg2['role'] == "Agent":
                arg1, arg2 = arg2, arg1
            
            # Exclude relations related with disjoint entities or triggers
            if arg1['id'] not in entities \
            or arg2['id'] not in entities \
            or trg_id not in entities:
                continue
            
            ent1_char_start = entities[arg1['id']]['start']
            ent1_char_end = entities[arg1['id']]['end']
            ent1_token_start = char2token_map[ent1_char_start]
            ent1_token_end = char2token_map[ent1_char_end-1]
            ent2_char_start = entities[arg2['id']]['start']
            ent2_char_end = entities[arg2['id']]['end']
            ent2_token_start = char2token_map[ent2_char_start]
            ent2_token_end = char2token_map[ent2_char_end-1]
            trg_char_start = entities[trg_id]['start']
            trg_char_end = entities[trg_id]['end']
            trg_token_start = char2token_map[trg_char_start]
            trg_token_end = char2token_map[trg_char_end-1]            
            
            new_annotation = [
                ent1_token_start, ent1_token_end, 
                ent2_token_start, ent2_token_end, 
                trg_type, modality
            ]
            new_annotation_triplet = [
                ent1_token_start, ent1_token_end, 
                ent2_token_start, ent2_token_end, 
                trg_token_start, trg_token_end,
                trg_type
            ]
            
            new_annotation_reverse = []
            new_annotation_triplet_reverse = []
            if arg1['role'] != "Agent":
                new_annotation_reverse = [
                    ent2_token_start, ent2_token_end, 
                    ent1_token_start, ent1_token_end, 
                    trg_type, modality
                ]
                new_annotation_triplet_reverse = [
                    ent2_token_start, ent2_token_end, 
                    ent1_token_start, ent1_token_end, 
                    trg_token_start, trg_token_end,
                    trg_type
                ]
            
            for sent_idx, boundary in enumerate(sentence_boundaries_token):
                if boundary['start'] <= ent1_token_start < boundary['end']:
                    sent_level_relations[sent_idx].append(new_annotation)
                    sent_level_triplets[sent_idx].append(new_annotation_triplet)
                    
                    NUM_REL += 1 # don't count reversed case
                    modality_dict[modality] += 1
                    
                    if new_annotation_reverse:
                        sent_level_relations[sent_idx].append(new_annotation_reverse)
                        sent_level_triplets[sent_idx].append(new_annotation_triplet_reverse)
                    
                    # Count inter-sentential relations
                    if not boundary['start'] <= ent2_token_start < boundary['end']:
                        inter_sent_relations[trg_type] += 1
                        
                    break
                    
        doc_output['ner'] = sent_level_entities
        doc_output['triggers'] = sent_level_triggers
        doc_output['relations'] = sent_level_relations
        doc_output['triplets'] = sent_level_triplets
        
        doc_output['num_entities'] = NUM_ENT
        doc_output['num_triggers'] = NUM_TRG
        doc_output['num_relations'] = NUM_REL
        
        assert len(doc_output['sentences']) == len(doc_output['ner']) \
        == len(doc_output['triggers']) == len(doc_output['relations']), (len(doc_output['sentences']), len(doc_output['ner']) \
        , len(doc_output['triggers']), len(doc_output['relations']), len(normalised_sentence_boundaries))
               
        all_documents.append(doc_output)

               
    print(f"\n=== SUMMARY for {datapath.split('/')[-1]} ===")
    print(f"=== NUMBER OF SENTS: {sum(num_sents)}({statistics.median(num_sents)}) ===")
    print(f"=== NUMBER OF TOKENS: {sum(num_tokens)}({statistics.median(num_tokens)}) ===")
    print(f"=== NUMBER OF ENTITIES: {sum(num_entities)}({statistics.median(num_entities)}) ===")
    print(f"=== NUMBER OF TRIGGERS: {sum(num_triggers)}({statistics.median(num_triggers)}) ===")
    print(f"=== NUMBER OF RELATIONS: {sum(num_relations)}({statistics.median(num_relations)}) ===")
    print(f"=== NUMBER OF MODALITIES: {sum(num_modalities)}({statistics.median(num_modalities)}) ===\n\n")

    stats = {
        "Inter_sent_Relations": inter_sent_relations,
        "Modalities": modality_dict
    }
        
    return all_documents, stats

In [None]:
title_abs_files, title_abs_stats = preprocess(title_abs_path)
result_files, result_stats = preprocess(results_path)

In [25]:
def count_ent_rel(ner_labels, rel_labels, all_documents):    
    ent_cnt = {}
    for k in ner_labels:
        ent_cnt[k] = 0
        
    trg_cnt = {}
    for k in rel_labels:
        trg_cnt[k] = 0

    rel_cnt = {}
    for k in rel_labels:
        rel_cnt[k] = 0

    for d in all_documents:
        for sent in d['ner']:
            for ent in sent:
                if len(ent) > 0:
                    ent_cnt[ent[-1]] += 1
        for sent in d['triggers']:
            for trg in sent:
                if len(trg) > 0:
                    trg_cnt[trg[-1]] += 1
        for sent in d['relations']:
            for rel in sent:
                if len(rel) > 0:
                    rel_cnt[rel[-2]] += 1
    
    return ent_cnt, trg_cnt, rel_cnt

In [26]:
title_abs_ent_cnt, title_abs_trg_cnt, title_abs_rel_cnt = count_ent_rel(
    defined_types['entity_types'], defined_types['trigger_types'], title_abs_files
)

print(title_abs_ent_cnt, "| # ENTITY TOTAL:", sum(title_abs_ent_cnt.values()))
print("-"*79)
print(title_abs_trg_cnt, "| # TRIGGER TOTAL:", sum(title_abs_trg_cnt.values()))
print("-"*79)
print(title_abs_rel_cnt, "| # RELATION TOTAL:", sum(title_abs_rel_cnt.values()))

{'DietPattern': 437, 'Gene': 61, 'Enzyme': 68, 'Physiology': 1028, 'Methodology': 530, 'Nutrient': 1206, 'Microorganism': 801, 'Measurement': 366, 'Disease': 403, 'Food': 685, 'Chemical': 505, 'Metabolite': 636, 'DiversityMetric': 60, 'Biospecimen': 235, 'Population': 612} | # ENTITY TOTAL: 7633
-------------------------------------------------------------------------------
{'DECREASES': 292, 'IMPROVES': 156, 'POS_ASSOCIATED_WITH': 81, 'NEG_ASSOCIATED_WITH': 44, 'WORSENS': 15, 'PREVENTS': 27, 'INCREASES': 402, 'AFFECTS': 276, 'INTERACTS_WITH': 18, 'CAUSES': 21, 'PREDISPOSES': 14, 'HAS_COMPONENT': 95, 'ASSOCIATED_WITH': 95} | # TRIGGER TOTAL: 1536
-------------------------------------------------------------------------------
{'DECREASES': 466, 'IMPROVES': 211, 'POS_ASSOCIATED_WITH': 144, 'NEG_ASSOCIATED_WITH': 108, 'WORSENS': 21, 'PREVENTS': 39, 'INCREASES': 770, 'AFFECTS': 584, 'INTERACTS_WITH': 50, 'CAUSES': 30, 'PREDISPOSES': 28, 'HAS_COMPONENT': 152, 'ASSOCIATED_WITH': 303} | # REL

In [27]:
res_ent_cnt, res_trg_cnt, res_rel_cnt = count_ent_rel(
    defined_types['entity_types'], defined_types['trigger_types'], result_files
)

print(res_ent_cnt, "| # ENTITY TOTAL:", sum(res_ent_cnt.values()))
print("-"*79)
print(res_trg_cnt, "| # TRIGGER TOTAL:", sum(res_trg_cnt.values()))
print("-"*79)
print(res_rel_cnt, "| # RELATION TOTAL:", sum(res_rel_cnt.values()))

{'DietPattern': 398, 'Gene': 105, 'Enzyme': 167, 'Physiology': 598, 'Methodology': 412, 'Nutrient': 821, 'Microorganism': 1130, 'Measurement': 429, 'Disease': 201, 'Food': 347, 'Chemical': 608, 'Metabolite': 513, 'DiversityMetric': 166, 'Biospecimen': 70, 'Population': 568} | # ENTITY TOTAL: 6533
-------------------------------------------------------------------------------
{'DECREASES': 148, 'IMPROVES': 23, 'POS_ASSOCIATED_WITH': 64, 'NEG_ASSOCIATED_WITH': 39, 'WORSENS': 10, 'PREVENTS': 0, 'INCREASES': 194, 'AFFECTS': 125, 'INTERACTS_WITH': 3, 'CAUSES': 3, 'PREDISPOSES': 1, 'HAS_COMPONENT': 37, 'ASSOCIATED_WITH': 53} | # TRIGGER TOTAL: 700
-------------------------------------------------------------------------------
{'DECREASES': 285, 'IMPROVES': 35, 'POS_ASSOCIATED_WITH': 166, 'NEG_ASSOCIATED_WITH': 103, 'WORSENS': 10, 'PREVENTS': 0, 'INCREASES': 315, 'AFFECTS': 283, 'INTERACTS_WITH': 8, 'CAUSES': 3, 'PREDISPOSES': 1, 'HAS_COMPONENT': 64, 'ASSOCIATED_WITH': 181} | # RELATION TOTAL

In [28]:
def count_cross_sents(rel_cross_sents:dict, rel_cnt:dict):
    result = {}
    # Set the upper bound score based on cross-sentence relations
    print(rel_cross_sents, "| TOTAL:", sum(rel_cross_sents.values()))
    print("-"*79)
    print("TOTAL UPPER BOUND: ", 1 - round(sum(rel_cross_sents.values()) / sum(rel_cnt.values()), 4))
    print("-"*79)

    for rel, cnt in rel_cnt.items():
        if rel in rel_cross_sents.keys():
            UPPER = 1 - round(rel_cross_sents[rel] / cnt, 4)
        else: UPPER = 1.00
        print(f"UPPER BOUND for {rel}: {UPPER}")
        result[rel] = UPPER
    return result

In [29]:
abs_cross_sents_dict = count_cross_sents(title_abs_stats["Inter_sent_Relations"], title_abs_rel_cnt)

defaultdict(<class 'int'>, {'AFFECTS': 114, 'DECREASES': 57, 'POS_ASSOCIATED_WITH': 5, 'INCREASES': 82, 'IMPROVES': 21, 'ASSOCIATED_WITH': 15, 'HAS_COMPONENT': 4, 'PREVENTS': 2, 'PREDISPOSES': 3, 'WORSENS': 2}) | TOTAL: 305
-------------------------------------------------------------------------------
TOTAL UPPER BOUND:  0.895
-------------------------------------------------------------------------------
UPPER BOUND for DECREASES: 0.8777
UPPER BOUND for IMPROVES: 0.9005
UPPER BOUND for POS_ASSOCIATED_WITH: 0.9653
UPPER BOUND for NEG_ASSOCIATED_WITH: 1.0
UPPER BOUND for WORSENS: 0.9048
UPPER BOUND for PREVENTS: 0.9487
UPPER BOUND for INCREASES: 0.8935
UPPER BOUND for AFFECTS: 0.8048
UPPER BOUND for INTERACTS_WITH: 1.0
UPPER BOUND for CAUSES: 1.0
UPPER BOUND for PREDISPOSES: 0.8929
UPPER BOUND for HAS_COMPONENT: 0.9737
UPPER BOUND for ASSOCIATED_WITH: 0.9505


In [30]:
res_cross_sents_dict = count_cross_sents(result_stats["Inter_sent_Relations"], res_rel_cnt)

defaultdict(<class 'int'>, {'AFFECTS': 96, 'DECREASES': 38, 'INCREASES': 53, 'ASSOCIATED_WITH': 3, 'POS_ASSOCIATED_WITH': 2}) | TOTAL: 192
-------------------------------------------------------------------------------
TOTAL UPPER BOUND:  0.868
-------------------------------------------------------------------------------
UPPER BOUND for DECREASES: 0.8667
UPPER BOUND for IMPROVES: 1.0
UPPER BOUND for POS_ASSOCIATED_WITH: 0.988
UPPER BOUND for NEG_ASSOCIATED_WITH: 1.0
UPPER BOUND for WORSENS: 1.0
UPPER BOUND for PREVENTS: 1.0
UPPER BOUND for INCREASES: 0.8317
UPPER BOUND for AFFECTS: 0.6608
UPPER BOUND for INTERACTS_WITH: 1.0
UPPER BOUND for CAUSES: 1.0
UPPER BOUND for PREDISPOSES: 1.0
UPPER BOUND for HAS_COMPONENT: 1.0
UPPER BOUND for ASSOCIATED_WITH: 0.9834


In [31]:
print(title_abs_stats['Modalities'])
print(result_stats['Modalities'])

defaultdict(<class 'int'>, {'Factual': 2110, 'Unknown': 364, 'Negated': 241})
defaultdict(<class 'int'>, {'Factual': 1107, 'Negated': 196, 'Unknown': 15})


## Dataset Split only for Title-abs version
- Load the predefined list of file names for split

In [209]:
split_filenames_dir = './split_filenames.json'
with open(split_filenames_dir) as json_file:
    split_fn = json.load(json_file)

In [210]:
def split_dataset(title_abs_files, split_fn, result_files=None):
        
    train = []
    train_all = []
    dev = []
    test = []
    
    for i in range(len(title_abs_files)):
        doc_key = title_abs_files[i]["doc_key"]
        if doc_key in split_fn['train']:
            train.append(title_abs_files[i])
        elif doc_key in split_fn['dev']:
            dev.append(title_abs_files[i])
        else:
            test.append(title_abs_files[i])
            
    if result_files:
        train_all = train + result_files
#         random.shuffle(train_all)

    return train, train_all, dev, test

In [211]:
train_abs, train_all, dev, test = split_dataset(title_abs_files, split_fn, result_files)
print(len(train_abs), len(train_all), len(dev), len(test))

109 139 19 37


In [212]:
# Detailed Stats for Train, Dev, and Test
def dataset_stat(data):
    sent = 0
    ner = 0
    rel = 0
    for d in data:
        for _, entity, relation in zip(d["sentences"], d["ner"], d["relations"]):
            sent += 1
            ner += len(entity)
            rel += len(relation)
    return sent, ner, rel

print(dataset_stat(train_abs))
print(dataset_stat(dev))
print(dataset_stat(test))
print(dataset_stat(train_all))

(1460, 4283, 2082)
(233, 652, 285)
(494, 1321, 539)
(3722, 9766, 3536)


In [213]:
# delete unnecessary keys
for data in [train_all, dev, test]:
    for d in data:
        del d["num_entities"]
        del d["num_triggers"]
        del d["num_relations"]

## Below is the dataset for Span Prediction Task

In [214]:
# Set your own path to store input files for PURE training
outputpath = "../data/DiMB-RE/ner_reduced_v6.1_trg_abs"

if not os.path.exists(outputpath):
    os.mkdir(outputpath)

with open(os.path.join(outputpath, "train.json"), 'w') as f:
    for docu in train_abs:
        f.write(json.dumps(docu) + '\n')
with open(os.path.join(outputpath, "dev.json"), 'w') as f:
    for docu in dev:
        f.write(json.dumps(docu) + '\n')
with open(os.path.join(outputpath, "test.json"), 'w') as f:
    for docu in test:
        f.write(json.dumps(docu) + '\n')

In [215]:
# Set your own path to store input files for PURE training
outputpath = "../data/DiMB-RE/ner_reduced_v6.1_trg_abs_result"

if not os.path.exists(outputpath):
    os.mkdir(outputpath)

with open(os.path.join(outputpath, "train.json"), 'w') as f:
    for docu in train_all:
        f.write(json.dumps(docu) + '\n')
with open(os.path.join(outputpath, "dev.json"), 'w') as f:
    for docu in dev:
        f.write(json.dumps(docu) + '\n')
with open(os.path.join(outputpath, "test.json"), 'w') as f:
    for docu in test:
        f.write(json.dumps(docu) + '\n')