# Train the Generative Model for Accurate Labeling

This notebook is designed to run the generative model snorkel uses for estimating the probability of each candidate being a true candidate (label of 1). 

## MUST RUN AT THE START OF EVERYTHING

Import the necessary modules and set up the database for database operations.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from collections import Counter, OrderedDict, defaultdict
import os
import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_curve, auc

In [2]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [3]:
from snorkel import SnorkelSession
from snorkel.annotations import FeatureAnnotator, LabelAnnotator, save_marginals
from snorkel.learning import GenerativeModel
from snorkel.learning.utils import MentionScorer
from snorkel.models import Candidate, FeatureKey, candidate_subclass, Label
from snorkel.utils import get_as_dict
from tree_structs import corenlp_to_xmltree
from treedlib import compile_relation_feature_generator

  return f(*args, **kwds)


In [4]:
edge_type = "dg"

In [5]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
else:
    print("Please pick a valid edge type")

# Load preprocessed data 

This code will load the label matrix that was generated in the previous notebook ([Notebook 2](2.data-labeler.ipynb)). **Disclaimer**: this block might break, which means that the snorkel code is still using its old code. The problem with the old code is that sqlalchemy will attempt to load all the labels into memory. Doesn't sound bad if you keep the amount of labels small, but doesn't scale when the amount of labels increases exponentially. Good news is that there is a pull request to fix this issue. [Check it out here!](https://github.com/HazyResearch/snorkel/pull/789)

In [6]:
from snorkel.annotations import load_gold_labels
#L_gold_train = load_gold_labels(session, annotator_name='danich1', split=0)
#annotated_cands_train_ids = list(map(lambda x: L_gold_train.row_index[x], L_gold_train.nonzero()[0]))

sql = '''
SELECT candidate_id FROM gold_label
'''
gold_cids = [x[0] for x in session.execute(sql)]
cids = session.query(Candidate.id).filter(Candidate.id.in_(gold_cids))

L_gold_dev = load_gold_labels(session, annotator_name='danich1', cids_query=cids)
annotated_cands_dev_ids = list(map(lambda x: L_gold_dev.row_index[x], L_gold_dev.nonzero()[0]))

In [7]:
L_gold_dev

<1028x1 sparse matrix of type '<class 'numpy.int64'>'
	with 1028 stored elements in Compressed Sparse Row format>

In [8]:
%%time
labeler = LabelAnnotator(lfs=[])

# Only grab candidates that have human labels
#cids = session.query(Candidate.id).filter(Candidate.id.in_(annotated_cands_train_ids))
L_train = labeler.load_matrix(session, split=0) #

cids = session.query(Candidate.id).filter(Candidate.id.in_(gold_cids))
L_dev = labeler.load_matrix(session,cids_query=cids)

CPU times: user 5.21 s, sys: 276 ms, total: 5.48 s
Wall time: 7.36 s


In [9]:
type(L_train)

snorkel.annotations.csr_LabelMatrix

In [10]:
print("Total Data Shape:")
print(L_train.shape)

Total Data Shape:
(2667604, 12)


In [11]:
L_train = L_train[np.unique(L_train.nonzero()[0]), :]
print("Total Data Shape:")
print(L_train.shape)

Total Data Shape:
(50897, 12)


In [12]:
type(L_train)

snorkel.annotations.csr_AnnotationMatrix

# Train the Generative Model

Here is the first step of classification step of this project, where we train a gnerative model to discriminate the correct label each candidate will receive. Snorkel's generative model uses a Gibbs Sampling on a [factor graph](http://deepdive.stanford.edu/assets/factor_graph.pdf), to generate the probability of a potential candidate being a true candidate (label of 1).

In [13]:
%%time
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(
    L_train,
    epochs=30,
    decay=0.95,
    step_size=0.1 / L_train.shape[0],
    reg_param=1e-6,
    threads=50,
    verbose=True
)

Inferred cardinality: 2
FACTOR 0: STARTED BURN-IN...
FACTOR 0: DONE WITH BURN-IN
FACTOR 0: STARTED LEARNING
FACTOR 0: EPOCH #0
Current stepsize = 1.9647523429671693e-06
Learning epoch took 0.000 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  0.0

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  0.0

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  0.0

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  0.0

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  0.0

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  0.0

    weightId: 12
        isFixed: True
        weight:  1.0

  

FACTOR 0: EPOCH #5
Current stepsize = 1.5202879098964573e-06
Learning epoch took 0.484 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.194106031533

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.323181032502

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.221495082114

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.398343089035

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.228546087805

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.395878801009

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: Fals

FACTOR 0: EPOCH #10
Current stepsize = 1.1763698041895959e-06
Learning epoch took 0.497 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.292663142003

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.485607876139

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.331006513028

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.595648054229

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.349219646669

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.60295162193

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: Fals

FACTOR 0: EPOCH #15
Current stepsize = 9.102525299325166e-07
Learning epoch took 0.486 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.349872212367

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.575189299788

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.394560592386

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.69939591362

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.425611629169

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.716937521764

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: False

FACTOR 0: EPOCH #20
Current stepsize = 7.043360559729294e-07
Learning epoch took 0.486 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.385385789927

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.628089989154

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.433546049161

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.757616594319

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.477052007237

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.785818469813

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: Fals

FACTOR 0: EPOCH #25
Current stepsize = 5.450018137057857e-07
Learning epoch took 0.520 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.409708815903

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.66171841337

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.459073142185

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.792522896221

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.514057747461

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.829101130468

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: False

FACTOR 0: DONE WITH LEARNING
CPU times: user 1min 8s, sys: 128 ms, total: 1min 8s
Wall time: 27.3 s


In [14]:
gen_model.weights.lf_accuracy

array([ 0.57421705,  0.31635338,  0.5233131 ,  0.18443137,  0.46022959,
        0.14095574,  0.18862038,  0.21427893,  0.12856828,  0.1330774 ,
        0.13257227,  0.13153372])

In [15]:
from utils.disease_gene_lf import LFS
learned_stats_df = gen_model.learned_lf_stats()
learned_stats_df.index = list(LFS)
learned_stats_df

Unnamed: 0,Accuracy,Coverage,Precision,Recall
LF_HETNET_DISEASES,0.755543,0.6991,0.756112,0.526657
LF_HETNET_DOAF,0.644343,0.6824,0.639368,0.443138
LF_HETNET_DisGeNET,0.739646,0.7002,0.738457,0.521995
LF_HETNET_GWAS,0.592681,0.6695,0.587654,0.397527
LF_HETNET_ABSENT,0.711108,0.6878,0.700117,0.483681
LF_CHECK_GENE_TAG,0.567519,0.6687,0.559432,0.383539
LF_IS_BIOMARKER,0.588699,0.6725,0.58637,0.401176
LF_ASSOCIATION,0.600765,0.6793,0.597742,0.407865
LF_NO_ASSOCIATION,0.573929,0.6628,0.566328,0.372998
LF_NO_CONCLUSION,0.571049,0.6777,0.565584,0.381107


In [16]:
%time train_marginals = gen_model.marginals(L_train)

CPU times: user 19.2 s, sys: 0 ns, total: 19.2 s
Wall time: 19.2 s


In [None]:
print(len(train_marginals[train_marginals > 0.5]))

In [None]:
plt.hist(train_marginals, bins=20)
plt.title("Training Marginals for Gibbs Sampler")
plt.show()

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, L_gold_dev)

In [None]:
dev_marginals = gen_model.marginals(L_dev)

In [None]:
fpr, tpr, threshold = roc_curve(L_gold_dev.todense(), dev_marginals)
plt.plot([0,1], [0,1])
plt.plot(fpr, tpr, label='AUC {:.2f}'.format(auc(fpr, tpr)))
plt.legend()

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(fp, session)
else:
    sv = None

In [None]:
sv

In [None]:
c = sv.get_selected() if sv else list(fp.union(fn))[0]
c

In [None]:
rows = list()
for i in tqdm.tqdm(range(L_train.shape[0])):
    row = OrderedDict()
    candidate = L_train.get_candidate(session, i)
    row['candidate_id'] = candidate.id
    row['disease'] = candidate[0].get_span()
    row['gene'] = candidate[1].get_span()
    row['doid_id'] = candidate.Disease_cid
    row['entrez_gene_id'] = candidate.Gene_cid
    row['sentence'] = candidate.get_parent().text
    row['label'] = train_marginals[i]
    rows.append(row)
sentence_df = pd.DataFrame(rows)

sentence_df = pd.concat([
    sentence_df,
    pd.DataFrame(L_train.todense(), columns=list(LFS))
], axis='columns')

sentence_df.tail()

 89%|████████▉ | 45388/50897 [03:44<00:27, 202.30it/s]

In [None]:
writer = pd.ExcelWriter('data/sentence-labels.xlsx')
sentence_df.to_excel(writer, sheet_name='sentences', index=False)
if writer.engine == 'xlsxwriter':
    for sheet in writer.sheets.values():
        sheet.freeze_panes(1, 0)
writer.close()

In [None]:
sentence_df

In [None]:
c.labels

In [None]:
L_dev.get_row_index(c)

In [None]:
dev_marginals[26]

In [None]:
L_dev.lf_stats(session, L_gold_dev[L_gold_dev!=0].T, gen_model.learned_lf_stats()['Accuracy'])

# Save Training Marginals

Save the training marginals for [Notebook 4](4.data-disc-model).

In [None]:
np.savetxt("vanilla_lstm/lstm_disease_gene_holdout/subsampled/train_marginals_subsampled.txt", train_marginals)

In [None]:
#%time save_marginals(session, L_train, train_marginals)