# Combining Labeling Functions

This notebook combines all labeling functions and trains a generative snorkel model with them.

**Common mistakes**
- double matching (if software is fully cited)
- Medical kits, drugs, lab animals and other devices are often cited in the same context as software and often matched
- a lot of 3 letter abbreviations are in the software list but have a second meaning in life sciences

In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np

BASE_NAME = 'sosci_ssc_0'
DATABASE_NAME = 'sosci_ssc_0'
LABELS_NAME = 'sosci_ssc_annotation' 
os.environ['SNORKELDB'] = 'postgres://snorkel:snorkel@localhost/' + DATABASE_NAME

import spacy

from snorkel import SnorkelSession
from snorkel.models import candidate_subclass
from snorkel.annotations import load_gold_labels
from snorkel.learning.utils import MentionScorer
from snorkel.viewer import SentenceNgramViewer
from itertools import product
from functools import partial, update_wrapper 
from snorkel.annotations import save_marginals, LabelAnnotator

from learning_functions import (
    LF_pan_top_1, LF_pan_top_2, LF_pan_top_3, LF_pan_top_4, LF_pan_top_5, 
    LF_pan_top_6, LF_pan_top_7, LF_pan_top_8, LF_pan_top_9, LF_pan_top_10, 
    LF_software_head_nouns, LF_version_number, LF_url, LF_developer, LF_distant_supervision
)

In [None]:
convergence_testing = True
set_mapping = {
    'train': 0, 
    'test': 1,
    'new': 2
}
session = SnorkelSession()
software = candidate_subclass('software', ['software'])
devel_gold_labels = load_gold_labels(session, annotator_name='gold', split=set_mapping['train'])

test_cands = session.query(software).filter(software.split==set_mapping['train']).all()
test_labels = load_gold_labels(session, annotator_name="gold", split=set_mapping['train'])
scorer = MentionScorer(test_cands, test_labels)

In [None]:
%store -r known_software_01_09
%store -r known_software_22_11
%store -r duck_dict_first_char_upper
%store -r known_software_lower_01_09
%store -r known_software_lower_22_11
%store -r duck_dict_lower
%store -r acronym_dic
%store -r gen_seq_triplets
LF_dist = partial(LF_distant_supervision, 
                  software_dict=known_software_01_09,
                  software_dict_lower=known_software_lower_01_09,
                  english_dict=duck_dict_first_char_upper,
                  english_dict_lower=duck_dict_lower,
                  acronym_dict=acronym_dic,
                  gen_seqs=gen_seq_triplets)
update_wrapper(LF_dist, LF_distant_supervision)

spacy_nlp = spacy.load('en')
stopwords = spacy.lang.en.stop_words.STOP_WORDS
stopwords_left_context = spacy.lang.en.stop_words.STOP_WORDS

LF_pan_1 = partial(LF_pan_top_1, stopwords=stopwords)
update_wrapper(LF_pan_1, LF_pan_top_1)
LF_pan_2 = partial(LF_pan_top_2, stopwords=stopwords)
update_wrapper(LF_pan_2, LF_pan_top_2)
LF_pan_3 = partial(LF_pan_top_3, stopwords=stopwords)
update_wrapper(LF_pan_3, LF_pan_top_3)
LF_pan_4 = partial(LF_pan_top_4, stopwords=stopwords)
update_wrapper(LF_pan_4, LF_pan_top_4)
LF_pan_5 = partial(LF_pan_top_5, stopwords=stopwords)
update_wrapper(LF_pan_5, LF_pan_top_5)
LF_pan_6 = partial(LF_pan_top_6, stopwords=stopwords)
update_wrapper(LF_pan_6, LF_pan_top_6)
LF_pan_7 = partial(LF_pan_top_7, stopwords=stopwords)
update_wrapper(LF_pan_7, LF_pan_top_7)
LF_pan_8 = partial(LF_pan_top_8, stopwords=stopwords)
update_wrapper(LF_pan_8, LF_pan_top_8)
LF_pan_9 = partial(LF_pan_top_9, stopwords=stopwords)
update_wrapper(LF_pan_9, LF_pan_top_9)
LF_pan_10 = partial(LF_pan_top_10, stopwords=stopwords)
update_wrapper(LF_pan_10, LF_pan_top_10)
LF_head_nouns = partial(LF_software_head_nouns, stopwords=stopwords)
update_wrapper(LF_head_nouns, LF_software_head_nouns)

In [None]:
LFs = [
    LF_pan_1, LF_pan_2, LF_pan_3, LF_pan_4, LF_pan_5, 
    LF_pan_6, LF_pan_7, LF_pan_8, LF_head_nouns, 
    LF_version_number, LF_url, LF_developer, LF_dist
]

In [None]:
labeler = LabelAnnotator(lfs=LFs)

In [None]:
np.random.seed(42)
%time L_train = labeler.apply(split=set_mapping['train'], parallelism=60)

In [None]:
%%time
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train, epochs=100, decay=0.95, step_size=0.1 / L_train.shape[0], reg_param=1e-6)

In [None]:
%%time
L_dev = labeler.apply_existing(split=set_mapping['train'], parallelism=60)

And finally, we get the score of the generative model:

In [None]:
%%time
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, devel_gold_labels)

and now the one that actually allows estimating future performance since it was unseen before:

In [None]:
%%time
test_gold_labels = load_gold_labels(session, annotator_name='gold', split=set_mapping['test'])
L_test = labeler.apply_existing(split=set_mapping['test'], parallelism=60)
tp, fp, tn, fn = gen_model.error_analysis(session, L_test, test_gold_labels)

No onto labeling new data:

In [None]:
%%time
np.random.seed(42)
L_new = labeler.apply(split=set_mapping['new'], parallelism=60)

In [None]:
new_marginals = gen_model.marginals(L_new)

Snorkel knows only its database so we have to improvise with the export. 

In [None]:
%%time
with open('{}_training_samples.csv'.format(BASE_NAME), 'w') as new_samples:
    header = 'span_id,span,beg_off,end_off,sent_id,sent,doc_id,doc,marg\n'
    new_samples.writelines(header)
    count = 0
    for i in range(len(new_marginals)):
        if new_marginals[i] > 0.5: # This is the fixed threshold snorkel applies, it can easily be adjusted
            cand = L_new.get_candidate(session, i) # We want all information from this candidate
            span_id = cand[0].id
            span = " ".join(cand[0].get_attrib_tokens(a="words"))
            span_off_beg = cand[0].char_start
            span_off_end = cand[0].char_end
            sentence_id = cand[0].sentence_id
            sentence = cand[0].sentence.text.rstrip('\n')
            doc_id = cand[0].sentence.document_id
            doc = cand[0].sentence.document.name
            marginal = round(new_marginals[i], 3)
            entry = '{},"{}",{},{},{},"{}",{},"{}",{}\n'.format(
                span_id,
                span,
                span_off_beg,
                span_off_end,
                sentence_id,
                sentence,
                doc_id,
                doc,
                marginal
            )
            new_samples.writelines(entry)
            count += 1 
            if count % 100 == 0:
                print('Processed {} new samples.'.format(count))

In [None]:
%time save_marginals(session, L_new, new_marginals)

In [None]:
from snorkel.viewer import SentenceNgramViewer
SentenceNgramViewer(fp, session)