In [1]:
import os
import json
import pickle
import random
from collections import defaultdict, Counter

from indra.literature.adeft_tools import universal_extract_text
from indra.databases.hgnc_client import get_hgnc_name, get_hgnc_id

from adeft.discover import AdeftMiner
from adeft.gui import ground_with_gui
from adeft.modeling.label import AdeftLabeler
from adeft.modeling.classify import AdeftClassifier
from adeft.disambiguate import AdeftDisambiguator, load_disambiguator

from indra_db_lite.api import get_entrez_pmids_for_hgnc
from indra_db_lite.api import get_entrez_pmids_for_uniprot
from indra_db_lite.api import get_plaintexts_for_text_ref_ids
from indra_db_lite.api import get_text_ref_ids_for_agent_text
from indra_db_lite.api import get_text_ref_ids_for_pmids


from adeft_indra.grounding import AdeftGrounder
from adeft_indra.s3 import model_to_s3
from adeft_indra.model_building.escape import escape_filename

In [2]:
def get_text_ref_ids_for_entity(ns, id_):
    if ns == 'HGNC':
        pmids = get_entrez_pmids_for_hgnc(id_)
    elif ns == 'UP':
        pmids = get_entrez_pmids_for_uniprot(id_)
    return list(get_text_ref_ids_for_pmids(pmids).values())

In [6]:
adeft_grounder = AdeftGrounder()

In [7]:
shortforms = ['FAF', 'FAFs']
model_name = ':'.join(sorted(escape_filename(shortform) for shortform in shortforms))
results_path = os.path.abspath(os.path.join('../../', 'results', model_name))

In [8]:
miners = dict()
all_texts = {}
for shortform in shortforms:
    text_ref_ids = get_text_ref_ids_for_agent_text(shortform)
    content = get_plaintexts_for_text_ref_ids(text_ref_ids, contains=shortforms)
    text_dict = content.flatten()
    miners[shortform] = AdeftMiner(shortform)
    miners[shortform].process_texts(text_dict.values())
    all_texts.update(text_dict)

longform_dict = {}
for shortform in shortforms:
    longforms = miners[shortform].get_longforms()
    longforms = [(longform, count, score) for longform, count, score in longforms
                 if count*score > 2]
    longform_dict[shortform] = longforms
    
combined_longforms = Counter()
for longform_rows in longform_dict.values():
    combined_longforms.update({longform: count for longform, count, score
                               in longform_rows})
grounding_map = {}
names = {}
for longform in combined_longforms:
    groundings = adeft_grounder.ground(longform)
    if groundings:
        grounding = groundings[0]['grounding']
        grounding_map[longform] = grounding
        names[grounding] = groundings[0]['name']
longforms, counts = zip(*combined_longforms.most_common())
pos_labels = []

In [9]:
list(zip(longforms, counts))

[('fundus autofluorescence', 61),
 ('fibroblast activating factor', 6),
 ('fantastic four', 4),
 ('familial amyloidosis of finnish type', 4),
 ('finnish type familial amyloidosis', 3)]

In [10]:
grounding_map, names, pos_labels = ground_with_gui(longforms, counts, 
                                                   grounding_map=grounding_map,
                                                   names=names, pos_labels=pos_labels, no_browser=True, port=8890)

In [11]:
result = [grounding_map, names, pos_labels]

In [12]:
result

[{'familial amyloidosis of finnish type': 'MESH:D028226',
  'fantastic four': 'IP:IPR021410',
  'fibroblast activating factor': 'HGNC:3590',
  'finnish type familial amyloidosis': 'MESH:D028226',
  'fundus autofluorescence': 'NCIT:NCIT:C162465'},
 {'MESH:D028226': 'Amyloidosis, Familial',
  'IP:IPR021410': 'The fantastic four family',
  'HGNC:3590': 'FAP',
  'NCIT:NCIT:C162465': 'Fundus Autofluorescence Imaging'},
 []]

In [44]:
grounding_map, names, pos_labels = [{'familial amyloidosis of finnish type': 'MESH:D028226',
  'fantastic four': 'IP:IPR021410',
  'fibroblast activating factor': 'HGNC:3590',
  'finnish type familial amyloidosis': 'MESH:D028226',
  'fundus autofluorescence': 'NCIT:NCIT:C162465'},
 {'MESH:D028226': 'Amyloidosis, Familial',
  'IP:IPR021410': 'The fantastic four family',
  'HGNC:3590': 'FAP',
  'NCIT:NCIT:C162465': 'Fundus Autofluorescence Imaging'},
 ['MESH:D028226']]

In [45]:
excluded_longforms = []

In [46]:
grounding_dict = {shortform: {longform: grounding_map[longform] 
                              for longform, _, _ in longforms if longform in grounding_map
                              and longform not in excluded_longforms}
                  for shortform, longforms in longform_dict.items()}
result = [grounding_dict, names, pos_labels]

if not os.path.exists(results_path):
    os.mkdir(results_path)
with open(os.path.join(results_path, f'{model_name}_preliminary_grounding_info.json'), 'w') as f:
    json.dump(result, f)

In [47]:
additional_entities = {
    'HGNC:12632': ['USP9X', ['FAF', 'fat facets-like', 'USP9X']],
    'HGNC:3587': ['FANCF', ['FAC', 'fanconi anemia', 'FANCF', 'FA complementation']],
    'HGNC:3590': ['FAP', ['FAF', 'fibroblast activating']]
}

In [48]:
unambiguous_agent_texts = {}

In [49]:
labeler = AdeftLabeler(grounding_dict)
corpus = labeler.build_from_texts(
    (text, text_ref_id) for text_ref_id, text in all_texts.items()
)
agent_text_text_ref_id_map = defaultdict(list)
for text, label, id_ in corpus:
    agent_text_text_ref_id_map[label].append(id_)

entity_text_ref_id_map = {
    entity: set(
        get_text_ref_ids_for_entity(*entity.split(':', maxsplit=1))
    )
    for entity in additional_entities
}

In [50]:
intersection1 = []
for entity1, trids1 in entity_text_ref_id_map.items():
    for entity2, trids2 in entity_text_ref_id_map.items():
        intersection1.append((entity1, entity2, len(trids1 & trids2)))

In [51]:
intersection2 = []
for entity1, trids1 in agent_text_text_ref_id_map.items():
    for entity2, pmids2 in entity_text_ref_id_map.items():
        intersection2.append((entity1, entity2, len(set(trids1) & trids2)))

In [52]:
intersection1

[('HGNC:12632', 'HGNC:12632', 229),
 ('HGNC:12632', 'HGNC:3587', 2),
 ('HGNC:12632', 'HGNC:3590', 6),
 ('HGNC:3587', 'HGNC:12632', 2),
 ('HGNC:3587', 'HGNC:3587', 74),
 ('HGNC:3587', 'HGNC:3590', 4),
 ('HGNC:3590', 'HGNC:12632', 6),
 ('HGNC:3590', 'HGNC:3587', 4),
 ('HGNC:3590', 'HGNC:3590', 153)]

In [53]:
intersection2

[('NCIT:NCIT:C162465', 'HGNC:12632', 0),
 ('NCIT:NCIT:C162465', 'HGNC:3587', 0),
 ('NCIT:NCIT:C162465', 'HGNC:3590', 0),
 ('HGNC:3590', 'HGNC:12632', 0),
 ('HGNC:3590', 'HGNC:3587', 0),
 ('HGNC:3590', 'HGNC:3590', 0),
 ('MESH:D028226', 'HGNC:12632', 0),
 ('MESH:D028226', 'HGNC:3587', 0),
 ('MESH:D028226', 'HGNC:3590', 0),
 ('IP:IPR021410', 'HGNC:12632', 0),
 ('IP:IPR021410', 'HGNC:3587', 0),
 ('IP:IPR021410', 'HGNC:3590', 0)]

In [54]:
all_used_trids = set()
for entity, agent_texts in unambiguous_agent_texts.items():
    used_trids = set()
    for agent_text in agent_texts[1]:
        trids = set(get_text_ref_ids_for_agent_text(agent_text))
        new_trids = list(trids - all_texts.keys() - used_trids)
        content = get_plaintexts_for_text_ref_ids(new_trids, contains=agent_texts[1])
        text_dict = content.flatten()
        corpus.extend(
            [
                (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
            ]
        )
        used_trids.update(new_trids)
    all_used_trids.update(used_trids)
        
for entity, trids in entity_text_ref_id_map.items():
    new_trids = list(set(trids) - all_texts.keys() - all_used_trids)
    _, contains = additional_entities[entity]
    content = get_plaintexts_for_text_ref_ids(new_trids, contains=contains)
    text_dict = content.flatten()
    corpus.extend(
        [
            (text, entity, trid) for trid, text in text_dict.items() if len(text) >= 5
        ]
    )

In [55]:
names.update({key: value[0] for key, value in additional_entities.items()})
names.update({key: value[0] for key, value in unambiguous_agent_texts.items()})
pos_labels = list(set(pos_labels) | additional_entities.keys() |
                  unambiguous_agent_texts.keys())

In [56]:
%%capture

classifier = AdeftClassifier(shortforms, pos_labels=pos_labels, random_state=1729)
param_grid = {'C': [100.0], 'max_features': [10000]}
texts, labels, pmids = zip(*corpus)
classifier.cv(texts, labels, param_grid, cv=5, n_jobs=5)

INFO: [2021-10-07 15:49:56] /adeft/Py/adeft/adeft/modeling/classify.py - Beginning grid search in parameter space:
{'C': [100.0], 'max_features': [10000]}
INFO: [2021-10-07 15:49:59] /adeft/Py/adeft/adeft/modeling/classify.py - Best f1 score of 1.0 found for parameter values:
{'logit__C': 100.0, 'tfidf__max_features': 10000}


In [57]:
classifier.stats

{'label_distribution': {'NCIT:NCIT:C162465': 40,
  'HGNC:3590': 5,
  'MESH:D028226': 5,
  'IP:IPR021410': 3,
  'HGNC:12632': 89,
  'HGNC:3587': 41},
 'f1': {'mean': 1.0, 'std': 0.0},
 'precision': {'mean': 1.0, 'std': 0.0},
 'recall': {'mean': 1.0, 'std': 0.0},
 'HGNC:12632': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0.0}},
 'HGNC:3587': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0.0}},
 'HGNC:3590': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0.0}},
 'IP:IPR021410': {'f1': {'mean': 0.6, 'std': 0.489898},
  'pr': {'mean': 0.6, 'std': 0.489898},
  'rc': {'mean': 0.6, 'std': 0.489898}},
 'MESH:D028226': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0.0}},
 'NCIT:NCIT:C162465': {'f1': {'mean': 1.0, 'std': 0.0},
  'pr': {'mean': 1.0, 'std': 0.0},
  'rc': {'mean': 1.0, 'std': 0

In [58]:
disamb = AdeftDisambiguator(classifier, grounding_dict, names)

In [59]:
disamb.dump(model_name, results_path)

In [60]:
print(disamb.info())

Disambiguation model for FAF, and FAFs

Produces the disambiguations:
	Amyloidosis, Familial*	MESH:D028226
	FANCF*	HGNC:3587
	FAP*	HGNC:3590
	Fundus Autofluorescence Imaging	NCIT:NCIT:C162465
	The fantastic four family	IP:IPR021410
	USP9X*	HGNC:12632

Class level metrics:
--------------------
Grounding                      	Count	F1     
                          USP9X*	89	    1.0
                          FANCF*	41	    1.0
Fundus Autofluorescence Imaging	40	    1.0
                            FAP*	 5	    1.0
          Amyloidosis, Familial*	 5	    1.0
      The fantastic four family	 3	    0.6

Global Metrics:
-----------------
	F1 score:	1.0
	Precision:	1.0
	Recall:		1.0

* Positive labels
See Docstring for explanation



In [61]:
model_to_s3(disamb)

In [26]:
from adeft.disambiguate import load_disambiguator

In [27]:
disamb = load_disambiguator("BAL")

In [28]:
disamb

<adeft.disambiguate.AdeftDisambiguator at 0x7f4f001b33a0>

In [29]:
print(_28.info())

Disambiguation model for BAL

Produces the disambiguations:
	Bronchoalveolar Lavage	MESH:D018893
	CEL*	HGNC:1848
	Liver, Artificial	MESH:D019164
	benzaldehyde lyase*	MESH:C059416
	betaine aldehyde*	CHEBI:CHEBI:15710
	dimercaprol*	CHEBI:CHEBI:64198

Class level metrics:
--------------------
Grounding             	Count	F1     
Bronchoalveolar Lavage	1259	 0.9929
                   CEL*	  36	    1.0
     Liver, Artificial	  18	0.83619
            Ungrounded	  17	   0.65
           dimercaprol*	   8	    0.4
    benzaldehyde lyase*	   3	    0.2
      betaine aldehyde*	   2	    0.2

Global Metrics:
-----------------
	F1 score:	0.90773
	Precision:	1.0
	Recall:		0.83293

* Positive labels
See Docstring for explanation



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
