In [32]:
import pandas as pd
from pandas.errors import EmptyDataError
import numpy as np
import difflib as df
import spacy
from transformers import pipeline
import spacy_component
import glob
import operator
from collections import Counter

In [2]:
nlp = spacy.load('en_core_web_sm')

# Create Dict

In [3]:
def create_tagged_dict(tagged_token):
    tagged_dict = {} # creating a data dict using the tagged dataframe
    prev_sen = 0
    word_count = 0
    error_flag = 0
    for index, row in tagged_token.iterrows():
        current_sen = row["sent_index"]
        if prev_sen != current_sen:
            word_count = 0
        key = str(row["sent_index"]) + "_" + str(row["word_index"])
        tagged_dict[key] = [row["pos"], row["tag"], row["dep"], row["entity type"], row["entity mention ID"], str(row["text"])]
        word_count += 1
        prev_sen = current_sen
            
    return tagged_dict

# Sentence Entity

In [4]:
def sentence_entity(tagged_dict, sentence_index, sentence_nlp, enhanced_bool = False):
    
    entities = []
    entity = ""
    entity_index = ""
    prev_token = ""
    prev_token_index = ""
    indices = []
    entities_types = []
    key_str = []
    
    
    if enhanced_bool == True:
        for index, token in enumerate(sentence_nlp):

            if token.text == "''":
                token_text = "'"
            else:
                token_text = token.text

            key = str(sentence_index) + "_" + str(index)

            pos = tagged_dict[key][0]
            dep = str(tagged_dict[key][2])
            entity_type = tagged_dict[key][3]

            if not pd.isnull(entity_type):
                entity = entity + " " + token_text
                if entity_index != "":
                    entity_index = entity_index + "_" + str(index)
                else:
                    entity_index = str(index)

            elif pd.isnull(entity_type) and (dep == "compound" or dep.endswith("mod")):
                prev_token = prev_token + " " + token_text
                if prev_token_index != "":
                    prev_token_index = prev_token_index + "_" + str(index)
                else:
                    prev_token_index = str(index)

            elif pd.isnull(entity_type) and dep != "compound" and dep.endswith("mod") == False and entity == "":
                prev_token = ""
                prev_token_index = ""

            entity = entity.strip()

            if pd.isnull(entity_type) and entity != "":

                if prev_token != "":
                    entity = prev_token + " " + entity
                    entity_index = prev_token_index + "_" + entity_index

                entities.append(entity.strip())
                indices.append(entity_index)
                entities_types.append(entity_type.strip())
                entity = ""
                entity_index = ""
                prev_token = ""
                prev_token_index = ""
    else:
        for key in tagged_dict.keys():
            
            key_split = key.split("_")
            sentence_id = key_split[0]
            key_id = key_split[1]
            
            if str(sentence_id) == str(sentence_index):
                entity_type = tagged_dict[key][3]
                entity_id = tagged_dict[key][4]
                
                if not pd.isnull(entity_type):
                    
                    if len(indices) != 0 and indices[-1] == entity_id:
                        entities[-1] = entities[-1] + " " + tagged_dict[key][5]
                        key_str[-1] = key_str[-1] + "-" + str(key_id)
                        
                    else:
                        key_str.append(str(key_id))
                        entities.append(tagged_dict[key][5])
                        indices.append(entity_id)
                        entities_types.append(entity_type.strip())              
    return entities, indices, key_str, entities_types

# Raw Relationships

In [17]:
def find_relation(tagged_dict, sentence_index, entity_list):
    
    entities_relations = []
    
    sentence_entities = entity_list[1]
    
    entities_text = sentence_entities[0]
    entities_indices = sentence_entities[1]
    entities_key = sentence_entities[2]
    entities_type = sentence_entities[3]
    entities_len = len(entities_text)
    
    for index, entity_text in enumerate(entities_text):
        
        entity1 = ""
        entity1_index = ""
        entity2 = ""
        entity2_index = ""
        
        entity_pair_relation = ""
        entity_pair_relation_pos_pattern = ""
        
        if index == entities_len - 1:
            continue
        else:
            
            entity1 = entity_text
            entity1_index = entities_key[index].split('-')[-1]
            entity1_id = entities_indices[index]
            entity1_type = entities_type[index]
            
            entity2 = entities_text[index + 1]
            entity2_index = entities_key[index + 1].split('-')[0]
            entity2_id = entities_indices[index + 1]
            entity2_type = entities_type[index + 1]
            
            for i in range(int(entity1_index) + 1, int(entity2_index)):
                
                token_text = tagged_dict[str(sentence_index) + "_" + str(i)][5]
                
                entity_pair_relation = entity_pair_relation + " " + token_text
                
                if entity_pair_relation_pos_pattern != "" and entity_pair_relation_pos_pattern != 'PUNCT':
                    tmp = tagged_dict[str(sentence_index) + "_" + str(i)][0]
                    if tmp != 'PUNCT':
                        entity_pair_relation_pos_pattern = entity_pair_relation_pos_pattern + "-" + tagged_dict[str(sentence_index) + "_" + str(i)][0]
                else:
                    tmp = tagged_dict[str(sentence_index) + "_" + str(i)][0]
                    if tmp != 'PUNCT':
                        entity_pair_relation_pos_pattern = tagged_dict[str(sentence_index) + "_" + str(i)][0]
                    
        entity_pair_relation = entity1 + " " + entity_pair_relation.strip() + " " + entity2
        entity_pair_type = entity1_type + "-" + entity2_type
        entities_relations.append([[entity1, entity2], [entity1_id, entity2_id, entity_pair_type], [entity_pair_relation, entity_pair_relation_pos_pattern]])
    
    return entities_relations

# Global Snowball

In [35]:
def cal_pos_pattern_count(snowball_index, pos_pattern_list):
    for index, pos_pattern in enumerate(pos_pattern_list):
        if pos_pattern in snowball_index:
            snowball_index[pos_pattern] += 1
        elif pos_pattern != pos_pattern:
            pass
        else:
            snowball_index[pos_pattern] = 1

    return snowball_index

In [20]:
def generate_snowball(files):
    
    snowball_index = {}

    for file in files:
        try:
            file_tagged_truth = pd.read_csv(file, sep = "\t")
            if len(file_tagged_truth) != 0:
                pos_pattern_list = file_tagged_truth.pos_pattern.tolist()
                snowball_index = cal_pos_pattern_count(snowball_index, pos_pattern_list)
        except EmptyDataError:
            print(f"No columns to parse from file {file}")
            
    return snowball_index

In [21]:
def cal_rank(snowball_index):
    
    snowball_pos_pattern = {}
    total_index = 0
    
    for value in snowball_index.values():
        total_index += value
    
    for key, value in snowball_index.items():
        snowball_pos_pattern[key] = value/total_index
    
    return snowball_pos_pattern, total_index

In [22]:
def sort_dict(snowball_pos_pattern):
    return dict(sorted(snowball_pos_pattern.items(), key=operator.itemgetter(1), reverse=True))

In [23]:
def remove_out(snowball_pos_pattern):
    snowball_pos_pattern_values = np.array(list(snowball_pos_pattern.values()))
    snowball_pos_pattern_keys = list(snowball_pos_pattern.keys())
    
    mean = np.mean(snowball_pos_pattern_values, axis=0)
    sd = np.std(snowball_pos_pattern_values, axis=0)
    
    final_list = [x for x in snowball_pos_pattern_values if (x >= mean - 2 * sd)]
    final_list = [x for x in final_list if (x <= mean + 2 * sd)]
    
    snowball_pos_pattern_keys = snowball_pos_pattern_keys[:len(final_list)]
    
    snowball_pos_pattern_filter = dict(zip(snowball_pos_pattern_keys, final_list))
    
    return snowball_pos_pattern_filter

In [40]:
def re_rank_snowball(snowball_pos_pattern, snowball_key, total_index):
    
    for key in snowball_pos_pattern.keys():
        
        if snowball_key == key:
            snowball_pos_pattern[key] = ((snowball_pos_pattern[key]*(total_index)) + 1)/(total_index + 1)
        else:
            snowball_pos_pattern[key] = (snowball_pos_pattern[key]*(total_index))/(total_index + 1)
            
    return snowball_pos_pattern
    

In [41]:
def check_pos_pattern_and_update(snowball_pos_pattern, pos_pattern, total_index, re_compute = False):
    
    if pos_pattern in snowball_pos_pattern.keys():
        score = snowball_pos_pattern[pos_pattern]
        if re_compute == True:
            print("\n=====Matched with SnowBall -- Re-Computing Snowball=====\n")
            snowball_pos_pattern = re_rank_snowball(snowball_pos_pattern, pos_pattern, total_index)
        return score, snowball_pos_pattern
    else:
        return False, snowball_pos_pattern

# Main Script

In [6]:
train_annotated_0_truth = pd.read_csv("../Data/relex_processed_data/relex/docred/ground_truth/train_annotated_0.csv", sep = "\t")
train_annotated_0_tagged = pd.read_csv("../Data/relex_processed_data/relex/docred/tagged_tokens/train_annotated_0.csv", sep = "\t")

In [7]:
train_annotated_0_text = open("../Data/relex_processed_data/relex/docred/raw_text/train_annotated_0.txt", "r", encoding="utf8")

In [8]:
train_annotated_0_text_list = train_annotated_0_text.readlines()

In [24]:
train_annotated_0_text_list = [sentence for sentence in train_annotated_0_text_list if sentence != '\n']
train_annotated_0_text_list = [sentence.replace('\n', '') for sentence in train_annotated_0_text_list]

In [10]:
train_annotated_0_dict = create_tagged_dict(train_annotated_0_tagged)

In [11]:
sentences_entity_list = []

for index, sentence in enumerate(train_annotated_0_text_list):
    sentence_nlp = nlp(sentence)
    entities, indices, key_str, entities_types = sentence_entity(train_annotated_0_dict, index, sentence_nlp, False)
    sentences_entity_list.append([index, [entities, indices, key_str, entities_types]])

In [13]:
for index, entity_list in enumerate(sentences_entity_list):
    print(entity_list[0])
    print(train_annotated_0_text_list[entity_list[0]])
    print(entity_list[1])
    print("\n")

0
Zest Airways , Inc. operated as AirAsia Zest ( formerly Asian Spirit and Zest Air ) , was a low - cost airline based at the Ninoy Aquino International Airport in Pasay City , Metro Manila in the Philippines .
[['Zest Airways , Inc.', 'AirAsia Zest', 'Asian Spirit and Zest Air', 'Ninoy Aquino International Airport', 'Pasay City', 'Metro Manila', 'Philippines'], ['train_annotated_0_0_0', 'train_annotated_0_0_2', 'train_annotated_0_0_1', 'train_annotated_0_1_1', 'train_annotated_0_2_0', 'train_annotated_0_3_0', 'train_annotated_0_4_0'], ['0-1-2-3', '6-7', '10-11-12-13-14', '26-27-28-29', '31-32', '34-35', '38'], ['ORG', 'ORG', 'ORG', 'LOC', 'LOC', 'LOC', 'LOC']]


1
It operated scheduled domestic and international tourist services , mainly feeder services linking Manila and Cebu with 24 domestic destinations in support of the trunk route operations of other airlines .
[['Manila', 'Cebu', '24'], ['train_annotated_0_5_0', 'train_annotated_0_6_0', 'train_annotated_0_7_0'], ['53', '55', '57

## Generate Snowball

In [26]:
docred_truth_dir = "../Data/relex_processed_data/relex/docred/ground_truth/*.csv"
docred_truth_dir_files = glob.glob(docred_truth_dir)

In [36]:
snowball_index = generate_snowball(docred_truth_dir_files)
snowball_pos_pattern, total_index = cal_rank(snowball_index)
snowball_pos_pattern = sort_dict(snowball_pos_pattern)
snowball_pos_pattern_filter = remove_out(snowball_pos_pattern)

In [39]:
snowball_pos_pattern_filter

{'ADP': 0.003048780487804878,
 'PROPN': 0.0029563932002956393,
 'VERB': 0.0029563932002956393,
 'NOUN-ADP': 0.002771618625277162,
 'NOUN': 0.0025868440502586844,
 'ADP-DET': 0.0024944567627494456,
 'VERB-DET': 0.0024944567627494456,
 'PART': 0.002402069475240207,
 'NUM-SYM': 0.002402069475240207,
 'VERB-VERB-ADP': 0.002402069475240207,
 'CCONJ': 0.0022172949002217295,
 'ADP-PROPN': 0.0021249076127124907,
 'VERB-ADP': 0.0021249076127124907,
 'PROPN-PROPN': 0.0020325203252032522,
 'NUM-PROPN': 0.0019401330376940134,
 'NOUN-NOUN': 0.0019401330376940134,
 'VERB-DET-NOUN-ADP': 0.0018477457501847746,
 'PROPN-ADP': 0.001755358462675536,
 'DET-NOUN-ADP': 0.0016629711751662971,
 'PART-NOUN': 0.0016629711751662971,
 'ADP-DET-NOUN-ADP': 0.0016629711751662971,
 'ADP-PROPN-PROPN': 0.0015705838876570585,
 'NOUN-ADP-DET': 0.0014781966001478197,
 'NOUN-ADP-PROPN': 0.0014781966001478197,
 'VERB-NUM-ADP': 0.0014781966001478197,
 'DET': 0.001385809312638581,
 'VERB-DET-NOUN-ADP-DET': 0.001293422025129342

In [49]:
valid_relationships = []
valid_ids = []

for index, entity_list in enumerate(sentences_entity_list):
    print("==========")
    entities_relations = find_relation(train_annotated_0_dict, entity_list[0], entity_list)
    print(train_annotated_0_text_list[entity_list[0]])
    print(entity_list[1])
    print("------")
    for entity_relation in entities_relations:
        print(entity_relation)
        score, snowball_pos_pattern = check_pos_pattern_and_update(snowball_pos_pattern_filter, entity_relation[2][1], total_index)
        print("POS Pattern Score: ", score)
        if score != False:
            valid_relationships.append(entity_relation[2][0])
            valid_ids.append(entity_relation[1][:2])
    print("------")
    print("==========")
    print("\n")

Zest Airways , Inc. operated as AirAsia Zest ( formerly Asian Spirit and Zest Air ) , was a low - cost airline based at the Ninoy Aquino International Airport in Pasay City , Metro Manila in the Philippines .
[['Zest Airways , Inc.', 'AirAsia Zest', 'Asian Spirit and Zest Air', 'Ninoy Aquino International Airport', 'Pasay City', 'Metro Manila', 'Philippines'], ['train_annotated_0_0_0', 'train_annotated_0_0_2', 'train_annotated_0_0_1', 'train_annotated_0_1_1', 'train_annotated_0_2_0', 'train_annotated_0_3_0', 'train_annotated_0_4_0'], ['0-1-2-3', '6-7', '10-11-12-13-14', '26-27-28-29', '31-32', '34-35', '38'], ['ORG', 'ORG', 'ORG', 'LOC', 'LOC', 'LOC', 'LOC']]
------
[['Zest Airways , Inc.', 'AirAsia Zest'], ['train_annotated_0_0_0', 'train_annotated_0_0_2', 'ORG-ORG'], ['Zest Airways , Inc. operated as AirAsia Zest', 'VERB-ADP']]
POS Pattern Score:  0.0021249076127124907
[['AirAsia Zest', 'Asian Spirit and Zest Air'], ['train_annotated_0_0_2', 'train_annotated_0_0_1', 'ORG-ORG'], ['Air

# Classify Relationships

In [52]:
nlp_re = spacy.load('en_core_web_sm')
DEVICE = -1
nlp_re.add_pipe("rebel", config={
    'device':DEVICE, # Number of the GPU, -1 if want to use CPU
    'model_name':'Babelscape/rebel-large'} # Model used, will default to 'Babelscape/rebel-large' if not given
    )

<spacy_component.RebelComponent at 0x1c784d06eb0>

In [53]:
valid_rel = []
for index, relationships in enumerate(valid_relationships):
    doc = nlp_re(relationships)
    rel = []
    for value, rel_dict in doc._.rel.items():
        print(f"{relationships} : {value}: {rel_dict['relation']}")
        rel.append(rel_dict['relation'])
    valid_rel.append(rel)

Zest Airways , Inc. operated as AirAsia Zest : (6, 0): subsidiary
AirAsia Zest ( formerly Asian Spirit and Zest Air : (0, 4): replaces
AirAsia Zest ( formerly Asian Spirit and Zest Air : (4, 0): replaced by
Ninoy Aquino International Airport in Pasay City : (0, 5): place served by transport hub
Metro Manila in the Philippines : (0, 4): country
Metro Manila in the Philippines : (4, 0): contains administrative territorial entity
Manila and Cebu : (0, 2): twinned administrative body
Manila and Cebu : (2, 0): twinned administrative body
Cebu with 24 : (0, 2): population
Asian Spirit , the first airline in the Philippines : (0, 8): country
AirAsia and Zest Air : (0, 2): subsidiary
AirAsia and Zest Air : (2, 0): parent organization
AirAsia Philippines in January 2016 : (0, 3): inception


In [55]:
valid_filt_ids = []
valid_filt_rel = []
valid_filt_relationships = []
for index, rels in enumerate(valid_rel):
    for rel in enumerate(rels):
        print(valid_ids[index], rel[1])
        valid_filt_ids.append(valid_ids[index])
        valid_filt_rel.append(rel[1])
        valid_filt_relationships.append(valid_relationships[index])

['train_annotated_0_0_0', 'train_annotated_0_0_2'] subsidiary
['train_annotated_0_0_2', 'train_annotated_0_0_1'] replaces
['train_annotated_0_0_2', 'train_annotated_0_0_1'] replaced by
['train_annotated_0_1_1', 'train_annotated_0_2_0'] place served by transport hub
['train_annotated_0_3_0', 'train_annotated_0_4_0'] country
['train_annotated_0_3_0', 'train_annotated_0_4_0'] contains administrative territorial entity
['train_annotated_0_5_0', 'train_annotated_0_6_0'] twinned administrative body
['train_annotated_0_5_0', 'train_annotated_0_6_0'] twinned administrative body
['train_annotated_0_6_0', 'train_annotated_0_7_0'] population
['train_annotated_0_10_0', 'train_annotated_0_4_1'] country
['train_annotated_0_14_0', 'train_annotated_0_12_1'] subsidiary
['train_annotated_0_14_0', 'train_annotated_0_12_1'] parent organization
['train_annotated_0_15_0', 'train_annotated_0_16_0'] inception


# Benchmark

In [59]:
def sort_entity_ids(entity_ids):
    
    for ids in entity_ids:
        ids.sort()
        
    return entity_ids

In [60]:
def precision(true_positives, false_positives):
    print("Precision: ", true_positives/(true_positives + false_positives))
    return true_positives/(true_positives + false_positives)

In [61]:
def recall(true_positives, false_negatives):
    print("Recall: ", true_positives/(true_positives + false_negatives))
    return true_positives/(true_positives + false_negatives)

In [70]:
def benchmark(valid_relationships, valid_ids, valid_rel, ground_truth, rel_type_out):
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    
    valid_relationships = list(map(str, valid_relationships))
    ground_truth = ground_truth[ground_truth["pos_pattern"].notna()]
    ground_truth_relationships = list(map(str, ground_truth["span"].tolist()))
    ground_truth_relationships_type = list(map(str, ground_truth["rel type"].tolist()))
    
    ground_truth_id = ground_truth[["entity 1 mention ID", "entity 2 mention ID"]].values.tolist()
    ground_truth_id = sort_entity_ids(ground_truth_id)
    valid_fixed_ids = []
    for ids in valid_ids: 
        valid_fixed_ids.append(ids[:2])
    valid_ids = sort_entity_ids(valid_fixed_ids)
    pred_rel = ''
    ground_truth_rel = ''
    
    print("\n")
    print("DETECTED RELATIONSHIPS: ")
    [print(relationships) for relationships in valid_relationships]
    print("\n")
    print("GROUND TRUTH RELATIONSHIPS: ")
    [print(relationships) for relationships in ground_truth_relationships]
    print("\n")

    pred_grd_rel_type = []
    for index, ids in enumerate(valid_ids):
        if ids in ground_truth_id:
            if len(valid_rel) != 0:
                id_index = ground_truth_id.index(ids)
                ground_truth_rel = ground_truth_relationships_type[id_index]
                pred_rel = valid_rel[index]
                pred_grd_rel_type.append([pred_rel, ground_truth_rel])
                if ground_truth_rel not in rel_type_out:
                    if ground_truth_rel == pred_rel:
                        print("\nMatched Relationship: %s with relationship: %s" %(valid_relationships[index], pred_rel))
                        true_positives += 1
                    else:
                        print("\nIndex matched for ''%s'' but relationship type not matched with PREDICTED as: %s and ACTUAL as: %s" %(valid_relationships[index], pred_rel, ground_truth_rel))
                        false_positives += 1
                else:
                    print("Ignoring less significant relationalship type ''%s''" %(ground_truth_rel))
            else:
                print("\nMatched Relationship: %s" %(valid_relationships[index]))
                true_positives += 1
        else:
            pred_grd_rel_type.append([valid_rel[index], 'NA'])
            false_positives += 1
    
    for index, ids in enumerate(ground_truth_id):
        if ids not in valid_ids:
            false_negatives += 1
            
    return true_positives, false_positives, false_negatives, pred_grd_rel_type

In [93]:
def benchmark_dataset(tagged_token_files, ground_truth_files, raw_text_files, snowball_pos_pattern, total_index, re_compute = False, relations = False):
    
    all_precision = []
    all_recall = []
    
    all_valid_relationships = []
    all_valid_ids = []
    all_valid_rel = []
    all_valid_ents = []
    all_valid_ents_type = []
    all_pred_grd_rel_type = []
    
    for index, file in enumerate(tagged_token_files):
        
        try: 
            print("==================Start of File %d==================" % index)
            print("Reading Files: " + raw_text_files[index] + "\n- " + file + "\n- " + ground_truth_files[index])
            print("\n")
            tagged_token = pd.read_csv(file, sep = "\t")
            ground_truth = pd.read_csv(ground_truth_files[index], sep = "\t")
            raw_text = open(raw_text_files[index], "r", encoding="utf8")
            raw_text_list = raw_text.readlines()
            raw_text_list = [sentence for sentence in raw_text_list if sentence != '\n']
            raw_text_list = [sentence.replace(' \n', '') for sentence in raw_text_list]
            
            tagged_token_dict = create_tagged_dict(tagged_token)

            sentences_entity_list = []

            for index, sentence in enumerate(raw_text_list):
                sentence_nlp = nlp(sentence)
                
                entities, indices, key_str, entities_types = sentence_entity(tagged_token_dict, index, sentence_nlp, False)
                # sent_roots, sent_roots_index = extract_root(tagged_token_dict, index, sentence_nlp)
                sentences_entity_list.append([index, [entities, indices, key_str, entities_types]])
                
            # sentences_entity_list = filter_sentences(sentences_entity_list)

            valid_relationships = []
            valid_ids = []
            valid_pos = []
            valid_rel = []
            valid_ents = []
            valid_ents_typ = []
            
            
            for index, entity_list in enumerate(sentences_entity_list):
                entities_relations = find_relation(tagged_token_dict, entity_list[0], entity_list)
                print("===============")
                print(raw_text_list[entity_list[0]])
                # print(entity_list[1])
                print("+++++")
                for entity_relation in entities_relations:
                    # print(entity_relation)
                    score, snowball_pos_pattern = check_pos_pattern_and_update(snowball_pos_pattern, entity_relation[2][1], total_index, re_compute)
                    # print("POS Pattern Score: ", score)
                    if score != False:
                        valid_ents.append([entity_relation[0][0], entity_relation[0][1]])
                        valid_ents_typ.append(entity_relation[1][2])
                        
                        valid_pos.append(entity_relation[2][1])
                        valid_ids.append(entity_relation[1][:2])
                        relationships = entity_relation[2][0]
                        valid_relationships.append(relationships)
                        if relations == True:
                            doc = nlp_re(relationships)
                            rel = []
                            for value, rel_dict in doc._.rel.items():
                                print(f"Detected relationship in '{relationships}' of category '{rel_dict['relation']}'")
                                rel.append(rel_dict['relation'])
                            valid_rel.append(rel)
                            
                print("===============\n")
            valid_filt_ids = []
            valid_filt_rel = []
            valid_filt_ent = []
            valid_filt_ent_typ = []
            valid_filt_relationships = []
            for index, rels in enumerate(valid_rel):
                for rel in enumerate(rels):
                    # print(valid_ids[index], rel[1])
                    valid_filt_ids.append(valid_ids[index])
                    valid_filt_rel.append(rel[1])
                    valid_filt_relationships.append(valid_relationships[index])
                    valid_filt_ent.append(valid_ents[index])
                    valid_filt_ent_typ.append(valid_ents_typ[index])
            

            if "span" in ground_truth:
                true_positives, false_positives, false_negatives, _ = benchmark(valid_filt_relationships, valid_filt_ids, valid_filt_rel, ground_truth[["rel type", "entity 1 mention ID", "entity 2 mention ID", "pos_pattern", "span"]], [])

                print("\n")
                print("BENCHMARK:")
                # print(len(pred_grd_rel_type))

                print("True Positives: %d" % true_positives)
                print("False Positives: %d" % false_positives)
                print("False Negatives: %d" % false_negatives)

                print("\n")
                
                all_valid_relationships.append(valid_filt_relationships)
                all_valid_ids.append(valid_filt_ids)
                all_valid_rel.append(valid_filt_rel)
                all_valid_ents.append(valid_filt_ent)
                all_valid_ents_type.append(valid_filt_ent_typ)
                # all_pred_grd_rel_type.append(pred_grd_rel_type)

                if((true_positives + false_positives) != 0):
                    all_precision.append(precision(true_positives, false_positives))
                else:
                     all_precision.append(0)

                if((true_positives + false_negatives) != 0):
                    all_recall.append(recall(true_positives, false_negatives))
                else:
                    all_recall.append(0)
                    
            else:
                print("\nCan't benchmark ground truth file empty")


            print("==================End of File==================")

            print("\n")
        
        except EmptyDataError:
            print("One of the csv file is empty")
        
    if(len(all_precision) > 0 and len(all_recall) > 0):
        
        print("==================Final Result==================")

        print("\n")

        print("Average Precision %f" %np.mean(all_precision))
        print("Average Recall %f" %np.mean(all_recall))
        
        return all_precision, all_recall, all_valid_relationships, all_valid_ids, all_valid_rel, all_valid_ents, all_valid_ents_type

        print("\n")
        
        print("===============================================")

In [66]:
true_positives, false_positives, false_negatives, _ = benchmark(valid_filt_relationships, valid_filt_ids, valid_filt_rel, train_annotated_0_truth[["rel type", "entity 1 mention ID", "entity 2 mention ID", "pos_pattern", "span"]], [])



DETECTED RELATIONSHIPS: 
Zest Airways , Inc. operated as AirAsia Zest
AirAsia Zest ( formerly Asian Spirit and Zest Air
AirAsia Zest ( formerly Asian Spirit and Zest Air
Ninoy Aquino International Airport in Pasay City
Metro Manila in the Philippines
Metro Manila in the Philippines
Manila and Cebu
Manila and Cebu
Cebu with 24
Asian Spirit , the first airline in the Philippines
AirAsia and Zest Air
AirAsia and Zest Air
AirAsia Philippines in January 2016


GROUND TRUTH RELATIONSHIPS: 
Pasay City , Metro Manila in the Philippines
Metro Manila in the Philippines
Asian Spirit , the first airline in the Philippines



Index matched for ''Metro Manila in the Philippines'' but relationship type not matched with PREDICTED as: country and ACTUAL as: contains administrative territorial entity

Matched Relationship: Metro Manila in the Philippines with relationship: contains administrative territorial entity

Matched Relationship: Asian Spirit , the first airline in the Philippines with relatio

In [67]:
precision(true_positives, false_positives)

Precision:  0.15384615384615385


0.15384615384615385

In [68]:
recall(true_positives, false_negatives)

Recall:  0.6666666666666666


0.6666666666666666

# Benchmarking docred dataset

In [69]:
tagged_token_files_docred = glob.glob("../Data/relex_processed_data/relex/docred/tagged_tokens/*")
ground_truth_files_docred = glob.glob("../Data/relex_processed_data/relex/docred/ground_truth/*")
raw_text_files_docred = glob.glob("../Data/relex_processed_data/relex/docred/raw_text/*")

## 60% Sampling of all files

In [73]:
tagged_token_files_docred_sample = tagged_token_files_docred[:int(len(tagged_token_files_docred) * .60)]
ground_truth_files_docred_sample = ground_truth_files_docred[:int(len(ground_truth_files_docred) * .60)]
raw_text_files_docred_sample = raw_text_files_docred[:int(len(raw_text_files_docred) * .60)]

## 40% Sampling of all files for testing

In [74]:
tagged_token_files_docred_sample_test = tagged_token_files_docred[int(len(tagged_token_files_docred) * .60):]
ground_truth_files_docred_sample_test = ground_truth_files_docred[int(len(ground_truth_files_docred) * .60):]
raw_text_files_docred_sample_test = raw_text_files_docred[int(len(raw_text_files_docred) * .60):]

## 70% sampled SnowBall from ground truth

In [75]:
docred_file_dir_files = glob.glob("../Data/relex_processed_data/relex/docred/ground_truth/*.csv")

In [76]:
docred_file_dir_files_sample = docred_file_dir_files[:int(len(docred_file_dir_files) * .70)]
snowball_index_docred = generate_snowball(docred_file_dir_files_sample)
snowball_pos_pattern_docred, docred_total_index  = cal_rank(snowball_index_docred)
snowball_pos_pattern_docred = sort_dict(snowball_pos_pattern_docred)
snowball_pos_pattern_docred_filter = remove_out(snowball_pos_pattern_docred)

## Benchmarking on Trained

In [94]:
all_precision_docred, all_recall_docred, all_valid_relationships, all_valid_ids, all_valid_rel, all_valid_ents, all_valid_ents_type = benchmark_dataset(tagged_token_files_docred_sample, ground_truth_files_docred_sample, raw_text_files_docred_sample, snowball_pos_pattern_docred_filter, docred_total_index, False, True)

Reading Files: ../Data/relex_processed_data/relex/docred/raw_text\train_annotated_0.txt
- ../Data/relex_processed_data/relex/docred/tagged_tokens\train_annotated_0.csv
- ../Data/relex_processed_data/relex/docred/ground_truth\train_annotated_0.csv


Zest Airways , Inc. operated as AirAsia Zest ( formerly Asian Spirit and Zest Air ) , was a low - cost airline based at the Ninoy Aquino International Airport in Pasay City , Metro Manila in the Philippines .

+++++
Detected relationship in 'Zest Airways , Inc. operated as AirAsia Zest' of category 'subsidiary'
Detected relationship in 'AirAsia Zest ( formerly Asian Spirit and Zest Air' of category 'replaces'
Detected relationship in 'AirAsia Zest ( formerly Asian Spirit and Zest Air' of category 'replaced by'
Detected relationship in 'Ninoy Aquino International Airport in Pasay City' of category 'place served by transport hub'
Detected relationship in 'Metro Manila in the Philippines' of category 'country'
Detected relationship in 'Metro Ma

KeyboardInterrupt: 

In [None]:
all_precision_docred

In [None]:
all_recall_docred