# MUST RUN AT THE START OF EVERYTHING

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from collections import defaultdict
import re
import os
import operator


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sqlalchemy import and_
import tqdm
from wordcloud import WordCloud

In [None]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [None]:
from snorkel.annotations import FeatureAnnotator, LabelAnnotator
from snorkel.models import candidate_subclass
from snorkel.viewer import SentenceNgramViewer
from snorkel.models import Candidate

In [None]:
edge_type = "dg"
debug = False

In [None]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
    edge = "disease_gene"
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
    edge = "gene_gene"
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
    edge = "compound_gene"
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
    edge = "compound_disease"
else:
    print("Please pick a valid edge type")

In [None]:
TRAIN = 0
DEV = 1
TEST = 2

# Look at potential Candidates

Use this to look at loaded candidates from a given set. The constants represent the index to retrieve the training set, development set and testing set.

In [None]:
candidates = session.query(DiseaseGene).filter(DiseaseGene.split==TRAIN).limit(100)
sv = SentenceNgramViewer(candidates, session)

In [None]:
sv

# Trigger Words for Label Function Design

This section attempts to find "trigger" words that will help distinguish true candidate relations from the background. 

In [None]:
def update_freq_map(freq_map, pos_tag, pos_array, words):
    for i, sens_pos_tags in enumerate(pos_array):
        if pos_tag in sens_pos_tags:
            freq_map[words[i]] += 1

In [None]:
trigger_verbs = defaultdict(int)
gene_trigger_adj = defaultdict(int)
disease_trigger_adj = defaultdict(int)

candidates = session.query(DiseaseGene).filter(DiseaseGene.split==TRAIN)
for c in tqdm.tqdm(candidates):
    candidate_context = c.get_contexts()
    disease_start, disease_end = candidate_context[0].get_word_start(), candidate_context[0].get_word_end() 
    gene_start, gene_end = candidate_context[1].get_word_start(), candidate_context[1].get_word_end()
    sentence = c.get_parent()
    
    #Verb
    if gene_start < disease_start:
        update_freq_map(trigger_verbs, "VB", sentence.pos_tags[gene_start:disease_start], sentence.words[gene_start:disease_start])
    else:
        update_freq_map(trigger_verbs, "VB", sentence.pos_tags[disease_start:gene_start], sentence.words[disease_start:gene_start])
    
    #Adjectives
    if gene_start > 3:
        update_freq_map(gene_trigger_adj, "JJ", sentence.pos_tags[gene_start-3:gene_start], sentence.words[gene_start-3:gene_start])
    
    if gene_end+3  < len(sentence.text):
        update_freq_map(gene_trigger_adj, "JJ", sentence.pos_tags[gene_start:gene_start+3], sentence.words[gene_start:gene_start+3])
    
    if disease_start > 3:
        update_freq_map(disease_trigger_adj, "JJ", sentence.pos_tags[disease_start-3:disease_start], sentence.words[disease_start-3:disease_start])
    
    if disease_end+3  < len(sentence.text):
        update_freq_map(disease_trigger_adj, "JJ", sentence.pos_tags[disease_start:disease_start+3], sentence.words[disease_start:disease_start+3])

In [None]:
sorted_verbs = sorted(fixed_trigger_verbs.items(), key=operator.itemgetter(1))
sorted_verbs.reverse()

In [None]:
sorted_gene_adj = sorted(fixed_gene_trigger_adj.items(), key=operator.itemgetter(1))
sorted_gene_adj.reverse()

In [None]:
sorted_disease_adj = sorted(fixed_disease_trigger_adj.items(), key=operator.itemgetter(1))
sorted_disease_adj.reverse()

In [None]:
pd.DataFrame(sorted_verbs, columns=["word", "freq"]).to_csv("verb_freq.csv", index=False)
pd.DataFrame(sorted_gene_adj, columns=["word", "freq"]).to_csv("gadj_freq.csv", index=False)
pd.DataFrame(sorted_disease_adj, columns=["word", "freq"]).to_csv("dadj_freq.csv", index=False)

# Word Cloud

Show the word cloud for particular edge cases

In [None]:
word_df = pd.read_csv("dadj_freq.csv")

In [None]:
word_freq = {word:value for word, value in word_df.to_dict('split')['data']}

In [None]:
wordcloud_obj = WordCloud(background_color='black', colormap="autumn")
wordcloud_obj.generate_from_frequencies(word_freq)

plt.imshow(wordcloud_obj)
plt.axis("off")
plt.show()

# Label Functions

Here is the fundamental part of the project. Below are the label functions that are used to give a candidate a label of 1,0 or -1 which corresponds to correct relation, not sure and incorrection relation. The goal here is to develop functions that can label as many candidates as possible.

In [None]:
if edge_type == "dg":
    from utils.disease_gene_lf import *
elif edge_type == "gg":
    from utils.gene_gene_lf import *
elif edge_type == "cg":
    from utils.compound_gene_lf import *
elif edge_type == "cd":
    from utils.compound_disease_lf import *
else:
    print("Please pick a valid edge type")

# Debug Label Function

In [None]:
def LF_DEBUG(C):
    print "Left Tokens"
    print get_left_tokens(c,window=3)
    print
    print "Right Tokens"
    print get_right_tokens(c)
    print
    print "Between Tokens"
    print get_between_tokens(c)
    print 
    print "Tagged Text"
    print get_tagged_text(c)
    print re.search(r'{{B}} .* is a .* {{A}}',get_tagged_text(c))
    print
    print "Get between Text"
    print get_text_between(c)
    print len(get_text_between(c))
    print 
    print "Parent Text"
    print c.get_parent()
    print
    return 0

In [None]:
LFs = get_lfs() if not debug else [LF_DEBUG]

# Test out Label Functions

In [None]:
labeled = []
candidates = session.query(DiseaseGene).filter(DiseaseGene.split == 0).limit(1).all()

for c in candidates:
    print c
    print get_text_between(c)
    print c[1].sentence.entity_cids[c[1].get_word_start()]

# Label The Candidates

This block of code will run through the label functions and label each candidate in the training and development groups.

In [None]:
labeler = LabelAnnotator(lfs=LFs)

cids = session.query(Candidate.id).filter(Candidate.split==0)
%time L_train = labeler.apply(split=0, cids_query=cids, parallelism=5)

cids = session.query(Candidate.id).filter(Candidate.split==1)
%time L_dev = labeler.apply_existing(split=1, cids_query=cids, parallelism=5, clear=False)

cids = session.query(Candidate.id).filter(Candidate.split==2)
%time L_test = labeler.apply_existing(split=2, cids_query=cids, parallelism=5, clear=False)

# Generate Candidate Features

This block of code will generate features that some ml algorithms will use for classification.

In [None]:
%%time
featurizer = FeatureAnnotator()
featurizer.apply(split=0, clear=False)

In [None]:
%time F_dev = featurizer.apply_existing(split=1, parallelism=5, clear=False)
%time F_test = featurizer.apply_existing(split=2, parallelism=5, clear=False)

# Work Around for above code

This code below is a work around for the forever taking featurizer. Need to debug featurizer or at least check if there are snorkel updates on it on github, but anyway below code will write the feature rows to a text file. From that file psql will copy all the data to the database.

In [None]:
import csv
feature_key_hash = {}
with open('feature_key_fixed.sql', 'rb') as d:
    d.readline()
    feature_key_reader = csv.reader(d, delimiter='\t', quoting=csv.QUOTE_NONNUMERIC)
    for row in tqdm.tqdm(feature_key_reader):
        if len(row) <3:
            print row
        else:
            feature_key_hash[row[1]] = row[2]
            feat_counter = row[2]
print feat_counter

In [None]:
from snorkel.features import get_span_feats
group = 0
seen = set()
with open('feature_key_fixed.sql', 'ab') as f:
    with open('feature.sql', 'ab') as g:
        #f.write("COPY feature_key(\"group\", name, id) from stdin with CSV DELIMITER '	' QUOTE '\"';\n")
        #g.write("COPY feature(value, candidate_id, key_id) from stdin with CSV DELIMITER '	' QUOTE '\"';\n")
        
        feature_key_writer = csv.writer(f, delimiter='\t',  quoting=csv.QUOTE_NONNUMERIC)
        feature_writer = csv.writer(g, delimiter='\t', quoting=csv.QUOTE_NONNUMERIC)
        for c in tqdm.tqdm(session.query(Candidate).filter(Candidate.split==0).offset(2508430).all()):
            for name, value in get_span_feats(c):
                if name not in feature_key_hash:
                    feature_key_hash[name] = feat_counter
                    feat_counter = feat_counter + 1
                    feature_key_writer.writerow([group, name, feature_key_hash[name]])
                    
                if (c.id, name) not in seen:
                    feature_writer.writerow([value, c.id, feature_key_hash[name]])
                    seen.add((c.id, name))
            seen = set()
        for c in tqdm.tqdm(session.query(Candidate).filter(Candidate.split==1).all()):
            for name, value in get_span_feats(c):
                if name in feature_key_hash:
                    if (c.id, name) not in seen:
                        feature_writer.writerow([value, c.id, feature_key_hash[name]])
                        seen.add((c.id, name))
       
            seen = set()
        for c in tqdm.tqdm(session.query(Candidate).filter(Candidate.split==2).all()):
            for name, value in get_span_feats(c):
                if name in feature_key_hash:
                    if (c.id, name) not in seen:
                        feature_writer.writerow([value, c.id, feature_key_hash[name]])
                        seen.add((c.id, name))
            seen = set()

In [None]:
import hashlib
import re
seen = set()
with open('feature_key_fixed.sql', 'rb') as f:
    with open('feature_key.sql', 'wb') as g:
        g.write(f.readline())
        for line in tqdm.tqdm(f):
            data = line.split("\t")
            data[2] = re.sub(r'\.\d+','',data[2])
            md5hash = hashlib.md5(data[2]).hexdigest()
            if md5hash not in seen:
                seen.add(md5hash)
                g.write("\t".join(data))
            else:
                ids = re.search(r'\d+', data[2]).group(0)
                ids = int(ids) + 1
                data[2] = "{}\r\n".format(ids)
                seen.add(hashlib.md5(data[2]).hexdigest())
                g.write("\t".join(data))

In [3]:
import hashlib
seen = set()
with open('feature.sql', 'rb') as f:
    with open('feature_2.sql', 'wb') as g:
        g.write(f.readline())
        for line in tqdm.tqdm(f):
            data = line.split("\t")
            if  len(data) > 3:
                data = [data[1], data[2], data[3]]
            md5hash = hashlib.md5(line).hexdigest()
            if md5hash not in seen:
                data[2] = re.sub(r'\.\d+','',data[2])
                seen.add(md5hash)
                g.write(line)

246125907it [07:05, 577922.96it/s]


# Generate Coverage Stats

Before throwing our labels at a machine learning algorithm take a look at some quick stats. The code below will show the coverage of each label function and some other stat things. 

In [None]:
print L_train.lf_stats(session, )

In [None]:
print L_train.get_candidate(session,21)
print L_train.get_candidate(session,21).get_parent()

In [None]:
print L_train.shape
print L_train[L_train < 0].shape
print L_train[:,0]

In [None]:
print L_dev.lf_stats(session, )