In [1]:
import os
import json
import pickle
import random
from collections import defaultdict, Counter

from indra.literature.adeft_tools import universal_extract_text
from indra.databases.hgnc_client import get_hgnc_name, get_hgnc_id

from adeft.discover import AdeftMiner
from adeft.gui import ground_with_gui
from adeft.modeling.label import AdeftLabeler
from adeft.modeling.classify import AdeftClassifier
from adeft.disambiguate import AdeftDisambiguator, load_disambiguator

from indra_db_lite.api import get_entrez_pmids_for_hgnc
from indra_db_lite.api import get_entrez_pmids_for_uniprot
from indra_db_lite.api import get_plaintexts_for_text_ref_ids
from indra_db_lite.api import get_text_ref_ids_for_agent_text
from indra_db_lite.api import get_text_ref_ids_for_pmids


from adeft_indra.grounding import AdeftGrounder
from adeft_indra.s3 import model_to_s3
from adeft_indra.model_building.escape import escape_filename

In [2]:
def get_text_ref_ids_for_entity(ns, id_):
    if ns == 'HGNC':
        pmids = get_entrez_pmids_for_hgnc(id_)
    elif ns == 'UP':
        pmids = get_entrez_pmids_for_uniprot(id_)
    return list(get_text_ref_ids_for_pmids(pmids).values())

In [3]:
adeft_grounder = AdeftGrounder()

In [4]:
shortforms = ['MST']
model_name = ':'.join(sorted(escape_filename(shortform) for shortform in shortforms))
results_path = os.path.abspath(os.path.join('../../', 'results', model_name))

In [5]:
miners = dict()
all_texts = {}
for shortform in shortforms:
    text_ref_ids = get_text_ref_ids_for_agent_text(shortform)
    content = get_plaintexts_for_text_ref_ids(text_ref_ids, contains=shortforms)
    text_dict = content.flatten()
    miners[shortform] = AdeftMiner(shortform)
    miners[shortform].process_texts(text_dict.values())
    all_texts.update(text_dict)

longform_dict = {}
for shortform in shortforms:
    longforms = miners[shortform].get_longforms()
    longforms = [(longform, count, score) for longform, count, score in longforms
                 if count*score > 2]
    longform_dict[shortform] = longforms
    
combined_longforms = Counter()
for longform_rows in longform_dict.values():
    combined_longforms.update({longform: count for longform, count, score
                               in longform_rows})
grounding_map = {}
names = {}
for longform in combined_longforms:
    groundings = adeft_grounder.ground(longform)
    if groundings:
        grounding = groundings[0]['grounding']
        grounding_map[longform] = grounding
        names[grounding] = groundings[0]['name']
longforms, counts = zip(*combined_longforms.most_common())
pos_labels = []

In [6]:
list(zip(longforms, counts))

[('median survival time', 141),
 ('mean survival time', 58),
 ('microscale thermophoresis', 30),
 ('multisystemic therapy', 20),
 ('magnetic seizure therapy', 18),
 ('maximal strength training', 17),
 ('mercaptopyruvate sulfurtransferase', 15),
 ('minimum spanning tree', 14),
 ('microbial source tracking', 10),
 ('mean skin temperature', 10),
 ('mnemonic strategy training', 10),
 ('military sexual trauma', 9),
 ('motor sequence task', 6),
 ('multi scale test', 6),
 ('monosodium titanate', 6),
 ('masitinib', 6),
 ('mammalian ste20 like kinase', 6),
 ('medial superior temporal', 5),
 ('mitotic somal translocation', 5),
 ('mental stress test', 5),
 ('mammalian ste20 like', 5),
 ('mendez santiago teja', 4),
 ('molecular subtype', 4),
 ('music supported therapy', 4),
 ('montenegro s skin test', 4),
 ('multiple subpial transection', 3),
 ('muscle sparing thoracotomy', 3),
 ('mediastinal stromal tissue', 3),
 ('mitochondrial stress test', 3),
 ('montenegro skin test', 3)]

In [7]:
grounding_map, names, pos_labels = ground_with_gui(longforms, counts, 
                                                   grounding_map=grounding_map,
                                                   names=names, pos_labels=pos_labels, no_browser=True, port=8890)

In [8]:
result = [grounding_map, names, pos_labels]

In [9]:
result

[{'magnetic seizure therapy': 'OMIT:OMIT:0026527',
  'mammalian ste20 like': 'HGNC:11408',
  'mammalian ste20 like kinase': 'HGNC:11408',
  'masitinib': 'CHEBI:CHEBI:63450',
  'maximal strength training': 'MESH:D055070',
  'mean skin temperature': 'MESH:D012881',
  'mean survival time': 'NCIT:NCIT:C94477',
  'medial superior temporal': 'ungrounded',
  'median survival time': 'NCIT:NCIT:C94594',
  'mediastinal stromal tissue': 'ungrounded',
  'mendez santiago teja': 'ungrounded',
  'mental stress test': 'ungrounded',
  'mercaptopyruvate sulfurtransferase': 'HGNC:7223',
  'microbial source tracking': 'ungrounded',
  'microscale thermophoresis': 'ECO:ECO:0006261',
  'military sexual trauma': 'MESH:D000082002',
  'minimum spanning tree': 'ungrounded',
  'mitochondrial stress test': 'ungrounded',
  'mitotic somal translocation': 'GO:GO:0021802',
  'mnemonic strategy training': 'ungrounded',
  'molecular subtype': 'ungrounded',
  'monosodium titanate': 'PUBCHEM:73555913',
  'montenegro s ski

In [10]:
grounding_map, names, pos_labels = [{'magnetic seizure therapy': 'OMIT:OMIT:0026527',
  'mammalian ste20 like': 'HGNC:11408',
  'mammalian ste20 like kinase': 'HGNC:11408',
  'masitinib': 'CHEBI:CHEBI:63450',
  'maximal strength training': 'MESH:D055070',
  'mean skin temperature': 'MESH:D012881',
  'mean survival time': 'NCIT:NCIT:C94477',
  'medial superior temporal': 'ungrounded',
  'median survival time': 'NCIT:NCIT:C94594',
  'mediastinal stromal tissue': 'ungrounded',
  'mendez santiago teja': 'ungrounded',
  'mental stress test': 'ungrounded',
  'mercaptopyruvate sulfurtransferase': 'HGNC:7223',
  'microbial source tracking': 'ungrounded',
  'microscale thermophoresis': 'ECO:ECO:0006261',
  'military sexual trauma': 'MESH:D000082002',
  'minimum spanning tree': 'ungrounded',
  'mitochondrial stress test': 'ungrounded',
  'mitotic somal translocation': 'GO:GO:0021802',
  'mnemonic strategy training': 'ungrounded',
  'molecular subtype': 'ungrounded',
  'monosodium titanate': 'PUBCHEM:73555913',
  'montenegro s skin test': 'ungrounded',
  'montenegro skin test': 'ungrounded',
  'motor sequence task': 'ungrounded',
  'multi scale test': 'ungrounded',
  'multiple subpial transection': 'ungrounded',
  'multisystemic therapy': 'multisystemic_therapy',
  'muscle sparing thoracotomy': 'MESH:D013908',
  'music supported therapy': 'ungrounded'},
 {'OMIT:OMIT:0026527': 'Magnetic Field Therapy',
  'HGNC:11408': 'STK4',
  'CHEBI:CHEBI:63450': 'masitinib',
  'MESH:D055070': 'Resistance Training',
  'MESH:D012881': 'Skin Temperature',
  'NCIT:NCIT:C94477': 'Mean Survival Time',
  'NCIT:NCIT:C94594': 'Median Survival Time',
  'HGNC:7223': 'MPST',
  'ECO:ECO:0006261': 'microscale thermophoresis evidence',
  'MESH:D000082002': 'Sexual Trauma',
  'GO:GO:0021802': 'somal translocation',
  'PUBCHEM:73555913': 'Titanate (Ti2(OH)O41-), sodium (1:1)',
  'multisystemic_therapy': 'multisystemic_therapy',
  'MESH:D013908': 'Thoracotomy'},
 ['HGNC:11408', 'HGNC:7223', 'MESH:D055070', 'OMIT:OMIT:0026527']]

In [11]:
excluded_longforms = []

In [12]:
grounding_dict = {shortform: {longform: grounding_map[longform] 
                              for longform, _, _ in longforms if longform in grounding_map
                              and longform not in excluded_longforms}
                  for shortform, longforms in longform_dict.items()}
result = [grounding_dict, names, pos_labels]

if not os.path.exists(results_path):
    os.mkdir(results_path)
with open(os.path.join(results_path, f'{model_name}_preliminary_grounding_info.json'), 'w') as f:
    json.dump(result, f)

In [16]:
additional_entities = {
    'HGNC:29678': ['MSTO1', ['MST', 'MSTO1']],
    'HGNC:6849': ['MAP3K10', ['MST', 'MAP3K10']],
    'HGNC:7223': ['MPST', ['MST', 'MPST']],
}

In [17]:
unambiguous_agent_texts = {}

In [18]:
labeler = AdeftLabeler(grounding_dict)
corpus = labeler.build_from_texts(
    (text, text_ref_id) for text_ref_id, text in all_texts.items()
)
agent_text_text_ref_id_map = defaultdict(list)
for text, label, id_ in corpus:
    agent_text_text_ref_id_map[label].append(id_)

entity_text_ref_id_map = {
    entity: set(
        get_text_ref_ids_for_entity(*entity.split(':', maxsplit=1))
    )
    for entity in additional_entities
}

In [19]:
intersection1 = []
for entity1, trids1 in entity_text_ref_id_map.items():
    for entity2, trids2 in entity_text_ref_id_map.items():
        intersection1.append((entity1, entity2, len(trids1 & trids2)))

In [20]:
intersection2 = []
for entity1, trids1 in agent_text_text_ref_id_map.items():
    for entity2, pmids2 in entity_text_ref_id_map.items():
        intersection2.append((entity1, entity2, len(set(trids1) & trids2)))

In [21]:
intersection1

[('HGNC:29678', 'HGNC:29678', 36),
 ('HGNC:29678', 'HGNC:6849', 4),
 ('HGNC:29678', 'HGNC:7223', 4),
 ('HGNC:6849', 'HGNC:29678', 4),
 ('HGNC:6849', 'HGNC:6849', 42),
 ('HGNC:6849', 'HGNC:7223', 2),
 ('HGNC:7223', 'HGNC:29678', 4),
 ('HGNC:7223', 'HGNC:6849', 2),
 ('HGNC:7223', 'HGNC:7223', 33)]

In [22]:
intersection2

[('NCIT:NCIT:C94594', 'HGNC:29678', 0),
 ('NCIT:NCIT:C94594', 'HGNC:6849', 0),
 ('NCIT:NCIT:C94594', 'HGNC:7223', 0),
 ('ungrounded', 'HGNC:29678', 0),
 ('ungrounded', 'HGNC:6849', 0),
 ('ungrounded', 'HGNC:7223', 0),
 ('PUBCHEM:73555913', 'HGNC:29678', 0),
 ('PUBCHEM:73555913', 'HGNC:6849', 0),
 ('PUBCHEM:73555913', 'HGNC:7223', 0),
 ('OMIT:OMIT:0026527', 'HGNC:29678', 0),
 ('OMIT:OMIT:0026527', 'HGNC:6849', 0),
 ('OMIT:OMIT:0026527', 'HGNC:7223', 0),
 ('NCIT:NCIT:C94477', 'HGNC:29678', 0),
 ('NCIT:NCIT:C94477', 'HGNC:6849', 0),
 ('NCIT:NCIT:C94477', 'HGNC:7223', 0),
 ('HGNC:11408', 'HGNC:29678', 0),
 ('HGNC:11408', 'HGNC:6849', 0),
 ('HGNC:11408', 'HGNC:7223', 0),
 ('multisystemic_therapy', 'HGNC:29678', 0),
 ('multisystemic_therapy', 'HGNC:6849', 0),
 ('multisystemic_therapy', 'HGNC:7223', 0),
 ('MESH:D000082002', 'HGNC:29678', 0),
 ('MESH:D000082002', 'HGNC:6849', 0),
 ('MESH:D000082002', 'HGNC:7223', 0),
 ('MESH:D012881', 'HGNC:29678', 0),
 ('MESH:D012881', 'HGNC:6849', 0),
 ('MES

In [23]:
all_used_trids = set()
for entity, agent_texts in unambiguous_agent_texts.items():
    used_trids = set()
    for agent_text in agent_texts[1]:
        trids = set(get_text_ref_ids_for_agent_text(agent_text))
        new_trids = list(trids - all_texts.keys() - used_trids)
        content = get_plaintexts_for_text_ref_ids(new_trids, contains=agent_texts[1])
        text_dict = content.flatten()
        corpus.extend(
            [
                (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
            ]
        )
        used_trids.update(new_trids)
    all_used_trids.update(used_trids)
        
for entity, trids in entity_text_ref_id_map.items():
    new_trids = list(set(trids) - all_texts.keys() - all_used_trids)
    _, contains = additional_entities[entity]
    content = get_plaintexts_for_text_ref_ids(new_trids, contains=contains)
    text_dict = content.flatten()
    corpus.extend(
        [
            (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
        ]
    )

In [24]:
names.update({key: value[0] for key, value in additional_entities.items()})
names.update({key: value[0] for key, value in unambiguous_agent_texts.items()})
pos_labels = list(set(pos_labels) | additional_entities.keys() |
                  unambiguous_agent_texts.keys())

In [25]:
%%capture

classifier = AdeftClassifier(shortforms, pos_labels=pos_labels, random_state=1729)
param_grid = {'C': [100.0], 'max_features': [10000]}
texts, labels, pmids = zip(*corpus)
classifier.cv(texts, labels, param_grid, cv=5, n_jobs=5)

INFO: [2021-10-08 16:29:18] /adeft/Py/adeft/adeft/modeling/classify.py - Beginning grid search in parameter space:
{'C': [100.0], 'max_features': [10000]}
INFO: [2021-10-08 16:29:26] /adeft/Py/adeft/adeft/modeling/classify.py - Best f1 score of 0.8534589593413123 found for parameter values:
{'logit__C': 100.0, 'tfidf__max_features': 10000}


In [26]:
classifier.stats

{'label_distribution': {'NCIT:NCIT:C94594': 119,
  'ungrounded': 43,
  'PUBCHEM:73555913': 4,
  'OMIT:OMIT:0026527': 11,
  'NCIT:NCIT:C94477': 49,
  'HGNC:11408': 9,
  'multisystemic_therapy': 10,
  'MESH:D000082002': 6,
  'MESH:D012881': 3,
  'ECO:ECO:0006261': 23,
  'HGNC:7223': 18,
  'GO:GO:0021802': 3,
  'MESH:D055070': 6,
  'MESH:D013908': 3,
  'CHEBI:CHEBI:63450': 1,
  'HGNC:29678': 4,
  'HGNC:6849': 6},
 'f1': {'mean': 0.853459, 'std': 0.16059},
 'precision': {'mean': 0.922857, 'std': 0.111245},
 'recall': {'mean': 0.800505, 'std': 0.190599},
 'CHEBI:CHEBI:63450': {'f1': {'mean': 0.0, 'std': 0.0},
  'pr': {'mean': 0.0, 'std': 0.0},
  'rc': {'mean': 0.0, 'std': 0.0}},
 'ECO:ECO:0006261': {'f1': {'mean': 0.926984, 'std': 0.060734},
  'pr': {'mean': 0.91, 'std': 0.111355},
  'rc': {'mean': 0.96, 'std': 0.08}},
 'GO:GO:0021802': {'f1': {'mean': 0.4, 'std': 0.489898},
  'pr': {'mean': 0.4, 'std': 0.489898},
  'rc': {'mean': 0.4, 'std': 0.489898}},
 'HGNC:11408': {'f1': {'mean': 0.933

In [32]:
disamb = AdeftDisambiguator(classifier, grounding_dict, names)

In [33]:
disamb.dump(model_name, results_path)

In [34]:
print(disamb.info())

Disambiguation model for MST

Produces the disambiguations:
	MAP3K10*	HGNC:6849
	MPST*	HGNC:7223
	MSTO1*	HGNC:29678
	Magnetic Field Therapy*	OMIT:OMIT:0026527
	Mean Survival Time	NCIT:NCIT:C94477
	Median Survival Time	NCIT:NCIT:C94594
	Resistance Training*	MESH:D055070
	STK4*	HGNC:11408
	Sexual Trauma	MESH:D000082002
	Skin Temperature	MESH:D012881
	Thoracotomy	MESH:D013908
	Titanate (Ti2(OH)O41-), sodium (1:1)	PUBCHEM:73555913
	masitinib	CHEBI:CHEBI:63450
	microscale thermophoresis evidence	ECO:ECO:0006261
	multisystemic_therapy	multisystemic_therapy
	somal translocation	GO:GO:0021802

Class level metrics:
--------------------
Grounding                           	Count	F1     
                Median Survival Time	119	0.79233
                  Mean Survival Time	 49	0.40412
                          Ungrounded	 43	0.80104
  microscale thermophoresis evidence	 23	0.92698
                                MPST*	 18	0.85143
              Magnetic Field Therapy*	 11	   0.76
               mul

In [35]:
model_to_s3(disamb)

In [26]:
from adeft.disambiguate import load_disambiguator

In [27]:
disamb = load_disambiguator("BAL")

In [28]:
disamb

<adeft.disambiguate.AdeftDisambiguator at 0x7f4f001b33a0>

In [29]:
print(_28.info())

Disambiguation model for BAL

Produces the disambiguations:
	Bronchoalveolar Lavage	MESH:D018893
	CEL*	HGNC:1848
	Liver, Artificial	MESH:D019164
	benzaldehyde lyase*	MESH:C059416
	betaine aldehyde*	CHEBI:CHEBI:15710
	dimercaprol*	CHEBI:CHEBI:64198

Class level metrics:
--------------------
Grounding             	Count	F1     
Bronchoalveolar Lavage	1259	 0.9929
                   CEL*	  36	    1.0
     Liver, Artificial	  18	0.83619
            Ungrounded	  17	   0.65
           dimercaprol*	   8	    0.4
    benzaldehyde lyase*	   3	    0.2
      betaine aldehyde*	   2	    0.2

Global Metrics:
-----------------
	F1 score:	0.90773
	Precision:	1.0
	Recall:		0.83293

* Positive labels
See Docstring for explanation



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [31]:
classifier.confusion_info

{'CHEBI:CHEBI:63450': {'CHEBI:CHEBI:63450': [0, 0, 0, 0, 0],
  'ECO:ECO:0006261': [0, 0, 0, 0, 0],
  'GO:GO:0021802': [0, 0, 0, 0, 0],
  'HGNC:11408': [0, 0, 0, 0, 0],
  'HGNC:29678': [0, 0, 0, 0, 0],
  'HGNC:6849': [0, 0, 0, 0, 0],
  'HGNC:7223': [0, 0, 0, 0, 0],
  'MESH:D000082002': [0, 0, 0, 0, 0],
  'MESH:D012881': [0, 0, 0, 0, 0],
  'MESH:D013908': [0, 0, 0, 0, 0],
  'MESH:D055070': [0, 0, 0, 0, 0],
  'NCIT:NCIT:C94477': [0, 0, 0, 0, 0],
  'NCIT:NCIT:C94594': [0, 0, 0, 0, 0],
  'OMIT:OMIT:0026527': [0, 0, 0, 0, 0],
  'PUBCHEM:73555913': [0, 0, 0, 0, 0],
  'multisystemic_therapy': [0, 0, 0, 0, 0],
  'ungrounded': [0, 0, 1, 0, 0]},
 'ECO:ECO:0006261': {'CHEBI:CHEBI:63450': [0, 0, 0, 0, 0],
  'ECO:ECO:0006261': [4, 5, 4, 3, 5],
  'GO:GO:0021802': [0, 0, 0, 0, 0],
  'HGNC:11408': [0, 0, 0, 0, 0],
  'HGNC:29678': [0, 0, 0, 0, 0],
  'HGNC:6849': [0, 0, 0, 0, 0],
  'HGNC:7223': [0, 0, 0, 0, 0],
  'MESH:D000082002': [0, 0, 0, 0, 0],
  'MESH:D012881': [0, 0, 0, 0, 0],
  'MESH:D013908': [0,