- This notebook file is the first step to preprocess BRAT format files and split them into Abstracts-only and Results-only files

In [1]:
import os
import re
from tqdm import tqdm
import json
from collections import defaultdict
import statistics
import pandas as pd

import ssplit
from tokenization_bert import BasicTokenizer

In [3]:
# Input raw files
raw_file_path = "../data/brat/raw_files"

# Load raw files
ann_files = []
txt_files = []
for file_name in os.listdir(raw_file_path):
    if file_name.endswith(".ann"):
        ann_files.append(file_name)
    elif file_name.endswith(".txt"):
        txt_files.append(file_name)

ann_files.sort()
txt_files.sort()
print(len(ann_files), (len(txt_files)))

165
165


## Filter BRAT annotation errors and Split raw dataset into 1) Title-Abstract and 2) Result part

In [4]:
# Load offset information for fully annotated text
offset_path = "../data/brat/AnnotationPMC.csv"
offset_dict = {}
with open(offset_path, "r") as f:
    next(f)
    for line in f:
        line = line.strip().rstrip(",")
        PMCID, OFFSETS = line.split(",", 1)
        if len(OFFSETS.split(",")) > 1:
            ABS_END, RES_OFFSET, RES_END = OFFSETS.split(",")
        else:
            ABS_END = OFFSETS
            
        offset_dict[PMCID] = {}
        offset_dict[PMCID]["Abs_end"] = int(ABS_END)
        if RES_OFFSET:
            offset_dict[PMCID]["Res_offset"] = int(RES_OFFSET)
            offset_dict[PMCID]["Res_end"] = int(RES_END)
            
        ABS_END = 0
        RES_OFFSET = 0
        RES_END = 0

In [34]:
# Switch if you want to include all 15 types of entities
processed_file_path = "../data/brat/processed_files"
# processed_file_path = "../data/brat/processed_files_full_entities"

# Output files
title_abs_path = os.path.join(processed_file_path, "title_abs")
results_path = os.path.join(processed_file_path, "results")

os.makedirs(title_abs_path, exist_ok=True)
os.makedirs(results_path, exist_ok=True)

In [35]:
BASIC_TOKENIZER = BasicTokenizer(do_lower_case=False)

In [36]:
# Create 30 'results-only' txt files

num_sents_abs = []
num_tokens_abs = []
num_sents_res = []
num_tokens_res = []
num_tables_res = 0

for txt in tqdm(txt_files):
    doc_key = txt.strip(".txt")
    
    doc_sents_abs = 0
    doc_tokens_abs = 0
    doc_sents_res = 0
    doc_tokens_res = 0
    
    with open(os.path.join(raw_file_path, txt), encoding="utf-8") as f:
        doc = f.read()
        
    if doc_key.startswith('PMC'):
        offsets = offset_dict[doc_key]
        if 'Res_offset' in offset_dict[doc_key].keys():
            result = doc[offsets["Res_offset"]:offsets["Res_end"]]
            
            # Count number of words and sentences
            sentences = re.split(r"[\.!\?]\s(?=[A-Z])|\n+", result)
            for sent in sentences:
                
                # Detect table format
                if '\t' in sent:
                    num_tables_res += 1
                else:
                    doc_sents_res += 1
                    sent += "." if sent[-1] != "." else sent
                    doc_tokens_res += len(BASIC_TOKENIZER.tokenize(sent))

            result_fname = doc_key + "-RESULT.txt"
            with open(os.path.join(results_path, result_fname), "w", encoding="utf-8") as f:
                f.write(result)

        result = doc[:offsets["Abs_end"]]

        # Count number of words and sentences
        sentences = re.split(r"[\.!\?]\s(?=[A-Z])|\n+", result)
        for sent in sentences:
            doc_sents_abs += 1
            sent += "." if sent[-1] != "." else sent
            doc_tokens_abs += len(BASIC_TOKENIZER.tokenize(sent))

        result_fname = doc_key + ".txt"
        with open(os.path.join(title_abs_path, result_fname), "w", encoding="utf-8") as f:
            f.write(result)
            
        num_sents_abs.append(doc_sents_abs)
        num_tokens_abs.append(doc_tokens_abs)
        num_sents_res.append(doc_sents_res)
        num_tokens_res.append(doc_tokens_res)
            
    else:  # just re-write for Non-PMC files            
        # Count number of words and sentences
        sentences = re.split(r"[\.!\?]\s(?=[A-Z])|\n+", doc)
        for sent in sentences:
            doc_sents_abs += 1
            sent += "." if sent[-1] != "." else sent
            doc_tokens_abs += len(BASIC_TOKENIZER.tokenize(sent))
            
        result_fname = doc_key + ".txt"
        with open(os.path.join(title_abs_path, result_fname), "w", encoding="utf-8") as f:
            f.write(doc)
            
        num_sents_abs.append(doc_sents_abs)
        num_tokens_abs.append(doc_tokens_abs)
            
print(f"# Sentences in 165 Abs corpus: {sum(num_sents_abs)} | Med.: {statistics.median(num_sents_abs)}")
print(f"# Words in 165 Abs corpus: {sum(num_tokens_abs)} | Med.: {statistics.median(num_tokens_abs)}")
print(f"# Sentences in 30 Res corpus: {sum(num_sents_res)} | Med.: {statistics.median(num_sents_res)}")
print(f"# Words in 30 Res corpus: {sum(num_tokens_res)} | Med.: {statistics.median(num_tokens_res)}")
print(f"# Tables in 30 Res corpus: {num_tables_res} | Mean: {num_tables_res/30:.2f}")

100%|████████████████████████████████████████| 165/165 [00:01<00:00, 145.70it/s]

# Sentences in 165 Abs corpus: 2181 | Med.: 13
# Words in 165 Abs corpus: 68136 | Med.: 411
# Sentences in 30 Res corpus: 2129 | Med.: 28.5
# Words in 30 Res corpus: 68911 | Med.: 944.5
# Tables in 30 Res corpus: 58 | Mean: 1.93





In [37]:
def custom_sort_key(item):
    order = {'T': 1, 'E': 2, 'A': 3, 'N': 4}
    tag_id, _ = item.split("\t", 1)
    return order[tag_id[0]], int(tag_id[1:])

In [None]:
# For full text annotation, split annotation for 1) abstract part and 2) Result part

# To quantify our annotations
entity_type_stats = defaultdict(int)
relation_type_stats = defaultdict(int)
factuality_type_stats = defaultdict(dict)
normalization_stats = defaultdict(dict)

disjoint_entity_stats = defaultdict(int)
entity_pair_4_directional_rel_stats = defaultdict(dict)
entity_pair_4_bidirectional_rel_stats = defaultdict(dict)

entity_per_doc_abs_stats = {}
entity_per_doc_result_stats = {}
trigger_per_doc_abs_stats = {}
trigger_per_doc_result_stats = {}
relation_per_doc_abs_stats = {}
relation_per_doc_result_stats = {}
factuality_per_doc_abs_stats = {}
factuality_per_doc_result_stats = {}
normalization_per_doc_abs_stats = {}
normalization_per_doc_result_stats = {}

IGNORED_CLASSES = ["Methodology", "Biospecimen", "Population"]
# IGNORED_CLASSES = []

# This is to for 3-way classification of certainty: Factual / Unknown / Negated
factuality_unknown = ['Probable', 'Possible', 'Doubtful']
# factuality_unknown = []

for ann in ann_files:
    
    doc_key = ann.strip(".ann")
    print(f"### {doc_key} ###")
    
    entity_per_doc_abs_stats[doc_key] = defaultdict(int)
    entity_per_doc_result_stats[doc_key] = defaultdict(int)
    trigger_per_doc_abs_stats[doc_key] = defaultdict(int)
    trigger_per_doc_result_stats[doc_key] = defaultdict(int)
    relation_per_doc_abs_stats[doc_key] = defaultdict(int)
    relation_per_doc_result_stats[doc_key] = defaultdict(int)
    factuality_per_doc_abs_stats[doc_key] = defaultdict(int)
    factuality_per_doc_result_stats[doc_key] = defaultdict(int)
    normalization_per_doc_abs_stats[doc_key] = defaultdict(int)
    normalization_per_doc_result_stats[doc_key] = defaultdict(int)
    
    entities = {}
    relations = {}
    modalities = {}
    normalizations = {}
    
    duplicates = defaultdict(list)
    title_abs_doc = ""
    result_doc = ""

    with open(os.path.join(raw_file_path, ann), encoding="utf-8") as f:
        
        lines = f.read().splitlines()
        lines = sorted(lines, key=custom_sort_key)  # T -> E -> A -> N

        # Tag information
        title_abs_tags = []
        result_tags = []
        
        for line in lines:
            if line.startswith("T"):
                entity_id, entity_annotation, entity_mention = line.split("\t")
                
                entity_id = entity_id.strip()
                entity_annotation = entity_annotation.strip()
                entity_mention = entity_mention.strip()
                
                annotation_elements = entity_annotation.split(";")
                entity_type, *first_offset_pair = annotation_elements[0].split()
                
                entity_type_stats[entity_type] += 1
                
                # Re-categorization: Need to be fixed following the dataset versions
                if entity_type in IGNORED_CLASSES:
                    continue
                
                offset_pairs = [first_offset_pair] + [
                    offset_pair.split() for offset_pair in annotation_elements[1:]
                ]
                
                if len(offset_pairs) > 1:
#                     print(f"## Discontinuous entity found: {entity_id} in {doc_key}")
                    disjoint_entity_stats[entity_type] += 1
                    
                value = {
                    "type": entity_type,
                    "spans": offset_pairs,
                    "mention": entity_mention
                }
                
                # Check duplication
                if value in entities.values():
                    for k, v in entities.items():
                        if v == value:
                            duplicates[k].append(entity_id)
                            entity_type_stats[entity_type] -= 1
                            print(f"## {entity_id} was replaced by {k} in {doc_key}")
                            break
                    continue
                else:
                    entities[entity_id] = value

                abs_offset_pairs = []
                result_offset_pairs = []
                # Split into abs & result
                if doc_key.startswith('PMC') and ('Res_offset' in offset_dict[doc_key]):                  
                    for span_pair in offset_pairs:
                        start, end = span_pair
                        start, end = int(start), int(end)
                        # Title-abs part
                        if start < offset_dict[doc_key]['Res_offset']:
                            abs_offset_pairs.append(f"{start} {end}")
                        # Result part -> we need to modify the span information
                        else:
                            OFFSET = offset_dict[doc_key]['Res_offset']
                            start -= OFFSET
                            end -= OFFSET
                            result_offset_pairs.append(f"{start} {end}")

                    if not result_offset_pairs:        
                        new_offset = ";".join(abs_offset_pairs)
                        new_annotation = " ".join([entity_type, new_offset])
                        newline = "\t".join([entity_id, new_annotation, entity_mention])                      
                        title_abs_doc += newline + "\n"
                        title_abs_tags.append(entity_id)
                        if not entity_type.isupper():
                            entity_per_doc_abs_stats[doc_key][entity_type] += 1
                        else:
                            trigger_per_doc_abs_stats[doc_key][entity_type] += 1
                    else:
                        new_offset = ";".join(result_offset_pairs)
                        new_annotation = " ".join([entity_type, new_offset])
                        newline = "\t".join([entity_id, new_annotation, entity_mention])
                        result_doc += newline + "\n"
                        result_tags.append(entity_id)
                        if not entity_type.isupper():
                            entity_per_doc_result_stats[doc_key][entity_type] += 1
                        else:
                            trigger_per_doc_result_stats[doc_key][entity_type] += 1
                else:
                    for span_pair in offset_pairs:
                        start, end = span_pair
                        abs_offset_pairs.append(f"{start} {end}")
                        
                    new_offset = ";".join(abs_offset_pairs)
                    new_annotation = " ".join([entity_type, new_offset])
                    newline = "\t".join([entity_id, new_annotation, entity_mention])    
                    title_abs_doc += newline + "\n"
                    title_abs_tags.append(entity_id)
                    if not entity_type.isupper():
                        entity_per_doc_abs_stats[doc_key][entity_type] += 1
                    else:
                        trigger_per_doc_abs_stats[doc_key][entity_type] += 1
                                  
            elif line.startswith("E"):
                relation_id, relation_annotation = line.split("\t")

                relation_id = relation_id.strip()
                relation_annotation = relation_annotation.strip()

                assert len(relation_annotation.split()) == 3, f"##{relation_id} in {doc_key} without two args"
            
                trigger, arg1, arg2 = relation_annotation.split()

                trigger_type, trigger_id = trigger.split(":")
                arg1_role, arg1_id = arg1.split(":")
                arg2_role, arg2_id = arg2.split(":")
                
                if entities[arg1_id]["type"] in ["Methodology", "Biospecimen", "Population"]:
                    print(line)
                    print(arg1_id, ":", entities[arg1_id]["type"])
                if entities[arg2_id]["type"] in ["Methodology", "Biospecimen", "Population"]:
                    print(line)
                    print(arg2_id, ":", entities[arg2_id]["type"])
                
                assert arg1_id in entities, print(line)
                assert arg2_id in entities, print(line)
                assert sorted([arg1_role, arg2_role]) in [["Agent", "Theme"], ["Theme", "Theme2"]], print(line)
                
                relation_type_stats[trigger_type] += 1
                
                bidirectional = False
                if arg1_role == "Agent":
                    agent_entity_type = entities[arg1_id]["type"]
                    theme_entity_type = entities[arg2_id]["type"]
                elif arg2_role == "Agent":
                    agent_entity_type = entities[arg2_id]["type"]
                    theme_entity_type = entities[arg1_id]["type"]
                else:  # bidirectional rel
                    bidirectional = True
                    if not trigger_type in ["INTERACTS_WITH", "ASSOCIATED_WITH",
                                           "POS_ASSOCIATED_WITH", "NEG_ASSOCIATED_WITH"]:
                        print("## Inappropriate Arguments")
                        print(line, "\n")
                        
                    if arg1_role == "Theme":
                        agent_entity_type = entities[arg1_id]["type"]
                        theme_entity_type = entities[arg2_id]["type"]
                    elif arg1_role == "Theme2":
                        agent_entity_type = entities[arg2_id]["type"]
                        theme_entity_type = entities[arg1_id]["type"]
                        
                if not bidirectional:
                    if trigger_type == "INTERACTS_WITH":
                        print("## Inappropriate Arguments")
                        print(line, "\n")
                
                entity_pair = (agent_entity_type, theme_entity_type)
                if not bidirectional:
                    if trigger_type not in entity_pair_4_directional_rel_stats[entity_pair]:
                        entity_pair_4_directional_rel_stats[entity_pair][trigger_type] = [(doc_key, relation_id)]
                    else:
                        entity_pair_4_directional_rel_stats[entity_pair][trigger_type].append([(doc_key, relation_id)])
                else:
                    if trigger_type not in entity_pair_4_bidirectional_rel_stats[entity_pair]:
                        entity_pair_4_bidirectional_rel_stats[entity_pair][trigger_type] = [(doc_key, relation_id)]
                    else:    
                        entity_pair_4_bidirectional_rel_stats[entity_pair][trigger_type].append([(doc_key, relation_id)])
                
                # replace removed entity ids
                for key_id, sub_ids in duplicates.items():
                    if trigger_id in sub_ids:
                        print(f"## REL: {trigger_id} was replaced by {key_id} in {doc_key}")
                        trigger_id = key_id
                    if arg1_id in sub_ids:
                        print(f"## REL: {arg1_id} was replaced by {key_id} in {doc_key}")
                        arg1_id = key_id
                    if arg2_id in sub_ids:
                        print(f"## REL: {arg2_id} was replaced by {key_id} in {doc_key}")
                        arg2_id = key_id

                args = [
                    {"role": arg1_role, "id": arg1_id},
                    {"role": arg2_role, "id": arg2_id}
                ]

                value = {
                    "trigger_type": trigger_type,
                    "trigger_id": trigger_id,
                    "args": args,
                }
                
                # check duplication
                if value in relations.values():
                    for k, v in relations.items():
                        if v == value:
                            duplicates[k].append(relation_id)
                            print(f"## REL: {relation_id} was replaced by {k} in {doc_key}")
                            break
                    continue
                else:
                    relations[relation_id] = value
                
                new_trigger = [f"{trigger_type}:{trigger_id}"]
                new_args = [f"{arg['role']}:{arg['id']}" for arg in args]
                new_relation_annotation = " ".join(new_trigger + new_args)
                newline = "\t".join([relation_id, new_relation_annotation])
                
                if trigger_id in title_abs_tags:
                    title_abs_doc += newline + "\n"
                    title_abs_tags.append(relation_id)
                    relation_per_doc_abs_stats[doc_key][trigger_type] += 1
                else:
                    result_doc += newline + "\n"
                    result_tags.append(relation_id)
                    relation_per_doc_result_stats[doc_key][trigger_type] += 1
                
                
            elif line.startswith("A"):
                modal_id, modal_annotation = line.split("\t")
                _, reference_id, modal_type = modal_annotation.split(" ")
                
                assert modal_type in ["Probable", "Possible", "Negated", "Doubtful", "Unknown"], print(doc_key, modal_id)
                
                if modal_type in factuality_unknown:
                    modal_type = 'Unknown'
                
                for key_id, sub_ids in duplicates.items():
                    if reference_id in sub_ids:
                        print(f"## MOD: {reference_id} was replaced by {key_id} in {doc_key}")
                        reference_id = key_id
                
                value = {"type": modal_type, "reference_id": reference_id}
                
                trigger_type = relations[reference_id]["trigger_type"]
                if modal_type not in factuality_type_stats[trigger_type]:
                    factuality_type_stats[trigger_type][modal_type] = 1
                else:
                    factuality_type_stats[trigger_type][modal_type] += 1
                
                # check duplication
                if value in modalities.values():
                    for k, v in modalities.items():
                        if v == value:
                            duplicates[k].append(modal_id)
                            print(f"## MOD: {modal_id} was replaced by {k} in {doc_key}")
                            break
                else:
                    modalities[modal_id] = value
                    
                new_modal_annotation = " ".join([modal_type, reference_id])
                newline = "\t".join([modal_id, new_modal_annotation])
                    
                if reference_id in title_abs_tags:
                    title_abs_doc += newline + "\n"
                    factuality_per_doc_abs_stats[doc_key][modal_type] += 1
                else:
                    result_doc += newline + "\n"
                    factuality_per_doc_result_stats[doc_key][modal_type] += 1
                    
            
            elif line.startswith("N"):
                normalization_id, normalization_annotation, normalized_name = line.split("\t")
                _, reference_id, normalization_src_id = normalization_annotation.split(" ")
                ontology_src, ontology_id = normalization_src_id.split(":")
                
                # Just to process edge case
                if not ontology_src.isupper():
                    ontology_src = ontology_src.upper()
                
                if reference_id in entities:  # Consider target entity types only
                    entity_type = entities[reference_id]["type"]
                else:
                    continue
                    
                if ontology_src not in normalization_stats[entity_type]:
                    normalization_stats[entity_type][ontology_src] = 1
                else:
                    normalization_stats[entity_type][ontology_src] += 1
                
                for key_id, sub_ids in duplicates.items():
                    if reference_id in sub_ids:
                        print(f"## MOD: {reference_id} was replaced by {key_id} in {doc_key}")
                        reference_id = key_id
                        
                value = {
                    "source": ontology_src,
                    "id": ontology_id,
                    "reference_id": reference_id,
                    "normalized_name": normalized_name
                }

                if value in normalizations.values():
                    for k, v in normalizations.items():
                        if v == value:
                            duplicates[k].append(normalization_id)
                            print(f"## MOD: {normalization_id} was replaced by {k} in {doc_key}")
                            break
                else:
                    normalizations[normalization_id] = value
                    
                new_norm_annotation = " ".join([normalization_src_id, reference_id])
                newline = "\t".join([normalization_id, new_norm_annotation, normalized_name])
                
                if reference_id in title_abs_tags:
                    title_abs_doc += newline + "\n"
                    normalization_per_doc_abs_stats[doc_key][ontology_src] += 1
                else:
                    result_doc += newline + "\n"
                    normalization_per_doc_result_stats[doc_key][ontology_src] += 1

    # Update 'Factual'
    valid_reference_ids = []
    for modal_id in modalities:
        valid_reference_ids.append(modalities[modal_id]["reference_id"])
    
    if modalities.keys():
        last_modal_int = int(sorted(modalities.keys(), key=lambda x: int(x[1:]))[-1][1:])
    else:
        last_modal_int = 0

    for trigger_id in relations:
        if trigger_id not in valid_reference_ids:
            valid_reference_ids.append(trigger_id)
            last_modal_int += 1
            modal_id = f"A{str(last_modal_int)}"
            modal_type = "Factual"
            value = {"type": modal_type, "reference_id": trigger_id}

            trigger_type = relations[trigger_id]["trigger_type"]
            if modal_type not in factuality_type_stats[trigger_type]:
                factuality_type_stats[trigger_type][modal_type] = 1
            else:
                factuality_type_stats[trigger_type][modal_type] += 1

            # check duplication
            if value in modalities.values():
                for k, v in modalities.items():
                    if v == value:
                        duplicates[k].append(modal_id)
                        print(f"## MOD: {modal_id} was replaced by {k} in {doc_key}")
                        break
            else:
                modalities[modal_id] = value

            new_modal_annotation = " ".join([modal_type, trigger_id])
            newline = "\t".join([modal_id, new_modal_annotation])

            if trigger_id in title_abs_tags:
                title_abs_doc += newline + "\n"
                factuality_per_doc_abs_stats[doc_key][modal_type] += 1
            else:
                result_doc += newline + "\n"
                factuality_per_doc_result_stats[doc_key][modal_type] += 1
                    
    title_abs_fname = doc_key + ".ann"
    with open(os.path.join(title_abs_path, title_abs_fname), "w", encoding="utf-8") as f:
        f.write(title_abs_doc)
        print(f"{title_abs_fname} was saved in Title-abs folder")

    if result_doc:
        result_fname = doc_key + "-RESULT.ann"   
        with open(os.path.join(results_path, result_fname), "w", encoding="utf-8") as f:
            f.write(result_doc)
            print(f"{result_fname} was saved in Result folder")
    
    print()


In [39]:
entity_stats = {k:v for k, v in entity_type_stats.items() if not k.isupper()}
trigger_stats = {k:v for k, v in entity_type_stats.items() if k.isupper()}

print(entity_stats)
print()
print(trigger_stats)

num_entity = sum(entity_stats.values())
num_trigger = sum(trigger_stats.values())
        
print()        
print(f"# of Entity: {num_entity}, # of Trigger: {num_trigger}")

{'Food': 1047, 'Metabolite': 1181, 'Population': 1235, 'Measurement': 814, 'Microorganism': 1938, 'Chemical': 1119, 'Methodology': 951, 'Physiology': 1716, 'Nutrient': 2045, 'Biospecimen': 307, 'Disease': 604, 'DietPattern': 851, 'Enzyme': 239, 'Gene': 173, 'DiversityMetric': 231}

{'IMPROVES': 179, 'WORSENS': 25, 'HAS_COMPONENT': 132, 'AFFECTS': 402, 'INCREASES': 596, 'DECREASES': 440, 'ASSOCIATED_WITH': 148, 'INTERACTS_WITH': 21, 'POS_ASSOCIATED_WITH': 145, 'PREDISPOSES': 15, 'CAUSES': 24, 'PREVENTS': 27, 'NEG_ASSOCIATED_WITH': 83}

# of Entity: 14451, # of Trigger: 2237


In [11]:
print(relation_type_stats)
print()        
print(f"# of Relation: {sum(relation_type_stats.values())}")

defaultdict(<class 'int'>, {'IMPROVES': 257, 'HAS_COMPONENT': 218, 'WORSENS': 36, 'AFFECTS': 942, 'INCREASES': 1128, 'DECREASES': 774, 'ASSOCIATED_WITH': 250, 'INTERACTS_WITH': 29, 'POS_ASSOCIATED_WITH': 282, 'PREDISPOSES': 31, 'CAUSES': 33, 'PREVENTS': 39, 'NEG_ASSOCIATED_WITH': 187})

# of Relation: 4206


In [12]:
print(factuality_type_stats)

defaultdict(<class 'dict'>, {'WORSENS': {'Negated': 6, 'Factual': 28, 'Unknown': 2}, 'AFFECTS': {'Negated': 422, 'Unknown': 283, 'Factual': 237}, 'IMPROVES': {'Factual': 205, 'Unknown': 44, 'Negated': 8}, 'HAS_COMPONENT': {'Factual': 215, 'Negated': 3}, 'INCREASES': {'Factual': 1117, 'Negated': 2, 'Unknown': 9}, 'DECREASES': {'Factual': 757, 'Negated': 9, 'Unknown': 8}, 'ASSOCIATED_WITH': {'Factual': 208, 'Negated': 21, 'Unknown': 21}, 'INTERACTS_WITH': {'Factual': 28, 'Unknown': 1}, 'POS_ASSOCIATED_WITH': {'Factual': 266, 'Negated': 2, 'Unknown': 14}, 'PREDISPOSES': {'Unknown': 6, 'Factual': 25}, 'CAUSES': {'Factual': 32, 'Negated': 1}, 'PREVENTS': {'Unknown': 15, 'Factual': 24}, 'NEG_ASSOCIATED_WITH': {'Factual': 187}})


In [13]:
factuality_count_dict = defaultdict(int)    
for rel_type, factualities in factuality_type_stats.items():
    for fact_type, cnt in factualities.items():
        factuality_count_dict[fact_type] += cnt
        
print(factuality_count_dict)
print(sum(factuality_count_dict.values()))

defaultdict(<class 'int'>, {'Negated': 474, 'Factual': 3329, 'Unknown': 403})
4206


In [14]:
print(normalization_stats)

defaultdict(<class 'dict'>, {'Food': {'MESH': 124, 'FOODON': 491, 'CHEBI': 12, 'NCIT': 215, 'OCHV': 50}, 'Metabolite': {'CHEBI': 1000, 'NCIT': 24, 'MESH': 73}, 'Microorganism': {'NCBITAXON': 1780, 'OCHV': 19, 'NCIT': 60, 'MESH': 5}, 'Chemical': {'NCIT': 486, 'CHEBI': 398, 'MESH': 148, 'OCHV': 28}, 'Physiology': {'NCIT': 658, 'OCHV': 156, 'MESH': 259, 'CHEBI': 10}, 'Nutrient': {'FOODON': 34, 'NCIT': 602, 'CHEBI': 732, 'MESH': 305, 'NCBITAXON': 12, 'OCHV': 6}, 'Measurement': {'NCIT': 426, 'OCHV': 39, 'MESH': 93, 'CHEBI': 1}, 'Disease': {'NCIT': 472, 'MESH': 119, 'OCHV': 4}, 'DietPattern': {'OCHV': 225, 'MESH': 119, 'NCIT': 182, 'CHEBI': 1, 'FOODON': 72}, 'Enzyme': {'NCIT': 77, 'MESH': 120, 'CHEBI': 1}, 'Gene': {'GENE': 121, 'NCIT': 10}, 'DiversityMetric': {'NCIT': 24}})


In [15]:
normalization_count_dict = defaultdict(int)    
for ent_type, sources in normalization_stats.items():
    for source, cnt in sources.items():
        normalization_count_dict[source] += cnt
        
print(normalization_count_dict)
print(sum(normalization_count_dict.values()))

defaultdict(<class 'int'>, {'MESH': 1365, 'FOODON': 597, 'CHEBI': 2155, 'NCIT': 3236, 'OCHV': 527, 'NCBITAXON': 1792, 'GENE': 121})
9793


In [16]:
print(disjoint_entity_stats, "\n")
print(f"% of disjoint entities: {round(sum(disjoint_entity_stats.values())/num_entity*100, 2)}%")

defaultdict(<class 'int'>, {'Food': 15, 'Physiology': 90, 'Measurement': 19, 'Microorganism': 7, 'Metabolite': 32, 'Nutrient': 18, 'DietPattern': 16, 'Gene': 7, 'DiversityMetric': 5, 'Enzyme': 4, 'AFFECTS': 1, 'Chemical': 6}) 

% of disjoint entities: 1.52%


In [17]:
n_entity_abs = 0
for k, v in entity_per_doc_abs_stats.items():
    n_entity_abs += sum(v.values())
    
n_entity_res = 0
for k, v in entity_per_doc_result_stats.items():
    n_entity_res += sum(v.values())
    
n_trigger_abs = 0
for k, v in trigger_per_doc_abs_stats.items():
    n_trigger_abs += sum(v.values())
    
n_trigger_res = 0
for k, v in trigger_per_doc_result_stats.items():
    n_trigger_res += sum(v.values())
    
n_relation_abs = 0
for k, v in relation_per_doc_abs_stats.items():
    n_relation_abs += sum(v.values())
    
n_relation_res = 0
for k, v in relation_per_doc_result_stats.items():
    n_relation_res += sum(v.values())
    
n_fact_abs_dict = defaultdict(int)
for k, v in factuality_per_doc_abs_stats.items():
    for k_, v_ in v.items():
        n_fact_abs_dict[k_] += v_
    
n_fact_res_dict = defaultdict(int)
for k, v in factuality_per_doc_result_stats.items():
    for k_, v_ in v.items():
        n_fact_res_dict[k_] += v_
    
n_norm_abs_dict = defaultdict(int)
for k, v in normalization_per_doc_abs_stats.items():
    for k_, v_ in v.items():
        n_norm_abs_dict[k_] += v_
    
n_norm_res_dict = defaultdict(int)
for k, v in normalization_per_doc_result_stats.items():
    for k_, v_ in v.items():
        n_norm_res_dict[k_] += v_
    
print(n_entity_abs, n_entity_res, n_trigger_abs, n_trigger_res, n_relation_abs, n_relation_res)
print()
print(n_fact_abs_dict, sum(n_fact_abs_dict.values()))
print(n_fact_res_dict, sum(n_fact_res_dict.values()))
print()
print(n_norm_abs_dict, sum(n_norm_abs_dict.values()))
print(n_norm_res_dict, sum(n_norm_res_dict.values()))

6388 5570 1536 701 2831 1375

defaultdict(<class 'int'>, {'Negated': 256, 'Unknown': 388, 'Factual': 2187}) 2831
defaultdict(<class 'int'>, {'Negated': 218, 'Factual': 1142, 'Unknown': 15}) 1375

defaultdict(<class 'int'>, {'MESH': 754, 'CHEBI': 1126, 'FOODON': 328, 'NCBITAXON': 724, 'NCIT': 1675, 'OCHV': 306, 'GENE': 34}) 4947
defaultdict(<class 'int'>, {'CHEBI': 1029, 'NCBITAXON': 1068, 'OCHV': 221, 'NCIT': 1561, 'MESH': 611, 'GENE': 87, 'FOODON': 269}) 4846


In [18]:
doc_abs_stats = {}
for k, v in relation_per_doc_abs_stats.items():
    doc_abs_stats[k] = v
    doc_abs_stats[k].update(entity_per_doc_abs_stats[k])
    doc_abs_stats[k].update(factuality_per_doc_abs_stats[k])

## Split dataset and store the filenames
- Stratify classes with 7:1:2

In [None]:
doc_count_per_ner_rel = {}
for label in relation_type_stats:
    doc_count_per_ner_rel[label] = defaultdict(int)
for label in entity_type_stats:
    if label not in IGNORED_CLASSES:
        doc_count_per_ner_rel[label] = defaultdict(int)
    
for doc_id, counts in doc_abs_stats.items():
    for label, cnt in counts.items():
        if label in ['Factual', 'Negated', 'Unknown']:
            continue
        doc_count_per_ner_rel[label][doc_id] = cnt

In [22]:
import random

def manual_split(inv_stats):

    predefined_split = {
        'train': [],
        'dev': ['19224658', '22099384', '23576043', '29665619', '30916575', 
                '34004416', '34143954', '34256014', '35654220', '36067589', 
                '36076452', 'PMC4994979', 'PMC5131798', 'PMC8839280'],
        'test': ['12004211', '20113315', '20591206', '25527750', '27052535', 
                 '31085979', '35613674', '35643872', '35748920', '35833889', 
                 'PMC3869907', 'PMC8942430', 'PMC9182596', 'PMC9183096', 'PMC9199182', 
                 'PMC9239261', 'PMC9279853'],
    }
    
    rare_labels = ['WORSENS', 'INTERACTS_WITH', 'PREDISPOSES', 'CAUSES', 'PREVENTS',
                  'Enzyme', 'Gene', 'DiversityMetric']
    
    for label in rare_labels:
        for fn in inv_stats[label]:
            if fn not in predefined_split['train'] + predefined_split['dev'] + predefined_split['test']:
                predefined_split['train'].append(fn)
    
    return predefined_split

def split_dataset(stats, inv_stats, result_stats, seed=42):
    
    predefined_split = manual_split(inv_stats)
    
    filenames = list(stats.keys())
    
    total_rel = 0
    for k, v in stats.items():
        total_rel += sum(v.values())
    
    train_rel = 0
    dev_rel = 0
    test_rel = 0
    train_res_rel = 0
    
    train_rel_dict = defaultdict(int)
    dev_rel_dict = defaultdict(int)
    test_rel_dict = defaultdict(int)
    result_rel_dict = defaultdict(int)
    
    split_fn = {"train":[], "dev":[], "test":[], "train_result":[]}
    for key, files in predefined_split.items():
        for fn in files:
            if key == "train":
                train_rel += sum(stats[fn].values())
                for k, v in stats[fn].items():
                    train_rel_dict[k] += v
            elif key == "dev":
                dev_rel += sum(stats[fn].values())
                for k, v in stats[fn].items():
                    dev_rel_dict[k] += v
            else:
                test_rel += sum(stats[fn].values())
                for k, v in stats[fn].items():
                    test_rel_dict[k] += v
            split_fn[key].append(fn)
            filenames.remove(fn)
    
#     print(f"## {len(filenames)} are left after processing predefined split for dev and test ##")
    
    random.seed(seed)
    random.shuffle(filenames)
    
    for fn in filenames:
        if dev_rel < total_rel*0.1:
            split_fn['dev'].append(fn)
            dev_rel += sum(stats[fn].values())
            for k, v in stats[fn].items():
                dev_rel_dict[k] += v
        elif test_rel < total_rel*0.2:
            split_fn['test'].append(fn)
            test_rel += sum(stats[fn].values())
            for k, v in stats[fn].items():
                test_rel_dict[k] += v
        else:
            split_fn['train'].append(fn)
            train_rel += sum(stats[fn].values())
            for k, v in stats[fn].items():
                train_rel_dict[k] += v
                
    train_rel_dict = dict(sorted([(k, v) for k, v in train_rel_dict.items()], key=lambda x: x[0]))
    dev_rel_dict = dict(sorted([(k, v) for k, v in dev_rel_dict.items()], key=lambda x: x[0]))
    test_rel_dict = dict(sorted([(k, v) for k, v in test_rel_dict.items()], key=lambda x: x[0]))
            
    # For RESULT part
    for fn, v in result_stats.items():
        if v:
            split_fn['train_result'].append(fn)
            train_res_rel += sum(result_stats[fn].values())
            for k_, v_ in result_stats[fn].items():
                result_rel_dict[k_] += v_
    
    result_rel_dict = dict(sorted([(k, v) for k, v in result_rel_dict.items()], key=lambda x: x[0]))

    return split_fn, train_rel_dict, dev_rel_dict, test_rel_dict, result_rel_dict

In [23]:
flag = True
for seed in range(5):
    split_filenames, train_rel_dict, dev_rel_dict, test_rel_dict, result_rel_dict = split_dataset(
        doc_abs_stats, doc_count_per_ner_rel, relation_per_doc_result_stats, seed=seed
    )
    try:
        for k in train_rel_dict:
            total = train_rel_dict[k]+dev_rel_dict[k]+test_rel_dict[k]
            if train_rel_dict[k]/total < 0.6 or dev_rel_dict[k]/total < 0.03:
                flag = False
                break
    except:
#         print(k)
        flag = False
        continue
    if flag:
        print("## All labels satisfy the threshold for dataset split ##")
        print(seed)
        break

0


In [24]:
for k, v in split_filenames.items():
    v.sort()

In [26]:
print(train_rel_dict, '\n')
print(dev_rel_dict, '\n')
print(test_rel_dict, '\n')

{'AFFECTS': 442, 'ASSOCIATED_WITH': 123, 'CAUSES': 21, 'Chemical': 343, 'DECREASES': 322, 'DietPattern': 355, 'Disease': 245, 'DiversityMetric': 42, 'Enzyme': 46, 'Factual': 1566, 'Food': 447, 'Gene': 48, 'HAS_COMPONENT': 107, 'IMPROVES': 145, 'INCREASES': 603, 'INTERACTS_WITH': 19, 'Measurement': 258, 'Metabolite': 468, 'Microorganism': 523, 'NEG_ASSOCIATED_WITH': 59, 'Negated': 176, 'Nutrient': 801, 'POS_ASSOCIATED_WITH': 107, 'PREDISPOSES': 18, 'PREVENTS': 28, 'Physiology': 771, 'Unknown': 268, 'WORSENS': 16} 

{'AFFECTS': 87, 'ASSOCIATED_WITH': 7, 'CAUSES': 3, 'Chemical': 70, 'DECREASES': 57, 'DietPattern': 42, 'Disease': 56, 'DiversityMetric': 7, 'Enzyme': 7, 'Factual': 203, 'Food': 83, 'Gene': 5, 'HAS_COMPONENT': 6, 'IMPROVES': 21, 'INCREASES': 71, 'INTERACTS_WITH': 2, 'Measurement': 31, 'Metabolite': 66, 'Microorganism': 92, 'NEG_ASSOCIATED_WITH': 15, 'Negated': 36, 'Nutrient': 120, 'POS_ASSOCIATED_WITH': 6, 'PREDISPOSES': 3, 'PREVENTS': 4, 'Physiology': 88, 'Unknown': 47, 'WORS

In [40]:
for k in train_rel_dict:
    total = train_rel_dict[k]+dev_rel_dict[k]+test_rel_dict[k]
    print(f"{k}({total}) >> TRAIN {train_rel_dict[k]}({train_rel_dict[k]*100/total:.2f}%) \
    | DEV {dev_rel_dict[k]}({dev_rel_dict[k]*100/total:.2f}%) | \
    TEST {test_rel_dict[k]}({test_rel_dict[k]*100/total:.2f}%)")

AFFECTS(636) >> TRAIN 442(69.50%)     | DEV 87(13.68%) |     TEST 107(16.82%)
ASSOCIATED_WITH(158) >> TRAIN 123(77.85%)     | DEV 7(4.43%) |     TEST 28(17.72%)
CAUSES(30) >> TRAIN 21(70.00%)     | DEV 3(10.00%) |     TEST 6(20.00%)
Chemical(507) >> TRAIN 343(67.65%)     | DEV 70(13.81%) |     TEST 94(18.54%)
DECREASES(480) >> TRAIN 322(67.08%)     | DEV 57(11.88%) |     TEST 101(21.04%)
DietPattern(446) >> TRAIN 355(79.60%)     | DEV 42(9.42%) |     TEST 49(10.99%)
Disease(403) >> TRAIN 245(60.79%)     | DEV 56(13.90%) |     TEST 102(25.31%)
DiversityMetric(61) >> TRAIN 42(68.85%)     | DEV 7(11.48%) |     TEST 12(19.67%)
Enzyme(69) >> TRAIN 46(66.67%)     | DEV 7(10.14%) |     TEST 16(23.19%)
Factual(2187) >> TRAIN 1566(71.60%)     | DEV 203(9.28%) |     TEST 418(19.11%)
Food(700) >> TRAIN 447(63.86%)     | DEV 83(11.86%) |     TEST 170(24.29%)
Gene(67) >> TRAIN 48(71.64%)     | DEV 5(7.46%) |     TEST 14(20.90%)
HAS_COMPONENT(154) >> TRAIN 107(69.48%)     | DEV 6(3.90%) |     TEST 4

In [None]:
# Save file names for split dataset
# Designate your own file name
split_filenames_dir = './split_filenames.json'
with open(split_filenames_dir, 'w') as json_file:
    json.dump(split_filenames, json_file, indent=4)