In [1]:
import os
import random

from collections import Counter
from collections import defaultdict

import dist_sup_lib.rel_ext as rel_ext
import dist_sup_lib.utils as utils

from src.utils import read_json_examples
from src.utils import read_kb_triples
from src.utils import read_kb_triples_json

In [2]:
# Set all the random seeds for reproducibility. Only the
# system seed is relevant for this notebook.

utils.fix_random_seeds()

In [3]:
rel_ext_data_sents = os.path.join('data', 'featurized_sentences')
rel_ext_data_kb = os.path.join("data", "knowledge_base")

In [4]:
example_data = []

In [5]:
for index in range(1, 200):
    s_i = str(index)
    zeros = "0" * (4 - len(s_i))
    # Files updated constantly and the names changes by year
    # for files downloaded in 2021 file_naming has to be changed to
    # featurized_sents_pubmed21n
    tagged_sent_file = f"featurized_sents_pubmed20n{zeros + s_i}.json"
    file_path = os.path.join(rel_ext_data_sents, tagged_sent_file)
    example_data.extend(read_json_examples(file_path))

In [6]:
# read data
# example_data = read_examples(os.path.join(rel_ext_data_sents, "tagged_medline_sents_files_1-20.tsv"))

In [6]:
print(type(example_data))
print(example_data[:3])
print(len(example_data))

<class 'list'>
[Example(entity_1='weight', entity_2='sodium', left='Determinations of the molecular ', mention_1='weight', middle=' of the enzyme based on its amino acid composition, sedimentation velocity, and ', mention_2='sodium', right=' dodecyl sulfate gel electrophoresis gave values of 17680, 17470 and 18300, respectively.', left_POS='Determinations/NNS of/IN the/DT molecular/JJ', mention_1_POS='weight/NN', middle_POS='of/IN the/DT enzyme/NN based/VBN on/IN its/PRP$ amino/NN acid/NN composition/NN ,/, sedimentation/NN velocity/NN ,/, and/CC', mention_2_POS='sodium/NN', right_POS='dodecyl/NN sulfate/NN gel/JJ electrophoresis/NN gave/VBD values/NNS of/IN 17680/CD ,/, 17470/CD and/CC 18300/CD ,/, respectively/RB ./.'), Example(entity_1='growth', entity_2='neostigmine', left='The enzymes from both species are inhibited by the anti-cholinesterases neostigmine, physostigmine, and 284c51 and by AMO-1618, a plant ', mention_1='growth', middle='', mention_2='neostigmine', right=', physost

In [7]:
corpus = rel_ext.Corpus(example_data)

In [8]:
kb_triples = read_kb_triples_json(os.path.join(rel_ext_data_kb, "rel_drug_react_triple_occ_all.json"))
kb = rel_ext.KB(kb_triples)

In [9]:
dataset = rel_ext.Dataset(corpus, kb)

In [10]:
splits = dataset.build_splits()

splits

{'tiny': Corpus with 1,428 examples; KB with 674 triples,
 'train': Corpus with 671,169 examples; KB with 35,548 triples,
 'dev': Corpus with 211,624 examples; KB with 16,678 triples,
 'all': Corpus with 884,221 examples; KB with 52,900 triples}

### Negative instances

By joining the corpus to the KB, we can obtain abundant positive instances for each relation. But a classifier cannot be trained on positive instances alone. In order to apply the distant supervision paradigm, we will also need some negative instances — that is, entity pairs which do not belong to any known relation. If you like, you can think of these entity pairs as being assigned to a special relation called `NO_RELATION`. We can find plenty of such pairs by searching for examples in the corpus which contain two entities which do not belong to any relation in the KB.

In [11]:
dataset.find_unrelated_pairs()

{('atypical_fracture', 'concomitant'),
 ('cardiac', 'hand_deformity'),
 ('pentoxifylline', 'infertility'),
 ('seminoma', 'gallium'),
 ('phenytoin', 'myotonia'),
 ('vasopressin', 'fluid_replacement'),
 ('piroxicam', 'colorectal_cancer'),
 ('potassium', 'renin'),
 ('caesarean_section', 'pancuronium'),
 ('verapamil', 'colorectal_cancer'),
 ('polydipsia', 'adriamycin'),
 ('saliva', 'phencyclidine'),
 ('cardiac', 'convalescent'),
 ('alfentanil', 'colonoscopy'),
 ('energy', 'mitoxantrone'),
 ('basal', 'skin_swelling'),
 ('hyperoxaluria', 'furosemide'),
 ('dermatosis', 'ammonium'),
 ('concomitant', 'mammoplasty'),
 ('chemotherapy', 'chromium'),
 ('medium', 'thyroglobulin'),
 ('droperidol', 'infusion'),
 ('carboplatin', 'nephropathy'),
 ('streptococcal_endocarditis', 'clostridium'),
 ('proparacaine', 'growth'),
 ('anaphylactic_shock', 'polypeptide'),
 ('monosodium', 'vitamin_a_deficiency'),
 ('immunoglobulin', 'performance_status'),
 ('pityriasis', 'potassium'),
 ('massage', 'acetazolamide'),


Let's determine how many examples we have for each triple in the KB. We'll compute averages per relation.

In [12]:
dataset.count_examples()

                                             examples
relation               examples    triples    /triple
--------               --------    -------    -------
accelerated                  44          1      44.00
acquired                     23         25       0.92
aggravated                 2226        421       5.29
altered                    1014        120       8.45
caused                   220957      44243       4.99
changed                     142         92       1.54
decreased                 56771       3628      15.65
delayed                      39         47       0.83
discoloured                 487        149       3.27
impaired                    332         77       4.31
improved                      0          1       0.00
increased                 63399       3720      17.04
infected                     99         29       3.41
lowered                       1          2       0.50
prolonged                   365        112       3.26
reduced                     

In [13]:
print(dir(dataset))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'build_dataset', 'build_splits', 'corpus', 'count_examples', 'count_relation_combinations', 'featurize', 'find_unrelated_pairs', 'kb']


In [14]:
for key, val in dataset.kb.kb_triples_by_relation.items():
    print(key, len(val))

accelerated 1
acquired 25
aggravated 421
altered 120
caused 44243
changed 92
decreased 3628
delayed 47
discoloured 149
impaired 77
improved 1
increased 3720
infected 29
lowered 2
prolonged 112
reduced 73
ruptured 17
shortened 143


In [15]:
for key, val in dataset.kb.kb_triples_by_relation.items():
    print(key, len(val))

accelerated 1
acquired 25
aggravated 421
altered 120
caused 44243
changed 92
decreased 3628
delayed 47
discoloured 149
impaired 77
improved 1
increased 3720
infected 29
lowered 2
prolonged 112
reduced 73
ruptured 17
shortened 143


In [16]:
print(len(dataset.corpus.examples_by_entities))

4454


In [17]:
print(dataset.corpus.examples_by_entities.keys())

dict_keys(['weight', 'growth', 'dependence', 'fluoride', 'cardiac', 'cardiac_output', 'anaesthesia', 'fall', 'stress', 'adrenalectomy', 'gastric_ulcer', 'ulcer', 'sodium', 'acidosis', 'metabolic_acidosis', 'extracorporeal_circulation', 'investigation', 'influenza', 'hypertension', 'minoxidil', 'fluid_retention', 'death', 'infection', 'infusion', 'pco2', 'epinephrine', 'glucocorticoids', 'mass', 'parathyroidectomy', 'surgery', 'lorazepam', 'dissociation', 'vasospasm', 'tachycardia', 'hypothermia', 'renin', 'injection', 'histology', 'necrosis', 'dialysis', 'haemodialysis', 'laparotomy', 'intranasal', 'double-blind', 'penicillin', 'schizophrenia', 'l-dopa', 'blunted_affect', 'anxiety', 'shock', 'psychotherapy', 'blood_pressure', 'heart_rate', 'propranolol', 'blood_ph', 'hypotension', 'hypertrophy', 'cardiac_arrest', 'manganese', 'hypoxia', 'pyelonephritis', 'alkalosis', 'respiratory_acidosis', 'ischaemia', 'body_temperature', 'enzyme_activity', 'tension', 'myocardial_ischaemia', 'sedation

In [18]:
dataset.count_relation_combinations()

The most common relation combinations are:
      1737 ('decreased', 'increased')
       102 ('caused', 'decreased', 'increased')
        51 ('caused', 'increased')
        48 ('discoloured', 'increased')
        39 ('caused', 'decreased')
        30 ('prolonged', 'shortened')
        30 ('decreased', 'shortened')
        29 ('caused', 'infected')
        26 ('changed', 'decreased', 'shortened')
        15 ('changed', 'decreased', 'increased')
        15 ('caused', 'ruptured')
        12 ('changed', 'decreased')
         9 ('changed', 'decreased', 'delayed', 'shortened')
         8 ('changed', 'decreased', 'increased', 'shortened')
         7 ('changed', 'decreased', 'delayed', 'increased', 'shortened')
         7 ('acquired', 'caused')
         5 ('changed', 'decreased', 'delayed', 'increased', 'prolonged', 'shortened')
         3 ('decreased', 'delayed', 'shortened')
         3 ('decreased', 'delayed')
         2 ('decreased', 'increased', 'shortened')
         2 ('changed', 'decrease

### Featurizers

Featurizers are functions which define the feature representation for our model. The primary input to a featurizer will be the `KBTriple` for which we are generating features. But since our features will be derived from corpus examples containing the entities of the `KBTriple`, we must also pass in a reference to a `Corpus`. And in order to make it easy to combine different featurizers, we'll also pass in a feature counter to hold the results.

Here's an implementation for a very simple bag-of-words featurizer. It finds all the corpus examples containing the two entities in the `KBTriple`, breaks the phrase appearing between the two entity mentions into words, and counts the words. Note that it makes no distinction between "forward" and "reverse" examples.


In [19]:
def simple_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        for word in ex.middle.split(' '):
            feature_counter[word] += 1
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        for word in ex.middle.split(' '):
            feature_counter[word] += 1
    return feature_counter

def count_words(sent_part: str, feature_counter: Counter):
    for word in sent_part.split(" "):
        feature_counter[word] += 1

def middle_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.middle, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.middle, feature_counter)
    return feature_counter

def start_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.left, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.left, feature_counter)
    return feature_counter

def end_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.right, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.right, feature_counter)
    return feature_counter
    

In [20]:
kbt = kb.kb_triples[5]

kbt

KBTriple(rel='acquired', sbj='metformin', obj='lipodystrophy')

In [21]:
corpus.get_examples_for_entities(kbt.sbj, kbt.obj)

[]

In [22]:
print(len(corpus.examples_by_entities))

4454


In [23]:
print(corpus.examples_by_entities.keys())

dict_keys(['weight', 'growth', 'dependence', 'fluoride', 'cardiac', 'cardiac_output', 'anaesthesia', 'fall', 'stress', 'adrenalectomy', 'gastric_ulcer', 'ulcer', 'sodium', 'acidosis', 'metabolic_acidosis', 'extracorporeal_circulation', 'investigation', 'influenza', 'hypertension', 'minoxidil', 'fluid_retention', 'death', 'infection', 'infusion', 'pco2', 'epinephrine', 'glucocorticoids', 'mass', 'parathyroidectomy', 'surgery', 'lorazepam', 'dissociation', 'vasospasm', 'tachycardia', 'hypothermia', 'renin', 'injection', 'histology', 'necrosis', 'dialysis', 'haemodialysis', 'laparotomy', 'intranasal', 'double-blind', 'penicillin', 'schizophrenia', 'l-dopa', 'blunted_affect', 'anxiety', 'shock', 'psychotherapy', 'blood_pressure', 'heart_rate', 'propranolol', 'blood_ph', 'hypotension', 'hypertrophy', 'cardiac_arrest', 'manganese', 'hypoxia', 'pyelonephritis', 'alkalosis', 'respiratory_acidosis', 'ischaemia', 'body_temperature', 'enzyme_activity', 'tension', 'myocardial_ischaemia', 'sedation

In [24]:
print(len(corpus.examples_by_entities["sodium"]))

617


In [26]:
simple_bag_of_words_featurizer(kb.kb_triples[0], corpus, Counter())

Counter({'': 61,
         'treatment': 4,
         'by': 5,
         'daily': 2,
         'injections': 2,
         'did': 3,
         'not': 3,
         'suppress': 2,
         'levels': 3,
         'of': 12,
         'growth': 2,
         'hormone,': 1,
         'prolactin,': 1,
         'or': 2,
         'epidermal': 1,
         'in': 5,
         'this': 1,
         'tumor': 2,
         'model': 1,
         'had': 1,
         'no': 2,
         'direct': 1,
         'inhibitory': 2,
         'effect': 2,
         'and': 9,
         'cause': 1,
         'an': 2,
         'endocrine': 1,
         'inhibition': 1,
         'mammary': 1,
         'represents': 1,
         'a': 3,
         'significant': 2,
         'advance': 1,
         'the': 5,
         'hormone-releasing': 2,
         'hormone': 4,
         '(GHRH)': 2,
         'syndrome,': 2,
         'Sandostatin': 1,
         'is': 1,
         'unequivocally': 1,
         'effective': 1,
         'and,': 1,
         'ectopic': 1,

### Experiments

In [27]:
train_result = rel_ext.train_models(
    splits, 
    featurizers=[start_bag_of_words_featurizer, middle_bag_of_words_featurizer, end_bag_of_words_featurizer],
    sampling_rate=0.5
)

#####################
relation: accelerated 
number positive examples: 1
relation examples:
 [KBTriple(rel='accelerated', sbj='sandostatin', obj='growth')]
-----------------------
number unrelated pairs: 39715
unrelated examples:
 [('buspirone', 'ulcer'), ('international_normalised_ratio', 'acenocoumarol'), ('basal', 'choriocarcinoma'), ('ampicillin', 'bile_duct_stone'), ('cholesteatoma', 'medium'), ('disability', 'acetazolamide'), ('pco2', 'acetazolamide'), ('diazepam', 'ischaemia'), ('cephalosporin', 'enterocolitis'), ('urapidil', 'bradycardia')]
#####################
relation: acquired 
number positive examples: 16
relation examples:
 [KBTriple(rel='acquired', sbj='ranitidine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='metformin', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='lamivudine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='nasal', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='lamivudine', obj='lipodystrophy'), KBTriple(rel='acquired', sb

#####################
relation: discoloured 
number positive examples: 103
relation examples:
 [KBTriple(rel='discoloured', sbj='valsartan', obj='faeces'), KBTriple(rel='discoloured', sbj='nasal', obj='faeces'), KBTriple(rel='discoloured', sbj='immunoglobulin', obj='sputum'), KBTriple(rel='discoloured', sbj='penicillin', obj='sputum'), KBTriple(rel='discoloured', sbj='prednisolone', obj='sputum'), KBTriple(rel='discoloured', sbj='nac', obj='sputum'), KBTriple(rel='discoloured', sbj='cholesterol', obj='faeces'), KBTriple(rel='discoloured', sbj='sildenafil', obj='faeces'), KBTriple(rel='discoloured', sbj='nasal', obj='sputum'), KBTriple(rel='discoloured', sbj='cholesterol', obj='sputum')]
-----------------------
number unrelated pairs: 39715
unrelated examples:
 [('buspirone', 'ulcer'), ('international_normalised_ratio', 'acenocoumarol'), ('basal', 'choriocarcinoma'), ('ampicillin', 'bile_duct_stone'), ('cholesteatoma', 'medium'), ('disability', 'acetazolamide'), ('pco2', 'acetazolamide'



Next comes `predict()`. This function takes as arguments a dictionary of data splits, the outputs of `train_models()`, and the name of the split for which to make predictions. It returns two parallel dictionaries: one holding the predictions (grouped by relation), the other holding the true labels (again, grouped by prediction).

In [28]:
predictions, true_labels, predict_setup = rel_ext.predict(
    splits, train_result, split_name='dev',
    sampling_rate=0.5
)

#####################
relation: acquired 
number positive examples: 9
relation examples:
 [KBTriple(rel='acquired', sbj='indinavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='sodium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='efavirenz', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='potassium', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='abacavir', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='potassium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='abacavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='methylprednisolone', obj='spinal_fusion'), KBTriple(rel='acquired', sbj='zidovudine', obj='lipodystrophy')]
-----------------------
number unrelated pairs: 12069
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', 

Now `evaluate_predictions()`. This function takes as arguments the parallel dictionaries of predictions and true labels produced by `predict()`. It prints summary statistics for each relation, including precision, recall, and F<sub>0.5</sub>-score, and it returns the macro-averaged F<sub>0.5</sub>-score.

In [29]:
rel_ext.evaluate_predictions(predictions, true_labels)

relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  1.000      0.778      0.875          9      12078
aggravated                0.000      0.000      0.000        117      12186
altered                   0.500      0.056      0.100         36      12105
caused                    0.471      0.096      0.159      13953      26022
changed                   0.565      0.464      0.510         28      12097
decreased                 0.636      0.073      0.132       1143      13212
delayed                   1.000      0.429      0.600         14      12083
discoloured               0.118      0.044      0.065         45      12114
impaired                  0.000      0.000      0.000         19      12088
increased                 0.649      0.042      0.079       1194      13263
infected                  0.000      0.000      0.000          6      12075
lowered     

  _warn_prf(average, modifier, msg_start, len(result))


0.2832997106706182

Finally, we introduce `rel_ext.experiment()`, which basically chains together `rel_ext.train_models()`, `rel_ext.predict()`, and `rel_ext.evaluate_predictions()`. For convenience, this function returns the output of `rel_ext.train_models()` as its result.

In [30]:
_ = rel_ext.experiment(
    splits,
    featurizers=[
        start_bag_of_words_featurizer, middle_bag_of_words_featurizer, 
        end_bag_of_words_featurizer],
    train_sampling_rate=0.5,
    test_sampling_rate=0.5
)  # [simple_bag_of_words_featurizer])

#####################
relation: accelerated 
number positive examples: 1
relation examples:
 [KBTriple(rel='accelerated', sbj='sandostatin', obj='growth')]
-----------------------
number unrelated pairs: 39715
unrelated examples:
 [('buspirone', 'ulcer'), ('international_normalised_ratio', 'acenocoumarol'), ('basal', 'choriocarcinoma'), ('ampicillin', 'bile_duct_stone'), ('cholesteatoma', 'medium'), ('disability', 'acetazolamide'), ('pco2', 'acetazolamide'), ('diazepam', 'ischaemia'), ('cephalosporin', 'enterocolitis'), ('urapidil', 'bradycardia')]
#####################
relation: acquired 
number positive examples: 16
relation examples:
 [KBTriple(rel='acquired', sbj='ranitidine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='metformin', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='lamivudine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='nasal', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='lamivudine', obj='lipodystrophy'), KBTriple(rel='acquired', sb

#####################
relation: lowered 
number positive examples: 1
relation examples:
 [KBTriple(rel='lowered', sbj='ranitidine', obj='convulsive_threshold')]
-----------------------
number unrelated pairs: 39715
unrelated examples:
 [('buspirone', 'ulcer'), ('international_normalised_ratio', 'acenocoumarol'), ('basal', 'choriocarcinoma'), ('ampicillin', 'bile_duct_stone'), ('cholesteatoma', 'medium'), ('disability', 'acetazolamide'), ('pco2', 'acetazolamide'), ('diazepam', 'ischaemia'), ('cephalosporin', 'enterocolitis'), ('urapidil', 'bradycardia')]
#####################
relation: prolonged 
number positive examples: 71
relation examples:
 [KBTriple(rel='prolonged', sbj='ibuprofen', obj='prothrombin_time'), KBTriple(rel='prolonged', sbj='quetiapine', obj='therapeutic_response'), KBTriple(rel='prolonged', sbj='phenylbutazone', obj='prothrombin_time'), KBTriple(rel='prolonged', sbj='furosemide', obj='therapeutic_response'), KBTriple(rel='prolonged', sbj='atenolol', obj='prothrombin_t



#####################
relation: acquired 
number positive examples: 9
relation examples:
 [KBTriple(rel='acquired', sbj='indinavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='sodium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='efavirenz', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='potassium', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='abacavir', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='potassium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='abacavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='methylprednisolone', obj='spinal_fusion'), KBTriple(rel='acquired', sbj='zidovudine', obj='lipodystrophy')]
-----------------------
number unrelated pairs: 12069
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', 

relation              precision     recall    f-score    support       size
------------------    ---------  ---------  ---------  ---------  ---------
acquired                  1.000      0.778      0.875          9      12078
aggravated                0.000      0.000      0.000        117      12186
altered                   0.500      0.056      0.100         36      12105
caused                    0.471      0.096      0.159      13953      26022
changed                   0.565      0.464      0.510         28      12097
decreased                 0.636      0.073      0.132       1143      13212
delayed                   1.000      0.429      0.600         14      12083
discoloured               0.118      0.044      0.065         45      12114
impaired                  0.000      0.000      0.000         19      12088
increased                 0.649      0.042      0.079       1194      13263
infected                  0.000      0.000      0.000          6      12075
lowered     

## Analysis

### Examining the trained models

One important way to gain understanding of our trained model is to inspect the model weights. What features are strong positive indicators for each relation, and what features are strong negative indicators?

In [31]:
rel_ext.examine_model_weights(train_result)

Highest and lowest feature weights for relation accelerated:

     1.193 growth
     0.957 somatostatin
     0.943 pancreatic
     ..... .....
    -0.412 with
    -0.578 ,
    -1.486 

Highest and lowest feature weights for relation acquired:

     0.910 mucous
     0.688 glucosuria
     0.688 nephrolithiasis,
     ..... .....
    -0.681 to
    -0.899 ,
    -2.613 

Highest and lowest feature weights for relation aggravated:

     0.727 condition
     0.490 this
     0.366 condition.
     ..... .....
    -0.221 and
    -0.237 ,
    -0.285 

Highest and lowest feature weights for relation altered:

     2.555 mood
     1.942 saliva
     1.518 lithium
     ..... .....
    -0.556 patients
    -0.718 ,
    -0.811 

Highest and lowest feature weights for relation caused:

     0.185 calcium
     0.136 furosemide
     0.133 mass
     ..... .....
    -0.153 peptide
    -0.165 basal
    -0.207 concomitant

Highest and lowest feature weights for relation changed:

     2.570 therapeutic
     1.

### Discovering new relation instances

Another way to gain insight into our trained models is to use them to discover new relation instances that don't currently appear in the KB. In fact, this is the whole point of building a relation extraction system: to extend an existing KB (or build a new one) using knowledge extracted from natural language text at scale. Can the models we've trained do this effectively?

Because the goal is to discover new relation instances which are _true_ but _absent from the KB_, we can't evaluate this capability automatically. But we can generate candidate KB triples and manually evaluate them for correctness.

To do this, we'll start from corpus examples containing pairs of entities which do not belong to any relation in the KB (earlier, we described these as "negative examples"). We'll then apply our trained models to each pair of entities, and sort the results by probability assigned by the model, in order to find the most likely new instances for each relation.

In [32]:
rel_ext.find_new_relation_instances(
    dataset,
    featurizers=[simple_bag_of_words_featurizer])

#####################
relation: accelerated 
number positive examples: 1
relation examples:
 [KBTriple(rel='accelerated', sbj='sandostatin', obj='growth')]
-----------------------
number unrelated pairs: 7943
unrelated examples:
 [('buspirone', 'ulcer'), ('international_normalised_ratio', 'acenocoumarol'), ('basal', 'choriocarcinoma'), ('ampicillin', 'bile_duct_stone'), ('cholesteatoma', 'medium'), ('disability', 'acetazolamide'), ('pco2', 'acetazolamide'), ('diazepam', 'ischaemia'), ('cephalosporin', 'enterocolitis'), ('urapidil', 'bradycardia')]
#####################
relation: acquired 
number positive examples: 16
relation examples:
 [KBTriple(rel='acquired', sbj='ranitidine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='metformin', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='lamivudine', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='nasal', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='lamivudine', obj='lipodystrophy'), KBTriple(rel='acquired', sbj

#####################
relation: acquired 
number positive examples: 9
relation examples:
 [KBTriple(rel='acquired', sbj='indinavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='sodium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='efavirenz', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='potassium', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='abacavir', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='potassium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='abacavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='methylprednisolone', obj='spinal_fusion'), KBTriple(rel='acquired', sbj='zidovudine', obj='lipodystrophy')]
-----------------------
number unrelated pairs: 24138
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', 

#####################
relation: ruptured 
number positive examples: 8
relation examples:
 [KBTriple(rel='ruptured', sbj='methotrexate', obj='ovarian_cyst'), KBTriple(rel='ruptured', sbj='sodium', obj='ovarian_cyst'), KBTriple(rel='ruptured', sbj='methylprednisolone', obj='aneurysm'), KBTriple(rel='ruptured', sbj='methotrexate', obj='aneurysm'), KBTriple(rel='ruptured', sbj='potassium', obj='liver_carcinoma'), KBTriple(rel='ruptured', sbj='rosuvastatin', obj='aneurysm'), KBTriple(rel='ruptured', sbj='sodium', obj='aneurysm'), KBTriple(rel='ruptured', sbj='heparin', obj='aneurysm')]
-----------------------
number unrelated pairs: 24138
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', 'resuscitation'), ('polyuria', 'enflurane'), ('alendronate', 'fasting')]
#####################

     0.250 KBTriple(rel='infected', sbj='technetium', obj='chronic_hepatitis')
     0.250 KBTriple(rel='infected', sbj='chronic_hepatitis', obj='technetium')
     0.230 KBTriple(rel='infected', sbj='technetium', obj='hepatitis')
     0.230 KBTriple(rel='infected', sbj='hepatitis', obj='technetium')
     0.121 KBTriple(rel='infected', sbj='lithiasis', obj='technetium')
     0.121 KBTriple(rel='infected', sbj='technetium', obj='lithiasis')
     0.077 KBTriple(rel='infected', sbj='sulindac', obj='urine_output')
     0.077 KBTriple(rel='infected', sbj='urine_output', obj='sulindac')
     0.072 KBTriple(rel='infected', sbj='aspiration', obj='oxytetracycline')
     0.072 KBTriple(rel='infected', sbj='oxytetracycline', obj='aspiration')

Highest probability examples for relation lowered:

     0.045 KBTriple(rel='lowered', sbj='tracheal_injury', obj='submucosal')
     0.045 KBTriple(rel='lowered', sbj='submucosal', obj='tracheal_injury')
     0.040 KBTriple(rel='lowered', sbj='autopsy', obj='

### Evaluating a random-guessing strategy

In order to validate our evaluation framework, and to set a floor under expected results for future evaluations, let's implement and evaluate a random-guessing strategy. The random guesser is a classifier which completely ignores its input, and simply flips a coin.

In [33]:
def lift(f):
    return lambda xs: [f(x) for x in xs]

def make_random_classifier(p=0.50):
    def random_classify(kb_triple):
        return random.random() < p
    return lift(random_classify)

In [34]:
rel_ext.evaluate(splits, make_random_classifier(), sampling_rate=0.5)

#####################
relation: acquired 
number positive examples: 9
relation examples:
 [KBTriple(rel='acquired', sbj='indinavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='sodium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='efavirenz', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='potassium', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='abacavir', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='potassium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='abacavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='methylprednisolone', obj='spinal_fusion'), KBTriple(rel='acquired', sbj='zidovudine', obj='lipodystrophy')]
-----------------------
number unrelated pairs: 12069
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', 

#####################
relation: shortened 
number positive examples: 41
relation examples:
 [KBTriple(rel='shortened', sbj='methylprednisolone', obj='prothrombin_time'), KBTriple(rel='shortened', sbj='sodium', obj='coagulation_time'), KBTriple(rel='shortened', sbj='sirolimus', obj='therapeutic_response'), KBTriple(rel='shortened', sbj='clonazepam', obj='therapeutic_response'), KBTriple(rel='shortened', sbj='metronidazole', obj='therapeutic_response'), KBTriple(rel='shortened', sbj='octreotide', obj='therapeutic_response'), KBTriple(rel='shortened', sbj='potassium', obj='coagulation_time'), KBTriple(rel='shortened', sbj='sodium', obj='prothrombin_time'), KBTriple(rel='shortened', sbj='alendronate', obj='therapeutic_response'), KBTriple(rel='shortened', sbj='rosuvastatin', obj='therapeutic_response')]
-----------------------
number unrelated pairs: 12069
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'd

0.048778928430296374

## A simple baseline model

It shouldn't be too hard to do better than random guessing. But for now, let's aim low — let's use the data we have in the easiest and most obvious way, and see how far that gets us.

We start from the intuition that the words between two entity mentions frequently tell us how they're related. For example, in the phrase "SpaceX was founded by Elon Musk", the words "was founded by" indicate that the `founders` relation holds between the first entity mentioned and the second. Likewise, in the phrase "Elon Musk established SpaceX", the word "established" indicates the `founders` relation holds between the second entity mentioned and the first.

So let's write some code to find the most common phrases that appear between the two entity mentions for each relation. As the examples illustrate, we need to make sure to consider both directions: that is, where the subject of the relation appears as the first mention, and where it appears as the second.

In [35]:
def find_common_middles(split, top_k=5, show_output=False):
    corpus = split.corpus
    kb = split.kb
    mids_by_rel = {
        'fwd': defaultdict(lambda: defaultdict(int)),
        'rev': defaultdict(lambda: defaultdict(int))}
    for rel in kb.all_relations:
        for kbt in kb.get_triples_for_relation(rel):
            for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
                mids_by_rel['fwd'][rel][ex.middle] += 1
            for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
                mids_by_rel['rev'][rel][ex.middle] += 1
    def most_frequent(mid_counter):
        return sorted([(cnt, mid) for mid, cnt in mid_counter.items()], reverse=True)[:top_k]
    for rel in kb.all_relations:
        for dir in ['fwd', 'rev']:
            top = most_frequent(mids_by_rel[dir][rel])
            if show_output:
                for cnt, mid in top:
                    print('{:20s} {:5s} {:10d} {:s}'.format(rel, dir, cnt, mid))
            mids_by_rel[dir][rel] = set([mid for cnt, mid in top])
    return mids_by_rel

_ = find_common_middles(splits['train'], show_output=False)

In [36]:
def train_top_k_middles_classifier(top_k=100):
    split = splits['train']
    corpus = split.corpus
    top_k_mids_by_rel = find_common_middles(split=split, top_k=top_k)
    def classify(kb_triple):
        fwd_mids = top_k_mids_by_rel['fwd'][kb_triple.rel]
        rev_mids = top_k_mids_by_rel['rev'][kb_triple.rel]
        for ex in corpus.get_examples_for_entities(kb_triple.sbj, kb_triple.obj):
            if ex.middle in fwd_mids:
                return True
        for ex in corpus.get_examples_for_entities(kb_triple.obj, kb_triple.sbj):
            if ex.middle in rev_mids:
                return True
        return False
    return lift(classify)

In [37]:
rel_ext.evaluate(splits, train_top_k_middles_classifier())

#####################
relation: acquired 
number positive examples: 9
relation examples:
 [KBTriple(rel='acquired', sbj='indinavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='sodium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='efavirenz', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='potassium', obj='dacryostenosis'), KBTriple(rel='acquired', sbj='abacavir', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='potassium', obj='fanconi_syndrome'), KBTriple(rel='acquired', sbj='abacavir', obj='lipodystrophy'), KBTriple(rel='acquired', sbj='methylprednisolone', obj='spinal_fusion'), KBTriple(rel='acquired', sbj='zidovudine', obj='lipodystrophy')]
-----------------------
number unrelated pairs: 2413
unrelated examples:
 [('dopamine', 'renal_artery_occlusion'), ('nifedipine', 'hiccups'), ('hypotension', 'frusemide'), ('phenytoin', 'diabetic_foot'), ('stress', 'desipramine'), ('discharge', 'methicillin'), ('enzyme_activity', 'calcium-magnesium'), ('bleomycin', '

0.0