# Phenotype/SNP relation extraction from tables

In [38]:
%load_ext autoreload
%autoreload 2

import sys
import cPickle

# import snorkel and gwasdb
sys.path.append('../snorkel')
sys.path.append('../src')
sys.path.append('../src/crawler')

# set up paths
abstract_dir = '../data/db/papers'

# set up matplotlib
import numpy as np
import matplotlib
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (12,4)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load corpus

In [20]:
from snorkel.parser import XMLDocParser
from extractor.parser import UnicodeXMLTableDocParser

xml_parser = UnicodeXMLTableDocParser(
    path=abstract_dir,
    doc='./*',
    text='.//table',
    id='.//article-id[@pub-id-type="pmid"]/text()',
    keep_xml_tree=True)

In [21]:
from snorkel.parser import HTMLParser
from extractor.parser import UnicodeTableParser
from snorkel.parser import CorpusParser
import cPickle

table_parser = UnicodeTableParser()
html_parser = HTMLParser(path='../data/db/papers/')

corpus_name = 'gwas-table-corpus.pkl'

try:
    with open(corpus_name,"r") as pkl:
        corpus = cPickle.load(pkl)
except:
    cp = CorpusParser(xml_parser, table_parser, max_docs=15)
    %time corpus = cp.parse_corpus(name='GWAS Corpus')
    # pickling currently doesn't work...
#     with open(corpus_name,"w") as pkl:
#         corpus = cPickle.dump(corpus, pkl)

CPU times: user 56.3 s, sys: 3.34 s, total: 59.6 s
Wall time: 1min 35s


## Candidate extraction

### RSid Extraction

In [22]:
from snorkel.matchers import DictionaryMatch, RegexMatchSpan, Union
from snorkel.candidates import EntityExtractor
from snorkel.candidates import TableNgrams

from db.kb import KnowledgeBase

# Define a candidate space
ngrams = TableNgrams(n_max=1)

# Get a list of all the RSids we know
kb = KnowledgeBase()
rs_ids = kb.get_rsid_candidates()

# Define matchers
dict_rsid_matcher = DictionaryMatch(d=rs_ids, longest_match_only=False)
regx_rsid_matcher = RegexMatchSpan(rgx=r'rs\d+')
rsid_matcher = Union(dict_rsid_matcher, regx_rsid_matcher)

rsid_extractor = EntityExtractor(ngrams, rsid_matcher)
# %time rs_candidates = rsid_extractor.extract(corpus.get_tables(), name='all')

In [23]:
# for cand in rs_candidates[:10]: 
#     print cand
# print "%s candidates extracted" % len(rs_candidates)
# print rs_candidates[0].context
# print rs_candidates[0].context.cell

NameError: name 'rs_candidates' is not defined

#### Statistics

In [None]:
from extractor.util import gold_rsid_stats, gold_rsid_precision

gold_set = frozenset( [ (doc.name, rs_id) for doc in corpus.documents for rs_id in kb.rsids_by_pmid(int(doc.name)) ] )
gold_set_rsids = [rs_id for doc_id, rs_id in gold_set]

gold_rsid_stats(rs_candidates, gold_set)

Interesting: some SNPs seem to be never mentioned (e.g. rs12122100) while others (rs727153) appear only in the text.
Sometimes, it's not picked up for a different, strange reason: see rs13314993.

In [None]:
cells = rs_candidates[0].context.cell.aligned_cells('row')
[cell.text for cell in cells]

### Phenotypes

In [24]:
from snorkel.matchers import DictionaryMatch, Union, CellNameMatcher, CellDictNameMatcher
from snorkel.candidates import EntityExtractor
from snorkel.candidates import TableNgrams, CellSpace

# Define a candidate space
ngrams = TableNgrams(n_max=9)
cells = CellSpace()

# Create a list of possible words that could denote phenotypes
phen_words = ['trait', 'phenotype']

# Define matchers
# dict_row_matcher = DictionaryMatch(d=phen_words, longest_match_only=False, stemmer='porter')
# cell_row_matcher = CellNameMatcher(row_matcher=dict_row_matcher, cand_space=ngrams)
# dict_col_matcher = DictionaryMatch(d=phen_words, longest_match_only=False, stemmer='porter')
# cell_col_matcher = CellNameMatcher(col_matcher=dict_col_matcher, cand_space=ngrams)
# phen_matcher = Union(cell_row_matcher, cell_col_matcher)
# phen_matcher = CellNameMatcher(col_matcher=dict_col_matcher, cand_space=ngrams)
phen_matcher = CellDictNameMatcher(axis='col', d=phen_words, n_max=3, ignore_case=True)

phen_extractor = EntityExtractor(cells, phen_matcher)
%time phen_candidates = phen_extractor.extract(corpus.get_tables(), name='all')

CPU times: user 22 s, sys: 295 ms, total: 22.3 s
Wall time: 22.8 s


In [25]:
print "%s candidates extracted" % len(phen_candidates)
for cand in phen_candidates[0:10]: 
    print cand.context.document.name, cand.context.table, cand.context.cell
    print unicode(cand)
#     print [span for span in cand.row_ngrams()]
#     print [span for span in cand.col_ngrams()]
#     print
print
print phen_candidates[0].context
print phen_candidates[0].context.document.name, phen_candidates[0].context.table
print phen_candidates[0].context.cell

1851 candidates extracted
17903292 Table('17903292', 0) Cell('17903292', 0, 10, u'Serum Creatinine')
Span("Serum Creatinine", context=None, chars=[0,15], words=[0,1])
17903292 Table('17903292', 0) Cell('17903292', 0, 15, u'Change in serum creatinine')
Span("Change in serum creatinine", context=None, chars=[0,25], words=[0,3])
17903292 Table('17903292', 0) Cell('17903292', 0, 20, u'Glomerular Filtration Rate (GFR)')
Span("Glomerular Filtration Rate (GFR)", context=None, chars=[0,35], words=[0,5])
17903292 Table('17903292', 0) Cell('17903292', 0, 25, u'Chronic Kidney Disease')
Span("Chronic Kidney Disease", context=None, chars=[0,21], words=[0,2])
17903292 Table('17903292', 0) Cell('17903292', 0, 30, u'Cystatin C')
Span("Cystatin C", context=None, chars=[0,9], words=[0,1])
17903292 Table('17903292', 0) Cell('17903292', 0, 35, u'Uric acid')
Span("Uric acid", context=None, chars=[0,8], words=[0,1])
17903292 Table('17903292', 0) Cell('17903292', 0, 40, u'Urinary Albumin Excretion')
Span("Ur

### Relations

In [26]:
from snorkel.candidates import AlignedTableRelationExtractor
relation_extractor = AlignedTableRelationExtractor(rsid_extractor, phen_extractor, axis='row', induced=True)
tables = corpus.get_tables()

# create smaller subsets for evaluation/debugging
easy_tables = [tables[8]]
# hard_tables = [t for t in tables if t.document.name=='17658951']
hard_doc = [d for d in corpus.documents if d.name == '17903293'][0]
hard_tables = [hard_doc.tables[2]]

In [118]:
%time candidates = relation_extractor.extract(tables, name='all')
print "%s relations extracted, e.g." % len(candidates)
for cand in candidates[:10]: 
    print cand

1578 relations extracted, e.g.
SpanPair(Span("rs1158167", context=None, chars=[0,8], words=[0,0]), Span("CysC", context=None, chars=[0,3], words=[0,0]))
SpanPair(Span("rs1712790", context=None, chars=[0,8], words=[0,0]), Span("UAE", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs6977660", context=None, chars=[0,8], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs9322817", context=None, chars=[0,8], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs10499559", context=None, chars=[0,9], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs9305354", context=None, chars=[0,8], words=[0,0]), Span("UAE", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs2145231", context=None, chars=[0,8], words=[0,0]), Span("CysC", context=None, chars=[0,3], words=[0,0]))
SpanPair(Span("rs723464", context=None, chars=[0,7], words=[0,0]), Span("UAE", context=None, chars=[0,2], words=[0,0]))

Here, we remove nested candidates

In [28]:
# load existing candidates into a dict
span_dict = { str(span_pair.span1.context) : list() for span_pair in candidates }
for span_pair in candidates:
    span = span_pair.span1
    span_dict[str(span.context)].append( (span.char_start, span.char_end) )

def nested(ivl1, ivl2):
    if ivl1 != ivl2 and ivl2[0] <= ivl1[0] <= ivl1[1] <= ivl2[1]:
        return True
    else:
        return False

new_candidates = list()
for span_pair in candidates:
    span = span_pair.span1
    span_ivl = span.char_start, span.char_end
    span_name = str(span.context)
    if all([not nested(span_ivl, other_ivl) for other_ivl in span_dict[span_name]]):
        new_candidates.append(span_pair)
        
print len(candidates) - len(new_candidates), 'candidates dropped, now we have', len(new_candidates)
# phen_c = new_phen_c

0 candidates dropped, now we have 1578


## Learning the correctness of relations

### Creating a gold set

To create a gold set, we save all extracted relations into a csv file. We annotate it manually, and save the result to a second file. It contains pairs of phenotype and rsid strings; if that file exists, we take these as gold truth.

In [31]:
# store relations to annotate
with open('rels.acroynms.unnanotated.tsv', 'w') as f:
    for span_pair in new_candidates:
        doc_id = span_pair.span0.context.document.name
        table_id = span_pair.span0.context.table.position
        row_num = span_pair.span0.context.cell.row_num
        str1 = span_pair.span0.get_span()
        str2 = span_pair.span1.get_span()
        try:
            f.write('%s\t%s\t%d\t%s\t%s\n' % (doc_id, table_id, row_num, str1, str2))
        except:
            continue

In [137]:
# load annotations
annotations = dict()
with open('rels.acronyms.annotated.txt') as f:
    text = f.read()
    for line in text.split('\r'):
        doc_id, table_id, col_n, rs_id, phen, res = line.strip().split('\t')
        res = 1 if int(res) == 1 else -1
        annotations[(doc_id, table_id, rs_id, phen)] = res

### Classify correct relations

In [138]:
from snorkel.features import TableNgramPairFeaturizer

pkl_f = 'acro_table_feats.pkl'
try:
    with open(pkl_f, 'rb') as f:
        featurizer = cPickle.load(f)
except:
    featurizer = TableNgramPairFeaturizer()
    featurizer.fit_transform(candidates)

Building feature index...
Extracting features...
0/22341
5000/22341
10000/22341
15000/22341
20000/22341


In [139]:
def spair2uid(span_pair):
    doc_id = span_pair.span0.context.document.name
    table_id = str(span_pair.span0.context.table.position)
    str1 = span_pair.span0.get_span()
    str2 = span_pair.span1.get_span()
    return (doc_id, table_id, str1, str2)

# Split into train and test set
training_candidates = []
gold_candidates     = []
gold_labels         = []
n_half = len(candidates)/2
for c in candidates[:n_half]:
    uid = spair2uid(c)
    if uid in annotations:
        gold_candidates.append(c)
        gold_labels.append(annotations[uid])
    else:
        training_candidates.append(c)
training_candidates.extend(candidates[n_half:])
gold_labels = np.array(gold_labels)
print "Training set size: %s" % len(training_candidates)
print "Gold set size: %s" % len(gold_candidates)
print "Positive labels in training set: %s" % len([c for c in training_candidates if annotations.get(spair2uid(c),0)==1])
print "Negative labels in training set: %s" % len([c for c in training_candidates if annotations.get(spair2uid(c),0)==-1])
print "Positive labels in gold set: %s" % len([c for c in gold_candidates if annotations[spair2uid(c)]==1])
print "Negative labels in gold set: %s" % len([c for c in gold_candidates if annotations[spair2uid(c)]==-1])

Training set size: 884
Gold set size: 694
Positive labels in training set: 789
Negative labels in training set: 0
Positive labels in gold set: 555
Negative labels in gold set: 139


In [180]:
# negative LFs
def LF_number(m):
    txt = m.span1.get_span()
    frac_num = len([ch for ch in txt if ch.isdigit()]) / float(len(txt))
    return -1 if frac_num > 0.6 else 0

def LF_bad_phen_mentions(m):
    top_cells = m.span1.context.cell.aligned_cells(axis='col', induced=True)
    top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
    if not top_phrases: return 0
    matching_phrases = []
    for phrase in top_phrases:
        if any (phen_matcher._f_span(word) for word in phrase.text.split(' ')):
            matching_phrases.append(phrase)
    small_matching_phrases = [phrase for phrase in matching_phrases if len(phrase.text) <= 25]
    return -1 if not small_matching_phrases else 0

LF_tables_neg = [LF_number, LF_bad_phen_mentions]

# positive LFs
def LF_no_neg(m):
    return +1 if not any(LF(m) for LF in LF_tables_neg) else 0

LF_tables_pos = [LF_no_neg]

LF_tables = LF_tables_neg + LF_tables_pos

In [None]:
from snorkel.snorkel import TrainingSet
from snorkel.features import NgramFeaturizer

training_set = TrainingSet(training_candidates, LF_tables, featurizer=TableNgramPairFeaturizer())

Applying LFs...
Featurizing...
Building feature index...
Extracting features...
0/14220
5000/14220
10000/14220
LF Summary Statistics: 3 LFs applied to 884 candidates
------------------------------------------------------------
Coverage (candidates w/ > 0 labels):		100.00%
Overlap (candidates w/ > 1 labels):		0.00%
Conflict (candidates w/ conflicting labels):	0.00%


In [None]:
from snorkel.snorkel import Learner
import snorkel.learning
from snorkel.learning import LogReg

learner = Learner(training_set, model=LogReg())

# Splitting into CV and test set
n_half = len(gold_candidates)/2
test_candidates = gold_candidates[:n_half]
test_labels     = gold_labels[:n_half]
cv_candidates   = gold_candidates[n_half:]
cv_labels       = gold_labels[n_half:]

from snorkel.learning_utils import GridSearch
gs       = GridSearch(learner, ['mu', 'lf_w0'], [[1e-5, 1e-7],[1.0,2.0]])
gs_stats = gs.fit(cv_candidates, cv_labels)

Testing mu = 1.00e-05, lf_w0 = 1.00e+00
Begin training for rate=0.01, mu=1e-05
	Learning epoch = 0	Gradient mag. = 0.027207
	Learning epoch = 250	Gradient mag. = 0.047458
	Learning epoch = 500	Gradient mag. = 0.076673
	Learning epoch = 750	Gradient mag. = 0.119269
Final gradient magnitude for rate=0.01, mu=1e-05: 0.180
Applying LFs...
Featurizing...
Testing mu = 1.00e-05, lf_w0 = 2.00e+00
Begin training for rate=0.01, mu=1e-05
	Learning epoch = 0	Gradient mag. = 0.052007
	Learning epoch = 250	Gradient mag. = 0.082656
	Learning epoch = 500	Gradient mag. = 0.117942
	Learning epoch = 750	Gradient mag. = 0.163135
Final gradient magnitude for rate=0.01, mu=1e-05: 0.210
Testing mu = 1.00e-07, lf_w0 = 1.00e+00
Begin training for rate=0.01, mu=1e-07
	Learning epoch = 0	Gradient mag. = 0.027207
	Learning epoch = 250	Gradient mag. = 0.047555
	Learning epoch = 500	Gradient mag. = 0.076892
	Learning epoch = 750	Gradient mag. = 0.119632
Final gradient magnitude for rate=0.01, mu=1e-07: 0.180
Testin

In [None]:
learner.test_wmv(test_candidates, test_labels)

In [None]:
# preds = learner.predict_wmv(candidates)
acronyms = [spair2uid(c) for (c, p) in zip(candidates, preds) if p == 1]
mislabeled_cand = [(c,p, annotations.get(spair2uid(c), None)) for c, p in zip(candidates, preds) if p != annotations.get(spair2uid(c), p)]
for (c,p,g) in mislabeled_cand[:20]:
    print c.span0.context.document.name, p, g
    print c.span0.context    
    print c.span0.get_span(), c.span1.get_span()
    txt = c.span1.get_span()
    top_cells = c.span1.context.cell.aligned_cells(axis='col', induced=True)
    top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
    print top_phrases
    matching_phrases = []
    for phrase in top_phrases:
        print [(word,phen_matcher._f_span(word)) for word in phrase.text.split(' ')]
        if any (phen_matcher._f_span(word) for word in phrase.text.split(' ')):
            matching_phrases.append(phrase)
            print phrase
#         print phrase, [phen_matcher._f_span(word) for word in phrase.text.split(' ')]
    print [LF(c) for LF in LF_tables]
    print

In [None]:
print phen_matcher._stem('Phenotype*')

Save the results

In [None]:
preds = learner.predict_wmv(candidates)
rels = [(c.span0.context.document.name, c.span0.get_span(), c.span1.get_span()) for (c, p) in zip(candidates, preds) if p == 1]
print len(rels), 'relations extracted, e.g.:'
print rels[:10]

# store relations to annotate
with open('rels.acronyms.extracted.tsv', 'w') as f:
    for doc_id, str1, str2 in rels:
        try:
            out = u'{}\t{}\t{}\n'.format(doc_id, unicode(str1), str2)
            f.write(out.encode("UTF-8"))
        except:
            print 'Error in saving:', str1, str2

## Resolve acronyms based on ones extracted earlier

In [None]:
from extractor.dictionary import Dictionary, unravel

D = Dictionary()
D.load('acronyms.extracted.tsv')
print len(D), 'definitions loaded'

Use dictionary to resolve acronyms

In [None]:
new_rels = [ (doc_id, rs_id, unravel(doc_id, phen, D)) for doc_id, rs_id, phen in rels ]

In [None]:
for rel in new_rels[:1000]:
    print rel

In [None]:
print D.storage['17903292']
print D.evidence['17903292']

In [None]:
unravel('17903292', u'CysC', D)