# Snorkel corpus preprocess
This notebook takes as input some snorkel-compatible files (corpus, entities, gold relation labels) and creates snorkel.db, where all of the above are persisted.
re-run to drop and re-create db

In [1]:
experiment_name = '25similar'
%load_ext autoreload
%autoreload 2
%matplotlib inline
%run init.py

  """)


Created snorkel session from  postgres:///snorkel25similar


In [2]:
# Connect with SQLite instead
# from snorkel import SnorkelSession
# session = SnorkelSession()
# from snorkel.models import  Document, Sentence

In [3]:
import pickle

In [4]:
import glob
import pandas as pd

In [5]:
# helpers for debugging
from utils import get_raw_document_txt

## Load the corpus

In [6]:
# Get train,dev,test from goldset, and chemdner_silver from NCBI_parsed (ensure consistency of named entities)
txt_corpus = glob.glob("/home/antonis/data/biocreative6/NCBI_parsed/similar25/*.txt") + \
                glob.glob("/home/antonis/data/biocreative6/goldset/*/*.txt")

# test cand extr+ goldlabel generator
# txt_corpus = glob.glob("/home/antonis/data/biocreative6/goldset/*/*.txt")

txt_corpus = pd.Series(txt_corpus,name='paths')
print len(txt_corpus), 'documents'

40743 documents


In [7]:
# full_corpus_paths.to_csv('full_corpus_paths.csv',header=True)
txt_corpus.to_csv('full_corpus_paths.csv',index=False)

In [8]:
from snorkel.parser import TextDocPreprocessor, CSVPathsPreprocessor
# path = "/home/antonis/data/biocreative6/corpus/training/"
# doc_preprocessor = TextDocPreprocessor(path)

csv_preprocessor = CSVPathsPreprocessor('full_corpus_paths.csv')

In [9]:
from snorkel.parser import CorpusParser
from snorkel.utils_cdr import TaggerOneTagger, CDRTagger
from snorkel.parser.spacy_parser import Spacy


tagger_one = TaggerOneTagger(fname_tags=
                             '/home/antonis/data/biocreative6/entities/unary_tags.pkl.bz2',
                            fname_mesh=
                             '/home/antonis/data/biocreative6/entities/mesh_dict.pkl.bz2')

corpus_parser = CorpusParser(parser = Spacy() , fn = tagger_one.tag)

corpus_parser.apply(list(csv_preprocessor), parallelism=6)

# Inspect DB contents
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Clearing existing...
Running UDF...
('Documents:', 40743L)
('Sentences:', 448455L)


** DebuG ** 


In [10]:
# # DebuG
# print("sample docs inserted: ")
# map(lambda x: x[1],csv_preprocessor)[0:10]

In [11]:
# # use Spacy as an alternative for now
# corpus_parser = CorpusParser( parser = Spacy() )
# corpus_parser.apply(doc_preprocessor)

In [12]:
# # Inspect given entities in sentence
# for i in range(10):
#     print ','.join(session.query(Sentence).all()[i].entity_types)
#     print '\n'

### Split dataset into train, validation, test

In [13]:
with open(path_pubmed_ids_pkl, 'rb') as f:
    pubmed_ids = pickle.load(f)

In [14]:
pubmed_ids.keys()

['incoming_citations',
 'test_gs',
 'outgoing_citations',
 'train',
 'similar25',
 'validation']

In [15]:
train_ids  =  set(pubmed_ids['train'])
val_ids = set(pubmed_ids['validation'])
test_ids = set(pubmed_ids['test_gs'])
# chemdner_silver_ids =  set(pubmed_ids['chemdner_silver'])
unlab_ids = set(pubmed_ids['similar25'])

In [16]:
# split sentences

In [17]:
train_sents, val_sents, test_sents, unlab_sents = set(), set(), set(), set()
docs = session.query(Document).order_by(Document.name).all()

In [18]:
for i, doc in enumerate(docs):
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)            
        elif doc.name in val_ids:
            val_sents.add(s)
        elif doc.name in test_ids:
            test_sents.add(s)
        elif doc.name in unlab_ids:
            unlab_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

## Candidate extraction

In [19]:
# Moved into init.py
# from snorkel.models import Candidate, candidate_subclass
# REGULATOR = candidate_subclass('REGULATOR', ['Chemical', 'Gene'])

In [20]:
from snorkel.candidates import PretaggedCandidateExtractor
candidate_extractor = PretaggedCandidateExtractor(REGULATOR, ['Chemical', 'Gene'])

In [21]:
for k, sents in enumerate([train_sents, val_sents, test_sents, unlab_sents]):
    print len(sents)

13542
3067
8171
423675


In [22]:
for k, sents in enumerate([train_sents, val_sents, test_sents, unlab_sents]):
    candidate_extractor.apply(sents, split=k, parallelism=6)
    print("Number of candidates:", session.query(REGULATOR).filter(REGULATOR.split == k).count())

Clearing existing...
Running UDF...
('Number of candidates:', 21184L)
Clearing existing...
Running UDF...
('Number of candidates:', 5080L)
Clearing existing...
Running UDF...
('Number of candidates:', 13935L)
Clearing existing...
Running UDF...
('Number of candidates:', 130424L)


# Import gold labels

In [23]:
from utils import load_external_labels

In [24]:
from snorkel.db_helpers import reload_annotator_labels
# load_external_labels(session,
#                      REGULATOR,
#                      FPATH='../../data/biocreative6/gold_rels_snorkel_format.tsv'
# #                      id_fname='../../data/biocreative6/pubmed_ids_extended.pickle'
#                     )

#load external labels into db
load_external_labels(session, REGULATOR, tsv_path='/home/antonis/data/biocreative6/entities/gold_rels_complete.tsv', 
                     reload=True, debug=True)

AnnotatorLabels created: 13007
AnnotatorLabels not matched to candidates (split=0): 308536
AnnotatorLabels created: 3242
AnnotatorLabels not matched to candidates (split=1): 318301
AnnotatorLabels created: 8221
AnnotatorLabels not matched to candidates (split=2): 313322


In [25]:
# See whats going on with candidate mapping

In [26]:
from snorkel.models import StableLabel
from sqlalchemy import and_

In [27]:
for k in range(4):
    print 'split = ',k
    print 'Total cands:', session.query(REGULATOR).filter(REGULATOR.split == k).count()
    print 'Mapped cands:', session.query(REGULATOR).filter(REGULATOR.split == k).filter(REGULATOR.gold_labels).count()
    print 'Un-mapped cands:', session.query(REGULATOR).filter(REGULATOR.split == k).filter(~REGULATOR.gold_labels.any()).count()
    print session.query(REGULATOR).filter(REGULATOR.split == k).count() == (session.query(REGULATOR).filter(REGULATOR.split == k).filter(REGULATOR.gold_labels).count() +
                                                                           session.query(REGULATOR).filter(REGULATOR.split == k).filter(~REGULATOR.gold_labels.any()).count())
    print ''


split =  0
Total cands: 21184
Mapped cands: 13007
Un-mapped cands: 8177
True

split =  1
Total cands: 5080
Mapped cands: 3242
Un-mapped cands: 1838
True

split =  2
Total cands: 13935
Mapped cands: 8221
Un-mapped cands: 5714
True

split =  3
Total cands: 130424
Mapped cands: 0
Un-mapped cands: 130424
True



In [28]:
#create list of unmapped cands to drop
to_drop = []
for k in range(3):
    query = session.query(REGULATOR).filter(and_(REGULATOR.split==k,~REGULATOR.gold_labels.any()))
    print 'Adding %i candidates from split=%i in to_drop list'%(query.count(), k)
    to_drop.extend(map(lambda x: x.id,query.all()))

Adding 8177 candidates from split=0 in to_drop list
Adding 1838 candidates from split=1 in to_drop list
Adding 5714 candidates from split=2 in to_drop list


# ~~~~~ STOPPED HERE

In [29]:
#drop unmapped cands
query = session.query(Candidate).filter(Candidate.id.in_(to_drop))
print query.count()
query.delete(synchronize_session=False)


15729


15729

In [30]:
#confirm they were dropped
query.count() == 0

True

In [31]:
for k in range(4):
    print 'split = ',k
    print 'Total cands:', session.query(REGULATOR).filter(REGULATOR.split == k).count()
    print 'Mapped cands:', session.query(REGULATOR).filter(REGULATOR.split == k).filter(REGULATOR.gold_labels).count()
    print 'Un-mapped cands:', session.query(REGULATOR).filter(REGULATOR.split == k).filter(~REGULATOR.gold_labels.any()).count()
    print session.query(REGULATOR).filter(REGULATOR.split == k).count() == (session.query(REGULATOR).filter(REGULATOR.split == k).filter(REGULATOR.gold_labels).count() +
                                                                           session.query(REGULATOR).filter(REGULATOR.split == k).filter(~REGULATOR.gold_labels.any()).count())
    print ''


split =  0
Total cands: 13007
Mapped cands: 13007
Un-mapped cands: 0
True

split =  1
Total cands: 3242
Mapped cands: 3242
Un-mapped cands: 0
True

split =  2
Total cands: 8221
Mapped cands: 8221
Un-mapped cands: 0
True

split =  3
Total cands: 130424
Mapped cands: 0
Un-mapped cands: 130424
True



# Exporting candidates from snorkel to sklearn for ML model training

In [32]:
from sklearn_bridge import export_snorkel_candidates

In [33]:
# export candidates for train, dev, test dataset
candidates = dict()
nr_cands_extracted=0
for i in range(4): #for train,dev,test export only labelled candidates 
    candidates[i] = export_snorkel_candidates(session,REGULATOR, i, True)
    print 'Extracted %i candidates from split = %i '%(len(candidates[i].keys()), i)
    nr_cands_extracted += len(candidates[i].keys())

print 'Extracted %i candidates in total'%nr_cands_extracted

Extracted 13007 candidates from split = 0 
Extracted 3242 candidates from split = 1 
Extracted 8221 candidates from split = 2 
Extracted 130424 candidates from split = 3 
Extracted 154894 candidates in total


In [34]:
with open(path_candidate_dict_pkl, 'wb') as f:
    pickle.dump(dict(candidates),f)

# #########################################
# Once this is done, results are persisted into snorkel.db and this step is no longer required, unless more documents are added.
# #########################################