# Tagging phenotypes: learning and labeling function iteration

## Introduction
In this notebook, we build a phenotype tagger from scratch.

Here's the pipeline we'll follow:

1. Load extracted candidates for tagging
2. Write labeling functions
6. Learn the tagging model
7. Iterate on labeling functions
3. Generate features
4. Learn two discriminative models - LogReg and LSTM


This notebook requires candidates extracted from `Complex_Pheno_Extraction.ipynb` and gold labels extracted from `Complex_Pheno_BRAT_Import.ipynb`.

In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import cPickle
import numpy as np
import matplotlib
# print(os.environ['SNORKELDB'])
# Use production DB
from set_env import set_env
set_env() 
sys.path.insert(1, '../snorkel')

# Must set SNORKELDB before importing SnorkelSession
from snorkel import SnorkelSession
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel.models import Document, Sentence, candidate_subclass
from snorkel.viewer import SentenceNgramViewer
session = SnorkelSession()

#np.random.seed(seed=1701)

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (18,6)

## Loading candidate extractions
First, we'll load in the candidates that we created in the last notebook. We can construct an docs object with the file.

In [None]:
PhenoPair = candidate_subclass('ComplexPhenotypes', ['descriptor', 'entity'])

In [None]:
#should ultimately edit splits to be 0,1,and,2.
train = session.query(PhenoPair).filter(PhenoPair.split == 3).all()  
dev = session.query(PhenoPair).filter(PhenoPair.split == 4).all()
test = session.query(PhenoPair).filter(PhenoPair.split == 5).all()

print "Documents:", session.query(Document).count()
print "Sentences:", session.query(Sentence).count()

print 'Train Document Set:\t{0} candidates'.format(len(train))
print 'Dev Document Set:\t{0} candidates'.format(len(dev))
print 'Test Document Set:\t{0} candidates'.format(len(test))

## Make Labeling Functions

In [None]:
from snorkel.models.context import TemporaryContext
import re
import os
from snorkel.lf_helpers import (
    get_left_tokens,
    get_between_tokens,
    get_right_tokens,
    contains_token,
    get_text_between,
    get_text_splits,
    get_tagged_text,
    is_inverted,
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,    
)

#DICTIONARIES
cause_words = set(['affect', 'lead', 'led', 'show', 'display', 'exhibit', 'cause', 'result in'])
mutant_words = set(['mutant', 'mutation', 'plant', 'line', 'phenotype', 'seedling', 'variant'])
helper_vbs = set(['is', 'was', 'are', 'were', 'become', 'became', 'has', 'had'])
tester_words = set(['sequence', 'published', 'diagram', 'hypothesis', 'hypothesize', 'aim', 'goal', 'understand', 'examine', 'we', 'our', 'experiment', 'test', 'study', 'design', 'analyze', 'analysis', 'research'])
neg_words = set(['strategy', 'public', 'examine', 'measure', 'subject', 'statistic', 'instance'])
adj_words = set(['increase', 'low', 'reduce', 'high', 'less', 'more', 'elevate', 'decrease', 'insensitive', 'absence', 'inhibit', 'double'])   
stats_words = set(['statistically', 'quantitative', 'qualitative', 'real-time', 'generate', 'expose', 'stratify'])
cc_words = set(['while', 'but', 'however', 'whereas'])
comp_words = set(['compare', 'relative', 'than', 'same', 'different', 'relatively', 'contrast', 'similar'])

#HELPERS
def inverted(c):
    return 1 if is_inverted(c) else 0

def distance_btwn(c):
    span0 = c[0]
    span1 = c[1]
    indices0 = set(np.arange(span0.get_word_start(), span0.get_word_end() + 1))
    indices1 = set(np.arange(span1.get_word_start(), span1.get_word_end() + 1))
    if len(indices0.intersection(indices1)) > 0: return 0
    if span0.get_word_start() < span1.get_word_start():
        return span1.get_word_start() - span0.get_word_end() - 1
    else:
        left_span = span1
        return span0.get_word_start() - span1.get_word_end() - 1
    
def overlap(c):
    span0 = c[0]
    span1 = c[1]
    indices0 = set(np.arange(span0.get_word_start(), span0.get_word_end() + 1))
    indices1 = set(np.arange(span1.get_word_start(), span1.get_word_end() + 1))
    if len(indices0.intersection(indices1)) > 0: return 1
    return 0

def ends_in(ci, val, attrib):
    return val == ci.get_attrib_tokens(attrib)[-1]
    
def starts_with(ci, val, attrib):
    return val == ci.get_attrib_tokens(attrib)[0]


#DISTANCE RULES
def lfdistBtw0(c):
    return 1 if distance_btwn(c) == 0 else 0
def lfdistBtwMax1(c):
    return 1 if distance_btwn(c) < 2 else 0
def lfdistBtwMax2(c):
    return 1 if distance_btwn(c) < 3 else 0
def lfdistBtwnOverlap(c):
    return overlap(c)

def lfdistBtwMin5(c):
    return -1 if distance_btwn(c) > 4 else 0
def lfdistBtwMin8(c):
    return -1 if distance_btwn(c) > 8 else 0
def lfdistBtwMin12(c):
    return -1 if distance_btwn(c) > 11 else 0
def lfdistBtwMin14(c):
    return -1 if distance_btwn(c) > 14 else 0

#LENGTH RULES
def lfLenCand(c):
    return -1 if len(c[0].get_attrib_tokens('words')) == 1 else 0

#PAIRWISE RULES
def lfend_prep(c):
    if not overlap(c):
        left = c[0] if c[0].get_word_start() < c[1].get_word_start() else c[1]
        right = c[1] if c[0].get_word_start() < c[1].get_word_start() else c[0]
        if left.get_attrib_tokens('pos_tags')[-1] == 'IN' and len(set(right.get_attrib_tokens('pos_tags')[:2]).intersection(set(['NN', 'NNS', 'NNP', 'NNPS', 'DT'])))==0:
            return -1 
    return 0
    
def lfend_det(c):
    if not overlap(c):
        left = c[0] if c[0].get_word_start() < c[1].get_word_start() else c[1]
        right = c[1] if c[0].get_word_start() < c[1].get_word_start() else c[0]
        if left.get_attrib_tokens('pos_tags')[-1] == 'DT' and right.get_attrib_tokens('pos_tags')[0] not in ['NN', 'NNS', 'NNP', 'NNPS']:
            return -1
    return 0

  
def lfend_adj(c):
    left = c[0] if c[0].get_word_start() < c[1].get_word_start() else c[1]
    right = c[1] if c[0].get_word_start() < c[1].get_word_start() else c[0]
    if left.get_attrib_tokens('pos_tags')[-1] in ['JJ', 'JJR', 'VBN'] and len(set(get_right_tokens(left, attrib='pos_tags', n_max=2)).intersection(set(['NN', 'NNS', 'NNP', 'NNPS'])))==0:
        return -1
    return 0
        
#CONTAINS RULES
def lfin_fig(c):
    return -1 if contains_token(c[0], 'Fig', attrib='words') \
    or contains_token(c[0], 'FIG', attrib='words') \
    or contains_token(c[0], 'fig', attrib='words') \
    or contains_token(c[0], 'Fig', attrib='words') \
    or contains_token(c[0], 'FIG.', attrib='words') \
    or contains_token(c[0], 'fig.', attrib='words') \
    or contains_token(c[0], 'Fig.', attrib='words') else 0

def lfin_num(c):
    return 1 if contains_token(c[0], 'CD', attrib='pos_tags') or contains_token(c[1], 'CD', attrib='pos_tags') else 0

def lfin_equals(c):
    return -1 if contains_token(c[0], '=', attrib='words') or contains_token(c[1], '=', attrib='words') else 0

def lfin_comp_adj(c):
    return 1 if contains_token(c[0], 'JJR', attrib='pos_tags') or contains_token(c[1], 'JJR', attrib='pos_tags') else 0
    
def lfin_comp_advb(c):
    return 1 if contains_token(c[0], 'RBR', attrib='pos_tags') or contains_token(c[1], 'RBR', attrib='pos_tags') else 0

#BETWEEN RULES
def lfbtwn_is(c):
    return 1 if len(helper_vbs.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=4)))) > 0 else 0
    return 1 if 1 == distance_btwn(c) and len(helper_vbs.intersection(set(get_between_tokens(c, attrib='lemmas')))) > 0 else 0
                                             
def lfbtwm_comma(c):
    return -1 if -1 == rule_regex_search_btw_BA(c, '*[,;]*', -1) or -1 == rule_regex_search_btw_AB(c, '*[,;]*', -1) else 0
                                             
def lfbtwn_parenthesis(c):
    return -1 if re.search(r'\([^\)]*{{A}}.*\).*{{B}}', get_tagged_text(c), flags=re.I) \
    or re.search(r'\([^\)]*{{B}}.*\).*{{A}}', get_tagged_text(c), flags=re.I) \
    or re.search(r'\{{A}}.*\([^\)]*{{B}}.*\)', get_tagged_text(c), flags=re.I) \
    or re.search(r'\{{B}}.*\([^\)]*{{A}}.*\)', get_tagged_text(c), flags=re.I) else 0
    
                                             
#WORD BASED RULES


#WORDS IN CAND RULES
def LF_dna(c):
    return -1 if contains_token(c, 'DNA', attrib='words') else 0
def LF_rna(c):
    return -1 if contains_token(c, 'RNA', attrib='words') else 0
def LF_snp(c):
    return -1 if contains_token(c, 'SNP', attrib='words') else 0

def lfwordis_result(c):
    return -1 if (len(c[0].get_attrib_tokens('words')) == 1 and contains_token(c[0], 'result', attrib='lemmas')) or (len(c[1].get_attrib_tokens('lemmas')) == 1 and contains_token(c[1], 'result', attrib='lemmas')) else 0

def lfwordsin_percent(c):
    return 1 if contains_token(c, r'fold') or contains_token(c, r'\d+(\.\d+)?%') or contains_token(c, 'percent') else 0

def lfwordsin_phenotype(c):
    return 1 if contains_token(c, 'phenotype', attrib='lemmas') else 0

def lfwordsin_testerwords(c):
    #return -1 if len(tester_words.intersection(set(get_tagged_text(c).split()))) > 0 else 0
    return -1 if len(tester_words.intersection(set(c.get_parent()._asdict()['text'].split()))) > 0 else 0

#def lfwordsin_statistically(c):
#    return 1 if 'statistically' in c.get_parent()._asdict()['text'].split() else 0

def lfwordsin_compwords(c):
    for word in comp_words:
        if contains_token(c, word, attrib='lemmas'): return 1
    return 0 

def lfwordsin_negwords(c):
    for word in neg_words:
        if contains_token(c, word, attrib='lemmas'): return -1
    return 0 
def lfwordsin_causewords(c):
    for aw in cause_words:
        if contains_token(c, aw, attrib='lemmas'): return 1
    return 0
def lfwordsin_adjwords(c):
    for aw in adj_words:
        if contains_token(c, aw, attrib='lemmas'): return 1
    return 0
def lfwordsin_statswords(c):
    for aw in stats_words:
        if contains_token(c, aw, attrib='lemmas'): return -1
    return 0

#WORDS IN CONTEXT
def lfwordscontext_mutant(c):
    return 1 if len(mutant_words.intersection(set(get_left_tokens(c[0], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_right_tokens(c[0], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_left_tokens(c[1], attrib='lemmas')))) > 0 or len(mutant_words.intersection(set(get_right_tokens(c[1], attrib='lemmas')))) > 0 else 0
def lfwordsbtwn_mutant(c):
    return 1 if len(mutant_words.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=4)))) > 0 else 0

def LF_variant(c):
    return 1 if ('variant' in get_right_tokens(c, attrib='lemmas')) or ('variant' in get_left_tokens(c, attrib='lemmas')) else 0
def LF_express(c):
    return 1 if ('express' in get_right_tokens(c, attrib='lemmas')) or ('express' in get_left_tokens(c, attrib='lemmas')) else 0  
#def lfLenCand(c):
#    return -1 if len(c[0].get_attrib_tokens('words')) == 1 or len(c[1].get_attrib_tokens('words')) == 1 else 0


def lfwordscontext_protein_desc(c):
    return -1 if 'protein' in get_left_tokens(c[0], attrib='lemmas') or 'protein' in get_right_tokens(c[0], attrib='lemmas') else 0
def lfwordscontext_protein_ent(c):
    return -1 if 'protein' in get_left_tokens(c[1], window=2, attrib='lemmas') or 'protein' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0
def lfwordsin_protein(c):
    return -1 if contains_token(c[1], 'protein', attrib='lemmas') or contains_token(c[0], 'protein', attrib='lemmas') else 0


#def lf1(c):
    #return 1 if 'in' in get_between_tokens(c, attrib='words') else 0
#def lf21(c):
    #return rule_regex_search_btw_BA(c, '.* in .*', 1)

def lf2(c):
    return 1 if len(cause_words.intersection(set(get_between_tokens(c, attrib='lemmas')))) > 0 else 0


def lf6(c):
    return 1 if len(helper_vbs.intersection(set(get_between_tokens(c, attrib='lemmas', n_max=3)))) > 0 else 0

#def lf7(c):
#    return -1 if 'not' in get_between_tokens(c) else 0

#def lf8(c):
#    return -1 if 'not' in get_left_tokens(c[0]) or 'not' in get_left_tokens(c[1]) else 0

#def lf9(c):
#    return -1 if 'level' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

#def lf10(c):
#    return -1 if 'transcript' in get_left_tokens(c[0], attrib='lemmas', n_max=3) or 'transcript' in get_right_tokens(c[0], attrib='lemmas', n_max=2) else 0

#def lf12(c):
#    return 1 if inverted(c) and lf1(c) else 0



#def lf16(c):
#    return -1 if 'activity' in get_left_tokens(c[0], attrib='lemmas', n_max=2) or 'level' in get_right_tokens(c[0], attrib='lemmas', n_max=1) else 0


def LF_phenotype_dp(c):
    return 1 if 'phenotype' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0

#def LF_dev_dp(c):
#    return -1 if 'development' in get_right_tokens(c[1], window=2, attrib='lemmas')  else 0
#def LF_network_dp(c):
#    return -1 if 'network' in get_right_tokens(c[1], window=2, attrib='lemmas') else 0

def lf_helpers(c):
    return 1 if any(word in get_left_tokens(c, window=2, attrib='words') for word in ['had', 'has', 'was', 'have', 'showed', 'were', 'is', 'are', 'results']) else 0

#def lf22(c):
#    return -1 if 'expression' in get_right_tokens(c[0], attrib='lemmas', window=2) or 'expression' in get_left_tokens(c[0], attrib='lemmas', window=2) else 0
    
def lf23(c):
    return -1 if not inverted(c) and len(helper_vbs.intersection(set(get_right_tokens(c[0], window = 1, attrib='lemmas')))) > 0 and ('VBN' == c[0].get_attrib_tokens('pos_tags')[0] or ('VBN' == c[0].get_attrib_tokens('pos_tags')[1] and 'RB' == c[0].get_attrib_tokens('pos_tags')[0])) else 0

def lf32(c):
     return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['is', 'are']]) else 0
        
def lf33(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['results', 'affected']]) else 0

def lf35(c):
    return 1 if any([word in get_left_tokens(c, window=4, attrib='words') for word in ['showed', 'were', 'was']]) else 0

#POS

def LF_LRB_Context(c):
    return -1 if '-RRB-' in get_right_tokens(c[0], window=1, attrib='pos_tags') or '-RRB-' in get_right_tokens(c[1], window=1, attrib='pos_tags')else 0
def LF_LRB_Contains(c):
    return -1 if '-LRB-' == c[0].get_attrib_tokens('pos_tags')[0] or '-LRB-' == c[1].get_attrib_tokens('pos_tags')[0] else 0
def LF_RRB(c):
    return -1 if '-LRB-' in get_right_tokens(c[0], window=1, attrib='pos_tags') or '-LRB-' in get_right_tokens(c[1], window=1, attrib='pos_tags') else 0
def LF_JJR(c):
    return 1 if contains_token(c, 'JJR', attrib='pos_tags') else 0

def LF_ModPhrase(c):
    if is_inverted(c):
        if c[1].get_attrib_tokens('lemmas')[0] in helper_vbs and c[1].get_attrib_tokens('pos_tags')[1] in ['JJR', 'VBN', 'JJ', 'RBR', 'RB']:
            return 1
    return 0

def LF_JJ(c):
    return 1 if 'JJ' in get_right_tokens(c, attrib='pos_tags') else 0
def LF_IN(c):
    return 1 if 'IN' in get_right_tokens(c, window=1, attrib='pos_tags') else 0
   
def LF_NNP(c):
    return -1 if contains_token(c, 'NNP', attrib='pos_tags') else 0


def lf13(c):
    return 1 if inverted(c) and 'IN' in get_between_tokens(c, attrib='pos_tags', n_max=4) else 0

def LF_JJ_dp(c):
    return -1 if 'JJ' in get_right_tokens(c[1], window=2, attrib='pos_tags') else 0

# def lf20(c):
#     lemmas = c[0].get_attrib_tokens('lemmas')
#     poses = c[0].get_attrib_tokens('pos_tags')
#     result = 0
#     for i, w in enumerate(lemmas):
#         if w in ['NN', 'NNS', 'NNP', 'NNPS'] and not re.match(r'\w+(ion|ment|vity)', lemmas[i]):
#             result = 0
#         elif re.match(r'\w+(ion|ment|vity)', lemmas[i]):
#             result = -1
#     return result

def lf20(c):
    lemmas = c[0].get_attrib_tokens('lemmas')
    poses = c[0].get_attrib_tokens('pos_tags')
    result = 0
    for i, w in enumerate(lemmas):
        if re.match(r'\w+(ion|ment|vity)', lemmas[i]):
            return 1
    return result


def lf24(c):
    return -1 if not contains_token(c[0], 'VB', attrib='pos_tags') and not contains_token(c[0], 'VBZ', attrib='pos_tags') and not contains_token(c[0], 'VBD', attrib='pos_tags') else 0

def lf25(c):
    return -1 if 'IN' == c[0].get_attrib_tokens('pos_tags')[0] or 'TO' == c[0].get_attrib_tokens('pos_tags')[0] else 0

def lf26(c):
    if len(c[0].get_attrib_tokens('pos_tags')) < 2:
        return 0
    return -1 if 'JJR' == c[0].get_attrib_tokens('pos_tags')[0] and len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')[1]))) == 0 else 0

def lfnonoun(c):
    return -1 if len(set(['NN', 'NNS', 'NNP', 'NNPS']).intersection(set(c[0].get_attrib_tokens('pos_tags')+c[1].get_attrib_tokens('pos_tags')))) == 0 else 0
    return -1 if (len(c[0]) < 3 and hasNoNoun) else 0
                  
def lf28(c):
    if len(c[0].get_attrib_tokens('pos_tags')) == 0:
        return 0
    if len(c[0].get_attrib_tokens('lemmas')) < 2:
        return 0
    lastWordAdj = True if c[0].get_attrib_tokens('pos_tags')[-1] in set(['JJ', 'JJR']) else False
    nextLastVrb = True if c[0].get_attrib_tokens('lemmas')[-2] in helper_vbs else False
    return -1 if not nextLastVrb and lastWordAdj else 0              
    
def lf29(c):                  #if pheno ends in VBG, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30(c):                #if ends in prep, its bad
    return -1 if c[0].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf29b(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'VBG' else 0

def lf30b(c):                #if ends in prep, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] == 'IN' else 0

def lf30c(c):                  #if pheno ends in VBG, its bad
    return -1 if c[1].get_attrib_tokens('pos_tags')[-1] in ['JJ', 'JJR', 'JJS'] else 0


In [None]:
LFs = [
    lfdistBtw0,
    lfdistBtwMax1,
    lfdistBtwMax2,
    lfdistBtwnOverlap,
    lfdistBtwMin5,
    lfdistBtwMin8,
    lfdistBtwMin12,
    lfdistBtwMin14,
    lfLenCand,
    lfend_prep,
    lfend_det,
    lfend_adj,
    lfin_fig,
    lfin_num,
    lfin_equals,
    lfin_comp_adj,
    lfbtwn_is,
    lfbtwm_comma,
    lfbtwn_parenthesis,
    LF_dna,
    LF_rna,
    LF_snp,
    lfwordis_result,
    lfwordsin_percent,
    lfwordsin_phenotype,
    lfwordsin_testerwords,
    lfwordsin_compwords,
    lfwordsin_negwords,
    lfwordsin_causewords,
    lfwordsin_adjwords,
    lfwordsin_statswords,
    lfnonoun,
#     lfwordscontext_mutant,
#     lfwordsbtwn_mutant,
#     LF_variant,
#     LF_express,
#     lfLenCand,
#     lfwordscontext_protein_desc,
#     lfwordscontext_protein_ent,
#     lfwordsin_protein,
     lf2,
     lf6,
     LF_phenotype_dp,
     lf_helpers,
#     lf23,
     lf32,
     lf33,
     lf35,
# #     LF_LRB_Context,
# #     LF_LRB_Contains,
# #     LF_RRB,
#     LF_JJR,
#     LF_ModPhrase,
#     LF_JJ,
#     LF_IN,
#     LF_NNP,
# #    lf13,
# #    LF_JJ_dp,
     lf20,
     lf24,
#     lf25,
     lf26
# #    lf27,
#     lf28,
#     lf29,
#     lf30
# #     lf29b,
# #     lf30b,
# #     lf30c
]

### testing how to query candidates - no need to run this

In [None]:
from snorkel.models.context import TemporaryContext
import re

print docs[15]
sent = docs[15].get_parent()
print sent
text = sent._asdict()['text']
splt = text.split()
print splt
print splt[4:5]
print "\n"
print "REGEX VERSION: "

resplit = re.split(' ',text)
print resplit
print resplit[4:5]
print "\n"

print docs[15].get_contexts()
print (docs[15][0]).get_attrib_tokens('words')
print len((docs[15][0]).get_attrib_tokens('words'))
# print (docs[15][0]).get_attrib_tokens('dep_parents')

#print LF_DP(docs[0])
print get_text_splits(docs[15])

# Running LFs on the Training Set

In [None]:
from snorkel.annotations import LabelAnnotator
import multiprocessing
from snorkel.annotations import load_gold_labels

In [None]:
labeler = LabelAnnotator(f=LFs)

In [None]:
%time L_train = labeler.apply(split=3, parallelism=multiprocessing.cpu_count())
L_train

In [None]:
L_train.lf_stats(session)

* <b>Coverage</b> is the fraction of candidates that the labeling function emits a non-zero label for.
* <b>Overlap</b> is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
* <b>Conflict</b> is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a conflicting non-zero label for.

In [None]:
L_gold_dev = load_gold_labels(session, annotator_name='gold_complex', split=4)

L_dev = labeler.apply_existing(split=4, parallelism=multiprocessing.cpu_count())

## Single LF Baseline 

In [None]:
#function returns +1 if candidates separated by at most 1 token and -1 otherwise
def LF1(c):
    return 1 if distance_btwn(c)<2 or overlap(c) else -1

single_LF = [LF1]
single_labeler = LabelAnnotator(f=LFs)
%time single_L_train = single_labeler.apply(split=3, parallelism=multiprocessing.cpu_count())

In [None]:
single_L_dev = single_labeler.apply_existing(split=4, parallelism=multiprocessing.cpu_count())

In [None]:
from scipy.sparse import csr_matrix
tlabs = (single_L_dev+L_gold_dev)/2
tlabs = csr_matrix.toarray(tlabs).reshape(tlabs.shape[0])
flabs = (single_L_dev-L_gold_dev)/2
flabs = csr_matrix.toarray(flabs).reshape(flabs.shape[0])

tp = tlabs[tlabs==1].shape[0]
tn = tlabs[tlabs==-1].shape[0]
fp = flabs[flabs==1].shape[0]
fn = flabs[flabs==-1].shape[0]

prec = float(tp) / (tp+fp)
rec = float(tp) / (tp+fn)
f1 = (2.0*prec*rec)/(prec+rec)
print 'Precision:', prec
print 'Recall:', rec
print 'F1:', f1
print '========================='

print 'True Positives:', tp
print 'False Positives:', fp
print 'True Negatives:', tn
print 'False Negatives:', fn

## Majority Vote Baseline 

In [None]:
summed = np.clip(np.sum(L_dev, axis=1), -1, 1)
eqs = (L_gold_dev == summed)
eqs_1 = (eqs == summed)
print eqs_1.shape
# print (summed[summed == 0].shape)
# print (summed[summed == 1].shape)
print L_gold_dev[L_gold_dev == 1].shape

In [None]:
print(L_gold_dev.shape)
print(summed.shape)

print eqs[eqs == True].shape
from snorkel.learning.utils import MentionScorer
dev_candidates = [L_dev.get_candidate(session, i) for i in xrange(L_dev.shape[0])]
s = MentionScorer(dev_candidates, L_gold_dev)
s.score(summed, b=0)

# Generative Model

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
len(deps)

In [None]:
deps # (lf, lf, relationship_type)

In [None]:
from snorkel.learning import GenerativeModel

In [None]:
gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, epochs=2, decay=0.975, step_size=0.0/L_train.shape[0],
    init_acc=2.0, reg_param=0.0, burn_in = 10,
    verbose=True
 )

We now apply the generative model to the training candidates to get the noise-aware training label set. We'll refer to these as the training marginals:

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
gen_model.weights.lf_accuracy()

In [None]:
from snorkel.annotations import save_marginals
save_marginals(session, L_train, train_marginals)

In [None]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, split=3)

### Using the Model to Iterate on Labeling Functions
Now that we have learned the generative model, we can stop here and use this to potentially debug and/or improve our labeling function set. First, we apply the LFs to our development set:

In [None]:
tp_gen, fp_gen, tn_gen, fn_gen = gen_model.score(session, L_dev, L_gold_dev)

## Doing Some Error Analysis

### At this point, we might want to look at some examples in one of the error buckets. For example, one of the false negatives that we did not correctly label as true mentions. To do this, we can again just use the Viewer:

### View False Positives 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv = SentenceNgramViewer(fp, session, height=400)
else:
    fpsv = None

In [None]:
fpsv

### View False Negatives 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv = SentenceNgramViewer(fn, session, height=400)
else:
    fnsv = None

In [None]:
fnsv

### View True Positives 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv = SentenceNgramViewer(tp, session, height=400)
else:
    tpsv = None

In [None]:
tpsv

### View True Negatives 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv = SentenceNgramViewer(tn, session, height=400)
else:
    tnsv = None

In [None]:
tnsv

## Automatically Creating Features

In [None]:
from snorkel.annotations import FeatureAnnotator
import multiprocessing
featurizer = FeatureAnnotator()

In [None]:
%time F_train = featurizer.apply(split=3, parallelism=multiprocessing.cpu_count())
F_train


Next, we apply the feature set we just got from the training set to the dev and test sets by using apply_existing:

In [None]:
%%time
F_dev  = featurizer.apply_existing(split=4, parallelism=multiprocessing.cpu_count())
F_test = featurizer.apply_existing(split=5, parallelism=multiprocessing.cpu_count())

In [None]:
F_train = featurizer.load_matrix(session, split=3)
F_dev   = featurizer.load_matrix(session, split=4)
F_test  = featurizer.load_matrix(session, split=5)

## Training the Discriminative Model

We use the training marginals to train a discriminative model that classifies each Candidate as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer.

In [None]:
from snorkel.learning import SparseLogisticRegression
disc_model = SparseLogisticRegression()


Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.
Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

In [None]:
from snorkel.learning.utils import MentionScorer
from snorkel.learning import ListParameter, RangeParameter

# Searching over learning rate
rate_param = RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10)
l1_param  = RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10)
l2_param  = RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)

Next, we'll load in our dev set labels. We will pick the optimal result from the hyperparameter search by testing against these labels:

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold_complex', split=4)

Finally, we run the hyperparameter search / train the end extraction model:

In [None]:
disc_model.train(F_train, train_marginals, n_epochs=50, lr=0.0001, batch_size=100, \
                 l1_penalty=0.0001, l2_penalty=0.01, print_freq=25,\
                 rebalance=0.5, seed=432)

In [None]:
w, _ = disc_model.get_weights()
largest_idxs = reversed(np.argsort(np.abs(w))[-5:])
for i in largest_idxs:
    print 'Feature: {0: <70}Weight: {1:.6f}'.format(F_train.get_key(session, i).name, w[i])


In this last section of the tutorial, we'll get the score we've been after: the performance of the extraction model on the blind test set (split 2). First, we load the test set labels and gold candidates we made in Part III.

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold_complex', split=5)

Now, we score using the discriminative model:

In [None]:
tp_lr, fp_lr, tn_lr, fn_lr = disc_model.score(session, F_dev, L_gold_dev)

In [None]:
tp_lr_test, fp_lr_test, tn_lr_test, fn_lr_test = disc_model.score(session, F_test, L_gold_test)

# LSTM

TODO: Implement a version of an LSTM w/ a decayed learning rate

In [None]:
lstm_dev_labels = (np.ravel(L_gold_dev.todense()) + 1) / 2

In [None]:
from snorkel.contrib.rnn import reRNN

np.random.seed(432)

train_kwargs = {
    'lr':   0.01,
    'dim':        50,
    'n_epochs':   42,
    'dropout':    0.5,
    'rebalance':  0.5,
    'print_freq': 1
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train, train_marginals, dev_candidates=dev, dev_labels=lstm_dev_labels, **train_kwargs)

In [None]:
from snorkel.contrib.rnn.utils import f1_score
dev_data, _ = lstm._preprocess_data(dev, extend=False)
labels = np.ravel(dev_labels)
dev_p = lstm._marginals_preprocessed(dev_data)
f1, p, r = f1_score(dev_p, labels)
b=0.5
tp_lstm = np.sum((dev_p > b) * (labels > b))
fp_lstm = np.sum((dev_p > b) * (labels <= b))
tn_lstm = np.sum((dev_p <= b) * (labels <= b))
fn_lstm = np.sum((dev_p <= b) * (labels > b))
print 'Precision:', p
print 'Recall:', r
print 'F1:', f1
print '============================'
print 'True Positives:', tp_lstm
print 'False Positives:', fp_lstm
print 'True Negatives:', tn_lstm
print 'False Negatives:', fn_lstm

# Viewers for Error analysis

## Non-dep Gen Model - f1=.605 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv_gen = SentenceNgramViewer(fp_gen, session, height=400)
else:
    fpsv_gen = None

In [None]:
fpsv_gen

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv_gen = SentenceNgramViewer(fn_gen, session, height=400)
else:
    fnsv_gen = None

In [None]:
fnsv_gen

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv_gen = SentenceNgramViewer(tp_gen, session, height=400)
else:
    tpsv_gen = None

In [None]:
tpsv_gen

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv_gen = SentenceNgramViewer(tn_gen, session, height=400)
else:
    tnsv_gen = None

In [None]:
tnsv_gen

## Non-dep log-reg dev set -f1=.61

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv_lr = SentenceNgramViewer(fp_lr, session, height=400)
else:
    fpsv_lr = None

In [None]:
fpsv_lr

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv_lr = SentenceNgramViewer(fn_lr, session, height=400)
else:
    fnsv_lr = None

In [None]:
fnsv_lr

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv_lr = SentenceNgramViewer(tp_lr, session, height=400)
else:
    tpsv_lr = None

In [None]:
tpsv_lr

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv_lr = SentenceNgramViewer(tn_lr, session, height=400)
else:
    tnsv_lr = None

In [None]:
tnsv_lr

## Non-dep log-reg test - f1=.64 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv_lr_test = SentenceNgramViewer(fp_lr_test, session, height=400)
else:
    fpsv_lr_test = None

In [None]:
fpsv_lr_test

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv_lr_test = SentenceNgramViewer(fn_lr_test, session, height=400)
else:
    fnsv_lr_test = None

In [None]:
fnsv_lr_test

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv_lr_test = SentenceNgramViewer(tp_lr_test, session, height=400)
else:
    tpsv_lr_test = None

In [None]:
tpsv_lr_test

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv_lr_test = SentenceNgramViewer(tn_lr_test, session, height=400)
else:
    tnsv_lr_test = None

In [None]:
tnsv_lr_test

## non-dep lstm dev 

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fpsv_lstm = SentenceNgramViewer(fp_lstm, session, height=400)
else:
    fpsv_lstm = None

In [None]:
fpsv_lstm

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    fnsv_lstm = SentenceNgramViewer(fn_lstm, session, height=400)
else:
    fnsv_lstm = None

In [None]:
fnsv_lstm

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tpsv_lstm = SentenceNgramViewer(tp_lstm, session, height=400)
else:
    tpsv_lstm = None

In [None]:
tpsv_lstm

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    tnsv_lstm = SentenceNgramViewer(tn_lstm, session, height=400)
else:
    tnsv_lstm = None

In [None]:
tnsv_lstm