# Intro. to Snorkel: Extracting Spouse Relations from the News

## Part III: Training an End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a Bi-LSTM, a state-of-the-art deep neural network implemented in [TensorFlow](https://www.tensorflow.org/).

In [23]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:///snorkel-intro'

from snorkel import SnorkelSession
from snorkel.learning.pytorch import LSTM

session = SnorkelSession()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


We repeat our definition of the `Spouse` `Candidate` subclass:

In [3]:
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

In [15]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()

In [19]:
import scipy.sparse as sp
L_train = sp.load_npz('/Users/ineschami/Desktop/spouses_data/train_label_matrix.npz')
label_accuracies = [0.5332131048993087, 0.536837376460018, 0.5586031175059952,
                    0.5579946603381786, 0.5283189850475759, 0.8033016228315614,
                    0.5427369686044765, 0.5838184803702046, 0.5648341798625635,
                    0.5423094203983825]

In [22]:
def get_marginals(alphas, preds):
    w = np.log(alphas/(1.-alphas))
    scores = preds.dot(w)

10

We reload the probabilistic training labels:

In [4]:
from snorkel.annotations import load_marginals

train_marginals = load_marginals(session, split=0)

We also reload the candidates:

In [5]:
train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

In [6]:
x= dev_cands[0]

In [7]:
x.person1

Span("b'Pope Francis'", sentence=11836, chars=[50,61], words=[10,11])

In [8]:
x.person2

Span("b'Pope Francis'", sentence=11836, chars=[428,439], words=[81,82])

In [9]:
x.get_parent().text

'Hers was the first execution of a US inmate since Pope Francis called for the global abolition of the death penalty in his speech to the US Congress last week The pope’s personal representative sent a letter to Georgia’s parole board on Tuesday making “an urgent appeal” to commute Gissendaner’s sentence to “one that would better express both justice and mercy” “Please be assured of my prayers as you consider this request by Pope Francis for what I believe would be a just act of clemency”'

In [10]:
y = x.get_parent()

Finally, we load gold labels for evaluation:

In [11]:
from snorkel.annotations import load_gold_labels

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

import scipy.sparse as sp
sp.save_npz('/Users/ineschami/Desktop/spouses_data/dev_labels.npz', L_gold_dev)

Now we can setup our discriminative model. Here we specify the model and learning hyperparameters.

They can also be set automatically using a search based on the dev set with a [GridSearch](https://github.com/HazyResearch/snorkel/blob/master/snorkel/learning/utils.py) object.

In [12]:
dim = 50

train_kwargs = {
    'lr':            0.01,
    'embedding_dim': dim,
    'hidden_dim':    dim,
    'n_epochs':      10,
    'dropout':       0.25,
    'seed':          1701
}

lstm = LSTM(n_threads=None)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[LSTM] Training model
[LSTM] n_train=17259  #epochs=10  batch size=64




[LSTM] Epoch 1 (83.68s)	Average loss=0.626622	Dev F1=0.00
[LSTM] Epoch 2 (169.77s)	Average loss=0.605108	Dev F1=37.11
[LSTM] Epoch 3 (268.44s)	Average loss=0.598997	Dev F1=36.27
[LSTM] Epoch 4 (360.29s)	Average loss=0.595865	Dev F1=39.21
[LSTM] Epoch 5 (455.22s)	Average loss=0.592701	Dev F1=38.00
[LSTM] Epoch 6 (551.68s)	Average loss=0.589949	Dev F1=38.26
[LSTM] Epoch 7 (640.61s)	Average loss=0.587997	Dev F1=38.51
[LSTM] Epoch 8 (747.17s)	Average loss=0.586224	Dev F1=36.29
[LSTM] Epoch 9 (851.04s)	Average loss=0.585057	Dev F1=40.97
[LSTM] Model saved as <LSTM>
[LSTM] Epoch 10 (943.24s)	Average loss=0.583684	Dev F1=37.08
[LSTM] Training done (949.21s)
[LSTM] Loaded model <LSTM>


Now, we get the precision, recall, and F1 score from the discriminative model:

In [13]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.375, Recall: 0.628, F1 Score: 0.470


We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [14]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.628
Neg. class accuracy: 0.908
Precision            0.375
Recall               0.628
F1                   0.47
----------------------------------------
TP: 137 | FP: 228 | TN: 2255 | FN: 81



Note that if this is the final test set that you will be reporting final numbers on, to avoid biasing results you should not inspect results.  However you can run the model on your _development set_ and, as we did in the previous part with the generative labeling function model, inspect examples to do error analysis.

You can also improve performance substantially by increasing the number of training epochs!

Finally, we can save the predictions of the model on the test set back to the database. (This also works for other candidate sets, such as unlabeled candidates.)

In [9]:
lstm.save_marginals(session, test_cands)

Saved 2424 marginals


##### More importantly, you completed the introduction to Snorkel! Give yourself a pat on the back!