In [1]:
import os
import random

from collections import Counter
from collections import defaultdict

import dist_sup_lib.rel_ext as rel_ext
import dist_sup_lib.utils as utils

from src.rel_extract_extend.experiment import get_new_relation_instances

from src.utils import read_json_examples
from src.utils import read_kb_triples
from src.utils import read_kb_triples_json

In [2]:
# Set all the random seeds for reproducibility. Only the
# system seed is relevant for this notebook.

utils.fix_random_seeds()

In [3]:
rel_ext_data_sents = os.path.join('data', 'featurized_sentences')
rel_ext_data_kb = os.path.join("data", "knowledge_base")

In [4]:
example_data = []

In [5]:
for index in range(1, 200):
    s_i = str(index)
    zeros = "0" * (4 - len(s_i))
    # Files updated constantly and the names changes by year
    # for files downloaded in 2021 file_naming has to be changed to
    # featurized_sents_pubmed21n
    tagged_sent_file = f"featurized_sents_pubmed20n{zeros + s_i}.json"
    file_path = os.path.join(rel_ext_data_sents, tagged_sent_file)
    example_data.extend(read_json_examples(file_path))

In [6]:
corpus = rel_ext.Corpus(example_data)

In [7]:
kb_triples = read_kb_triples_json(os.path.join(rel_ext_data_kb, "rel_drug_react_triple_occ_all.json"))


In [8]:
random.seed(100)
num_list = random.sample(range(0, len(kb_triples)), len(kb_triples))
split_n = int(len(kb_triples)/4) * 3

# kb_train_idxs = set(kb_triples[:split_n])
kb_eval_idxs = set(num_list[split_n:])

kb_train = []
kb_eval = []

for i, kb_triple in enumerate(kb_triples):
    if i in kb_eval_idxs:
        kb_eval.append(kb_triple)
    else:
        kb_train.append(kb_triple)
print(len(kb_train))
print(len(kb_eval))

41298
13769


In [9]:
kb = rel_ext.KB(kb_train)

In [10]:
dataset = rel_ext.Dataset(corpus, kb)

In [11]:
splits = dataset.build_splits()

splits

{'tiny': Corpus with 4,007 examples; KB with 745 triples,
 'train': Corpus with 669,482 examples; KB with 28,380 triples,
 'dev': Corpus with 210,732 examples; KB with 12,173 triples,
 'all': Corpus with 884,221 examples; KB with 41,298 triples}

### Negative instances

By joining the corpus to the KB, we can obtain abundant positive instances for each relation. But a classifier cannot be trained on positive instances alone. In order to apply the distant supervision paradigm, we will also need some negative instances — that is, entity pairs which do not belong to any known relation. If you like, you can think of these entity pairs as being assigned to a special relation called `NO_RELATION`. We can find plenty of such pairs by searching for examples in the corpus which contain two entities which do not belong to any relation in the KB.

In [12]:
dataset.find_unrelated_pairs()

{('dextromethorphan', 'poisoning'),
 ('cardiac_output', 'diazepam'),
 ('metronidazole', 'superinfection'),
 ('surgery', 'lignocaine'),
 ('obesity', 'beta-carotene'),
 ('gammopathy', 'concomitant'),
 ('injection', 'concomitant'),
 ('phagocytosis', 'mycelium'),
 ('brain_death', 'methotrexate'),
 ('pain', 'ornidazole'),
 ('dopamine', 'neurogenic_shock'),
 ('arachnoid_cyst', 'submucosal'),
 ('ear_infection', 'tetracycline'),
 ('sodium', 'vitamin_b1'),
 ('nortriptyline', 'paresis'),
 ('endoscopy', 'nasal'),
 ('vitamin_d', 'cardiac'),
 ('renal_failure', 'urapidil'),
 ('adverse_event', 'lovastatin'),
 ('intranasal', 'q_fever'),
 ('visual_analogue_scale', 'concomitant'),
 ('blood_gases', 'alfentanil'),
 ('propranolol', 'radioactive_iodine_therapy'),
 ('portal_hypertension', 'prednisolone'),
 ('angina_pectoris', 'diphenylhydantoin'),
 ('glucose_tolerance', 'medium'),
 ('hepatitis', 'corynebacterium'),
 ('immunoglobulin', 'colon_cancer'),
 ('magnetic_resonance_imaging', 'tamoxifen'),
 ('polyhydr

Let's determine how many examples we have for each triple in the KB. We'll compute averages per relation.

In [13]:
dataset.count_examples()

                                             examples
relation               examples    triples    /triple
--------               --------    -------    -------
acquired                     20         17       1.18
aggravated                 1392        322       4.32
altered                     894         97       9.22
caused                   170646      34601       4.93
changed                     101         70       1.44
decreased                 44333       2820      15.72
delayed                      25         30       0.83
discoloured                 395        114       3.46
impaired                    321         71       4.52
improved                      0          1       0.00
increased                 48708       2846      17.11
infected                    155         23       6.74
lowered                       1          2       0.50
prolonged                   230         85       2.71
reduced                      82         63       1.30
ruptured                    

In [14]:
for key, val in dataset.kb.kb_triples_by_relation.items():
    print(key, len(val))

acquired 17
aggravated 322
altered 97
caused 34601
changed 70
decreased 2820
delayed 30
discoloured 114
impaired 71
improved 1
increased 2846
infected 23
lowered 2
prolonged 85
reduced 63
ruptured 14
shortened 122


In [15]:
for key, val in dataset.kb.kb_triples_by_relation.items():
    print(key, len(val))

acquired 17
aggravated 322
altered 97
caused 34601
changed 70
decreased 2820
delayed 30
discoloured 114
impaired 71
improved 1
increased 2846
infected 23
lowered 2
prolonged 85
reduced 63
ruptured 14
shortened 122


In [16]:
print(len(dataset.corpus.examples_by_entities))

4454


In [18]:
dataset.count_relation_combinations()

The most common relation combinations are:
       995 ('decreased', 'increased')
        59 ('caused', 'decreased', 'increased')
        42 ('caused', 'increased')
        42 ('caused', 'decreased')
        27 ('discoloured', 'increased')
        22 ('decreased', 'shortened')
        20 ('prolonged', 'shortened')
        20 ('changed', 'decreased', 'shortened')
        17 ('caused', 'infected')
        11 ('changed', 'decreased')
         9 ('caused', 'ruptured')
         8 ('decreased', 'delayed', 'shortened')
         8 ('changed', 'decreased', 'increased')
         6 ('acquired', 'caused')
         4 ('changed', 'shortened')
         3 ('decreased', 'increased', 'shortened')
         2 ('changed', 'increased')
         2 ('changed', 'decreased', 'increased', 'shortened')
         2 ('changed', 'decreased', 'delayed', 'increased', 'prolonged', 'shortened')
         1 ('increased', 'prolonged', 'shortened')
         1 ('decreased', 'prolonged')
         1 ('decreased', 'increased', 'p

### Featurizers

Featurizers are functions which define the feature representation for our model. The primary input to a featurizer will be the `KBTriple` for which we are generating features. But since our features will be derived from corpus examples containing the entities of the `KBTriple`, we must also pass in a reference to a `Corpus`. And in order to make it easy to combine different featurizers, we'll also pass in a feature counter to hold the results.

Here's an implementation for a very simple bag-of-words featurizer. It finds all the corpus examples containing the two entities in the `KBTriple`, breaks the phrase appearing between the two entity mentions into words, and counts the words. Note that it makes no distinction between "forward" and "reverse" examples.


In [19]:
def simple_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        for word in ex.middle.split(' '):
            feature_counter[word] += 1
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        for word in ex.middle.split(' '):
            feature_counter[word] += 1
    return feature_counter

def count_words(sent_part: str, feature_counter: Counter):
    for word in sent_part.split(" "):
        feature_counter[word] += 1

def middle_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.middle, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.middle, feature_counter)
    return feature_counter

def start_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.left, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.left, feature_counter)
    return feature_counter

def end_bag_of_words_featurizer(kbt, corpus, feature_counter):
    for ex in corpus.get_examples_for_entities(kbt.sbj, kbt.obj):
        count_words(ex.right, feature_counter)
    for ex in corpus.get_examples_for_entities(kbt.obj, kbt.sbj):
        count_words(ex.right, feature_counter)
    return feature_counter
    

In [20]:
kbt = kb.kb_triples[5]

kbt

KBTriple(rel='acquired', sbj='atazanavir', obj='lipodystrophy')

In [21]:
print(len(corpus.examples_by_entities))

4454


In [22]:
print(len(corpus.examples_by_entities["sodium"]))

617


In [23]:
for i in range(0, len(kb.kb_triples)):
    res = simple_bag_of_words_featurizer(kb.kb_triples[i], corpus, Counter())
    if res:
        print(res)

Counter({'': 17, 'the': 12, 'in': 10, 'and': 8, 'of': 7, 'with': 6, 'by': 4, 'renal': 4, 'a': 4, 'patients': 3, 'seven': 2, 'isolated': 2, 'injection': 2, 'induced': 2, 'dog,': 2, 'we': 2, 'membrane': 2, 'to': 2, 'minutes': 2, 'after': 2, 'increased': 2, 'all': 2, 'hyperenzymurias': 2, 'rapidly,': 2, 'progressively,': 2, 'maleate': 1, '(200-400': 1, 'mg/kg)': 1, 'into': 1, 'rats': 1, 'produces': 1, 'aminoaciduria': 1, 'along': 1, 'glycosuria': 1, 'phosphaturia,': 1, 'resembling': 1, 'reabsorption': 1, 'proximal': 1, 'distal': 1, 'tubule': 1, 'was': 1, 'evaluated': 1, 'different': 1, 'methods': 1, 'healthy': 1, 'subjects,': 1, 'recurrent': 1, 'calcium': 1, 'nephrolithiasis,': 1, 'five': 1, 'glucosuria': 1, 'three': 1, ',': 1, 'decline': 1, 'intestinal': 1, 'mucosa': 1, '(Na+-K+)-ATPase': 1, 'simultaneous': 1, 'decrease': 1, 'produced': 1, 'single': 1, 'Basenji': 1, 'have': 1, 'used': 1, 'brush': 1, 'border': 1, 'vesicles': 1, 'examine': 1, 'two': 1, 'factors': 1, 'that': 1, 'influence':

Counter({'': 63, 'ventilation': 19, 'in': 17, 'with': 14, 'of': 12, 'pressure': 11, 'the': 11, 'and': 10, 'positive': 9, 'a': 8, 'is': 5, 'mask': 4, 'support': 4, 'acute': 4, 'chronic': 4, 'ventilatory': 4, 'patients': 4, 'intermittent': 4, 'been': 4, 'to': 4, 'it': 4, 'mechanical': 4, 'by': 4, 'treatment': 3, 'for': 3, 'has': 3, 'used': 3, 'airway': 3, 'nocturnal': 3, 'more': 3, 'oxygen': 3, 'was': 3, 'treated': 3, 'hypercapnic': 2, 'effective': 2, 'successfully': 2, 'restrictive': 2, 'as': 2, 'being': 2, 'intubation': 2, '60': 2, 'which': 2, 'be': 2, 'presence': 2, 'cardiac': 2, ',': 2, 'system': 2, 'via': 2, 'failure': 2, 'than': 2, 'were': 2, 'managing': 1, 'exacerbation': 1, 'safe': 1, 'modality': 1, 'some': 1, '(nIPPV)': 1, 'management': 1, '(NMV)': 1, 'custom': 1, 'molded': 1, '(NIPPV-C': 1, 'Mclermott,': 1, '1989),': 1, 'well': 1, 'NV': 1, 'tracheostomy': 1, '(TIPPV)': 1, 'male': 1, 'patient': 1, 'limb-girdle': 1, 'muscular': 1, 'dystrophy': 1, 'who': 1, 'had': 1, 'developed': 

Counter({'': 34, 'and': 8, 'of': 6, 'by': 6, 'in': 4, 'induced': 3, 'the': 3, 'caused': 3, 'was': 2, 'to': 2, 'erythrocytes': 2, 'a': 2, 'determined': 2, 'for': 2, 'influx': 2, 'with': 2, 'mM': 2, 'or': 2, 'penicillin': 1, 'conjugated': 1, 'sheep': 1, 'optimal': 1, 'quantities,': 1, 'added': 1, '5%': 1, 'SRBC': 1, 'suspension,': 1, 'were': 1, 'haemagglutination': 1, '(12-5': 1, 'mg/ml)': 1, '(S)': 1, 'salts': 1, 'chenodeoxycholate': 1, '(CDC),': 1, 'deoxycholate': 1, '(DC)': 1, 'cholate': 1, '(C)': 1, 'their': 1, 'glycine': 1, '(G)': 1, 'taurine': 1, '(T)': 1, 'conjugates)': 1, 'LPC': 1, 'using': 1, 'erythrocyte': 1, 'increased': 1, 'eight': 1, 'nine': 1, 'HE': 1, 'patients': 1, 'overt': 1, 'various': 1, 'mammalian': 1, '2': 1, '1.4': 1, 'decomposition': 1, 'samples,': 1, 'common': 1, 'anticoagulants': 1, 'absorption': 1, 'irrigating': 1, 'water': 1, 'such': 1, 'as': 1, 'normal': 1, 'ammonia': 1, 'an': 1, 'isotonic': 1, 'NH4Cl': 1, 'medium,': 1, 'but': 1, 'protected': 1, 'cells': 1, 'f

Counter({'': 5, 'in': 5, 'the': 2, 'of': 2, 'patients': 2, 'and': 2, 'treatment': 1, "Raynaud's": 1, 'phenomenon': 1, 'was': 1, 'assessed': 1, 'a': 1, 'prospective': 1, 'double-blind': 1, 'randomised': 1, 'cross-over': 1, 'trial': 1, '16': 1, '(7': 1, 'progressive': 1, 'systemic': 1, 'sclerosis,': 1, '2': 1, 'intensive': 1, 'care': 1, 'patients,': 1, 'AAG': 1, 'concentrations': 1, 'are': 1, 'increased,': 1, 'with': 1, 'myocardial': 1, 'infarction': 1, 'an': 1, 'increased': 1, 'binding': 1})
Counter({'': 7})
Counter({'': 10, 'and': 3, 'was': 2, 'treatment': 1, 'also': 1, 'markedly': 1, 'reduced': 1, 'the': 1, 'incidence': 1, 'magnitude': 1, 'of': 1, 'proteinuria': 1, 'prevented': 1, 'due': 1, 'to': 1, 'pump': 1, 'failure': 1, 'found': 1, 'considerably': 1, 'less': 1, 'often': 1, 'on': 1, 'rate': 1, 'similar': 1, 'in': 1, 'both': 1, 'drug': 1, 'groups,': 1, 'at': 1, '17.5': 1, '(enalapril)': 1, '24.0': 1, '(': 1})
Counter({'': 5, 'pantothenic': 2, 'acid': 2, 'antagonist,': 2, 'induced': 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



### Experiments

In [24]:
train_result = rel_ext.train_models(
    splits, 
    featurizers=[start_bag_of_words_featurizer, middle_bag_of_words_featurizer, end_bag_of_words_featurizer],
    split_name="all"
)



Next comes `predict()`. This function takes as arguments a dictionary of data splits, the outputs of `train_models()`, and the name of the split for which to make predictions. It returns two parallel dictionaries: one holding the predictions (grouped by relation), the other holding the true labels (again, grouped by prediction).

## Analysis

### Examining the trained models

One important way to gain understanding of our trained model is to inspect the model weights. What features are strong positive indicators for each relation, and what features are strong negative indicators?

In [25]:
rel_ext.examine_model_weights(train_result)

Highest and lowest feature weights for relation acquired:

     1.462 renal
     1.046 transport
     0.918 Fanconi
     ..... .....
    -0.478 (
    -0.807 ,
    -2.232 

Highest and lowest feature weights for relation aggravated:

     3.993 condition
     3.640 condition.
     1.834 this
     ..... .....
    -0.725 (
    -0.762 )
    -0.896 

Highest and lowest feature weights for relation altered:

     2.547 mood
     1.938 saliva
     1.749 lithium
     ..... .....
    -0.509 patients
    -0.759 ,
    -0.949 

Highest and lowest feature weights for relation caused:

     0.301 mass
     0.245 calcium
     0.234 pain
     ..... .....
    -0.274 somatostatin
    -0.308 endocrine
    -0.443 concomitant

Highest and lowest feature weights for relation changed:

     2.479 therapeutic
     1.399 concentration
     1.357 response.
     ..... .....
    -0.850 .
    -1.113 ,
    -1.359 

Highest and lowest feature weights for relation decreased:

     0.539 immunoglobulins
     0.466 mol

### Discovering new relation instances

Another way to gain insight into our trained models is to use them to discover new relation instances that don't currently appear in the KB. In fact, this is the whole point of building a relation extraction system: to extend an existing KB (or build a new one) using knowledge extracted from natural language text at scale. Can the models we've trained do this effectively?

Because the goal is to discover new relation instances which are _true_ but _absent from the KB_, we can't evaluate this capability automatically. But we can generate candidate KB triples and manually evaluate them for correctness.

To do this, we'll start from corpus examples containing pairs of entities which do not belong to any relation in the KB (earlier, we described these as "negative examples"). We'll then apply our trained models to each pair of entities, and sort the results by probability assigned by the model, in order to find the most likely new instances for each relation.

# Create and Evaluate New Instances

In 3 Steps we evaluate the top 10, top 50 and top 100

**Attention**: k is set to the double value, because it creates instances for two directions – sbj-obj and obj.sbj

In [44]:
new_instances = get_new_relation_instances(
    dataset,
    featurizers=[simple_bag_of_words_featurizer],
    k=20
)

eval_relations = defaultdict(set)

for kb_triple in kb_eval:
    eval_relations[kb_triple.rel].add((kb_triple.rel, kb_triple.sbj, kb_triple.obj))
    
suggestions = defaultdict(set)

for rel, rel_inst in new_instances.items():
    for instance in rel_inst:
        suggestions[rel].add((instance[1].rel, instance[1].sbj, instance[1].obj))

train_x_rel = None
train_y_rel = None
Highest probability examples for relation acquired:

     0.191 KBTriple(rel='acquired', sbj='corynebacterium', obj='faeces')
     0.191 KBTriple(rel='acquired', sbj='faeces', obj='corynebacterium')
     0.163 KBTriple(rel='acquired', sbj='tac', obj='toxoplasmosis')
     0.163 KBTriple(rel='acquired', sbj='toxoplasmosis', obj='tac')
     0.092 KBTriple(rel='acquired', sbj='microangiopathy', obj='somatostatin')
     0.092 KBTriple(rel='acquired', sbj='somatostatin', obj='microangiopathy')
     0.070 KBTriple(rel='acquired', sbj='tac', obj='progesterone')
     0.070 KBTriple(rel='acquired', sbj='progesterone', obj='tac')
     0.062 KBTriple(rel='acquired', sbj='stress', obj='ondansetron')
     0.062 KBTriple(rel='acquired', sbj='ondansetron', obj='stress')
     0.058 KBTriple(rel='acquired', sbj='econazole', obj='haemoglobin')
     0.058 KBTriple(rel='acquired', sbj='haemoglobin', obj='econazole')
     0.057 KBTriple(rel='acquired', sbj='ammonium', o

     0.895 KBTriple(rel='discoloured', sbj='diazepam', obj='faeces')
     0.895 KBTriple(rel='discoloured', sbj='faeces', obj='diazepam')
     0.810 KBTriple(rel='discoloured', sbj='imipenem', obj='obstruction')
     0.810 KBTriple(rel='discoloured', sbj='obstruction', obj='imipenem')
     0.705 KBTriple(rel='discoloured', sbj='diclofenac', obj='urine_flow')
     0.705 KBTriple(rel='discoloured', sbj='urine_flow', obj='diclofenac')
     0.667 KBTriple(rel='discoloured', sbj='urine_flow', obj='ammonium')
     0.667 KBTriple(rel='discoloured', sbj='ammonium', obj='urine_flow')
     0.556 KBTriple(rel='discoloured', sbj='leprosy', obj='tac')
     0.556 KBTriple(rel='discoloured', sbj='tac', obj='leprosy')
     0.434 KBTriple(rel='discoloured', sbj='sodium', obj='urine_calcium')
     0.434 KBTriple(rel='discoloured', sbj='urine_calcium', obj='sodium')
     0.343 KBTriple(rel='discoloured', sbj='phenylbutazone', obj='saliva')
     0.343 KBTriple(rel='discoloured', sbj='saliva', obj='phenylb

In [43]:
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '', 'examples', 'examples', 'examples'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    'relation', 'eval set', 'suggestions', 'intersection'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '--------', '--------', '-------', '-------'))

for rel, instances in eval_relations.items():
    print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
        rel, str(len(instances)), 
        str(int(len(suggestions[rel])/2)), 
        str(len(instances.intersection(suggestions[rel])))))

                       examples        examples        examples
relation               eval set     suggestions    intersection
--------               --------         -------         -------
accelerated                   1               0               0
acquired                      9              10               0
aggravated                  114              10               0
altered                      29              10               0
caused                    11426              10               6
changed                      25              10               0
decreased                   962              10               1
delayed                      24              10               0
discoloured                  40              10               1
impaired                     16              10               0
increased                  1031              10               1
infected                      9              10               0
prolonged                    30         

In [46]:
new_instances = get_new_relation_instances(
    dataset,
    featurizers=[simple_bag_of_words_featurizer],
    k=100
)

eval_relations = defaultdict(set)

for kb_triple in kb_eval:
    eval_relations[kb_triple.rel].add((kb_triple.rel, kb_triple.sbj, kb_triple.obj))
    
suggestions = defaultdict(set)

for rel, rel_inst in new_instances.items():
    for instance in rel_inst:
        suggestions[rel].add((instance[1].rel, instance[1].sbj, instance[1].obj))

train_x_rel = None
train_y_rel = None
Highest probability examples for relation acquired:

     0.191 KBTriple(rel='acquired', sbj='corynebacterium', obj='faeces')
     0.191 KBTriple(rel='acquired', sbj='faeces', obj='corynebacterium')
     0.163 KBTriple(rel='acquired', sbj='tac', obj='toxoplasmosis')
     0.163 KBTriple(rel='acquired', sbj='toxoplasmosis', obj='tac')
     0.092 KBTriple(rel='acquired', sbj='microangiopathy', obj='somatostatin')
     0.092 KBTriple(rel='acquired', sbj='somatostatin', obj='microangiopathy')
     0.070 KBTriple(rel='acquired', sbj='tac', obj='progesterone')
     0.070 KBTriple(rel='acquired', sbj='progesterone', obj='tac')
     0.062 KBTriple(rel='acquired', sbj='stress', obj='ondansetron')
     0.062 KBTriple(rel='acquired', sbj='ondansetron', obj='stress')
     0.058 KBTriple(rel='acquired', sbj='econazole', obj='haemoglobin')
     0.058 KBTriple(rel='acquired', sbj='haemoglobin', obj='econazole')
     0.057 KBTriple(rel='acquired', sbj='ammonium', o

     1.000 KBTriple(rel='caused', sbj='fall', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='fall')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='blood_pressure')
     1.000 KBTriple(rel='caused', sbj='blood_pressure', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='enzyme_activity', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='enzyme_activity')
     1.000 KBTriple(rel='caused', sbj='tachycardia', obj='propranolol')
     1.000 KBTriple(rel='caused', sbj='propranolol', obj='tachycardia')
     1.000 KBTriple(rel='caused', sbj='magnesium', obj='infusion')
     1.000 KBTriple(rel='caused', sbj='infusion', obj='magnesium')
     1.000 KBTriple(rel='caused', sbj='sodium', obj='hypercalciuria')
     1.000 KBTriple(rel='caused', sbj='hypercalciuria', obj='sodium')
     1.000 KBTriple(rel='caused', sbj='headache', obj='sumatriptan')
     1.000 KBTriple(rel='caused', sbj='sumatriptan', obj='headache')
     1.000 KBTriple(rel='cause

     0.154 KBTriple(rel='delayed', sbj='gamma-globulin', obj='cholangitis')
     0.154 KBTriple(rel='delayed', sbj='cholangitis', obj='gamma-globulin')
     0.133 KBTriple(rel='delayed', sbj='solar_urticaria', obj='cimetidine')
     0.133 KBTriple(rel='delayed', sbj='cimetidine', obj='solar_urticaria')
     0.107 KBTriple(rel='delayed', sbj='anaplastic_astrocytoma', obj='methylprednisolone')
     0.107 KBTriple(rel='delayed', sbj='methylprednisolone', obj='astrocytoma')
     0.107 KBTriple(rel='delayed', sbj='methylprednisolone', obj='anaplastic_astrocytoma')
     0.107 KBTriple(rel='delayed', sbj='astrocytoma', obj='methylprednisolone')
     0.091 KBTriple(rel='delayed', sbj='ejection_fraction', obj='gallopamil')
     0.091 KBTriple(rel='delayed', sbj='gallopamil', obj='ejection_fraction')
     0.087 KBTriple(rel='delayed', sbj='azatadine', obj='rhinitis')
     0.087 KBTriple(rel='delayed', sbj='rhinitis', obj='azatadine')
     0.083 KBTriple(rel='delayed', sbj='anoxia', obj='nadolol'

     0.698 KBTriple(rel='increased', sbj='oxygen_consumption', obj='diltiazem')
     0.696 KBTriple(rel='increased', sbj='renal_hypertrophy', obj='polypeptide')
     0.696 KBTriple(rel='increased', sbj='polypeptide', obj='renal_hypertrophy')
     0.694 KBTriple(rel='increased', sbj='tremor', obj='betaxolol')
     0.694 KBTriple(rel='increased', sbj='betaxolol', obj='tremor')
     0.685 KBTriple(rel='increased', sbj='weight', obj='gamma-globulin')
     0.685 KBTriple(rel='increased', sbj='gamma-globulin', obj='weight')
     0.681 KBTriple(rel='increased', sbj='ibuprofen', obj='renin')
     0.681 KBTriple(rel='increased', sbj='renin', obj='ibuprofen')
     0.674 KBTriple(rel='increased', sbj='epinephrine', obj='diazepam')
     0.674 KBTriple(rel='increased', sbj='diazepam', obj='epinephrine')
     0.661 KBTriple(rel='increased', sbj='terazosin', obj='fasting')
     0.661 KBTriple(rel='increased', sbj='fasting', obj='terazosin')
     0.657 KBTriple(rel='increased', sbj='cardiac_index', ob

     0.861 KBTriple(rel='ruptured', sbj='strontium', obj='osteolysis')
     0.861 KBTriple(rel='ruptured', sbj='osteolysis', obj='strontium')
     0.238 KBTriple(rel='ruptured', sbj='renal_hypertrophy', obj='polypeptide')
     0.238 KBTriple(rel='ruptured', sbj='polypeptide', obj='renal_hypertrophy')
     0.104 KBTriple(rel='ruptured', sbj='econazole', obj='erythema')
     0.104 KBTriple(rel='ruptured', sbj='erythema', obj='econazole')
     0.069 KBTriple(rel='ruptured', sbj='immunglobulin', obj='transplant')
     0.069 KBTriple(rel='ruptured', sbj='transplant', obj='immunglobulin')
     0.058 KBTriple(rel='ruptured', sbj='coma', obj='methylprednisolone')
     0.058 KBTriple(rel='ruptured', sbj='methylprednisolone', obj='coma')
     0.049 KBTriple(rel='ruptured', sbj='acyclovir', obj='seizure')
     0.049 KBTriple(rel='ruptured', sbj='seizure', obj='acyclovir')
     0.045 KBTriple(rel='ruptured', sbj='enzyme_activity', obj='pindolol')
     0.045 KBTriple(rel='ruptured', sbj='pindolol',

In [47]:
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '', 'examples', 'examples', 'examples'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    'relation', 'eval set', 'suggestions', 'intersection'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '--------', '--------', '-------', '-------'))

for rel, instances in eval_relations.items():
    print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
        rel, str(len(instances)), 
        str(int(len(suggestions[rel])/2)), 
        str(len(instances.intersection(suggestions[rel])))))

                       examples        examples        examples
relation               eval set     suggestions    intersection
--------               --------         -------         -------
accelerated                   1               0               0
acquired                      9              50               0
aggravated                  114              50               0
altered                      29              50               0
caused                    11426              50              15
changed                      25              50               0
decreased                   962              50               4
delayed                      24              50               0
discoloured                  40              50               1
impaired                     16              50               0
increased                  1031              50               7
infected                      9              50               0
prolonged                    30         

In [48]:
new_instances = get_new_relation_instances(
    dataset,
    featurizers=[simple_bag_of_words_featurizer],
    k=200
)

eval_relations = defaultdict(set)

for kb_triple in kb_eval:
    eval_relations[kb_triple.rel].add((kb_triple.rel, kb_triple.sbj, kb_triple.obj))
    
suggestions = defaultdict(set)

for rel, rel_inst in new_instances.items():
    for instance in rel_inst:
        suggestions[rel].add((instance[1].rel, instance[1].sbj, instance[1].obj))

train_x_rel = None
train_y_rel = None
Highest probability examples for relation acquired:

     0.191 KBTriple(rel='acquired', sbj='corynebacterium', obj='faeces')
     0.191 KBTriple(rel='acquired', sbj='faeces', obj='corynebacterium')
     0.163 KBTriple(rel='acquired', sbj='tac', obj='toxoplasmosis')
     0.163 KBTriple(rel='acquired', sbj='toxoplasmosis', obj='tac')
     0.092 KBTriple(rel='acquired', sbj='microangiopathy', obj='somatostatin')
     0.092 KBTriple(rel='acquired', sbj='somatostatin', obj='microangiopathy')
     0.070 KBTriple(rel='acquired', sbj='tac', obj='progesterone')
     0.070 KBTriple(rel='acquired', sbj='progesterone', obj='tac')
     0.062 KBTriple(rel='acquired', sbj='stress', obj='ondansetron')
     0.062 KBTriple(rel='acquired', sbj='ondansetron', obj='stress')
     0.058 KBTriple(rel='acquired', sbj='econazole', obj='haemoglobin')
     0.058 KBTriple(rel='acquired', sbj='haemoglobin', obj='econazole')
     0.057 KBTriple(rel='acquired', sbj='ammonium', o

     1.000 KBTriple(rel='caused', sbj='fall', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='fall')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='blood_pressure')
     1.000 KBTriple(rel='caused', sbj='blood_pressure', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='enzyme_activity', obj='pindolol')
     1.000 KBTriple(rel='caused', sbj='pindolol', obj='enzyme_activity')
     1.000 KBTriple(rel='caused', sbj='tachycardia', obj='propranolol')
     1.000 KBTriple(rel='caused', sbj='propranolol', obj='tachycardia')
     1.000 KBTriple(rel='caused', sbj='magnesium', obj='infusion')
     1.000 KBTriple(rel='caused', sbj='infusion', obj='magnesium')
     1.000 KBTriple(rel='caused', sbj='sodium', obj='hypercalciuria')
     1.000 KBTriple(rel='caused', sbj='hypercalciuria', obj='sodium')
     1.000 KBTriple(rel='caused', sbj='headache', obj='sumatriptan')
     1.000 KBTriple(rel='caused', sbj='sumatriptan', obj='headache')
     1.000 KBTriple(rel='cause

     0.154 KBTriple(rel='delayed', sbj='gamma-globulin', obj='cholangitis')
     0.154 KBTriple(rel='delayed', sbj='cholangitis', obj='gamma-globulin')
     0.133 KBTriple(rel='delayed', sbj='solar_urticaria', obj='cimetidine')
     0.133 KBTriple(rel='delayed', sbj='cimetidine', obj='solar_urticaria')
     0.107 KBTriple(rel='delayed', sbj='anaplastic_astrocytoma', obj='methylprednisolone')
     0.107 KBTriple(rel='delayed', sbj='methylprednisolone', obj='astrocytoma')
     0.107 KBTriple(rel='delayed', sbj='methylprednisolone', obj='anaplastic_astrocytoma')
     0.107 KBTriple(rel='delayed', sbj='astrocytoma', obj='methylprednisolone')
     0.091 KBTriple(rel='delayed', sbj='ejection_fraction', obj='gallopamil')
     0.091 KBTriple(rel='delayed', sbj='gallopamil', obj='ejection_fraction')
     0.087 KBTriple(rel='delayed', sbj='azatadine', obj='rhinitis')
     0.087 KBTriple(rel='delayed', sbj='rhinitis', obj='azatadine')
     0.083 KBTriple(rel='delayed', sbj='anoxia', obj='nadolol'

     1.000 KBTriple(rel='increased', sbj='renin', obj='sodium')
     1.000 KBTriple(rel='increased', sbj='sodium', obj='renin')
     0.999 KBTriple(rel='increased', sbj='pectin', obj='weight')
     0.999 KBTriple(rel='increased', sbj='weight', obj='pectin')
     0.999 KBTriple(rel='increased', sbj='propranolol', obj='venous_pressure')
     0.999 KBTriple(rel='increased', sbj='venous_pressure', obj='propranolol')
     0.991 KBTriple(rel='increased', sbj='weight', obj='diltiazem')
     0.991 KBTriple(rel='increased', sbj='diltiazem', obj='weight')
     0.986 KBTriple(rel='increased', sbj='cardiac_index', obj='propranolol')
     0.986 KBTriple(rel='increased', sbj='propranolol', obj='cardiac_index')
     0.966 KBTriple(rel='increased', sbj='blood_pressure', obj='urapidil')
     0.966 KBTriple(rel='increased', sbj='urapidil', obj='blood_pressure')
     0.965 KBTriple(rel='increased', sbj='weight', obj='pyrantel')
     0.965 KBTriple(rel='increased', sbj='pyrantel', obj='weight')
     0.959

     0.948 KBTriple(rel='reduced', sbj='diazepam', obj='central_cord_syndrome')
     0.948 KBTriple(rel='reduced', sbj='central_cord_syndrome', obj='diazepam')
     0.946 KBTriple(rel='reduced', sbj='sodium', obj='central_cord_syndrome')
     0.946 KBTriple(rel='reduced', sbj='central_cord_syndrome', obj='sodium')
     0.916 KBTriple(rel='reduced', sbj='microangiopathy', obj='somatostatin')
     0.916 KBTriple(rel='reduced', sbj='somatostatin', obj='microangiopathy')
     0.903 KBTriple(rel='reduced', sbj='urine_flow', obj='ammonium')
     0.903 KBTriple(rel='reduced', sbj='ammonium', obj='urine_flow')
     0.826 KBTriple(rel='reduced', sbj='thiazide', obj='crystalluria')
     0.826 KBTriple(rel='reduced', sbj='crystalluria', obj='thiazide')
     0.789 KBTriple(rel='reduced', sbj='ammonium', obj='nephrocalcinosis')
     0.789 KBTriple(rel='reduced', sbj='nephrocalcinosis', obj='ammonium')
     0.734 KBTriple(rel='reduced', sbj='pindolol', obj='dry_mouth')
     0.734 KBTriple(rel='reduc

In [49]:
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '', 'examples', 'examples', 'examples'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    'relation', 'eval set', 'suggestions', 'intersection'))
print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
    '--------', '--------', '-------', '-------'))

for rel, instances in eval_relations.items():
    print('{:15s} {:>15s} {:>15s} {:>15s}'.format(
        rel, str(len(instances)), 
        str(int(len(suggestions[rel])/2)), 
        str(len(instances.intersection(suggestions[rel])))))

                       examples        examples        examples
relation               eval set     suggestions    intersection
--------               --------         -------         -------
accelerated                   1               0               0
acquired                      9             100               0
aggravated                  114             100               0
altered                      29             100               0
caused                    11426             100              26
changed                      25             100               0
decreased                   962             100               8
delayed                      24             100               0
discoloured                  40             100               1
impaired                     16             100               0
increased                  1031             100               9
infected                      9             100               0
prolonged                    30         