In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
from snorkel import SnorkelSession

session = SnorkelSession()

In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()
test = session.query(ChemicalDisease).filter(ChemicalDisease.split == 2).all()

print('Training set:\t{0} candidates'.format(len(train)))
print('Dev set:\t{0} candidates'.format(len(dev)))
print('Test set:\t{0} candidates'.format(len(test)))

Training set:	8433 candidates
Dev set:	920 candidates
Test set:	4683 candidates


In [3]:
total = train.copy()
total.extend(dev.copy())

In [4]:
train_marginals_orig = np.fromfile("train_marginals_orig.txt")
train_marginals_random = np.fromfile("train_marginals_random.txt")
train_marginals_random_low = np.fromfile("train_marginals_random_low.txt")
train_marginals_lookup = np.fromfile("train_marginals_lookup.txt")

In [5]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

In [6]:
from snorkel.learning.pytorch import LSTM_orig

In [9]:
train_kwargs = {
    'lr':              0.01,
    'embedding_dim':   100,
    'hidden_dim':      100,
    'n_epochs':        20,
    'dropout':         0.5,
    'rebalance':       0.25,
    'print_freq':      5,
    'seed':            1701
}

lstm_orig = LSTM_orig(n_threads=None)
lstm_orig.train(total, train_marginals_orig, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)
lstm_orig.save(model_name="orig_3")

lstm_random = LSTM_orig(n_threads=None)
lstm_random.train(total, train_marginals_random, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)
lstm_random.save(model_name="random_3")

lstm_random_low = LSTM_orig(n_threads=None)
lstm_random_low.train(total, train_marginals_random_low, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)
lstm_random_low.save(model_name="random_low_3")

lstm_lookup = LSTM_orig(n_threads=None)
lstm_lookup.train(total, train_marginals_lookup, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)
lstm_lookup.save(model_name="lookup_3")


[LSTM_orig] Training model
[LSTM_orig] n_train=4112  #epochs=20  batch size=64
[LSTM_orig] Epoch 1 (17.83s)	Average loss=0.689812	Dev F1=51.52
[LSTM_orig] Epoch 6 (109.31s)	Average loss=0.668251	Dev F1=53.25
[LSTM_orig] Epoch 11 (197.10s)	Average loss=0.664024	Dev F1=52.07
[LSTM_orig] Epoch 16 (285.51s)	Average loss=0.665312	Dev F1=53.74
[LSTM_orig] Epoch 20 (356.87s)	Average loss=0.661678	Dev F1=53.99
[LSTM_orig] Model saved as <LSTM_orig>
[LSTM_orig] Training done (357.99s)
[LSTM_orig] Loaded model <LSTM_orig>
[LSTM_orig] Model saved as <orig_3>
[LSTM_orig] Training model
[LSTM_orig] n_train=6193  #epochs=20  batch size=64
[LSTM_orig] Epoch 1 (25.66s)	Average loss=0.697609	Dev F1=20.93
[LSTM_orig] Epoch 6 (155.16s)	Average loss=0.652933	Dev F1=40.65
[LSTM_orig] Epoch 11 (285.96s)	Average loss=0.611956	Dev F1=46.38
[LSTM_orig] Epoch 16 (411.62s)	Average loss=0.577380	Dev F1=38.34
[LSTM_orig] Epoch 20 (516.87s)	Average loss=0.558556	Dev F1=44.35
[LSTM_orig] Model saved as <LSTM_orig>
[

In [10]:
lstm_orig = LSTM_orig(n_threads=None)
lstm_orig.load(model_name="orig_3")
lstm_random = LSTM_orig(n_threads=None)
lstm_random.load(model_name="random_3")
lstm_random_low = LSTM_orig(n_threads=None)
lstm_random_low.load(model_name="random_low_3")
lstm_lookup = LSTM_orig(n_threads=None)
lstm_lookup.load(model_name="lookup_3")

[LSTM_orig] Loaded model <orig_3>
[LSTM_orig] Loaded model <random_3>
[LSTM_orig] Loaded model <random_low_3>
[LSTM_orig] Loaded model <lookup_3>


In [12]:
lstm_orig.score(test, L_gold_test)

(0.38788065210704398, 0.83289299867899602, 0.52927597061909748)

In [13]:
lstm_random.score(test, L_gold_test)

(0.33285233285233284, 0.4570673712021136, 0.38519343167269687)

In [14]:
lstm_random_low.score(test,L_gold_test)

(0.39768076398362895, 0.77014531043593126, 0.52451641925326142)

In [15]:
lstm_lookup.score(test, L_gold_test)

(0.39720062208398133, 0.84346103038309117, 0.5400718968069359)

From this, we can see that Snorkel is moderately effective at ignoring bad labeling functions-- it can handle it when a function has ~20% coverage with random labels, but not with 100% coverage. It also shows that the discriminative model can learn more when a perfect labeling function with low coverage is included. 

In [15]:
orig_preds = lstm_orig.marginals(test)
random_preds = lstm_random.marginals(test)
random_low_preds = lstm_random_low.marginals(test)
lookup_preds = lstm_lookup.marginals(test)


In [25]:
tp,fp,tn,fn = 0,0,0,0
res = L_gold_test.toarray()
for i,p in enumerate(lookup_preds):
    if p > .5:
        if res[i][0] == 1:
            tp += 1
        else:
            fp += 1
    else:
        if res[i][0] == 1:
            fn += 1
        else:
            tn += 1
print ((tp + tn)/ (tp + fp + tn + fn))
print (tp, fp, tn, fn)

0.535554131966688
1277 1938 1231 237


In [17]:
print (tp,fp,tn,fn)

1261 1990 1179 253


In [18]:
prec = tp / (tp + fp)
rec = tp / (tp + fn)

In [19]:
2/(1/prec + 1/rec)

0.5292759706190976