# MUST RUN AT THE START OF EVERYTHING

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

#Imports
import csv
import os
import random

import numpy as np
import pandas as pd
import tqdm

In [None]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [None]:
from snorkel.candidates import PretaggedCandidateExtractor
from snorkel.models import Document, Sentence, candidate_subclass
from snorkel.parser import CorpusParser
from snorkel.viewer import SentenceNgramViewer
from utils.bigdata_utils import XMLMultiDocPreprocessor
from utils.bigdata_utils import Tagger

# Parse the Pubmed Abstracts

The code below is designed to read and parse data gathered from pubtator. Pubtator outputs their annotated text in xml format, so that is the standard file format we are going to use. 

In [None]:
%time filter_df = pd.read_table('/home/danich1/Documents/pubtator/data/pubtator-hetnet-tags.tsv.xz')

In [None]:
%time grouped = filter_df.groupby('pubmed_id')

In [None]:
# Please change to your local document here
working_path = '/home/danich1/Documents/Database/pubmed_docs.xml'
xml_parser = XMLMultiDocPreprocessor(
    path= working_path,
    doc='.//document',
    text='.//passage/text/text()',
    id='.//id/text()')

In [None]:
dg_tagger = Tagger(grouped)

In [None]:
corpus_parser = CorpusParser(fn=dg_tagger.tag)
document_chunk = []

for document in tqdm.tqdm(xml_parser.generate()):
    
    document_chunk.append(document)

    # chunk the data because snorkel cannot 
    # scale properly yet
    if len(document_chunk) >= 5e4:
        corpus_parser.apply(document_chunk, parallelism=5, clear=False)
        document_chunk = []

# If generator exhausts and there are still
# document to parse
if len(document_chunk) > 0:
    corpus_parser.apply(data, parallelism=5, clear=False)
    document_chunk = []

In [None]:
print("Documents: {}".format(session.query(Document).count()))
print("Sentences: {}".format(session.query(Sentence).count()))

# Get each candidate relation

This block of code below is designed to gather and tag each sentence found. **Note**: This does include the title of each abstract.

In [None]:
chunk_size = 1e5

In [None]:
#This specifies that I want candidates that have a disease and gene mentioned in a given sentence
DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
ce = PretaggedCandidateExtractor(DiseaseGene, ['Disease', 'Gene'])

In [None]:
# Divide the sentences into train, dev and test sets

# set the seed for reproduction
np.random.seed(100)
   
#Grab the sentences!!!
train_sens = set()
dev_sens = set()
test_sens = set()

offset = 0
has_docs = True
#divde and insert into the database
while has_docs:
    has_docs = False
    for doc in tqdm.tqdm(session.query(Document).limit(chunk_size).offset(offset).all()): 
        has_docs = True
        for s in doc.sentences:
            
            # Stratify the data into train, dev, test 
            category = np.random.choice([0,1,2], 1, p=[0.7,0.2,0.1])
            
            if category == 0:
                train_sens.add(s)
            elif category == 1:
                dev_sens.add(s)
            else:
                test_sens.add(s)
    if has_docs:
        ce.apply(train_sens, split=0, parallelism=5, clear=False)
        ce.apply(dev_sens, split=1, parallelism=5, clear=False)
        ce.apply(test_sens, split=2, parallelism=5, clear=False)
        offset = offset + chunk_size

        #Reset for each chunk
        train_sens = set()
        dev_sens = set()
        test_sens = set()

In [None]:
print("Number of Candidates: {}".format(session.query(DiseaseGene).filter(DiseaseGene.split == 0).count()))
print("Number of Candidates: {}".format(session.query(DiseaseGene).filter(DiseaseGene.split == 1).count()))
print("Number of Candidates: {}".format(session.query(DiseaseGene).filter(DiseaseGene.split == 2).count()))   

# Look at the Potential Candidates

The one cool thing about jupyter is that you can use this tool to look at candidates. Check it out after everything above has finished running

In [None]:
TRAINING_SET = 0
DEVELOPMENT_SET = 1
TEST_SET = 2

In [None]:
candidates = session.query(DiseaseGene).filter(DiseaseGene.split==TRAINING_SET)
sv = SentenceNgramViewer(candidates, session)

In [None]:
sv