In [1]:
import os

import pandas as pd

import dist_sup_lib.utils as utils
import dist_sup_lib.rel_ext as rel_ext

from dist_sup_lib.rel_ext import RelSetup

from src.utils import read_json_examples
from src.utils import read_kb_triples
from src.utils import read_kb_triples_json

from src.rel_extract_extend.data import DatasetExt
from src.rel_extract_extend.kfold import make_kfold_val

from src.rel_extract_extend.featurizers import start_bag_of_words_featurizer
from src.rel_extract_extend.featurizers import middle_bag_of_words_featurizer
from src.rel_extract_extend.featurizers import end_bag_of_words_featurizer

In [2]:
utils.fix_random_seeds()
rel_ext_data_sents = os.path.join('data', 'featurized_sentences')
rel_ext_data_kb = os.path.join("data", "knowledge_base")

example_data = []

for index in range(1, 200):
    s_i = str(index)
    zeros = "0" * (4 - len(s_i))
    # Files updated constantly and the names changes by year
    # for files downloaded in 2021 file_naming has to be changed to
    # featurized_sents_pubmed21n
    tagged_sent_file = f"featurized_sents_pubmed20n{zeros + s_i}.json"
    file_path = os.path.join(rel_ext_data_sents, tagged_sent_file)
    example_data.extend(read_json_examples(file_path))

kb_triples = read_kb_triples_json(os.path.join(rel_ext_data_kb, "rel_drug_react_triple_occ_all.json"))
kb = rel_ext.KB(kb_triples)

corpus = rel_ext.Corpus(example_data)
dataset = DatasetExt(corpus, kb)

In [3]:
k = 5

results, train_setups, test_setups = make_kfold_val(
    dataset, 
    [
        start_bag_of_words_featurizer, 
        middle_bag_of_words_featurizer, 
        end_bag_of_words_featurizer
    ],
    avg_results=False,
    k=k,
    sampling_rate=0.5
)

{'0': Corpus with 135,453 examples; KB with 10,700 triples, '1': Corpus with 146,804 examples; KB with 9,650 triples, '2': Corpus with 108,424 examples; KB with 11,019 triples, '3': Corpus with 249,476 examples; KB with 10,295 triples, '4': Corpus with 244,064 examples; KB with 13,403 triples, 'all': Corpus with 884,221 examples; KB with 55,067 triples}




relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  1.000      1.000      1.000          3      10655
aggravated                0.343      0.130      0.189         92      10744
altered                   0.000      0.000      0.000         23      10675
caused                    0.651      0.868      0.744       8924      19576
changed                   0.684      0.565      0.619         23      10675
decreased                 0.462      0.033      0.061        734      11386
delayed                   1.000      0.692      0.818         13      10665
discoloured               0.250      0.071      0.111         28      10680
impaired                  0.167      0.067      0.095         15      10667
increased                 0.686      0.031      0.060        765      11417
infected                  0.000      0.000      0.000         12      10664
lowered     

  _warn_prf(average, modifier, msg_start, len(result))


relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  1.000      0.500      0.667          2      10601
aggravated                0.360      0.127      0.188         71      10670
altered                   0.250      0.056      0.091         18      10617
caused                    0.627      0.867      0.727       8117      18716
changed                   0.706      0.706      0.706         17      10616
decreased                 0.684      0.079      0.141        659      11258
delayed                   1.000      1.000      1.000          8      10607
discoloured               0.583      0.609      0.596         23      10622
impaired                  0.600      0.643      0.621         14      10613
increased                 0.627      0.071      0.128        659      11258
infected                  0.000      0.000      0.000          3      10602
prolonged   

In [4]:
print(results)

[{'acquired': [1.0, 1.0, 1.0, 3, 10655], 'aggravated': [0.34285714285714286, 0.13043478260869565, 0.1889763779527559, 92, 10744], 'altered': [0.0, 0.0, 0.0, 23, 10675], 'caused': [0.6510670475550328, 0.8683325862841775, 0.7441659464131374, 8924, 19576], 'changed': [0.6842105263157895, 0.5652173913043478, 0.6190476190476191, 23, 10675], 'decreased': [0.46153846153846156, 0.0326975476839237, 0.061068702290076333, 734, 11386], 'delayed': [1.0, 0.6923076923076923, 0.8181818181818181, 13, 10665], 'discoloured': [0.25, 0.07142857142857142, 0.11111111111111112, 28, 10680], 'impaired': [0.16666666666666666, 0.06666666666666667, 0.09523809523809522, 15, 10667], 'increased': [0.6857142857142857, 0.03137254901960784, 0.06, 765, 11417], 'infected': [0.0, 0.0, 0.0, 12, 10664], 'lowered': [0.0, 0.0, 0.0, 1, 10653], 'prolonged': [0.0, 0.0, 0.0, 19, 10671], 'reduced': [0.625, 0.7142857142857143, 0.6666666666666666, 14, 10666], 'ruptured': [0.0, 0.0, 0.0, 5, 10657], 'shortened': [0.0, 0.0, 0.0, 29, 106

In [5]:
columns = ["precision", "recall", "f-score", "support", "size"]
total_results = {}



for key in kb.all_relations:
    rel_res = []
    for part_result in results:
        res = part_result.get(key)
        if res:
            rel_res.append(part_result[key])
        else:
            rel_res.append([0] * 5)
    total_results[key] = pd.DataFrame(data=rel_res, columns=columns)

In [6]:
avg_vals = {rel: {} for rel in total_results.keys()}
for rel in total_results.keys():
    avg_precision = total_results[rel].precision.mean()
    avg_recall = total_results[rel].recall.mean()
    avg_vals[rel]["avg_precision"] = avg_precision
    avg_vals[rel]["avg_recall"] = avg_recall
    avg_vals[rel]["avg_fscore"] = (
        2 * avg_precision * avg_recall/(avg_precision + avg_recall)
        if avg_precision and avg_recall else 0
    )

In [7]:
prec_vals = []
rec_vals = []
fscore_vals = []

prec_zeros = 0
rec_zeros = 0
fscore_zeros = 0

for res in avg_vals.values():
    prec_vals.append(res["avg_precision"])
    rec_vals.append(res["avg_recall"])
    fscore_vals.append(res["avg_fscore"])
    if prec_vals[-1] == 0:
        prec_zeros += 1
    if rec_vals[-1] == 0:
        rec_zeros += 1
    if fscore_vals[-1] == 0:
        fscore_zeros += 1

In [8]:
macro_avg = {}
macro_avg_zeros_excluded = {}

zero_vals = min(prec_zeros, rec_zeros, fscore_zeros)

macro_avg["precision"] = sum(prec_vals)/len(prec_vals)
macro_avg["recall"] = sum(rec_vals)/len(rec_vals)
macro_avg["fscore"] = sum(fscore_vals)/len(fscore_vals)

macro_avg_zeros_excluded["precision"] = sum(prec_vals)/(len(prec_vals) - zero_vals)
macro_avg_zeros_excluded["recall"] = sum(rec_vals)/(len(rec_vals) - zero_vals)
macro_avg_zeros_excluded["fscore"] = sum(fscore_vals)/(len(fscore_vals) - zero_vals)

In [9]:
print(macro_avg)

{'precision': 0.4376116779395039, 'recall': 0.30366229100376135, 'fscore': 0.33433616386200227}


In [10]:
prec_vals

[0.0,
 1.0,
 0.3764855144855145,
 0.22999999999999998,
 0.6609287757061668,
 0.7845120823448998,
 0.5532025528716983,
 0.96,
 0.41174055829228245,
 0.47523809523809524,
 0.0,
 0.6076182091325159,
 0.0,
 0.0,
 0.37366310160427807,
 0.6767948717948717,
 0.2,
 0.5668264414407472]

In [11]:
print(macro_avg_zeros_excluded)

{'precision': 0.5626435859222193, 'recall': 0.39042294557626456, 'fscore': 0.4298607821082886}


In [12]:
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        'relation', 'precision', 'recall', 'f-score'))
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        '-' * 18, '-' * 9, '-' * 9, '-' * 9))

for rel, precision, recall, fscore in zip(total_results.keys(), prec_vals, rec_vals, fscore_vals):
    print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(rel, precision, recall, fscore))
print('{:20s} {:>10s} {:>10s} {:>10s}'.format(
        '-' * 18, '-' * 9, '-' * 9, '-' * 9))
print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(
    "macro avg", macro_avg["precision"], macro_avg["recall"], macro_avg["fscore"]))
print('{:20s} {:10.3f} {:10.3f} {:10.3f}'.format(
    "macro avg (-0 vals)", macro_avg_zeros_excluded["precision"], 
    macro_avg_zeros_excluded["recall"], macro_avg_zeros_excluded["fscore"])
     )

relation              precision     recall    f-score
------------------    ---------  ---------  ---------
accelerated               0.000      0.000      0.000
acquired                  1.000      0.833      0.909
aggravated                0.376      0.117      0.179
altered                   0.230      0.047      0.079
caused                    0.661      0.843      0.741
changed                   0.785      0.572      0.662
decreased                 0.553      0.077      0.136
delayed                   0.960      0.723      0.825
discoloured               0.412      0.322      0.361
impaired                  0.475      0.394      0.431
improved                  0.000      0.000      0.000
increased                 0.608      0.063      0.115
infected                  0.000      0.000      0.000
lowered                   0.000      0.000      0.000
prolonged                 0.374      0.252      0.301
reduced                   0.677      0.680      0.679
ruptured                  0.

In [13]:
all_relations = set(dataset.kb.all_relations)

print("number setups:", len(train_setups))
print("#################")

for i, setup in enumerate(train_setups):
    print("-------------------------")
    set_relations = set([x.relation for x in setup])
    for relation in all_relations.difference(set_relations):
        setup.append(RelSetup(relation, 0, 0))
    setup.sort()
    print(f"setup {i}, number setups:", len(setup))

number setups: 5
#################
-------------------------
setup 0, number setups: 18
-------------------------
setup 1, number setups: 18
-------------------------
setup 2, number setups: 18
-------------------------
setup 3, number setups: 18
-------------------------
setup 4, number setups: 18


In [14]:
columns = ["relation", "pos_examples", "neg_examples"]
train_setup_table = pd.DataFrame(data=train_setups)
train_setup_table

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,"(accelerated, 1, 40770)","(acquired, 23, 40770)","(aggravated, 344, 40770)","(altered, 103, 40770)","(caused, 37103, 40770)","(changed, 72, 40770)","(decreased, 3048, 40770)","(delayed, 41, 40770)","(discoloured, 126, 40770)","(impaired, 72, 40770)","(improved, 1, 40770)","(increased, 3112, 40770)","(infected, 20, 40770)","(lowered, 1, 40770)","(prolonged, 96, 40770)","(reduced, 67, 40770)","(ruptured, 14, 40770)","(shortened, 123, 40770)"
1,"(accelerated, 1, 40843)","(acquired, 24, 40843)","(aggravated, 365, 40843)","(altered, 108, 40843)","(caused, 37910, 40843)","(changed, 78, 40843)","(decreased, 3123, 40843)","(delayed, 46, 40843)","(discoloured, 131, 40843)","(impaired, 73, 40843)","(improved, 1, 40843)","(increased, 3218, 40843)","(infected, 29, 40843)","(lowered, 2, 40843)","(prolonged, 98, 40843)","(reduced, 65, 40843)","(ruptured, 17, 40843)","(shortened, 128, 40843)"
2,"(accelerated, 1, 39965)","(acquired, 25, 39965)","(aggravated, 345, 39965)","(altered, 98, 39965)","(caused, 36851, 39965)","(changed, 76, 39965)","(decreased, 3015, 39965)","(delayed, 45, 39965)","(discoloured, 113, 39965)","(impaired, 62, 39965)","(improved, 0, 0)","(increased, 3096, 39965)","(infected, 29, 39965)","(lowered, 2, 39965)","(prolonged, 90, 39965)","(reduced, 65, 39965)","(ruptured, 18, 39965)","(shortened, 117, 39965)"
3,"(accelerated, 1, 41048)","(acquired, 21, 41048)","(aggravated, 338, 41048)","(altered, 96, 41048)","(caused, 37492, 41048)","(changed, 76, 41048)","(decreased, 3053, 41048)","(delayed, 43, 41048)","(discoloured, 126, 41048)","(impaired, 70, 41048)","(improved, 1, 41048)","(increased, 3127, 41048)","(infected, 28, 41048)","(lowered, 2, 41048)","(prolonged, 95, 41048)","(reduced, 66, 41048)","(ruptured, 17, 41048)","(shortened, 120, 41048)"
4,"(accelerated, 0, 0)","(acquired, 11, 41013)","(aggravated, 352, 41013)","(altered, 99, 41013)","(caused, 34752, 41013)","(changed, 78, 41013)","(decreased, 2889, 41013)","(delayed, 41, 41013)","(discoloured, 120, 41013)","(impaired, 71, 41013)","(improved, 1, 41013)","(increased, 2955, 41013)","(infected, 22, 41013)","(lowered, 1, 41013)","(prolonged, 81, 41013)","(reduced, 61, 41013)","(ruptured, 10, 41013)","(shortened, 120, 41013)"


In [15]:
all_relations = set(dataset.kb.all_relations)

print("number setups:", len(test_setups))
print("#################")

for i, setup in enumerate(test_setups):
    print("-------------------------")
    set_relations = set([x.relation for x in setup])
    for relation in all_relations.difference(set_relations):
        setup.append(RelSetup(relation, 0, 0))
    setup.sort()
    print(f"setup {i}, number setups:", len(setup))

number setups: 5
#################
-------------------------
setup 0, number setups: 18
-------------------------
setup 1, number setups: 18
-------------------------
setup 2, number setups: 18
-------------------------
setup 3, number setups: 18
-------------------------
setup 4, number setups: 18


In [16]:
columns = ["relation", "pos_examples", "neg_examples"]
test_setup_table = pd.DataFrame(data=test_setups)
test_setup_table

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,"(accelerated, 0, 0)","(acquired, 3, 10652)","(aggravated, 92, 10652)","(altered, 23, 10652)","(caused, 8924, 10652)","(changed, 23, 10652)","(decreased, 734, 10652)","(delayed, 13, 10652)","(discoloured, 28, 10652)","(impaired, 15, 10652)","(improved, 0, 0)","(increased, 765, 10652)","(infected, 12, 10652)","(lowered, 1, 10652)","(prolonged, 19, 10652)","(reduced, 14, 10652)","(ruptured, 5, 10652)","(shortened, 29, 10652)"
1,"(accelerated, 0, 0)","(acquired, 2, 10599)","(aggravated, 71, 10599)","(altered, 18, 10599)","(caused, 8117, 10599)","(changed, 17, 10599)","(decreased, 659, 10599)","(delayed, 8, 10599)","(discoloured, 23, 10599)","(impaired, 14, 10599)","(improved, 0, 0)","(increased, 659, 10599)","(infected, 3, 10599)","(lowered, 0, 0)","(prolonged, 17, 10599)","(reduced, 16, 10599)","(ruptured, 2, 10599)","(shortened, 24, 10599)"
2,"(accelerated, 0, 0)","(acquired, 1, 11444)","(aggravated, 91, 11444)","(altered, 28, 11444)","(caused, 9176, 11444)","(changed, 19, 11444)","(decreased, 767, 11444)","(delayed, 9, 11444)","(discoloured, 41, 11444)","(impaired, 25, 11444)","(improved, 1, 11444)","(increased, 781, 11444)","(infected, 3, 11444)","(lowered, 0, 0)","(prolonged, 25, 11444)","(reduced, 16, 11444)","(ruptured, 1, 11444)","(shortened, 35, 11444)"
3,"(accelerated, 0, 0)","(acquired, 5, 10334)","(aggravated, 98, 10334)","(altered, 30, 10334)","(caused, 8535, 10334)","(changed, 19, 10334)","(decreased, 729, 10334)","(delayed, 11, 10334)","(discoloured, 28, 10334)","(impaired, 17, 10334)","(improved, 0, 0)","(increased, 750, 10334)","(infected, 4, 10334)","(lowered, 0, 0)","(prolonged, 20, 10334)","(reduced, 15, 10334)","(ruptured, 2, 10334)","(shortened, 32, 10334)"
4,"(accelerated, 1, 10398)","(acquired, 15, 10398)","(aggravated, 84, 10398)","(altered, 27, 10398)","(caused, 11275, 10398)","(changed, 17, 10398)","(decreased, 893, 10398)","(delayed, 13, 10398)","(discoloured, 34, 10398)","(impaired, 16, 10398)","(improved, 0, 0)","(increased, 922, 10398)","(infected, 10, 10398)","(lowered, 1, 10398)","(prolonged, 34, 10398)","(reduced, 20, 10398)","(ruptured, 9, 10398)","(shortened, 32, 10398)"
