# Tagging genes with ddlite: learning and labeling function iteration

## Introduction
In this example **ddlite** app, we'll build a gene tagger from scratch. Domain-specific tagging systems take months or years to develop. They use hand-crafted model circuitry and accurate, hand-labeled training data. We'll start to build a pretty good one in a few minutes with none of those things. The generalized extraction and learning utilities provided by ddlite will allow us to turn a sampling of article abstracts and some basic domain knowledge into an automated tagging system. Specifically, we want an accurate tagger for genes in academic articles. We have comprehensive dictionaries of genes, but applying a simple matching rule might yield a lot of false positives. For example, "p53" might get tagged as a gene if it refers to a page number. Our goal is to use distant supervision to improve precision.

Here's the pipeline we'll follow:

1. Obtain and parse input data (relevant article abstracts from PubMed)
2. Extract candidates for tagging
3. Generate features
4. Create a test set
5. Write labeling functions
6. Learn the tagging model
7. Iterate on labeling functions

Parts 3 through 7 are covered in this notebook. It requires candidates extracted from `GeneTaggerExample_Extraction.ipynb`, which covers parts 1 and 2.

In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import cPickle
import numpy as np
import matplotlib
# print(os.environ['SNORKELDB'])
# Use production DB
from set_env import set_env
set_env() 
sys.path.insert(1, '../snorkel')

# Must set SNORKELDB before importing SnorkelSession
from snorkel import SnorkelSession
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel.models import Document, Sentence, candidate_subclass
from snorkel.viewer import SentenceNgramViewer
session = SnorkelSession()

#np.random.seed(seed=1701)

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (18,6)

## Loading candidate extractions
First, we'll load in the candidates that we created in the last notebook. We can construct an docs object with the file.

In [None]:
PhenoPair = candidate_subclass('ComplexPhenotypes', ['descriptor', 'entity'])

docs = session.query(PhenoPair).filter(PhenoPair.split == 3).all()  #should edit split to be 3.  

print "Documents:", session.query(Document).count()
print "Sentences:", session.query(Sentence).count()

##Once we get all the labels, for loop through all docs and split into train, dev, test. 

print 'Document set:\t{0} candidates'.format(len(docs))

In [None]:
dev_docs = session.query(PhenoPair).filter(PhenoPair.split == 4).all()  #should edit split to be 1. 
test_docs = session.query(PhenoPair).filter(PhenoPair.split == 5).all()  #should edit split to be 1.
print 'Dev set:\t{0} candidates'.format(len(dev_docs))
print 'Test set:\t{0} candidates'.format(len(test_docs))

In [None]:
from snorkel.models.context import TemporaryContext
import re

print docs

## Make Labeling Functions

In [None]:
import re
import os
from snorkel.lf_helpers import (
    get_left_tokens,
    get_between_tokens,
    get_right_tokens,
    contains_token,
    get_text_between,
    get_text_splits,
    get_tagged_text,
    is_inverted,
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
    
)


def LF_mutant(c):
    return 1 if ('mutant' in get_right_tokens(c, attrib='lemmas')) or ('mutant' in get_left_tokens(c, attrib='lemmas')) else 0
def LF_variant(c):
    return 1 if ('variant' in get_right_tokens(c, attrib='lemmas')) or ('variant' in get_left_tokens(c, attrib='lemmas')) else 0
def LF_express(c):
    return 1 if ('express' in get_right_tokens(c, attrib='lemmas')) or ('express' in get_left_tokens(c, attrib='lemmas')) else 0  
def LF_JJ(c):
    return 1 if 'JJ' in get_right_tokens(c, attrib='pos_tags') else 0
def LF_IN(c):
    return 1 if 'IN' in get_right_tokens(c, window=1, attrib='pos_tags') else 0
def LF_dna(c):
    return -1 if contains_token(c, 'DNA', attrib='words') else 0
def LF_rna(c):
    return -1 if contains_token(c, 'RNA', attrib='words') else 0
def LF_snp(c):
    return -1 if contains_token(c, 'SNP', attrib='words') else 0
def LF_protein(c):
    return -1 if 'protein' in get_left_tokens(c, attrib='lemmas') else 0
def LF_LRB(c):
    return -1 if '-LRB-' in get_right_tokens(c, window=1, attrib='pos_tags') else 0
def LF_RRB(c):
    return -1 if '-RRB-' in get_right_tokens(c, window=1, attrib='pos_tags') else 0    
def LF_NNP(c):
    return -1 if contains_token(c, 'NNP', attrib='pos_tags') else 0
def lfdistBtw0(c):
    return 1 if len(get_between_tokens(c, attrib="words")) == 0 else 0
def lfdistBtw(c):
    return 1 if len(get_between_tokens(c, attrib="words")) < 3 else 0
def lfdistBtwNeg(c):
    return -1 if len(get_between_tokens(c, attrib="words")) > 4 else 0
def lfdistBtwNeg2(c):
    return -1 if len(get_between_tokens(c, attrib="words")) > 8 else 0

action_link_words = set(['affect', 'lead', 'led', 'show', 'display', 'exhibit', 'cause', 'result in'])
mutant_words = set(['mutant', 'mutation', 'plant', 'line', 'phenotype', 'seedlings', 'variant'])
helper_vbs = set(['is', 'was', 'are', 'were', 'become', 'became'])
tester_words = set(['sequence', 'published', 'diagram', 'hypothesis', 'hypothesize', 'aim', 'goal', 'understand', 'examine', 'we', 'our', 'experiment', 'test', 'study', 'design', 'analyze', 'analysis', 'results', 'research'])
neg_words = set(['strategy', 'public', 'examine', 'measure', 'subject', 'statistic', 'instance'])

def lfnegWords(c):
    for word in neg_words:
        if contains_token(c, word, attrib='lemmas'): return -1
    return 0    

def resultsAlone(c):
    return -1 if (len(c[0].get_attrib_tokens('words')) == 1 and contains_token(c[0], 'result', attrib='lemmas')) or (len(c[1].get_attrib_tokens('lemmas')) == 1 and contains_token(c[1], 'result', attrib='lemmas')) else 0

def lfLenCand(c):
    return -1 if len(c[0].get_attrib_tokens('words')) == 1 or len(c[1].get_attrib_tokens('words')) == 1 else 0

def lf1(c):
    return 1 if 'in' in get_between_tokens(c, attrib='words') else 0

def lf2(c):
    return 1 if len(action_link_words.intersection(set(get_between_tokens(c, attrib='lemmas')))) > 0 else 0

def lf2a(c):
    for aw in action_link_words:
        if contains_token(c[1], aw, attrib='lemmas'): return 1
    return 0

def lf3(c):
    return 1 if contains_token(c[0], 'JJR', attrib='pos_tags') else 0

def lf4(c):
    return 1 if contains_token(c, r'fold') or contains_token(c, r'\d+(\.\d+_)?%') or contains_token(c, 'percent') else 0

def lf5(c):
    return 1 if len(mutant_words.intersection(set(get_left_tokens(c, attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_right_tokens(c, attrib='lemmas')))) > 0 else 0

def lf6(c):
    return 1 if len(helper_vbs.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=3)))) > 0 else 0

def lf7(c):
    return -1 if 'not' in get_between_tokens(c) else 0

def lf8(c):
    return -1 if 'not' in get_left_tokens(c[0]) or 'not' in get_left_tokens(c[1]) else 0

def lf9(c):
    return -1 if 'level' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

def lf10(c):
    return -1 if 'transcript' in get_left_tokens(c[0], attrib='lemmas', n_max=3) or 'transcript' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

def lf11(c):
    return -1 if not contains_token(c, 'JJR', attrib='pos_tags') and not contains_token(c, 'JJ', attrib='pos_tags') and not contains_token(c[1], 'VBN', attrib='pos_tags') else 0

def inverted(c):
    return 1 if is_inverted(c) else 0

def lf12(c):
    return 1 if inverted(c) and lf1(c) else 0

def lf13(c):
    return 1 if inverted(c) and 'IN' in get_between_tokens(c, attrib='pos_tags', n_max=4) else 0

def lf14(c):
    return 1 if contains_token(c, 'phenotype', attrib='lemmas') else 0

def lf15(c):
    return -1 if 'protein' in get_left_tokens(c[0], attrib='lemmas') or 'protein' in get_right_tokens(c[0], attrib='lemmas') else 0

def lf16(c):
    return -1 if 'activity' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=1) else 0

def lf17(c):
    return -1 if len(tester_words.intersection(set(get_tagged_text(c).split()))) > 0 else 0

# def LF_gene_dp(c):
#     return 1 if 'gene' in get_left_tokens(c[0], window=2, attrib='lemmas') else 0
# def LF_genotype_dp(c):
#     return 1 if 'genotype' in get_left_tokens(c[0], window=2, attrib='lemmas') else 0
def LF_phenotype_dp(c):
    return 1 if 'phenotype' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0
def LF_mutation(c):
    return 1 if ('mutation' in get_right_tokens(c[0], window=2, attrib='lemmas') or 'mutant' in get_right_tokens(c[0], window=2, attrib='lemmas') or 'mutations' in get_right_tokens(c[0], window=2, attrib='lemmas')) else 0

def LF_pheno(c):
    return 1 if contains_token(c, 'phenotype', attrib='words') else 0

def LF_dev_dp(c):
    return -1 if 'development' in get_right_tokens(c[1], window=2, attrib='lemmas')  else 0
def LF_protein_dp(c):
    return -1 if 'protein' in get_left_tokens(c[1], window=2, attrib='lemmas') or 'protein' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0
def LF_network_dp(c):
    return -1 if 'network' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0
def LF_JJ_dp(c):
    return -1 if 'JJ' in get_right_tokens(c[1], window=2, attrib='pos_tags') else 0

def lf_helpers(c):
    return 1 if any(word in get_left_tokens(c, window=2, attrib='words') for word in ['had', 'has', 'was', 'have', 'showed', 'were', 'is', 'are', 'results']) else 0

adj_words = set(['increase', 'lower', 'reduce', 'higher', 'less', 'more', 'elevate', 'decrease', 'insensitive', 'absence', 'inhibit', 'double',])
def lf_adjwords(c):
    for aw in adj_words:
        if contains_token(c, aw, attrib='lemmas'): return 1
    return 0
    
stats_words = set(['statistically', 'significant', 'quantitative', 'real-time', 'generated', 'exposed', 'stratified'])
def lf_statswords(c):
    for aw in stats_words:
        if contains_token(c, aw, attrib='lemmas'): return -1
    return 0
   
def lf_proteinIn(c):
    return -1 if 'protein' in get_left_tokens(c[1], attrib='lemmas') or 'protein' in get_right_tokens(c[1], attrib='lemmas') or contains_token(c[1], 'protein', attrib='lemmas') else 0

def lf18(c):
    return 1 if 'mutant' in get_tagged_text(c).split() or 'mutation' in get_text_splits(c) else 0

# def lf19(c):
#     lemmas = c[0].get_attrib_tokens('lemmas')
#     poses = c[0].get_attrib_tokens('pos_tags')
#     for i, w in enumerate(poses):
#         if w in ['NN', 'NNS', 'NNP', 'NNPS'] and not re.match(r'\w+(ion|ment|vity)', lemmas[i]):
#             return -1
#     return 0

def lf20(c):
    lemmas = c[0].get_attrib_tokens('lemmas')
    poses = c[0].get_attrib_tokens('pos_tags')
    result = 0
    for i, w in enumerate(lemmas):
        if w in ['NN', 'NNS', 'NNP', 'NNPS'] and not re.match(r'\w+(ion|ment|vity)', lemmas[i]):
            result = 0
        elif re.match(r'\w+(ion|ment|vity)', lemmas[i]):
            result = -1
    return result

def lf21(c):
    return rule_regex_search_btw_BA(c, '.* in .*', 1)

def lf22(c):
    return -1 if 'expression' in get_right_tokens(c[0], attrib='lemmas', window=2) or 'expression' in get_left_tokens(c[0], attrib='lemmas', window=2) else 0
    
def lf23(c):
    return -1 if not inverted(c) and len(helper_vbs.intersection(set(get_right_tokens(c[0], window = 1, attrib='lemmas')))) > 0 and ('VBN' == c[0].get_attrib_tokens('pos_tags')[0] or ('VBN' == c[0].get_attrib_tokens('pos_tags')[1] and 'RB' == c[0].get_attrib_tokens('pos_tags')[0])) else 0

def lf24(c):
    return -1 if not contains_token(c[0], 'VB', attrib='pos_tags') and not contains_token(c[0], 'VBZ', attrib='pos_tags') and not contains_token(c[0], 'VBD', attrib='pos_tags') else 0

def lf25(c):
    return -1 if 'IN' == c[0].get_attrib_tokens('pos_tags')[0] or 'TO' == c[0].get_attrib_tokens('pos_tags')[0] else 0

def lf26(c):
    if len(c[0].get_attrib_tokens('pos_tags')) < 2:
        return 0
    return -1 if 'JJR' == c[0].get_attrib_tokens('pos_tags')[0] and len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')[1]))) == 0 else 0

def lf27(c):
    hasNoNoun = len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')[0]))) == 0
    return -1 if (len(c[0]) < 3 and hasNoNoun) else 0
                  
def lf28(c):
    if len(c[0].get_attrib_tokens('pos_tags')) == 0:
        return 0
    if len(c[0].get_attrib_tokens('lemmas')) < 2:
        return 0
    lastWordAdj = True if c[0].get_attrib_tokens('pos_tags')[-1] in set(['JJ', 'JJR']) else False
    nextLastVrb = True if c[0].get_attrib_tokens('lemmas')[-2] in helper_vbs else False
    return -1 if not nextLastVrb and lastWordAdj else 0              
    
def lf29(c):                  #if pheno ends in VBG, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30(c):                #if ends in prep, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf29b(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30b(c):                #if ends in prep, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf30c(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] in ['JJ', 'JJR', 'JJS'] else 0



def lf31(c):
    return 1 if len(get_between_tokens(c, attrib='words')) <= 2 else 0

def lf32(c):
     return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['is', 'are']]) else 0
        
def lf33(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['results', 'affected']]) else 0

def lf34(c):
    return -1 if len(get_between_tokens(c, attrib='words')) >= 5 else 0

def lf35(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['showed', 'were',]]) else 0


In [None]:
#HELPERS
def inverted(c):
    return 1 if is_inverted(c) else 0

#DISTANCE RULES
def lfdistBtw0(c):
    return 1 if len(get_between_tokens(c, attrib="words")) == 0 else 0
def lfdistBtwMax2(c):
    return 1 if len(get_between_tokens(c, attrib="words")) < 3 else 0
def lfdistBtwMin5(c):
    return -1 if len(get_between_tokens(c, attrib="words")) > 4 else 0
def lfdistBtwMin8(c):
    return -1 if len(get_between_tokens(c, attrib="words")) > 8 else 0
def lfdistBtwMax1(c):
    return 1 if len(get_between_tokens(c, attrib='words')) < 2 else 0
def lfdistBtwMin12(c):
    return -1 if len(get_between_tokens(c, attrib='words')) > 11 else 0
def lfdistBtwMin14(c):
    return -1 if len(get_between_tokens(c, attrib='words')) > 14 else 0

def lfLenCand(c):
    return -1 if len(c[0].get_attrib_tokens('words')) == 1 else 0

#WORD BASED RULES
cause_words = set(['affect', 'lead', 'led', 'show', 'display', 'exhibit', 'cause', 'result in'])
mutant_words = set(['mutant', 'mutation', 'plant', 'line', 'phenotype', 'seedling', 'variant'])
helper_vbs = set(['is', 'was', 'are', 'were', 'become', 'became', 'has', 'had'])
tester_words = set(['sequence', 'published', 'diagram', 'hypothesis', 'hypothesize', 'aim', 'goal', 'understand', 'examine', 'we', 'our', 'experiment', 'test', 'study', 'design', 'analyze', 'analysis', 'research'])
neg_words = set(['strategy', 'public', 'examine', 'measure', 'subject', 'statistic', 'instance'])
adj_words = set(['increase', 'low', 'reduce', 'high', 'less', 'more', 'elevate', 'decrease', 'insensitive', 'absence', 'inhibit', 'double'])   
stats_words = set(['statistically', 'quantitative', 'qualitative', 'real-time', 'generate', 'expose', 'stratify'])

#WORDS IN CAND RULES
def LF_dna(c):
    return -1 if contains_token(c, 'DNA', attrib='words') else 0
def LF_rna(c):
    return -1 if contains_token(c, 'RNA', attrib='words') else 0
def LF_snp(c):
    return -1 if contains_token(c, 'SNP', attrib='words') else 0

def lfwordis_result(c):
    return -1 if (len(c[0].get_attrib_tokens('words')) == 1 and contains_token(c[0], 'result', attrib='lemmas')) or (len(c[1].get_attrib_tokens('lemmas')) == 1 and contains_token(c[1], 'result', attrib='lemmas')) else 0

def lfwordsin_percent(c):
    return 1 if contains_token(c, r'fold') or contains_token(c, r'\d+(\.\d+)?%') or contains_token(c, 'percent') else 0

def lfwordsin_phenotype(c):
    return 1 if contains_token(c, 'phenotype', attrib='lemmas') else 0

def lfwordsin_testerwords(c):
    #return -1 if len(tester_words.intersection(set(get_tagged_text(c).split()))) > 0 else 0
    return -1 len(tester_words.intersection(set(c.get_parent()._asdict()['text'].split()))) > 0 else 0

def lfwordsin_statistically(c):
    return 1 if 'statistically' in c.get_parent()._asdict()['text'].split() else 0

def lfwordsin_negwords(c):
    for word in neg_words:
        if contains_token(c, word, attrib='lemmas'): return -1
    return 0 
def lfwordsin_causewords(c):
    for aw in cause_words:
        if contains_token(c, aw, attrib='lemmas'): return 1
    return 0
def lfwordsin_adjwords(c):
    for aw in adj_words:
        if contains_token(c, aw, attrib='lemmas'): return 1
    return 0
def lfwordsin_statswords(c):
    for aw in stats_words:
        if contains_token(c, aw, attrib='lemmas'): return -1
    return 0

#WORDS IN CONTEXT
def lfwordscontext_mutant(c):
    return 1 if len(mutant_words.intersection(set(get_left_tokens(c[0], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_right_tokens(c[0], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_left_tokens(c[1], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_right_tokens(c[1], attrib='lemmas')))) > 0 else 0
def lfwordsbtwn_mutant(c):
    return 1 if len(mutant_words.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=4)))) > 0

def LF_variant(c):
    return 1 if ('variant' in get_right_tokens(c, attrib='lemmas')) or ('variant' in get_left_tokens(c, attrib='lemmas')) else 0
def LF_express(c):
    return 1 if ('express' in get_right_tokens(c, attrib='lemmas')) or ('express' in get_left_tokens(c, attrib='lemmas')) else 0  
def lfLenCand(c):
    return -1 if len(c[0].get_attrib_tokens('words')) == 1 or len(c[1].get_attrib_tokens('words')) == 1 else 0


def lfwordscontext_protein_desc(c):
    return -1 if 'protein' in get_left_tokens(c[0], attrib='lemmas') or 'protein' in get_right_tokens(c[0], attrib='lemmas') else 0
def lfwordscontext_protein_ent(c):
    return -1 if 'protein' in get_left_tokens(c[1], window=2, attrib='lemmas') or 'protein' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0
def lfwordsin_protein(c):
    return -1 if contains_token(c[1], 'protein', attrib='lemmas') or contains_token(c[0], 'protein', attrib='lemmas') else 0


#def lf1(c):
    return 1 if 'in' in get_between_tokens(c, attrib='words') else 0
#def lf21(c):
    return rule_regex_search_btw_BA(c, '.* in .*', 1)

def lf2(c):
    return 1 if len(action_link_words.intersection(set(get_between_tokens(c, attrib='lemmas')))) > 0 else 0


def lf6(c):
    return 1 if len(helper_vbs.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=3)))) > 0 else 0

#def lf7(c):
#    return -1 if 'not' in get_between_tokens(c) else 0

#def lf8(c):
#    return -1 if 'not' in get_left_tokens(c[0]) or 'not' in get_left_tokens(c[1]) else 0

#def lf9(c):
#    return -1 if 'level' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

#def lf10(c):
#    return -1 if 'transcript' in get_left_tokens(c[0], attrib='lemmas', n_max=3) or 'transcript' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

def lf12(c):
    return 1 if inverted(c) and lf1(c) else 0



#def lf16(c):
#    return -1 if 'activity' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=1) else 0


def LF_phenotype_dp(c):
    return 1 if 'phenotype' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0

#def LF_dev_dp(c):
#    return -1 if 'development' in get_right_tokens(c[1], window=2, attrib='lemmas')  else 0
def LF_network_dp(c):
    return -1 if 'network' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0

def lf_helpers(c):
    return 1 if any(word in get_left_tokens(c, window=2, attrib='words') for word in ['had', 'has', 'was', 'have', 'showed', 'were', 'is', 'are', 'results']) else 0

def lf22(c):
    return -1 if 'expression' in get_right_tokens(c[0], attrib='lemmas', window=2) or 'expression' in get_left_tokens(c[0], attrib='lemmas', window=2) else 0
    
def lf23(c):
    return -1 if not inverted(c) and len(helper_vbs.intersection(set(get_right_tokens(c[0], window = 1, attrib='lemmas')))) > 0 and ('VBN' == c[0].get_attrib_tokens('pos_tags')[0] or ('VBN' == c[0].get_attrib_tokens('pos_tags')[1] and 'RB' == c[0].get_attrib_tokens('pos_tags')[0])) else 0

def lf32(c):
     return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['is', 'are']]) else 0
        
def lf33(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['results', 'affected']]) else 0

def lf35(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['showed', 'were', 'was']]) else 0

#POS

def LF_LRB_Context(c):
    return -1 if '-RRB-' in get_right_tokens(c[0], window=1, attrib='pos_tags') or '-RRB-' in get_right_tokens(c[1], window=1, attrib='pos_tags')else 0
def LF_LRB_Contains(c):
    return -1 if '-LRB-' == c[0].get_attrib_tokens['pos_tags'][0] or '-LRB-' == c[1].get_attrib_tokens['pos_tags'][0] else 0
def LF_RRB(c):
    return -1 if '-LRB-' in get_right_tokens(c[0], window=1, attrib='pos_tags') or '-LRB-' in get_right_tokens(c[1], window=1, attrib='pos_tags') else 0
def LF_JJR(c)
    return 1 if contains_token(c, 'JJR', attrib='pos_tags') else 0

def LF_ModPhrase(c):
    if is_inverted(c):
        if c[1].get_attrib_tokens['lemmas'][0] in helper_vbs and c[1].get_attrib_tokens['pos_tags'][1] in ['JJR', 'VBN', 'JJ', 'RBR', 'RB']:
            return 1
    else 0

def LF_JJ(c):
    return 1 if 'JJ' in get_right_tokens(c, attrib='pos_tags') else 0
def LF_IN(c):
    return 1 if 'IN' in get_right_tokens(c, window=1, attrib='pos_tags') else 0
   
def LF_NNP(c):
    return -1 if contains_token(c, 'NNP', attrib='pos_tags') else 0


def lf13(c):
    return 1 if inverted(c) and 'IN' in get_between_tokens(c, attrib='pos_tags', n_max=4) else 0

def LF_JJ_dp(c):
    return -1 if 'JJ' in get_right_tokens(c[1], window=2, attrib='pos_tags') else 0

def lf20(c):
    lemmas = c[0].get_attrib_tokens('lemmas')
    poses = c[0].get_attrib_tokens('pos_tags')
    result = 0
    for i, w in enumerate(lemmas):
        if w in ['NN', 'NNS', 'NNP', 'NNPS'] and not re.match(r'\w+(ion|ment|vity)', lemmas[i]):
            result = 0
        elif re.match(r'\w+(ion|ment|vity)', lemmas[i]):
            result = -1
    return result


def lf24(c):
    return -1 if not contains_token(c[0], 'VB', attrib='pos_tags') and not contains_token(c[0], 'VBZ', attrib='pos_tags') and not contains_token(c[0], 'VBD', attrib='pos_tags') else 0

def lf25(c):
    return -1 if 'IN' == c[0].get_attrib_tokens('pos_tags')[0] or 'TO' == c[0].get_attrib_tokens('pos_tags')[0] else 0

def lf26(c):
    if len(c[0].get_attrib_tokens('pos_tags')) < 2:
        return 0
    return -1 if 'JJR' == c[0].get_attrib_tokens('pos_tags')[0] and len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')[1]))) == 0 else 0

def lf27(c):
    hasNoNoun = len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')[0]))) == 0
    return -1 if (len(c[0]) < 3 and hasNoNoun) else 0
                  
def lf28(c):
    if len(c[0].get_attrib_tokens('pos_tags')) == 0:
        return 0
    if len(c[0].get_attrib_tokens('lemmas')) < 2:
        return 0
    lastWordAdj = True if c[0].get_attrib_tokens('pos_tags')[-1] in set(['JJ', 'JJR']) else False
    nextLastVrb = True if c[0].get_attrib_tokens('lemmas')[-2] in helper_vbs else False
    return -1 if not nextLastVrb and lastWordAdj else 0              
    
def lf29(c):                  #if pheno ends in VBG, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30(c):                #if ends in prep, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf29b(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30b(c):                #if ends in prep, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf30c(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] in ['JJ', 'JJR', 'JJS'] else 0


## Test how to query candidates

In [None]:
from snorkel.models.context import TemporaryContext
import re

print docs[15]
sent = docs[15].get_parent()
print sent
text = sent._asdict()['text']
splt = text.split()
print splt
print splt[4:5]
print "\n"
print "REGEX VERSION: "

resplit = re.split(' ',text)
print resplit
print resplit[4:5]
print "\n"

print docs[15].get_contexts()
print (docs[15][0]).get_attrib_tokens('words')
print len((docs[15][0]).get_attrib_tokens('words'))
# print (docs[15][0]).get_attrib_tokens('dep_parents')

#print LF_DP(docs[0])
print get_text_splits(docs[15])

# Running LFs on the Training Set

In [None]:
LFs = [
    LF_mutant,
    LF_variant,
    LF_express,
    LF_JJ,
    LF_IN,
    LF_dna,
    LF_rna,
    LF_snp,
    LF_protein,
    LF_LRB,
    LF_RRB,
    LF_NNP,
    lfdistBtw0,
    lfdistBtw,
    lfdistBtwNeg,
    lfdistBtwNeg2,
    lf1,
    lf2,
    lf3,
    lf4,
    lf5,
    lf6,
    lf7,
    lf8,
    lf9,
    lf10,
    lf11,
    lf2a,
    lf12,
    lf13,
    lf14,
    lf15,
    lf16,
    lf17,
    lf18,
    lf20,
    lf21,
    lf22,
    lf23,
    lf24,
    lf25,
    lf26,
    lf27,
    lf29,
    lf30,
#     lf29b,
#     lf30b,
#     lf30c,
    lf31,
    lf32,
    lf33,
    lf34,
    lf35,
    LF_phenotype_dp,
    LF_mutation,
    LF_pheno,
    LF_dev_dp,
    LF_protein_dp,
    LF_network_dp,
    LF_JJ_dp,
    lf_helpers,
    lf_adjwords,
    lf_statswords,
    lf_proteinIn,
    lfnegWords,
    resultsAlone,
    lfLenCand,
]

In [None]:
import re
import os
from snorkel.lf_helpers import (
    get_left_tokens,
    get_between_tokens,
    get_right_tokens,
    contains_token,
    get_text_between,
    get_text_splits,
    get_tagged_text,
    is_inverted,
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,    
)

def distance_btwn(c):
    span0 = c[0]
    span1 = c[1]
    indices0 = set(np.arange(span0.get_word_start(), span0.get_word_end() + 1))
    indices1 = set(np.arange(span1.get_word_start(), span1.get_word_end() + 1))
    if len(indices0.intersection(indices1)) > 0: return 0
    if span0.get_word_start() < span1.get_word_start():
        return span1.get_word_start() - span0.get_word_end() - 1
    else:
        left_span = span1
        return span0.get_word_start() - span1.get_word_end() - 1
    
def overlap(c):
    span0 = c[0]
    span1 = c[1]
    indices0 = set(np.arange(span0.get_word_start(), span0.get_word_end() + 1))
    indices1 = set(np.arange(span1.get_word_start(), span1.get_word_end() + 1))
    if len(indices0.intersection(indices1)) > 0: return 1
    return 0

#DISTANCE RULES
def LF(c):
    return 1 if distance_btwn(c) <2 or overlap(c) else -1

In [None]:
LFs = [LF]

In [None]:
from snorkel.annotations import LabelAnnotator
import multiprocessing
labeler = LabelAnnotator(f=LFs)

In [None]:
%time L_train = labeler.apply(split=3, parallelism=multiprocessing.cpu_count())
L_train

In [None]:
L_train.lf_stats(session)

* <b>Coverage</b> is the fraction of candidates that the labeling function emits a non-zero label for.
* <b>Overlap</b> is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
* <b>Conflict</b> is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a conflicting non-zero label for.

# Generative Model

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
len(deps)

In [None]:
deps # (lf, lf, relationship_type)

In [None]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, epochs=2, decay=0.95, step_size=0.1/L_train.shape[0],
    init_acc=2.0, reg_param=0.0,
    verbose=True
)

We now apply the generative model to the training candidates to get the noise-aware training label set. We'll refer to these as the training marginals:

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
gen_model.weights.lf_accuracy()

In [None]:
from snorkel.annotations import save_marginals
save_marginals(session, L_train, train_marginals)

# START RUNNING HERE TO GET SAVED MARGINALS

In [None]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, split=3)

# TARA- DON'T RUN THESE CELLS

### Using the Model to Iterate on Labeling Functions
Now that we have learned the generative model, we can stop here and use this to potentially debug and/or improve our labeling function set. First, we apply the LFs to our development set:

In [None]:
L_dev = labeler.apply_existing(split=4, parallelism=multiprocessing.cpu_count())

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold_complex', split=4)

In [None]:
tp, fp, tn, fn = gen_model.score(session, L_dev, L_gold_dev)

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv = SentenceNgramViewer(fp, session, height=400)
else:
    fpsv = None

In [None]:
fpsv

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv = SentenceNgramViewer(fn, session, height=400)
else:
    fnsv = None

In [None]:
fnsv

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv = SentenceNgramViewer(tp, session, height=400)
else:
    tpsv = None

In [None]:
tpsv

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv = SentenceNgramViewer(tn, session, height=400)
else:
    tnsv = None

In [None]:
tnsv

At this point, we should be getting an F1 score of around 0.6 to 0.7 on the development set, which is pretty good! However, we should be very careful in interpreting this. Since we developed our labeling functions using this development set as a guide, and our generative model is composed of these labeling functions, we expect it to score very well here!
In fact, it is probably somewhat overfit to this set. However this is fine, since in the next tutorial, we'll train a more powerful end extraction model which will generalize beyond the development set, and which we will evaluate on a blind test set (i.e. one we never looked at during development).

## Doing Some Error Analysis

At this point, we might want to look at some examples in one of the error buckets. For example, one of the false negatives that we did not correctly label as true mentions. To do this, we can again just use the Viewer:

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(fp, session, height=400)
else:
    sv = None

In [None]:
sv

In [None]:
from snorkel.lf_helpers import (
    get_left_tokens,
    get_between_tokens,
    get_right_tokens,
    contains_token,
    get_doc_candidate_spans,
    get_sent_candidate_spans,
    get_text_between,
    get_text_splits,
    get_tagged_text,
    is_inverted,
    get_tagged_text,
)
c = sv.get_selected() if sv else list(fp.union(fn))[0]
print(c)
print("\n")

c.labels
fp

## Automatically Creating Features

In [None]:
from snorkel.annotations import FeatureAnnotator
import multiprocessing
featurizer = FeatureAnnotator()

In [None]:
%time F_train = featurizer.apply(split=3, parallelism=multiprocessing.cpu_count())
F_train


Next, we apply the feature set we just got from the training set to the dev and test sets by using apply_existing:

In [None]:
%%time
F_dev  = featurizer.apply_existing(split=4, parallelism=multiprocessing.cpu_count())
F_test = featurizer.apply_existing(split=5, parallelism=multiprocessing.cpu_count())

In [None]:
F_train = featurizer.load_matrix(session, split=3)
F_dev   = featurizer.load_matrix(session, split=4)
F_test  = featurizer.load_matrix(session, split=5)

## Training the Discriminative Model

We use the training marginals to train a discriminative model that classifies each Candidate as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer.

In [None]:
from snorkel.learning import SparseLogisticRegression
disc_model = SparseLogisticRegression()


Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.
Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

In [None]:
from snorkel.learning.utils import MentionScorer
from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# Searching over learning rate
rate_param = RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10)
l1_param  = RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10)
l2_param  = RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)

searcher = RandomSearch(session, disc_model, F_train, train_marginals, [rate_param, l1_param, l2_param], n=20)

Next, we'll load in our dev set labels. We will pick the optimal result from the hyperparameter search by testing against these labels:

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold_complex', split=4)

Finally, we run the hyperparameter search / train the end extraction model:

In [None]:
np.random.seed(1701)
searcher.fit(F_dev, L_gold_dev, n_epochs=50, rebalance=0.5, print_freq=25)


Note that to train a model without tuning any hyperparameters (at your own risk) just use the train method of the discriminative model. For instance, to train with 20 epochs and a learning rate of 0.001, you could run:
disc_model.train(F_train, train_marginals, n_epochs=20, lr=0.001)
We can analyze the learned model by examining the weights. For example, we can print out the features with the highest weights.

In [None]:
w, _ = disc_model.get_weights()
largest_idxs = reversed(np.argsort(np.abs(w))[-5:])
for i in largest_idxs:
    print 'Feature: {0: <70}Weight: {1:.6f}'.format(F_train.get_key(session, i).name, w[i])


In this last section of the tutorial, we'll get the score we've been after: the performance of the extraction model on the blind test set (split 2). First, we load the test set labels and gold candidates we made in Part III.

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold_complex', split=5)

Now, we score using the discriminative model:

In [None]:
_, _, _, _ = disc_model.score(session, F_test, L_gold_test)


Note that if this is the final test set that you will be reporting final numbers on, to avoid biasing results you should not inspect results. However you can run the model on your development set and, as we did in the previous part with the generative labeling function model, inspect examples to do error analysis.

## Training an LSTM extraction model

In the intro tutorial, we automatically featurized the candidates and trained a linear model over these features. Here, we'll train a more complicated model for relation extraction: an LSTM network. You can read more about LSTMs here or here. An LSTM is a type of recurrent neural network and automatically generates a numerical representation for the candidate based on the sentence text, so no need for featurizing explicitly as in the intro tutorial. LSTMs take longer to train, and Snorkel doesn't currently support hyperparameter searches for them. We'll train a single model here, but feel free to try out other parameter sets. Just make sure to use the development set - and not the test set - for model selection.

In [None]:
dev_labels = (np.ravel(L_gold_dev.todense()) + 1) / 2
print dev_labels

In [None]:
from snorkel.contrib.rnn import reRNN

train_kwargs = {
    'lr':         0.01,
    'dim':        100,
    'n_epochs':   50,
    'dropout':    0.5,
    'rebalance':  0.25,
    'print_freq': 5
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(docs, train_marginals, dev_candidates=dev_docs, dev_labels=dev_labels, **train_kwargs)

In [None]:
from snorkel.contrib.rnn import reRNN

train_kwargs = {
    'lr':         0.01,
    'dim':        100,
    'n_epochs':   50,
    'dropout':    0.5,
    'rebalance':  0.25,
    'print_freq': 5
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(docs, train_marginals, dev_candidates=dev_docs, dev_labels=dev_labels, **train_kwargs)

### Trying to test....but it doesn't really work

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold_complex', split=5)

test_labels = (np.ravel(L_gold_test.todense()) + 1) / 2
print test_labels
_, _, _, _ = lstm.score(session, test_docs, test_labels)