In [1]:
import os
import json
import pickle
import random
from collections import defaultdict, Counter

from indra.literature.adeft_tools import universal_extract_text
from indra.databases.hgnc_client import get_hgnc_name, get_hgnc_id

from adeft.discover import AdeftMiner
from adeft.gui import ground_with_gui
from adeft.modeling.label import AdeftLabeler
from adeft.modeling.classify import AdeftClassifier
from adeft.disambiguate import AdeftDisambiguator, load_disambiguator

from indra_db_lite.api import get_entrez_pmids_for_hgnc
from indra_db_lite.api import get_entrez_pmids_for_uniprot
from indra_db_lite.api import get_plaintexts_for_text_ref_ids
from indra_db_lite.api import get_text_ref_ids_for_agent_text
from indra_db_lite.api import get_text_ref_ids_for_pmids


from adeft_indra.grounding import AdeftGrounder
from adeft_indra.s3 import model_to_s3
from adeft_indra.model_building.escape import escape_filename

In [2]:
def get_text_ref_ids_for_entity(ns, id_):
    if ns == 'HGNC':
        pmids = get_entrez_pmids_for_hgnc(id_)
    elif ns == 'UP':
        pmids = get_entrez_pmids_for_uniprot(id_)
    return list(get_text_ref_ids_for_pmids(pmids).values())

In [3]:
adeft_grounder = AdeftGrounder()

In [4]:
shortforms = ['CAS']
model_name = ':'.join(sorted(escape_filename(shortform) for shortform in shortforms))
results_path = os.path.abspath(os.path.join('../../', 'results', model_name))

In [5]:
miners = dict()
all_texts = {}
for shortform in shortforms:
    text_ref_ids = get_text_ref_ids_for_agent_text(shortform)
    content = get_plaintexts_for_text_ref_ids(text_ref_ids, contains=shortforms)
    text_dict = content.flatten()
    miners[shortform] = AdeftMiner(shortform)
    miners[shortform].process_texts(text_dict.values())
    all_texts.update(text_dict)

longform_dict = {}
for shortform in shortforms:
    longforms = miners[shortform].get_longforms()
    longforms = [(longform, count, score) for longform, count, score in longforms
                 if count*score > 2]
    longform_dict[shortform] = longforms
    
combined_longforms = Counter()
for longform_rows in longform_dict.values():
    combined_longforms.update({longform: count for longform, count, score
                               in longform_rows})
grounding_map = {}
names = {}
for longform in combined_longforms:
    groundings = adeft_grounder.ground(longform)
    if groundings:
        grounding = groundings[0]['grounding']
        grounding_map[longform] = grounding
        names[grounding] = groundings[0]['name']
longforms, counts = zip(*combined_longforms.most_common())
pos_labels = []

In [6]:
list(zip(longforms, counts))

[('carotid artery stenting', 149),
 ('coronary artery spasm', 68),
 ('crk associated substrate', 37),
 ('caspofungin', 37),
 ('casein', 26),
 ('carotid artery stenosis', 23),
 ('childhood apraxia of speech', 21),
 ('contralateral acoustic stimulation', 19),
 ('chrome azurol s', 18),
 ('coronary artery stenosis', 17),
 ('condomless anal sex', 15),
 ('carotid angioplasty and stenting', 13),
 ('clinical activity score', 12),
 ('catalytic anionic site', 12),
 ('cycloartenol synthase', 11),
 ('cellular apoptosis susceptibility protein', 11),
 ('cyanoalanine synthase', 10),
 ('catalytic activity site', 9),
 ('calcific aortic stenosis', 8),
 ('carotid stenting', 8),
 ('cas', 6),
 ('cellular apoptosis susceptibility', 6),
 ('castanospermine', 6),
 ('chronic adolescent stress', 6),
 ('carotid atherosclerosis', 6),
 ('clean air scenario', 5),
 ('computer aided surgery', 5),
 ('central auditory system', 5),
 ('contact activity system', 5),
 ('central anticholinergic syndrome', 5),
 ('coronavirus 

In [49]:
grounding_map, names, pos_labels = ground_with_gui(longforms, counts, 
                                                   grounding_map=grounding_map,
                                                   names=names, pos_labels=pos_labels, no_browser=True, port=8890)

In [50]:
result = [grounding_map, names, pos_labels]

In [51]:
result

[{'calcific aortic stenosis': 'MESH:D001024',
  'calcium sensing receptor': 'ungrounded',
  'carotid angioplasty and stenting': 'carotid_artery_stenting',
  'carotid angioplasty with stenting': 'carotid_artery_stenting',
  'carotid artery stenosis': 'MESH:D016893',
  'carotid artery stenting': 'carotid_artery_stenting',
  'carotid atherosclerosis': 'EFO:0009783',
  'carotid stenting': 'carotid_artery_stenting',
  'cas': 'ungrounded',
  'casein': 'MESH:D002364',
  'caspofungin': 'CHEBI:CHEBI:474180',
  'castanospermine': 'CHEBI:CHEBI:27860',
  'catalytic activity site': 'ungrounded',
  'catalytic anionic site': 'ungrounded',
  'catalytic site': 'ungrounded',
  'celebrity attitude scale': 'ungrounded',
  'cellular apoptosis susceptibility': 'ungrounded',
  'cellular apoptosis susceptibility protein': 'HGNC:2431',
  'central airway stenosis': 'MESH:D003251',
  'central anticholinergic syndrome': 'MESH:D064807',
  'central auditory system': 'ungrounded',
  'centrosome amplification score':

In [52]:
grounding_map, names, pos_labels = [{'calcific aortic stenosis': 'MESH:D001024',
  'calcium sensing receptor': 'ungrounded',
  'carotid angioplasty and stenting': 'carotid_artery_stenting',
  'carotid angioplasty with stenting': 'carotid_artery_stenting',
  'carotid artery stenosis': 'MESH:D016893',
  'carotid artery stenting': 'carotid_artery_stenting',
  'carotid atherosclerosis': 'EFO:0009783',
  'carotid stenting': 'carotid_artery_stenting',
  'cas': 'ungrounded',
  'casein': 'MESH:D002364',
  'caspofungin': 'CHEBI:CHEBI:474180',
  'castanospermine': 'CHEBI:CHEBI:27860',
  'catalytic activity site': 'ungrounded',
  'catalytic anionic site': 'ungrounded',
  'catalytic site': 'ungrounded',
  'celebrity attitude scale': 'ungrounded',
  'cellular apoptosis susceptibility': 'ungrounded',
  'cellular apoptosis susceptibility protein': 'HGNC:2431',
  'central airway stenosis': 'MESH:D003251',
  'central anticholinergic syndrome': 'MESH:D064807',
  'central auditory system': 'ungrounded',
  'centrosome amplification score': 'ungrounded',
  'childhood apraxia of speech': 'DOID:DOID:0111275',
  'chinese academy of sciences': 'ungrounded',
  'chromazurol s': 'MESH:C015076',
  'chrome azurol s': 'MESH:C015076',
  'chronic adolescent stress': 'ungrounded',
  'cis acting sequences': 'ungrounded',
  'clavaminate synthase': 'ungrounded',
  'clean air scenario': 'ungrounded',
  'clinical activity score': 'ungrounded',
  'cognitive attentional syndrome': 'ungrounded',
  'color analog scale': 'ungrounded',
  'compensatory auditory stimulation': 'MESH:D000161',
  'complete artificial saliva': 'MESH:D012464',
  'complex adapter system': 'ungrounded',
  'composite appetite score': 'ungrounded',
  'computer aided surgery': 'MESH:D025321',
  'condomless anal sex': 'ungrounded',
  'contact activity system': 'ungrounded',
  'contralateral acoustic stimulation': 'MESH:D000161',
  'coronary artery spasm': 'MESH:D003329',
  'coronary artery stenosis': 'MESH:D023921',
  'coronary atherosclerosis': 'MESH:D003324',
  'coronavirus anxiety scale': 'ungrounded',
  'critical asthma syndrome': 'ungrounded',
  'crk associated substrate': 'HGNC:971',
  'crude antimicrobial substance': 'ungrounded',
  'cutaneous angiosarcoma': 'MESH:D006394',
  'cyanoalanine synthase': 'ungrounded',
  'cycloartenol synthase': 'UP:P38605',
  'cytolytic activity score': 'ungrounded',
  'scs casein percentage': 'ungrounded'},
 {'MESH:D001024': 'Aortic Valve Stenosis',
  'carotid_artery_stenting': 'carotid_artery_stenting',
  'MESH:D016893': 'Carotid Stenosis',
  'EFO:0009783': 'carotid atherosclerosis',
  'MESH:D002364': 'Caseins',
  'CHEBI:CHEBI:474180': 'caspofungin',
  'CHEBI:CHEBI:27860': 'castanospermine',
  'HGNC:2431': 'CSE1L',
  'MESH:D003251': 'Constriction, Pathologic',
  'MESH:D064807': 'Anticholinergic Syndrome',
  'DOID:DOID:0111275': 'speech-language disorder-1',
  'MESH:C015076': 'chrome azurol S',
  'MESH:D000161': 'Acoustic Stimulation',
  'MESH:D012464': 'Saliva, Artificial',
  'MESH:D025321': 'Surgery, Computer-Assisted',
  'MESH:D003329': 'Coronary Vasospasm',
  'MESH:D023921': 'Coronary Stenosis',
  'MESH:D003324': 'Coronary Artery Disease',
  'HGNC:971': 'BCAR1',
  'MESH:D006394': 'Hemangiosarcoma',
  'UP:P38605': 'Cycloartenol synthase'},
 ['CHEBI:CHEBI:474180',
  'DOID:DOID:0111275',
  'HGNC:2431',
  'HGNC:971',
  'MESH:C015076',
  'MESH:D000161',
  'MESH:D002364',
  'MESH:D003329']]

In [53]:
excluded_longforms = ['cas']

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(a

In [54]:
grounding_dict = {shortform: {longform: grounding_map[longform] 
                              for longform, _, _ in longforms if longform in grounding_map
                              and longform not in excluded_longforms}
                  for shortform, longforms in longform_dict.items()}
result = [grounding_dict, names, pos_labels]

if not os.path.exists(results_path):
    os.mkdir(results_path)
with open(os.path.join(results_path, f'{model_name}_preliminary_grounding_info.json'), 'w') as f:
    json.dump(result, f)

In [55]:
additional_entities = {}

In [56]:
unambiguous_agent_texts = {}

In [57]:
labeler = AdeftLabeler(grounding_dict)
corpus = labeler.build_from_texts(
    (text, text_ref_id) for text_ref_id, text in all_texts.items()
)
agent_text_text_ref_id_map = defaultdict(list)
for text, label, id_ in corpus:
    agent_text_text_ref_id_map[label].append(id_)

entity_text_ref_id_map = {
    entity: set(
        get_text_ref_ids_for_entity(*entity.split(':', maxsplit=1))
    )
    for entity in additional_entities
}

In [58]:
intersection1 = []
for entity1, trids1 in entity_text_ref_id_map.items():
    for entity2, trids2 in entity_text_ref_id_map.items():
        intersection1.append((entity1, entity2, len(trids1 & trids2)))

In [59]:
intersection2 = []
for entity1, trids1 in agent_text_text_ref_id_map.items():
    for entity2, pmids2 in entity_text_ref_id_map.items():
        intersection2.append((entity1, entity2, len(set(trids1) & trids2)))

In [60]:
intersection1

[]

In [61]:
intersection2

[]

In [62]:
all_used_trids = set()
for entity, agent_texts in unambiguous_agent_texts.items():
    used_trids = set()
    for agent_text in agent_texts[1]:
        trids = set(get_text_ref_ids_for_agent_text(agent_text))
        new_trids = list(trids - all_texts.keys() - used_trids)
        content = get_plaintexts_for_text_ref_ids(new_trids, contains=agent_texts[1])
        text_dict = content.flatten()
        corpus.extend(
            [
                (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
            ]
        )
        used_trids.update(new_trids)
    all_used_trids.update(used_trids)
        
for entity, trids in entity_text_ref_id_map.items():
    new_trids = list(set(trids) - all_texts.keys() - all_used_trids)
    _, contains = additional_entities[entity]
    content = get_plaintexts_for_text_ref_ids(new_trids, contains=contains)
    text_dict = content.flatten()
    corpus.extend(
        [
            (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
        ]
    )

In [63]:
names.update({key: value[0] for key, value in additional_entities.items()})
names.update({key: value[0] for key, value in unambiguous_agent_texts.items()})
pos_labels = list(set(pos_labels) | additional_entities.keys() |
                  unambiguous_agent_texts.keys())

In [64]:
%%capture

classifier = AdeftClassifier(shortforms, pos_labels=pos_labels, random_state=1729)
param_grid = {'C': [100.0], 'max_features': [10000]}
texts, labels, pmids = zip(*corpus)
classifier.cv(texts, labels, param_grid, cv=5, n_jobs=5)

INFO: [2021-10-05 21:07:28] /adeft/Py/adeft/adeft/modeling/classify.py - Beginning grid search in parameter space:
{'C': [100.0], 'max_features': [10000]}
INFO: [2021-10-05 21:07:41] /adeft/Py/adeft/adeft/modeling/classify.py - Best f1 score of 0.9102677425734932 found for parameter values:
{'logit__C': 100.0, 'tfidf__max_features': 10000}


In [45]:
classifier.stats

{'label_distribution': {'carotid_artery_stenting': 121,
  'HGNC:971': 32,
  'MESH:D000161': 14,
  'EFO:0009783': 3,
  'ungrounded': 85,
  'MESH:C015076': 18,
  'MESH:D003329': 42,
  'MESH:D001024': 4,
  'MESH:D025321': 3,
  'MESH:D002364': 21,
  'UP:P38605': 8,
  'CHEBI:CHEBI:474180': 25,
  'HGNC:2431': 9,
  'CHEBI:CHEBI:27860': 5,
  'MESH:D023921': 8,
  'MESH:D012464': 3,
  'MESH:D003251': 2,
  'MESH:D016893': 17,
  'DOID:DOID:0111275': 9,
  'MESH:D064807': 5,
  'MESH:D003324': 3,
  'MESH:D006394': 4},
 'f1': {'mean': 0.917081, 'std': 0.009011},
 'precision': {'mean': 0.933816, 'std': 0.026684},
 'recall': {'mean': 0.901935, 'std': 0.020463},
 'CHEBI:CHEBI:27860': {'f1': {'mean': 0.4, 'std': 0.489898},
  'pr': {'mean': 0.4, 'std': 0.489898},
  'rc': {'mean': 0.4, 'std': 0.489898}},
 'CHEBI:CHEBI:474180': {'f1': {'mean': 0.941818, 'std': 0.079169},
  'pr': {'mean': 0.96, 'std': 0.08},
  'rc': {'mean': 0.926667, 'std': 0.090431}},
 'DOID:DOID:0111275': {'f1': {'mean': 0.8, 'std': 0.4},


In [65]:
disamb = AdeftDisambiguator(classifier, grounding_dict, names)

In [66]:
disamb.dump(model_name, results_path)

In [67]:
print(disamb.info())

Disambiguation model for CAS

Produces the disambiguations:
	Acoustic Stimulation*	MESH:D000161
	Anticholinergic Syndrome	MESH:D064807
	Aortic Valve Stenosis	MESH:D001024
	BCAR1*	HGNC:971
	CSE1L*	HGNC:2431
	Carotid Stenosis	MESH:D016893
	Caseins*	MESH:D002364
	Constriction, Pathologic	MESH:D003251
	Coronary Artery Disease	MESH:D003324
	Coronary Stenosis	MESH:D023921
	Coronary Vasospasm*	MESH:D003329
	Cycloartenol synthase	UP:P38605
	Hemangiosarcoma	MESH:D006394
	Saliva, Artificial	MESH:D012464
	Surgery, Computer-Assisted	MESH:D025321
	carotid atherosclerosis	EFO:0009783
	carotid_artery_stenting	carotid_artery_stenting
	caspofungin*	CHEBI:CHEBI:474180
	castanospermine	CHEBI:CHEBI:27860
	chrome azurol S*	MESH:C015076
	speech-language disorder-1*	DOID:DOID:0111275

Class level metrics:
--------------------
Grounding                 	Count	F1     
   carotid_artery_stenting	121	0.94106
                Ungrounded	 85	0.77288
        Coronary Vasospasm*	 42	 0.9317
                     BCAR1

In [68]:
model_to_s3(disamb)

In [26]:
from adeft.disambiguate import load_disambiguator

In [27]:
disamb = load_disambiguator("BAL")

In [28]:
disamb

<adeft.disambiguate.AdeftDisambiguator at 0x7f4f001b33a0>

In [29]:
print(_28.info())

Disambiguation model for BAL

Produces the disambiguations:
	Bronchoalveolar Lavage	MESH:D018893
	CEL*	HGNC:1848
	Liver, Artificial	MESH:D019164
	benzaldehyde lyase*	MESH:C059416
	betaine aldehyde*	CHEBI:CHEBI:15710
	dimercaprol*	CHEBI:CHEBI:64198

Class level metrics:
--------------------
Grounding             	Count	F1     
Bronchoalveolar Lavage	1259	 0.9929
                   CEL*	  36	    1.0
     Liver, Artificial	  18	0.83619
            Ungrounded	  17	   0.65
           dimercaprol*	   8	    0.4
    benzaldehyde lyase*	   3	    0.2
      betaine aldehyde*	   2	    0.2

Global Metrics:
-----------------
	F1 score:	0.90773
	Precision:	1.0
	Recall:		0.83293

* Positive labels
See Docstring for explanation



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
