In [1]:
import os
import json
import pickle
import random
from collections import defaultdict, Counter

from indra.literature.adeft_tools import universal_extract_text
from indra.databases.hgnc_client import get_hgnc_name, get_hgnc_id

from adeft.discover import AdeftMiner
from adeft.gui import ground_with_gui
from adeft.modeling.label import AdeftLabeler
from adeft.modeling.classify import AdeftClassifier
from adeft.disambiguate import AdeftDisambiguator, load_disambiguator

from indra_db_lite.api import get_entrez_pmids_for_hgnc
from indra_db_lite.api import get_entrez_pmids_for_uniprot
from indra_db_lite.api import get_plaintexts_for_text_ref_ids
from indra_db_lite.api import get_text_ref_ids_for_agent_text
from indra_db_lite.api import get_text_ref_ids_for_pmids


from adeft_indra.grounding import AdeftGrounder
from adeft_indra.s3 import model_to_s3
from adeft_indra.model_building.escape import escape_filename

In [2]:
def get_text_ref_ids_for_entity(ns, id_):
    if ns == 'HGNC':
        pmids = get_entrez_pmids_for_hgnc(id_)
    elif ns == 'UP':
        pmids = get_entrez_pmids_for_uniprot(id_)
    return list(get_text_ref_ids_for_pmids(pmids).values())

In [3]:
adeft_grounder = AdeftGrounder()

In [4]:
shortforms = ['CAR']
model_name = ':'.join(sorted(escape_filename(shortform) for shortform in shortforms))
results_path = os.path.abspath(os.path.join('../../', 'results', model_name))

In [5]:
miners = dict()
all_texts = {}
for shortform in shortforms:
    text_ref_ids = get_text_ref_ids_for_agent_text(shortform)
    content = get_plaintexts_for_text_ref_ids(text_ref_ids, contains=shortforms)
    text_dict = content.flatten()
    miners[shortform] = AdeftMiner(shortform)
    miners[shortform].process_texts(text_dict.values())
    all_texts.update(text_dict)

longform_dict = {}
for shortform in shortforms:
    longforms = miners[shortform].get_longforms()
    longforms = [(longform, count, score) for longform, count, score in longforms
                 if count*score > 2]
    longform_dict[shortform] = longforms
    
combined_longforms = Counter()
for longform_rows in longform_dict.values():
    combined_longforms.update({longform: count for longform, count, score
                               in longform_rows})
grounding_map = {}
names = {}
for longform in combined_longforms:
    groundings = adeft_grounder.ground(longform)
    if groundings:
        grounding = groundings[0]['grounding']
        grounding_map[longform] = grounding
        names[grounding] = groundings[0]['name']
longforms, counts = zip(*combined_longforms.most_common())
pos_labels = []

In [6]:
list(zip(longforms, counts))

[('chimeric antigen receptor', 1959),
 ('cortisol awakening response', 127),
 ('conditioned avoidance response', 69),
 ('cxcl12 abundant reticular', 35),
 ('carvacrol', 34),
 ('carrageenan', 31),
 ('contractile actomyosin ring', 22),
 ('cancer associated retinopathy', 21),
 ('carnosine', 19),
 ('central african republic', 18),
 ('carvedilol', 17),
 ('c reactive protein albumin ratio', 15),
 ('cumulative attack rate', 14),
 ('carbachol', 14),
 ('carsknkdc', 12),
 ('conditioned avoidance responding', 11),
 ('car 19 il 15', 11),
 ('crp to albumin ratio', 10),
 ('contractile actin ring', 9),
 ('carboxylic acid reductase', 8),
 ('carnitine', 7),
 ('circadian activity rhythm', 7),
 ('chimeric antigen receptor modified', 6),
 ('cholinergic anti inflammatory response', 6),
 ('car19 il 15', 6),
 ('carbazole', 5),
 ('carnosol', 5),
 ('cilia associated respiratory', 5),
 ('cardamonin', 5),
 ('cariprazine', 5),
 ('carotenoid', 5),
 ('cxadr', 5),
 ('cxcl 12 abundant reticular', 5),
 ('car t cells',

In [7]:
grounding_map, names, pos_labels = ground_with_gui(longforms, counts, 
                                                   grounding_map=grounding_map,
                                                   names=names, pos_labels=pos_labels, no_browser=True, port=8890)

In [8]:
result = [grounding_map, names, pos_labels]

In [9]:
result

[{'c reactive protein albumin ratio': 'ungrounded',
  'cab and regulated': 'ungrounded',
  'cancer associated retinopathy': 'MESH:D012164',
  'car 19 il 15': 'MESH:D000076962',
  'car t cells': 'MESH:D000076962',
  'car19 il 15': 'MESH:D000076962',
  'carbachol': 'CHEBI:CHEBI:3385',
  'carbapenem': 'CHEBI:CHEBI:46765',
  'carbazole': 'CHEBI:CHEBI:3391',
  'carboxylic acid reductase': 'ungrounded',
  'cardamonin': 'MESH:C436747',
  'cariprazine': 'CHEBI:CHEBI:90933',
  'carnitine': 'MESH:D002331',
  'carnosine': 'CHEBI:CHEBI:15727',
  'carnosol': 'CHEBI:CHEBI:3429',
  'carotenoid': 'CHEBI:CHEBI:23044',
  'carrageenan': 'CHEBI:CHEBI:3435',
  'carsknkdc': 'ungrounded',
  'carvacrol': 'CHEBI:CHEBI:3440',
  'carvedilol': 'CHEBI:CHEBI:3441',
  'central african republic': 'MESH:D002488',
  'cerebral autoregulation': 'MESH:D006706',
  'chimeric antigen receptor': 'MESH:D000076962',
  'chimeric antigen receptor modified': 'MESH:D000076962',
  'chimeric receptor antibody': 'MESH:D000076962',
  '

In [27]:
grounding_map, names, pos_labels = [{'c reactive protein albumin ratio': 'ungrounded',
  'cab and regulated': 'ungrounded',
  'cancer associated retinopathy': 'MESH:D012164',
  'car 19 il 15': 'MESH:D000076962',
  'car t cells': 'MESH:D000076962',
  'car19 il 15': 'MESH:D000076962',
  'carbachol': 'CHEBI:CHEBI:3385',
  'carbapenem': 'CHEBI:CHEBI:46765',
  'carbazole': 'CHEBI:CHEBI:3391',
  'carboxylic acid reductase': 'ungrounded',
  'cardamonin': 'MESH:C436747',
  'cariprazine': 'CHEBI:CHEBI:90933',
  'carnitine': 'MESH:D002331',
  'carnosine': 'CHEBI:CHEBI:15727',
  'carnosol': 'CHEBI:CHEBI:3429',
  'carotenoid': 'CHEBI:CHEBI:23044',
  'carrageenan': 'CHEBI:CHEBI:3435',
  'carsknkdc': 'ungrounded',
  'carvacrol': 'CHEBI:CHEBI:3440',
  'carvedilol': 'CHEBI:CHEBI:3441',
  'central african republic': 'MESH:D002488',
  'cerebral autoregulation': 'MESH:D006706',
  'chimeric antigen receptor': 'MESH:D000076962',
  'chimeric antigen receptor modified': 'MESH:D000076962',
  'chimeric receptor antibody': 'MESH:D000076962',
  'cholinergic anti inflammatory response': 'ungrounded',
  'cilia associated respiratory': 'ungrounded',
  'circadian activity rhythm': 'GO:GO:0048511',
  'coincidence to accidental ratio': 'ungrounded',
  'conditioned autoregression': 'ungrounded',
  'conditioned avoidance responding': 'ungrounded',
  'conditioned avoidance response': 'ungrounded',
  'constitutive androstane': 'CHEBI:CHEBI:35509',
  'contractile actin ring': 'GO:GO:0005826',
  'contractile actomyosin ring': 'GO:GO:0005826',
  'cortisol awakening response': 'ungrounded',
  'crp alb ratio': 'ungrounded',
  'crp albumin ratio': 'ungrounded',
  'crp to albumin ratio': 'ungrounded',
  'cumulative attack rate': 'ungrounded',
  'cxadr': 'HGNC:2559',
  'cxcl 12 abundant reticular': 'NCIT:C114786',
  'cxcl12 abundant reticular': 'NCIT:C114786',
  'cytoplasmic accumulation region': 'ungrounded'},
 {'MESH:D012164': 'Retinal Diseases',
  'MESH:D000076962': 'Receptors, Chimeric Antigen',
  'CHEBI:CHEBI:3385': 'carbachol',
  'CHEBI:CHEBI:46765': 'carbapenem',
  'CHEBI:CHEBI:3391': 'carbazole',
  'MESH:C436747': 'cardamonin',
  'CHEBI:CHEBI:90933': 'cariprazine',
  'MESH:D002331': 'Carnitine',
  'CHEBI:CHEBI:15727': 'carnosine',
  'CHEBI:CHEBI:3429': 'Carnosol',
  'CHEBI:CHEBI:23044': 'carotenoid',
  'CHEBI:CHEBI:3435': 'carrageenan',
  'CHEBI:CHEBI:3440': 'carvacrol',
  'CHEBI:CHEBI:3441': 'carvedilol',
  'MESH:D002488': 'Central African Republic',
  'MESH:D006706': 'Homeostasis',
  'GO:GO:0048511': 'rhythmic process',
  'CHEBI:CHEBI:35509': 'androstane',
  'GO:GO:0005826': 'actomyosin contractile ring',
  'HGNC:2559': 'CXADR',
  'NCIT:C114786': 'CXCL12-Abundant Reticular Cell'},
 ['CHEBI:CHEBI:3435',
  'GO:GO:0005826',
  'MESH:D000076962',
  'NCIT:C114786']]

In [28]:
excluded_longforms = []

In [29]:
grounding_dict = {shortform: {longform: grounding_map[longform] 
                              for longform, _, _ in longforms if longform in grounding_map
                              and longform not in excluded_longforms}
                  for shortform, longforms in longform_dict.items()}
result = [grounding_dict, names, pos_labels]

if not os.path.exists(results_path):
    os.mkdir(results_path)
with open(os.path.join(results_path, f'{model_name}_preliminary_grounding_info.json'), 'w') as f:
    json.dump(result, f)

In [30]:
additional_entities = {}

In [31]:
unambiguous_agent_texts = {}

In [32]:
labeler = AdeftLabeler(grounding_dict)
corpus = labeler.build_from_texts(
    (text, text_ref_id) for text_ref_id, text in all_texts.items()
)
agent_text_text_ref_id_map = defaultdict(list)
for text, label, id_ in corpus:
    agent_text_text_ref_id_map[label].append(id_)

entity_text_ref_id_map = {
    entity: set(
        get_text_ref_ids_for_entity(*entity.split(':', maxsplit=1))
    )
    for entity in additional_entities
}

In [33]:
intersection1 = []
for entity1, trids1 in entity_text_ref_id_map.items():
    for entity2, trids2 in entity_text_ref_id_map.items():
        intersection1.append((entity1, entity2, len(trids1 & trids2)))

In [34]:
intersection2 = []
for entity1, trids1 in agent_text_text_ref_id_map.items():
    for entity2, pmids2 in entity_text_ref_id_map.items():
        intersection2.append((entity1, entity2, len(set(trids1) & trids2)))

In [35]:
intersection1

[]

In [36]:
intersection2

[]

In [37]:
all_used_trids = set()
for entity, agent_texts in unambiguous_agent_texts.items():
    used_trids = set()
    for agent_text in agent_texts[1]:
        trids = set(get_text_ref_ids_for_agent_text(agent_text))
        new_trids = list(trids - all_texts.keys() - used_trids)
        content = get_plaintexts_for_text_ref_ids(new_trids, contains=agent_texts[1])
        text_dict = content.flatten()
        corpus.extend(
            [
                (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
            ]
        )
        used_trids.update(new_trids)
    all_used_trids.update(used_trids)
        
for entity, trids in entity_text_ref_id_map.items():
    new_trids = list(set(trids) - all_texts.keys() - all_used_trids)
    _, contains = additional_entities[entity]
    content = get_plaintexts_for_text_ref_ids(new_trids, contains=contains)
    text_dict = content.flatten()
    corpus.extend(
        [
            (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
        ]
    )

In [38]:
names.update({key: value[0] for key, value in additional_entities.items()})
names.update({key: value[0] for key, value in unambiguous_agent_texts.items()})
pos_labels = list(set(pos_labels) | additional_entities.keys() |
                  unambiguous_agent_texts.keys())

In [39]:
%%capture

classifier = AdeftClassifier(shortforms, pos_labels=pos_labels, random_state=1729)
param_grid = {'C': [100.0], 'max_features': [10000]}
texts, labels, pmids = zip(*corpus)
classifier.cv(texts, labels, param_grid, cv=5, n_jobs=5)

INFO: [2021-10-05 17:36:34] /adeft/Py/adeft/adeft/modeling/classify.py - Beginning grid search in parameter space:
{'C': [100.0], 'max_features': [10000]}
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))


INFO: [2021-10-05 17:37:45] /adeft/Py/adeft/adeft/modeling/classify.py - Best f1 score of 0.9790689704682312 found for parameter values:
{'logit__C': 100.0, 'tfidf__max_features': 10000}


In [40]:
classifier.stats

{'label_distribution': {'MESH:D000076962': 1274,
  'CHEBI:CHEBI:3429': 4,
  'ungrounded': 183,
  'CHEBI:CHEBI:3435': 27,
  'CHEBI:CHEBI:3391': 4,
  'MESH:D012164': 14,
  'GO:GO:0005826': 21,
  'NCIT:C114786': 29,
  'GO:GO:0048511': 3,
  'CHEBI:CHEBI:23044': 5,
  'CHEBI:CHEBI:3440': 20,
  'CHEBI:CHEBI:3385': 10,
  'MESH:D002488': 6,
  'CHEBI:CHEBI:15727': 17,
  'CHEBI:CHEBI:3441': 12,
  'CHEBI:CHEBI:46765': 2,
  'HGNC:2559': 2,
  'MESH:D002331': 6,
  'MESH:D006706': 2,
  'CHEBI:CHEBI:35509': 3,
  'CHEBI:CHEBI:90933': 2,
  'MESH:C436747': 4},
 'f1': {'mean': 0.979069, 'std': 0.006787},
 'precision': {'mean': 0.972289, 'std': 0.007443},
 'recall': {'mean': 0.985953, 'std': 0.006744},
 'CHEBI:CHEBI:15727': {'f1': {'mean': 0.396667, 'std': 0.143139},
  'pr': {'mean': 0.366667, 'std': 0.15456},
  'rc': {'mean': 0.516667, 'std': 0.280872}},
 'CHEBI:CHEBI:23044': {'f1': {'mean': 0.2, 'std': 0.4},
  'pr': {'mean': 0.2, 'std': 0.4},
  'rc': {'mean': 0.2, 'std': 0.4}},
 'CHEBI:CHEBI:3385': {'f1':

In [41]:
disamb = AdeftDisambiguator(classifier, grounding_dict, names)

In [42]:
disamb.dump(model_name, results_path)

In [43]:
print(disamb.info())

Disambiguation model for CAR

Produces the disambiguations:
	CXADR	HGNC:2559
	CXCL12-Abundant Reticular Cell*	NCIT:C114786
	Carnitine	MESH:D002331
	Carnosol	CHEBI:CHEBI:3429
	Central African Republic	MESH:D002488
	Homeostasis	MESH:D006706
	Receptors, Chimeric Antigen*	MESH:D000076962
	Retinal Diseases	MESH:D012164
	actomyosin contractile ring*	GO:GO:0005826
	androstane	CHEBI:CHEBI:35509
	carbachol	CHEBI:CHEBI:3385
	carbapenem	CHEBI:CHEBI:46765
	carbazole	CHEBI:CHEBI:3391
	cardamonin	MESH:C436747
	cariprazine	CHEBI:CHEBI:90933
	carnosine	CHEBI:CHEBI:15727
	carotenoid	CHEBI:CHEBI:23044
	carrageenan*	CHEBI:CHEBI:3435
	carvacrol	CHEBI:CHEBI:3440
	carvedilol	CHEBI:CHEBI:3441
	rhythmic process	GO:GO:0048511

Class level metrics:
--------------------
Grounding                     	Count	F1     
   Receptors, Chimeric Antigen*	1274	 0.9876
                    Ungrounded	 183	0.89526
CXCL12-Abundant Reticular Cell*	  29	0.86444
                   carrageenan*	  27	0.72308
   actomyosin contract

In [44]:
model_to_s3(disamb)

In [26]:
from adeft.disambiguate import load_disambiguator

In [27]:
disamb = load_disambiguator("BAL")

In [28]:
disamb

<adeft.disambiguate.AdeftDisambiguator at 0x7f4f001b33a0>

In [29]:
print(_28.info())

Disambiguation model for BAL

Produces the disambiguations:
	Bronchoalveolar Lavage	MESH:D018893
	CEL*	HGNC:1848
	Liver, Artificial	MESH:D019164
	benzaldehyde lyase*	MESH:C059416
	betaine aldehyde*	CHEBI:CHEBI:15710
	dimercaprol*	CHEBI:CHEBI:64198

Class level metrics:
--------------------
Grounding             	Count	F1     
Bronchoalveolar Lavage	1259	 0.9929
                   CEL*	  36	    1.0
     Liver, Artificial	  18	0.83619
            Ungrounded	  17	   0.65
           dimercaprol*	   8	    0.4
    benzaldehyde lyase*	   3	    0.2
      betaine aldehyde*	   2	    0.2

Global Metrics:
-----------------
	F1 score:	0.90773
	Precision:	1.0
	Recall:		0.83293

* Positive labels
See Docstring for explanation



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
