# Phenotype/SNP relation extraction from tables

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import cPickle

# import snorkel and gwasdb
sys.path.append('../snorkel')
sys.path.append('../src')
sys.path.append('../src/crawler')

# set up paths
abstract_dir = '../data/db/papers'

# set up matplotlib
import numpy as np
import matplotlib
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (12,4)

## Load corpus

In [2]:
from snorkel.parser import XMLDocParser
from extractor.parser import UnicodeXMLTableDocParser

xml_parser = UnicodeXMLTableDocParser(
    path=abstract_dir,
    doc='./*',
    text='.//table',
    id='.//article-id[@pub-id-type="pmid"]/text()',
    keep_xml_tree=True)

In [12]:
from snorkel.parser import HTMLParser
from extractor.parser import UnicodeTableParser
from snorkel.parser import CorpusParser
import cPickle

table_parser = UnicodeTableParser()
html_parser = HTMLParser(path='../data/db/papers/')

corpus_name = 'gwas-table-corpus.pkl'

try:
    with open(corpus_name,"r") as pkl:
        corpus = cPickle.load(pkl)
except:
    cp = CorpusParser(xml_parser, table_parser)
    %time corpus = cp.parse_corpus(name='GWAS Corpus')
    # pickling currently doesn't work...
#     with open(corpus_name,"w") as pkl:
#         corpus = cPickle.dump(corpus, pkl)

CPU times: user 15min 33s, sys: 48.9 s, total: 16min 22s
Wall time: 23min 42s


In [5]:
# corpus15 = corpus

## Candidate extraction

### RSid Extraction

In [13]:
from snorkel.matchers import DictionaryMatch, RegexMatchSpan, Union
from snorkel.candidates import EntityExtractor
from snorkel.candidates import TableNgrams

from db.kb import KnowledgeBase

# Define a candidate space
ngrams = TableNgrams(n_max=1)

# Get a list of all the RSids we know
kb = KnowledgeBase()
rs_ids = kb.get_rsid_candidates()

# Define matchers
dict_rsid_matcher = DictionaryMatch(d=rs_ids, longest_match_only=False)
regx_rsid_matcher = RegexMatchSpan(rgx=r'rs\d+')
rsid_matcher = Union(dict_rsid_matcher, regx_rsid_matcher)

rsid_extractor = EntityExtractor(ngrams, rsid_matcher)
# %time rs_candidates = rsid_extractor.extract(corpus.get_tables(), name='all')

In [8]:
# for cand in rs_candidates[:10]: 
#     print cand
# print "%s candidates extracted" % len(rs_candidates)
# print rs_candidates[0].context
# print rs_candidates[0].context.cell

#### Statistics

In [226]:
from extractor.util import gold_rsid_stats, gold_rsid_precision

gold_set = frozenset( [ (doc.name, rs_id) for doc in corpus.documents for rs_id in kb.rsids_by_pmid(int(doc.name)) ] )
gold_set_rsids = [rs_id for doc_id, rs_id in gold_set]

gold_rsid_stats(rs_candidates, gold_set)

NameError: name 'rs_candidates' is not defined

Interesting: some SNPs seem to be never mentioned (e.g. rs12122100) while others (rs727153) appear only in the text.
Sometimes, it's not picked up for a different, strange reason: see rs13314993.

In [None]:
cells = rs_candidates[0].context.cell.aligned_cells('row')
[cell.text for cell in cells]

### Phenotypes

In [14]:
from snorkel.matchers import DictionaryMatch, Union, CellDictNameMatcher
from snorkel.candidates import EntityExtractor
from snorkel.candidates import TableNgrams, CellSpace

# Define a candidate space
ngrams = TableNgrams(n_max=9)
cells = CellSpace()

# Create a list of possible words that could denote phenotypes
phen_words = ['trait', 'phenotype']

# Define matchers
# dict_row_matcher = DictionaryMatch(d=phen_words, longest_match_only=False, stemmer='porter')
# cell_row_matcher = CellNameMatcher(row_matcher=dict_row_matcher, cand_space=ngrams)
# dict_col_matcher = DictionaryMatch(d=phen_words, longest_match_only=False, stemmer='porter')
# cell_col_matcher = CellNameMatcher(col_matcher=dict_col_matcher, cand_space=ngrams)
# phen_matcher = Union(cell_row_matcher, cell_col_matcher)
# phen_matcher = CellNameMatcher(col_matcher=dict_col_matcher, cand_space=ngrams)
phen_matcher = CellDictNameMatcher(axis='col', d=phen_words, n_max=3, ignore_case=True)

phen_extractor = EntityExtractor(cells, phen_matcher)
# %time phen_candidates = phen_extractor.extract(corpus.get_tables(), name='all')

In [15]:
from db.kb import KnowledgeBase
from snorkel.utils import slice_into_ngrams
from snorkel.matchers import FullCellDictMatcher

def make_ngrams(L, n_max=10, n_min=3, delim=' '):
    for l in L:
        yield l
        tokens = l.strip().split(delim)
        for ngram in slice_into_ngrams(tokens, n_max=n_max, n_min=n_min, delim=delim):
            yield ngram

# collect phenotype list
kb = KnowledgeBase()
# efo phenotypes
efo_phenotype_list0 = kb.get_phenotype_candidates(source='efo', peek=True) # TODO: remove peaking
efo_phenotype_list = list(make_ngrams(efo_phenotype_list0))
# mesh diseases
mesh_phenotype_list0 = kb.get_phenotype_candidates(source='mesh')
mesh_phenotype_list = list(make_ngrams(mesh_phenotype_list0))
# mesh chemicals
chem_phenotype_list = kb.get_phenotype_candidates(source='chemical')

phenotype_names = efo_phenotype_list + mesh_phenotype_list + chem_phenotype_list
full_cell_matcher = FullCellDictMatcher(d=phenotype_names, ignore_case=True, stemmer='porter')

cells = CellSpace()
phen_extractor2 = EntityExtractor(cells, full_cell_matcher)
select_tables = [table for table in corpus.get_tables() if table.document.name == '19197348']
# %time phen_candidates = phen_extractor2.extract(select_tables, name='all')

In [9]:
for c in (phen_candidates):
    print c

Span("BMI", context=None, chars=[0,2], words=[0,0])
Span("Height", context=None, chars=[0,5], words=[0,0])
Span("Waist Circumference", context=None, chars=[0,18], words=[0,1])
Span("Weight", context=None, chars=[0,5], words=[0,0])
Span("Fasting plasma glucose", context=None, chars=[0,21], words=[0,2])
Span("Thyroid Stimulating Hormone", context=None, chars=[0,26], words=[0,2])
Span("BMI", context=None, chars=[0,2], words=[0,0])
Span("Height", context=None, chars=[0,5], words=[0,0])
Span("Fasting plasma glucose", context=None, chars=[0,21], words=[0,2])
Span("Thyroid Stimulating Hormone", context=None, chars=[0,26], words=[0,2])


In [10]:
print "%s candidates extracted" % len(phen_candidates)
for cand in phen_candidates[0:10]: 
    print cand.context.document.name, cand.context.table, cand.context.cell
    print unicode(cand)
#     print [span for span in cand.row_ngrams()]
#     print [span for span in cand.col_ngrams()]
#     print
print
print phen_candidates[0].context
print phen_candidates[0].context.document.name, phen_candidates[0].context.table
print phen_candidates[0].context.cell

10 candidates extracted
19197348 Table('19197348', 2) Cell('19197348', 2, 11, u'BMI')
Span("BMI", context=None, chars=[0,2], words=[0,0])
19197348 Table('19197348', 2) Cell('19197348', 2, 23, u'Height')
Span("Height", context=None, chars=[0,5], words=[0,0])
19197348 Table('19197348', 2) Cell('19197348', 2, 46, u'Waist Circumference')
Span("Waist Circumference", context=None, chars=[0,18], words=[0,1])
19197348 Table('19197348', 2) Cell('19197348', 2, 102, u'Weight')
Span("Weight", context=None, chars=[0,5], words=[0,0])
19197348 Table('19197348', 2) Cell('19197348', 2, 540, u'Fasting plasma glucose')
Span("Fasting plasma glucose", context=None, chars=[0,21], words=[0,2])
19197348 Table('19197348', 2) Cell('19197348', 2, 552, u'Thyroid Stimulating Hormone')
Span("Thyroid Stimulating Hormone", context=None, chars=[0,26], words=[0,2])
19197348 Table('19197348', 3) Cell('19197348', 3, 22, u'BMI')
Span("BMI", context=None, chars=[0,2], words=[0,0])
19197348 Table('19197348', 3) Cell('191973

### Relations

In [16]:
from snorkel.candidates import AlignedTableRelationExtractor, SpanningTableRelationExtractor
relation_extractor = AlignedTableRelationExtractor(rsid_extractor, phen_extractor, axis='row', induced=True)
relation_extractor2 = SpanningTableRelationExtractor(rsid_extractor, phen_extractor2, axis='row')


tables = corpus.get_tables()

# create smaller subsets for evaluation/debugging
easy_tables = [tables[8]]
# hard_tables = [t for t in tables if t.document.name=='17658951']
hard_doc = [d for d in corpus.documents if d.name == '17903293'][0]
hard_tables = [hard_doc.tables[2]]

In [12]:
%time candidates = relation_extractor2.extract([select_tables[2]], name='all')

2 3 1 0
2 3 3 0
4 3 1 0
4 3 3 0
5 3 1 0
5 3 3 0
CPU times: user 10.6 s, sys: 405 ms, total: 11 s
Wall time: 10.7 s


In [13]:
for cand in candidates[:10]: 
    print cand

In [17]:
%time candidates = relation_extractor.extract(tables, name='all')
print "%s relations extracted, e.g." % len(candidates)
for cand in candidates[:10]: 
    print cand

CPU times: user 2h 46min 52s, sys: 11min 26s, total: 2h 58min 18s
Wall time: 3h 36min 15s
3561 relations extracted, e.g.
SpanPair(Span("rs1158167", context=None, chars=[0,8], words=[0,0]), Span("CysC", context=None, chars=[0,3], words=[0,0]))
SpanPair(Span("rs1712790", context=None, chars=[0,8], words=[0,0]), Span("UAE", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs6977660", context=None, chars=[0,8], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs9322817", context=None, chars=[0,8], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs10499559", context=None, chars=[0,9], words=[0,0]), Span("TSH", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs9305354", context=None, chars=[0,8], words=[0,0]), Span("UAE", context=None, chars=[0,2], words=[0,0]))
SpanPair(Span("rs2145231", context=None, chars=[0,8], words=[0,0]), Span("CysC", context=None, chars=[0,3], words=[0,0]))
SpanPair(Span("rs723464", con

Here, we remove nested candidates

In [15]:
# load existing candidates into a dict
span_dict = { str(span_pair.span1.context) : list() for span_pair in candidates }
for span_pair in candidates:
    span = span_pair.span1
    span_dict[str(span.context)].append( (span.char_start, span.char_end) )

def nested(ivl1, ivl2):
    if ivl1 != ivl2 and ivl2[0] <= ivl1[0] <= ivl1[1] <= ivl2[1]:
        return True
    else:
        return False

new_candidates = list()
for span_pair in candidates:
    span = span_pair.span1
    span_ivl = span.char_start, span.char_end
    span_name = str(span.context)
    if all([not nested(span_ivl, other_ivl) for other_ivl in span_dict[span_name]]):
        new_candidates.append(span_pair)
        
print len(candidates) - len(new_candidates), 'candidates dropped, now we have', len(new_candidates)
# phen_c = new_phen_c

0 candidates dropped, now we have 1835


In [16]:
# candidates15 = candidates

## Learning the correctness of relations

### Creating a gold set

To create a gold set, we save all extracted relations into a csv file. We annotate it manually, and save the result to a second file. It contains pairs of phenotype and rsid strings; if that file exists, we take these as gold truth.

In [17]:
# store relations to annotate
with open('rels.acroynms.unnanotated.tsv', 'w') as f:
    for span_pair in new_candidates:
        doc_id = span_pair.span0.context.document.name
        table_id = span_pair.span0.context.table.position
        row_num = span_pair.span0.context.cell.row_num
        str1 = span_pair.span0.get_span()
        str2 = span_pair.span1.get_span()
        try:
            f.write('%s\t%s\t%d\t%s\t%s\n' % (doc_id, table_id, row_num, str1, str2))
        except:
            continue

In [18]:
# load annotations
annotations = dict()
with open('rels.acronyms.annotated.txt') as f:
    text = f.read()
    for line in text.split('\r'):
        doc_id, table_id, col_n, rs_id, phen, res = line.strip().split('\t')
        res = 1 if int(res) == 1 else -1
        annotations[(doc_id, table_id, rs_id, phen)] = res

### Classify correct relations

In [19]:
from snorkel.features import TableNgramPairFeaturizer

pkl_f = 'acro_table_feats.pkl'
try:
    with open(pkl_f, 'rb') as f:
        featurizer = cPickle.load(f)
except:
    featurizer = TableNgramPairFeaturizer()
    featurizer.fit_transform(candidates)

Building feature index...
Extracting features...
0/29079
5000/29079
10000/29079
15000/29079
20000/29079
25000/29079


In [20]:
def spair2uid(span_pair):
    doc_id = span_pair.span0.context.document.name
    table_id = str(span_pair.span0.context.table.position)
    str1 = span_pair.span0.get_span()
    str2 = span_pair.span1.get_span()
    return (doc_id, table_id, str1, str2)

# Split into train and test set
training_candidates = []
gold_candidates     = []
gold_labels         = []
n_half = len(candidates)/2
for c in candidates[:n_half]:
    uid = spair2uid(c)
    if uid in annotations:
        gold_candidates.append(c)
        gold_labels.append(annotations[uid])
    else:
        training_candidates.append(c)
training_candidates.extend(candidates[n_half:])
gold_labels = np.array(gold_labels)
print "Training set size: %s" % len(training_candidates)
print "Gold set size: %s" % len(gold_candidates)
print "Positive labels in training set: %s" % len([c for c in training_candidates if annotations.get(spair2uid(c),0)==1])
print "Negative labels in training set: %s" % len([c for c in training_candidates if annotations.get(spair2uid(c),0)==-1])
print "Positive labels in gold set: %s" % len([c for c in gold_candidates if annotations[spair2uid(c)]==1])
print "Negative labels in gold set: %s" % len([c for c in gold_candidates if annotations[spair2uid(c)]==-1])

Training set size: 1013
Gold set size: 822
Positive labels in training set: 661
Negative labels in training set: 0
Positive labels in gold set: 683
Negative labels in gold set: 139


In [21]:
bad_words = ['rs number', 'rs id', 'rsid']

# negative LFs
def LF_number(m):
    txt = m.span1.get_span()
    frac_num = len([ch for ch in txt if ch.isdigit()]) / float(len(txt))
    return -1 if len(txt) > 5 and frac_num > 0.4 or frac_num > 0.6 else 0

def LF_bad_phen_mentions(m):
    top_cells = m.span1.context.cell.aligned_cells(axis='col', induced=True)
    top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
    if not top_phrases: return 0
    matching_phrases = []
    for phrase in top_phrases:
        if any (phen_matcher._f_span(word) for word in phrase.text.split(' ')):
            matching_phrases.append(phrase)
    small_matching_phrases = [phrase for phrase in matching_phrases if len(phrase.text) <= 25]
    return -1 if not small_matching_phrases else 0

def LF_bad_word(m):
    txt = m.span1.get_span()
    return -1 if any(word in txt for word in bad_words) else 0

LF_tables_neg = [LF_number, LF_bad_phen_mentions]

# positive LFs
def LF_no_neg(m):
    return +1 if not any(LF(m) for LF in LF_tables_neg) else 0

LF_tables_pos = [LF_no_neg]

LF_tables = LF_tables_neg + LF_tables_pos

In [22]:
from snorkel.snorkel import TrainingSet
from snorkel.features import NgramFeaturizer

training_set = TrainingSet(training_candidates, LF_tables, featurizer=TableNgramPairFeaturizer())

because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



Applying LFs...
Featurizing...
Building feature index...
Extracting features...
0/19661
5000/19661
10000/19661
15000/19661
LF Summary Statistics: 3 LFs applied to 1013 candidates
------------------------------------------------------------
Coverage (candidates w/ > 0 labels):		100.00%
Overlap (candidates w/ > 1 labels):		0.00%
Conflict (candidates w/ conflicting labels):	0.00%


In [23]:
from snorkel.snorkel import Learner
import snorkel.learning
from snorkel.learning import LogReg

learner = Learner(training_set, model=LogReg())

# Splitting into CV and test set
n_half = len(gold_candidates)/2
test_candidates = gold_candidates[:n_half]
test_labels     = gold_labels[:n_half]
cv_candidates   = gold_candidates[n_half:]
cv_labels       = gold_labels[n_half:]

from snorkel.learning_utils import GridSearch
gs       = GridSearch(learner, ['mu', 'lf_w0'], [[1e-5, 1e-7],[1.0,2.0]])
gs_stats = gs.fit(cv_candidates, cv_labels)

Testing mu = 1.00e-05, lf_w0 = 1.00e+00
Begin training for rate=0.01, mu=1e-05
	Learning epoch = 0	Gradient mag. = 0.027386
	Learning epoch = 250	Gradient mag. = 0.046833
	Learning epoch = 500	Gradient mag. = 0.074606
	Learning epoch = 750	Gradient mag. = 0.114826
Final gradient magnitude for rate=0.01, mu=1e-05: 0.172
Applying LFs...
Featurizing...
Testing mu = 1.00e-05, lf_w0 = 2.00e+00
Begin training for rate=0.01, mu=1e-05
	Learning epoch = 0	Gradient mag. = 0.053098
	Learning epoch = 250	Gradient mag. = 0.084777
	Learning epoch = 500	Gradient mag. = 0.121184
	Learning epoch = 750	Gradient mag. = 0.168109
Final gradient magnitude for rate=0.01, mu=1e-05: 0.220
Testing mu = 1.00e-07, lf_w0 = 1.00e+00
Begin training for rate=0.01, mu=1e-07
	Learning epoch = 0	Gradient mag. = 0.027386
	Learning epoch = 250	Gradient mag. = 0.046932
	Learning epoch = 500	Gradient mag. = 0.074838
	Learning epoch = 750	Gradient mag. = 0.115220
Final gradient magnitude for rate=0.01, mu=1e-07: 0.172
Testin

In [24]:
learner.test_wmv(test_candidates, test_labels)

Applying LFs...
Featurizing...
Test set size:	411
----------------------------------------
Precision:	1.0
Recall:		1.0
F1 Score:	1.0
----------------------------------------
TP: 310 | FP: 0 | TN: 101 | FN: 0


In [25]:
# preds = learner.predict_wmv(candidates)
acronyms = [spair2uid(c) for (c, p) in zip(candidates, preds) if p == 1]
mislabeled_cand = [(c,p, annotations.get(spair2uid(c), None)) for c, p in zip(candidates, preds) if p != annotations.get(spair2uid(c), p)]
for (c,p,g) in mislabeled_cand[:20]:
    print c.span0.context.document.name, p, g
    print c.span0.context    
    print c.span0.get_span(), c.span1.get_span()
    txt = c.span1.get_span()
    top_cells = c.span1.context.cell.aligned_cells(axis='col', induced=True)
    top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
    print top_phrases
    matching_phrases = []
    for phrase in top_phrases:
        print [(word,phen_matcher._f_span(word)) for word in phrase.text.split(' ')]
        if any (phen_matcher._f_span(word) for word in phrase.text.split(' ')):
            matching_phrases.append(phrase)
            print phrase
#         print phrase, [phen_matcher._f_span(word) for word in phrase.text.split(' ')]
    print [LF(c) for LF in LF_tables]
    print

NameError: name 'preds' is not defined

Save the results

In [None]:
preds = learner.predict_wmv(candidates)
rels = [(c.span0.context.document.name, c.span0.get_span(), c.span1.get_span()) for (c, p) in zip(candidates, preds) if p == 1]
print len(rels), 'relations extracted, e.g.:'
print rels[:10]

# store relations to annotate
with open('rels.acronyms.extracted.tsv', 'w') as f:
    for doc_id, str1, str2 in rels:
        try:
            out = u'{}\t{}\t{}\n'.format(doc_id, unicode(str1), str2)
            f.write(out.encode("UTF-8"))
        except:
            print 'Error in saving:', str1, str2

## Resolve acronyms based on ones extracted earlier

In [None]:
from extractor.dictionary import Dictionary, unravel

D = Dictionary()
D.load('acronyms.extracted.tsv')
print len(D), 'definitions loaded'

Use dictionary to resolve acronyms

In [None]:
new_rels = [ (doc_id, rs_id, unravel(doc_id, phen, D)) for doc_id, rs_id, phen in rels ]

## Evaluate extracted relations

Let's first evaluate the recall w.r.t. GWAS Central

In [None]:
for doc in corpus.documents:
    assocs = [assoc for assoc in kb.assoc_by_pmid(doc.name) if assoc.source == 'gwas_central' and assoc.pvalue < 1e-5]
    print doc.name, len(assocs), len([(pmid, rsid, phen) for pmid, rsid, phen in new_rels if pmid == doc.name])
    

In [None]:
print ([(pmid, rsid, phen) for pmid, rsid, phen in new_rels if pmid == '17903305'])

In [None]:
pmids = sorted(list({pmid for pmid, _, _ in new_rels}))

from db.kb import KnowledgeBase
kb = KnowledgeBase()
assocs = [assoc for pmid in pmids for assoc in kb.assoc_by_pmid(pmid) if assoc.source == 'gwas_central' and assoc.pvalue < 1e-5]
print len(pmids), len(assocs)

In [None]:
print pmids

In [None]:
# collect resolved relations
rel_dict = { (pmid, rsid) : set() for (pmid, rsid, phen) in new_rels }
for (pmid, rsid, phen) in new_rels:
    rel_dict[(pmid, rsid)].add(phen)

gold_rel_dict = { (a.paper.pubmed_id, a.snp.rs_id) : set() for a in assocs }
for a in assocs:
    gold_rel_dict[(a.paper.pubmed_id, a.snp.rs_id)].add(a.phenotype.name)

First, evaluate recall: how many associations in GWAS central can we recover?

In [None]:
for a in assocs[:500]:
    s1 = gold_rel_dict[(a.paper.pubmed_id, a.snp.rs_id)]
    s2 = rel_dict.get((str(a.paper.pubmed_id), a.snp.rs_id), {})
    if len(s1) != 1 or len(s2) != 1:
        print a.paper.pubmed_id, a.snp.rs_id, a.source
        print 'GWC:', gold_rel_dict[(a.paper.pubmed_id, a.snp.rs_id)]
        print 'US: ', rel_dict.get((str(a.paper.pubmed_id), a.snp.rs_id), None)
        print

Second question: can we learn any more SNPs than the ones that are already in GWAS central?

In [None]:
pmids = sorted(list({pmid for pmid, _, _ in new_rels if int(pmid) < 17903297}))

from db.kb import KnowledgeBase
kb = KnowledgeBase()
assocs = [assoc for pmid in pmids for assoc in kb.assoc_by_pmid(pmid) if assoc.source == 'gwas_central']
print len(assocs)

In [None]:
for a in assocs:
    s1 = gold_rel_dict[(a.paper.pubmed_id, a.snp.rs_id)]
    s2 = rel_dict.get((str(a.paper.pubmed_id), a.snp.rs_id), {})
    print a.paper.pubmed_id, a.snp.rs_id, a.source
    print 'GWC:', gold_rel_dict[(a.paper.pubmed_id, a.snp.rs_id)]
    print 'US: ', rel_dict.get((str(a.paper.pubmed_id), a.snp.rs_id), None)
    print

## Combine with extracted pvalue/rsid relations

In [None]:
pval_rsid_dict = dict()
pval_dict = dict() # combine all of the pvalues for a SNPs in the same document into one set
with open('pval-rsid.raw.tsv') as f:
    for line in f:
        pmid, rsid, table_id, row_id, col_id, pval = line.strip().split('\t')
        pval, table_id, row_id, col_id = float(pval), int(table_id), int(row_id), int(col_id)
        
        if pmid not in pval_rsid_dict: pval_rsid_dict[pmid] = dict()
        key = (rsid, table_id, row_id)
        if key not in pval_rsid_dict[pmid]: pval_rsid_dict[pmid][key] = set()
        pval_rsid_dict[pmid][key].add(pval)
                
        if pmid not in pval_dict: pval_dict[pmid] = dict()
        if rsid not in pval_dict[pmid]: pval_dict[pmid][rsid] = set()
        pval_dict[pmid][rsid].add(pval)

pval_dict0 = {pmid : {rsid : min(pval_dict[pmid][rsid]) for rsid in pval_dict[pmid]} for pmid in pval_dict}
pval_rsid_dict0 = {pmid : {key : min(pval_rsid_dict[pmid][key]) for key in pval_rsid_dict[pmid]} for pmid in pval_rsid_dict}
pval_dict = pval_dict0
pval_rsid_dict = pval_rsid_dict0

Plan. If phen/rsid has been extracted from tables: take its pvalue from pval_rsid_dict.

If not, we assume that paper has only one phenotype and we take the smallest reported pvalue in the paper.

Our goal for now is just to filter phen/rsid relations that have pval<1e-5.

#### Save all relations that are sufficiently small p-values

In [None]:
# preds = learner.predict_wmv(candidates)
# predicted_candidates = [c for (c, p) in zip(candidates, preds) if p == 1]

import re
import unicodedata
def _normalize_str(s):
    try:
        s = s.encode('utf-8')
        return s
    except UnicodeEncodeError: 
        pass
    try:
        s = s.decode('utf-8')
        return s
    except UnicodeDecodeError: 
        pass    
    raise Exception()

with open('phen-rsid.table.rel.tsv', 'w') as f:
    for c in predicted_candidates:
        pmid = c.span0.context.document.name
        rsid = c.span0.get_span()
        phen = c.span1.get_span()        
        table_id = c.span0.context.table.position
        row_num = c.span0.context.cell.row_num
        col_num = c.span0.context.cell.col_num # of the rsid

        phen = (unravel(pmid, phen, D))
        if isinstance(phen, unicode):
            phen = phen.encode('utf-8')
                    
        pval = pval_rsid_dict[pmid].get((rsid, table_id, row_num), -1)
        if pval > 1e-5: continue

        out_str = '{pmid}\t{rsid}\t{phen}\t{pval}\ttable\t{table_id}\t{row}\t{col}\n'.format(
                    pmid=pmid, rsid=rsid, phen=phen, pval=pval, table_id=table_id, row=row_num, col=col_num)
        f.write(out_str)

In [None]:
print [(c, c.span0.context.cell.row_num, unravel(c.span0.context.document.name, c.span1.get_span(), D)) for c in candidates if c.span0.get_span() == 'rs10500631']

In [None]:
pval_rsid_dict['17903294'].get(('rs10500631', 1, 5), -1)

In [None]:
for x in pval_rsid_dict['17903294']:    
    print x, pval_rsid_dict['17903294'][x]