In [6]:
import os
import json
import pickle
import random
from collections import defaultdict, Counter

from indra.literature.adeft_tools import universal_extract_text
from indra.databases.hgnc_client import get_hgnc_name, get_hgnc_id

from adeft.discover import AdeftMiner
from adeft.gui import ground_with_gui
from adeft.modeling.label import AdeftLabeler
from adeft.modeling.classify import AdeftClassifier
from adeft.disambiguate import AdeftDisambiguator, load_disambiguator

from indra_db_lite.api import get_entrez_pmids_for_hgnc
from indra_db_lite.api import get_entrez_pmids_for_uniprot
from indra_db_lite.api import get_plaintexts_for_text_ref_ids
from indra_db_lite.api import get_text_ref_ids_for_agent_text
from indra_db_lite.api import get_text_ref_ids_for_pmids


from adeft_indra.grounding import AdeftGrounder
from adeft_indra.s3 import model_to_s3
from adeft_indra.model_building.escape import escape_filename

In [7]:
def get_text_ref_ids_for_entity(ns, id_):
    if ns == 'HGNC':
        pmids = get_entrez_pmids_for_hgnc(id_)
    elif ns == 'UP':
        pmids = get_entrez_pmids_for_uniprot(id_)
    return list(get_text_ref_ids_for_pmids(pmids).values())

In [8]:
adeft_grounder = AdeftGrounder()

In [9]:
shortforms = ['MCD', 'MCDs']
model_name = ':'.join(sorted(escape_filename(shortform) for shortform in shortforms))
results_path = os.path.abspath(os.path.join('../../', 'results', model_name))

In [10]:
miners = dict()
all_texts = {}
for shortform in shortforms:
    text_ref_ids = get_text_ref_ids_for_agent_text(shortform)
    content = get_plaintexts_for_text_ref_ids(text_ref_ids, contains=shortforms)
    text_dict = content.flatten()
    miners[shortform] = AdeftMiner(shortform)
    miners[shortform].process_texts(text_dict.values())
    all_texts.update(text_dict)

longform_dict = {}
for shortform in shortforms:
    longforms = miners[shortform].get_longforms()
    longforms = [(longform, count, score) for longform, count, score in longforms
                 if count*score > 2]
    longform_dict[shortform] = longforms
    
combined_longforms = Counter()
for longform_rows in longform_dict.values():
    combined_longforms.update({longform: count for longform, count, score
                               in longform_rows})
grounding_map = {}
names = {}
for longform in combined_longforms:
    groundings = adeft_grounder.ground(longform)
    if groundings:
        grounding = groundings[0]['grounding']
        grounding_map[longform] = grounding
        names[grounding] = groundings[0]['name']
longforms, counts = zip(*combined_longforms.most_common())
pos_labels = []

In [11]:
list(zip(longforms, counts))

[('methionine choline deficient', 286),
 ('methionine and choline deficient', 209),
 ('methyl β cyclodextrin', 106),
 ('minimal change disease', 102),
 ('magnetic circular dichroism', 64),
 ('malonyl coa decarboxylase', 61),
 ('multicentric castleman s disease', 44),
 ('methionine and choline deficient diet', 41),
 ('multicentric castleman disease', 31),
 ('methionine choline deficient diet', 31),
 ('methyl beta cyclodextrin', 24),
 ('malformations of cortical development', 24),
 ('multicentric cd', 19),
 ('macular corneal dystrophy', 16),
 ('methyl cyclodextrin', 11),
 ('medullary collecting duct', 10),
 ('methionine and choline', 10),
 ('mast cells degranulating', 7),
 ('mast cells density', 7),
 ('mast cells degranulating peptide', 7),
 ('monte carlo dropout', 6),
 ('mature c terminal domain', 6),
 ('minimum cylindrical diameter', 5),
 ('moral case deliberation', 4),
 ('mean consecutive different', 4),
 ('microvascular coronary dysfunction', 4),
 ('monochlorodimedone', 4),
 ('malony

In [72]:
grounding_map, names, pos_labels = ground_with_gui(longforms, counts, 
                                                   grounding_map=grounding_map,
                                                   names=names, pos_labels=pos_labels, no_browser=True, port=8890)

In [73]:
result = [grounding_map, names, pos_labels]

In [74]:
result

[{'macular corneal dystrophy': 'HP:HP:0001131',
  'magnetic cd': 'MESH:D002942',
  'magnetic circular dichroism': 'MESH:D002942',
  'malformations of cortical development': 'MESH:D054220',
  'malonyl coa decarboxylase': 'HGNC:7150',
  'malonyl coenzyme a coa decarboxylase': 'HGNC:7150',
  'malonyl coenzyme a decarboxylase': 'HGNC:7150',
  'mast cells degranulating': 'ungrounded',
  'mast cells degranulating peptide': 'ungrounded',
  'mast cells density': 'ungrounded',
  'maternal calving difficulty': 'ungrounded',
  'mature c terminal domain': 'ungrounded',
  'mean consecutive different': 'ungrounded',
  'medullary collecting duct': 'ungrounded',
  'methionie and choline deficient diet': 'methionine_choline_deficient',
  'methionine and choline': 'methionine_choline_deficient',
  'methionine and choline deficient': 'methionine_choline_deficient',
  'methionine and choline deficient diet': 'methionine_choline_deficient',
  'methionine choline deficient': 'methionine_choline_deficient',


In [75]:
grounding_map, names, pos_labels = [{'macular corneal dystrophy': 'HP:HP:0001131',
  'magnetic cd': 'MESH:D002942',
  'magnetic circular dichroism': 'MESH:D002942',
  'malformations of cortical development': 'MESH:D054220',
  'malonyl coa decarboxylase': 'HGNC:7150',
  'malonyl coenzyme a coa decarboxylase': 'HGNC:7150',
  'malonyl coenzyme a decarboxylase': 'HGNC:7150',
  'mast cells degranulating': 'ungrounded',
  'mast cells degranulating peptide': 'ungrounded',
  'mast cells density': 'ungrounded',
  'maternal calving difficulty': 'ungrounded',
  'mature c terminal domain': 'ungrounded',
  'mean consecutive different': 'ungrounded',
  'medullary collecting duct': 'ungrounded',
  'methionie and choline deficient diet': 'methionine_choline_deficient',
  'methionine and choline': 'methionine_choline_deficient',
  'methionine and choline deficient': 'methionine_choline_deficient',
  'methionine and choline deficient diet': 'methionine_choline_deficient',
  'methionine choline deficient': 'methionine_choline_deficient',
  'methionine choline deficient diet': 'methionine_choline_deficient',
  'methyl beta cyclodextrin': 'CHEBI:CHEBI:133151',
  'methyl cyclodextrin': 'CHEBI:CHEBI:133151',
  'methyl β cyclodextrin': 'CHEBI:CHEBI:133151',
  'microvascular coronary dysfunction': 'ungrounded',
  'minimal change disease': 'DOID:DOID:10966',
  'minimum cylindrical diameter': 'ungrounded',
  'mitotic cells death': 'ungrounded',
  'monochlorodimedone': 'MESH:C006991',
  'monte carlo dropout': 'ungrounded',
  'moral case deliberation': 'ungrounded',
  'multicentric castleman disease': 'DOID:DOID:0111152',
  'multicentric castleman s disease': 'DOID:DOID:0111152',
  'multicentric cd': 'DOID:DOID:0111152'},
 {'HP:HP:0001131': 'Corneal dystrophy',
  'MESH:D002942': 'Circular Dichroism',
  'MESH:D054220': 'Malformations of Cortical Development',
  'HGNC:7150': 'MLYCD',
  'methionine_choline_deficient': 'methionine_choline_deficient',
  'CHEBI:CHEBI:133151': 'methyl beta-cyclodextrin',
  'DOID:DOID:10966': 'lipoid nephrosis',
  'MESH:C006991': 'chlorodimedone',
  'DOID:DOID:0111152': 'multicentric Castleman disease'},
 ['CHEBI:CHEBI:133151',
  'DOID:DOID:0111152',
  'DOID:DOID:10966',
  'HGNC:7150',
  'HP:HP:0001131',
  'MESH:D002942',
  'MESH:D054220']]

In [76]:
excluded_longforms = []

In [77]:
grounding_dict = {shortform: {longform: grounding_map[longform] 
                              for longform, _, _ in longforms if longform in grounding_map
                              and longform not in excluded_longforms}
                  for shortform, longforms in longform_dict.items()}
result = [grounding_dict, names, pos_labels]

if not os.path.exists(results_path):
    os.mkdir(results_path)
with open(os.path.join(results_path, f'{model_name}_preliminary_grounding_info.json'), 'w') as f:
    json.dump(result, f)

In [78]:
additional_entities = {}

In [79]:
unambiguous_agent_texts = {}

In [80]:
labeler = AdeftLabeler(grounding_dict)
corpus = labeler.build_from_texts(
    (text, text_ref_id) for text_ref_id, text in all_texts.items()
)
agent_text_text_ref_id_map = defaultdict(list)
for text, label, id_ in corpus:
    agent_text_text_ref_id_map[label].append(id_)

entity_text_ref_id_map = {
    entity: set(
        get_text_ref_ids_for_entity(*entity.split(':', maxsplit=1))
    )
    for entity in additional_entities
}

In [81]:
intersection1 = []
for entity1, trids1 in entity_text_ref_id_map.items():
    for entity2, trids2 in entity_text_ref_id_map.items():
        intersection1.append((entity1, entity2, len(trids1 & trids2)))

In [82]:
intersection2 = []
for entity1, trids1 in agent_text_text_ref_id_map.items():
    for entity2, pmids2 in entity_text_ref_id_map.items():
        intersection2.append((entity1, entity2, len(set(trids1) & trids2)))

In [83]:
intersection1

[]

In [84]:
intersection2

[]

In [85]:
all_used_trids = set()
for entity, agent_texts in unambiguous_agent_texts.items():
    used_trids = set()
    for agent_text in agent_texts[1]:
        trids = set(get_text_ref_ids_for_agent_text(agent_text))
        new_trids = list(trids - all_texts.keys() - used_trids)
        content = get_plaintexts_for_text_ref_ids(new_trids, contains=agent_texts[1])
        text_dict = content.flatten()
        corpus.extend(
            [
                (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
            ]
        )
        used_trids.update(new_trids)
    all_used_trids.update(used_trids)
        
for entity, trids in entity_text_ref_id_map.items():
    new_trids = list(set(trids) - all_texts.keys() - all_used_trids)
    _, contains = additional_entities[entity]
    content = get_plaintexts_for_text_ref_ids(new_trids, contains=contains)
    text_dict = content.flatten()
    corpus.extend(
        [
            (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
        ]
    )

In [86]:
names.update({key: value[0] for key, value in additional_entities.items()})
names.update({key: value[0] for key, value in unambiguous_agent_texts.items()})
pos_labels = list(set(pos_labels) | additional_entities.keys() |
                  unambiguous_agent_texts.keys())

In [87]:
%%capture

classifier = AdeftClassifier(shortforms, pos_labels=pos_labels, random_state=1729)
param_grid = {'C': [100.0], 'max_features': [10000]}
texts, labels, pmids = zip(*corpus)
classifier.cv(texts, labels, param_grid, cv=5, n_jobs=5)

INFO: [2021-10-07 21:29:30] /adeft/Py/adeft/adeft/modeling/classify.py - Beginning grid search in parameter space:
{'C': [100.0], 'max_features': [10000]}
INFO: [2021-10-07 21:29:44] /adeft/Py/adeft/adeft/modeling/classify.py - Best f1 score of 0.9762313609319021 found for parameter values:
{'logit__C': 100.0, 'tfidf__max_features': 10000}


In [88]:
classifier.stats

{'label_distribution': {'HGNC:7150': 55,
  'methionine_choline_deficient': 360,
  'MESH:D002942': 46,
  'CHEBI:CHEBI:133151': 103,
  'DOID:DOID:10966': 69,
  'ungrounded': 43,
  'DOID:DOID:0111152': 69,
  'MESH:D054220': 21,
  'MESH:C006991': 4,
  'HP:HP:0001131': 12},
 'f1': {'mean': 0.976231, 'std': 0.008883},
 'precision': {'mean': 0.968982, 'std': 0.016206},
 'recall': {'mean': 0.983962, 'std': 0.01567},
 'CHEBI:CHEBI:133151': {'f1': {'mean': 0.958178, 'std': 0.024719},
  'pr': {'mean': 0.98, 'std': 0.024495},
  'rc': {'mean': 0.939818, 'std': 0.053147}},
 'DOID:DOID:0111152': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0.0}},
 'DOID:DOID:10966': {'f1': {'mean': 0.978289, 'std': 0.017752},
  'pr': {'mean': 0.971429, 'std': 0.034993},
  'rc': {'mean': 0.986667, 'std': 0.026667}},
 'HGNC:7150': {'f1': {'mean': 0.973913, 'std': 0.052174},
  'pr': {'mean': 0.981818, 'std': 0.036364},
  'rc': {'mean': 0.966667, 'std': 0.066667}},
 'H

In [89]:
disamb = AdeftDisambiguator(classifier, grounding_dict, names)

In [90]:
disamb.dump(model_name, results_path)

In [91]:
print(disamb.info())

Disambiguation model for MCD, and MCDs

Produces the disambiguations:
	Circular Dichroism*	MESH:D002942
	Corneal dystrophy*	HP:HP:0001131
	MLYCD*	HGNC:7150
	Malformations of Cortical Development*	MESH:D054220
	chlorodimedone	MESH:C006991
	lipoid nephrosis*	DOID:DOID:10966
	methionine_choline_deficient	methionine_choline_deficient
	methyl beta-cyclodextrin*	CHEBI:CHEBI:133151
	multicentric Castleman disease*	DOID:DOID:0111152

Class level metrics:
--------------------
Grounding                            	Count	F1     
         methionine_choline_deficient	360	0.99312
             methyl beta-cyclodextrin*	103	0.95818
                     lipoid nephrosis*	 69	0.97829
       multicentric Castleman disease*	 69	    1.0
                                MLYCD*	 55	0.97391
                   Circular Dichroism*	 46	0.97895
                           Ungrounded	 43	0.83703
Malformations of Cortical Development*	 21	    1.0
                    Corneal dystrophy*	 12	   0.96
                   

In [92]:
model_to_s3(disamb)

In [26]:
from adeft.disambiguate import load_disambiguator

In [27]:
disamb = load_disambiguator("BAL")

In [28]:
disamb

<adeft.disambiguate.AdeftDisambiguator at 0x7f4f001b33a0>

In [29]:
print(_28.info())

Disambiguation model for BAL

Produces the disambiguations:
	Bronchoalveolar Lavage	MESH:D018893
	CEL*	HGNC:1848
	Liver, Artificial	MESH:D019164
	benzaldehyde lyase*	MESH:C059416
	betaine aldehyde*	CHEBI:CHEBI:15710
	dimercaprol*	CHEBI:CHEBI:64198

Class level metrics:
--------------------
Grounding             	Count	F1     
Bronchoalveolar Lavage	1259	 0.9929
                   CEL*	  36	    1.0
     Liver, Artificial	  18	0.83619
            Ungrounded	  17	   0.65
           dimercaprol*	   8	    0.4
    benzaldehyde lyase*	   3	    0.2
      betaine aldehyde*	   2	    0.2

Global Metrics:
-----------------
	F1 score:	0.90773
	Precision:	1.0
	Recall:		0.83293

* Positive labels
See Docstring for explanation



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
