In [1]:
import os

import pandas as pd

import dist_sup_lib.utils as utils
import dist_sup_lib.rel_ext as rel_ext

from dist_sup_lib.rel_ext import RelSetup

from src.utils import read_json_examples
from src.utils import read_kb_triples
from src.utils import read_kb_triples_json

from src.rel_extract_extend.data import DatasetExt
from src.rel_extract_extend.kfold import make_kfold_val

from src.rel_extract_extend.featurizers import start_bag_of_words_featurizer
from src.rel_extract_extend.featurizers import middle_bag_of_words_featurizer
from src.rel_extract_extend.featurizers import end_bag_of_words_featurizer

In [2]:
utils.fix_random_seeds()
rel_ext_data_sents = os.path.join('data', 'featurized_sentences')
rel_ext_data_kb = os.path.join("data", "knowledge_base")

example_data = []

for index in range(1, 200):
    s_i = str(index)
    zeros = "0" * (4 - len(s_i))
    # Files updated constantly and the names changes by year
    # for files downloaded in 2021 file_naming has to be changed to
    # featurized_sents_pubmed21n
    tagged_sent_file = f"featurized_sents_pubmed20n{zeros + s_i}.json"
    file_path = os.path.join(rel_ext_data_sents, tagged_sent_file)
    example_data.extend(read_json_examples(file_path))

kb_triples = read_kb_triples_json(os.path.join(rel_ext_data_kb, "rel_drug_react_triple_occ_all.json"))
kb = rel_ext.KB(kb_triples)

corpus = rel_ext.Corpus(example_data)
dataset = DatasetExt(corpus, kb)

In [3]:
k = 5

results, train_setups, test_setups = make_kfold_val(
    dataset, 
    [
        start_bag_of_words_featurizer, 
        middle_bag_of_words_featurizer, 
        end_bag_of_words_featurizer
    ],
    avg_results=False,
    k=k,
    sampling_rate=0.1
)

{'0': Corpus with 135,453 examples; KB with 10,700 triples, '1': Corpus with 146,804 examples; KB with 9,650 triples, '2': Corpus with 108,424 examples; KB with 11,019 triples, '3': Corpus with 249,476 examples; KB with 10,295 triples, '4': Corpus with 244,064 examples; KB with 13,403 triples, 'all': Corpus with 884,221 examples; KB with 55,067 triples}




relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  1.000      1.000      1.000          3       2133
aggravated                0.753      0.761      0.757         92       2222
altered                   1.000      0.043      0.083         23       2153
caused                    0.826      0.979      0.896       8924      11054
changed                   0.875      0.609      0.718         23       2153
decreased                 0.592      0.140      0.227        734       2864
delayed                   0.900      0.692      0.783         13       2143
discoloured               0.815      0.786      0.800         28       2158
impaired                  0.650      0.867      0.743         15       2145
increased                 0.670      0.101      0.175        765       2895
infected                  0.000      0.000      0.000         12       2142
lowered     

  _warn_prf(average, modifier, msg_start, len(result))


relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  0.500      0.500      0.500          2       2121
aggravated                0.458      0.155      0.232         71       2190
altered                   0.750      0.167      0.273         18       2137
caused                    0.813      0.978      0.888       8117      10236
changed                   0.800      0.706      0.750         17       2136
decreased                 0.747      0.170      0.277        659       2778
delayed                   0.800      1.000      0.889          8       2127
discoloured               0.778      0.609      0.683         23       2142
impaired                  0.750      0.643      0.692         14       2133
increased                 0.671      0.149      0.243        659       2778
infected                  0.000      0.000      0.000          3       2122
prolonged   

In [4]:
print(results)

[{'acquired': [1.0, 1.0, 1.0, 3, 2133], 'aggravated': [0.7526881720430108, 0.7608695652173914, 0.7567567567567567, 92, 2222], 'altered': [1.0, 0.043478260869565216, 0.08333333333333333, 23, 2153], 'caused': [0.8256143667296786, 0.9788211564320932, 0.8957136997538966, 8924, 11054], 'changed': [0.875, 0.6086956521739131, 0.717948717948718, 23, 2153], 'decreased': [0.5919540229885057, 0.14032697547683923, 0.22687224669603523, 734, 2864], 'delayed': [0.9, 0.6923076923076923, 0.7826086956521738, 13, 2143], 'discoloured': [0.8148148148148148, 0.7857142857142857, 0.7999999999999999, 28, 2158], 'impaired': [0.65, 0.8666666666666667, 0.7428571428571429, 15, 2145], 'increased': [0.6695652173913044, 0.10065359477124183, 0.17500000000000002, 765, 2895], 'infected': [0.0, 0.0, 0.0, 12, 2142], 'lowered': [0.0, 0.0, 0.0, 1, 2131], 'prolonged': [0.9285714285714286, 0.6842105263157895, 0.7878787878787878, 19, 2149], 'reduced': [0.8571428571428571, 0.8571428571428571, 0.8571428571428571, 14, 2144], 'rup

In [5]:
columns = ["precision", "recall", "f-score", "support", "size"]
total_results = {}



for key in kb.all_relations:
    rel_res = []
    for part_result in results:
        res = part_result.get(key)
        if res:
            rel_res.append(part_result[key])
        else:
            rel_res.append([0] * 5)
    total_results[key] = pd.DataFrame(data=rel_res, columns=columns)

In [6]:
avg_vals = {rel: {} for rel in total_results.keys()}
for rel in total_results.keys():
    avg_precision = total_results[rel].precision.mean()
    avg_recall = total_results[rel].recall.mean()
    avg_vals[rel]["avg_precision"] = avg_precision
    avg_vals[rel]["avg_recall"] = avg_recall
    avg_vals[rel]["avg_fscore"] = (
        2 * avg_precision * avg_recall/(avg_precision + avg_recall)
        if avg_precision and avg_recall else 0
    )

In [7]:
prec_vals = []
rec_vals = []
fscore_vals = []

prec_zeros = 0
rec_zeros = 0
fscore_zeros = 0

for res in avg_vals.values():
    prec_vals.append(res["avg_precision"])
    rec_vals.append(res["avg_recall"])
    fscore_vals.append(res["avg_fscore"])
    if prec_vals[-1] == 0:
        prec_zeros += 1
    if rec_vals[-1] == 0:
        rec_zeros += 1
    if fscore_vals[-1] == 0:
        fscore_zeros += 1

In [8]:
macro_avg = {}
macro_avg_zeros_excluded = {}

zero_vals = min(prec_zeros, rec_zeros, fscore_zeros)

macro_avg["precision"] = sum(prec_vals)/len(prec_vals)
macro_avg["recall"] = sum(rec_vals)/len(rec_vals)
macro_avg["fscore"] = sum(fscore_vals)/len(fscore_vals)

macro_avg_zeros_excluded["precision"] = sum(prec_vals)/(len(prec_vals) - zero_vals)
macro_avg_zeros_excluded["recall"] = sum(rec_vals)/(len(rec_vals) - zero_vals)
macro_avg_zeros_excluded["fscore"] = sum(fscore_vals)/(len(fscore_vals) - zero_vals)

In [9]:
print(macro_avg)

{'precision': 0.6073301135431264, 'recall': 0.4275845011611931, 'fscore': 0.48081226211361716}


In [10]:
prec_vals

[0.0,
 0.9,
 0.6770306129747514,
 0.9400000000000001,
 0.8268978422232838,
 0.853961038961039,
 0.6700084836038062,
 0.8955555555555555,
 0.8097069243156201,
 0.7532467532467533,
 0.0,
 0.6611344144284833,
 0.2,
 0.0,
 0.8338180008845644,
 0.8847619047619049,
 0.2,
 0.8258205128205128]

In [11]:
print(macro_avg_zeros_excluded)

{'precision': 0.7287961362517517, 'recall': 0.5131014013934317, 'fscore': 0.5769747145363406}


In [12]:
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        'relation', 'precision', 'recall', 'f-score'))
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        '-' * 18, '-' * 9, '-' * 9, '-' * 9))

for rel, precision, recall, fscore in zip(total_results.keys(), prec_vals, rec_vals, fscore_vals):
    print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(rel, precision, recall, fscore))
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        '-' * 18, '-' * 9, '-' * 9, '-' * 9))
print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(
    "macro avg", macro_avg["precision"], macro_avg["recall"], macro_avg["fscore"]))
print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(
    "macro avg (-0 vals)", macro_avg_zeros_excluded["precision"], 
    macro_avg_zeros_excluded["recall"], macro_avg_zeros_excluded["fscore"])
     )

relation              precision     recall    f-score
------------------    ---------  ---------  ---------
accelerated               0.000      0.000      0.000
acquired                  0.900      0.833      0.865
aggravated                0.677      0.527      0.593
altered                   0.940      0.214      0.349
caused                    0.827      0.969      0.892
changed                   0.854      0.591      0.699
decreased                 0.670      0.285      0.400
delayed                   0.896      0.723      0.800
discoloured               0.810      0.630      0.709
impaired                  0.753      0.677      0.713
improved                  0.000      0.000      0.000
increased                 0.661      0.157      0.253
infected                  0.200      0.040      0.067
lowered                   0.000      0.000      0.000
prolonged                 0.834      0.622      0.713
reduced                   0.885      0.721      0.795
ruptured                  0.

In [13]:
all_relations = set(dataset.kb.all_relations)

print("number setups:", len(train_setups))
print("#################")

for i, setup in enumerate(train_setups):
    print("-------------------------")
    set_relations = set([x.relation for x in setup])
    for relation in all_relations.difference(set_relations):
        setup.append(RelSetup(relation, 0, 0))
    setup.sort()
    print(f"setup {i}, number setups:", len(setup))

number setups: 5
#################
-------------------------
setup 0, number setups: 18
-------------------------
setup 1, number setups: 18
-------------------------
setup 2, number setups: 18
-------------------------
setup 3, number setups: 18
-------------------------
setup 4, number setups: 18


In [14]:
columns = ["relation", "pos_examples", "neg_examples"]
train_setup_table = pd.DataFrame(data=train_setups)
train_setup_table

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,"(accelerated, 1, 8154)","(acquired, 23, 8154)","(aggravated, 344, 8154)","(altered, 103, 8154)","(caused, 37103, 8154)","(changed, 72, 8154)","(decreased, 3048, 8154)","(delayed, 41, 8154)","(discoloured, 126, 8154)","(impaired, 72, 8154)","(improved, 1, 8154)","(increased, 3112, 8154)","(infected, 20, 8154)","(lowered, 1, 8154)","(prolonged, 96, 8154)","(reduced, 67, 8154)","(ruptured, 14, 8154)","(shortened, 123, 8154)"
1,"(accelerated, 1, 8168)","(acquired, 24, 8168)","(aggravated, 365, 8168)","(altered, 108, 8168)","(caused, 37910, 8168)","(changed, 78, 8168)","(decreased, 3123, 8168)","(delayed, 46, 8168)","(discoloured, 131, 8168)","(impaired, 73, 8168)","(improved, 1, 8168)","(increased, 3218, 8168)","(infected, 29, 8168)","(lowered, 2, 8168)","(prolonged, 98, 8168)","(reduced, 65, 8168)","(ruptured, 17, 8168)","(shortened, 128, 8168)"
2,"(accelerated, 1, 7993)","(acquired, 25, 7993)","(aggravated, 345, 7993)","(altered, 98, 7993)","(caused, 36851, 7993)","(changed, 76, 7993)","(decreased, 3015, 7993)","(delayed, 45, 7993)","(discoloured, 113, 7993)","(impaired, 62, 7993)","(improved, 0, 0)","(increased, 3096, 7993)","(infected, 29, 7993)","(lowered, 2, 7993)","(prolonged, 90, 7993)","(reduced, 65, 7993)","(ruptured, 18, 7993)","(shortened, 117, 7993)"
3,"(accelerated, 1, 8209)","(acquired, 21, 8209)","(aggravated, 338, 8209)","(altered, 96, 8209)","(caused, 37492, 8209)","(changed, 76, 8209)","(decreased, 3053, 8209)","(delayed, 43, 8209)","(discoloured, 126, 8209)","(impaired, 70, 8209)","(improved, 1, 8209)","(increased, 3127, 8209)","(infected, 28, 8209)","(lowered, 2, 8209)","(prolonged, 95, 8209)","(reduced, 66, 8209)","(ruptured, 17, 8209)","(shortened, 120, 8209)"
4,"(accelerated, 0, 0)","(acquired, 11, 8202)","(aggravated, 352, 8202)","(altered, 99, 8202)","(caused, 34752, 8202)","(changed, 78, 8202)","(decreased, 2889, 8202)","(delayed, 41, 8202)","(discoloured, 120, 8202)","(impaired, 71, 8202)","(improved, 1, 8202)","(increased, 2955, 8202)","(infected, 22, 8202)","(lowered, 1, 8202)","(prolonged, 81, 8202)","(reduced, 61, 8202)","(ruptured, 10, 8202)","(shortened, 120, 8202)"


In [15]:
all_relations = set(dataset.kb.all_relations)

print("number setups:", len(test_setups))
print("#################")

for i, setup in enumerate(test_setups):
    print("-------------------------")
    set_relations = set([x.relation for x in setup])
    for relation in all_relations.difference(set_relations):
        setup.append(RelSetup(relation, 0, 0))
    setup.sort()
    print(f"setup {i}, number setups:", len(setup))

number setups: 5
#################
-------------------------
setup 0, number setups: 18
-------------------------
setup 1, number setups: 18
-------------------------
setup 2, number setups: 18
-------------------------
setup 3, number setups: 18
-------------------------
setup 4, number setups: 18


In [16]:
columns = ["relation", "pos_examples", "neg_examples"]
test_setup_table = pd.DataFrame(data=test_setups)
test_setup_table

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,"(accelerated, 0, 0)","(acquired, 3, 2130)","(aggravated, 92, 2130)","(altered, 23, 2130)","(caused, 8924, 2130)","(changed, 23, 2130)","(decreased, 734, 2130)","(delayed, 13, 2130)","(discoloured, 28, 2130)","(impaired, 15, 2130)","(improved, 0, 0)","(increased, 765, 2130)","(infected, 12, 2130)","(lowered, 1, 2130)","(prolonged, 19, 2130)","(reduced, 14, 2130)","(ruptured, 5, 2130)","(shortened, 29, 2130)"
1,"(accelerated, 0, 0)","(acquired, 2, 2119)","(aggravated, 71, 2119)","(altered, 18, 2119)","(caused, 8117, 2119)","(changed, 17, 2119)","(decreased, 659, 2119)","(delayed, 8, 2119)","(discoloured, 23, 2119)","(impaired, 14, 2119)","(improved, 0, 0)","(increased, 659, 2119)","(infected, 3, 2119)","(lowered, 0, 0)","(prolonged, 17, 2119)","(reduced, 16, 2119)","(ruptured, 2, 2119)","(shortened, 24, 2119)"
2,"(accelerated, 0, 0)","(acquired, 1, 2288)","(aggravated, 91, 2288)","(altered, 28, 2288)","(caused, 9176, 2288)","(changed, 19, 2288)","(decreased, 767, 2288)","(delayed, 9, 2288)","(discoloured, 41, 2288)","(impaired, 25, 2288)","(improved, 1, 2288)","(increased, 781, 2288)","(infected, 3, 2288)","(lowered, 0, 0)","(prolonged, 25, 2288)","(reduced, 16, 2288)","(ruptured, 1, 2288)","(shortened, 35, 2288)"
3,"(accelerated, 0, 0)","(acquired, 5, 2066)","(aggravated, 98, 2066)","(altered, 30, 2066)","(caused, 8535, 2066)","(changed, 19, 2066)","(decreased, 729, 2066)","(delayed, 11, 2066)","(discoloured, 28, 2066)","(impaired, 17, 2066)","(improved, 0, 0)","(increased, 750, 2066)","(infected, 4, 2066)","(lowered, 0, 0)","(prolonged, 20, 2066)","(reduced, 15, 2066)","(ruptured, 2, 2066)","(shortened, 32, 2066)"
4,"(accelerated, 1, 2079)","(acquired, 15, 2079)","(aggravated, 84, 2079)","(altered, 27, 2079)","(caused, 11275, 2079)","(changed, 17, 2079)","(decreased, 893, 2079)","(delayed, 13, 2079)","(discoloured, 34, 2079)","(impaired, 16, 2079)","(improved, 0, 0)","(increased, 922, 2079)","(infected, 10, 2079)","(lowered, 1, 2079)","(prolonged, 34, 2079)","(reduced, 20, 2079)","(ruptured, 9, 2079)","(shortened, 32, 2079)"
