# Train the Generative Model for Accurate Labeling

This notebook is designed to run the generative model snorkel uses for estimating the probability of each candidate being a true candidate (label of 1). 

## MUST RUN AT THE START OF EVERYTHING

Import the necessary modules and set up the database for database operations.

In [25]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from collections import Counter, OrderedDict, defaultdict
import os
import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_curve, auc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [27]:
from snorkel import SnorkelSession
from snorkel.annotations import FeatureAnnotator, LabelAnnotator, save_marginals
from snorkel.learning import GenerativeModel
from snorkel.learning.utils import MentionScorer
from snorkel.models import Candidate, FeatureKey, candidate_subclass, Label
from snorkel.utils import get_as_dict
from tree_structs import corenlp_to_xmltree
from treedlib import compile_relation_feature_generator

In [28]:
edge_type = "dg"

In [29]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
else:
    print("Please pick a valid edge type")

# Load preprocessed data 

This code will load the label matrix that was generated in the previous notebook ([Notebook 2](2.data-labeler.ipynb)). **Disclaimer**: this block might break, which means that the snorkel code is still using its old code. The problem with the old code is that sqlalchemy will attempt to load all the labels into memory. Doesn't sound bad if you keep the amount of labels small, but doesn't scale when the amount of labels increases exponentially. Good news is that there is a pull request to fix this issue. [Check it out here!](https://github.com/HazyResearch/snorkel/pull/789)

In [30]:
from snorkel.annotations import load_gold_labels
#L_gold_train = load_gold_labels(session, annotator_name='danich1', split=0)
#annotated_cands_train_ids = list(map(lambda x: L_gold_train.row_index[x], L_gold_train.nonzero()[0]))

sql = '''
SELECT candidate_id FROM gold_label
'''
gold_cids = [x[0] for x in session.execute(sql)]
cids = session.query(Candidate.id).filter(Candidate.id.in_(gold_cids))

L_gold_dev = load_gold_labels(session, annotator_name='danich1', cids_query=cids)
annotated_cands_dev_ids = list(map(lambda x: L_gold_dev.row_index[x], L_gold_dev.nonzero()[0]))

In [31]:
L_gold_dev

<1028x1 sparse matrix of type '<class 'numpy.int64'>'
	with 1028 stored elements in Compressed Sparse Row format>

In [32]:
train_candidate_ids = np.loadtxt('data/labeled_candidates.txt').astype(int).tolist()
train_candidate_ids

[9951794,
 904609,
 5192262,
 14552559,
 16277239,
 7513663,
 26709637,
 18498661,
 31276326,
 7508019,
 21182051,
 8718860,
 29420557,
 4576271,
 16265780,
 14579509,
 28546887,
 903561,
 2554892,
 11357489,
 6320225,
 10340995,
 1234711,
 27189858,
 23931639,
 27643583,
 8312730,
 24407247,
 12368020,
 17597476,
 20275693,
 23479706,
 28127286,
 34028805,
 33602727,
 25305237,
 16280703,
 33582365,
 2508141,
 34965519,
 5925834,
 26266918,
 24408282,
 17606676,
 9530917,
 35862276,
 28546709,
 4565185,
 7910524,
 23028505,
 18512660,
 27199812,
 29002515,
 19823705,
 35883032,
 1236055,
 34493648,
 33118337,
 4842917,
 32245680,
 31265266,
 33580470,
 32240426,
 23947258,
 13230693,
 22092561,
 22583261,
 7508178,
 6331841,
 33571610,
 35826537,
 9121725,
 4589207,
 27213249,
 33562081,
 2868939,
 31282135,
 9554292,
 25760465,
 16758106,
 34031511,
 24852520,
 4845785,
 7901313,
 27162063,
 30814233,
 31307332,
 28997960,
 20748470,
 8705830,
 29878788,
 7509402,
 17162052,
 1160875

In [33]:
dev_candidate_ids = np.loadtxt('data/labeled_dev_candidates.txt').astype(int).tolist()
dev_candidate_ids

[11196390,
 6691154,
 3718558,
 32199854,
 10365368,
 26731372,
 34517639,
 31299484,
 28967497,
 6716878,
 31770736,
 22539824,
 30827155,
 19827256,
 36390261,
 2176714,
 34075897,
 24867987,
 32229006,
 20288074,
 27674474,
 5194638,
 34481280,
 30796279,
 20283686,
 19353970,
 31287759,
 28544792,
 29430883,
 25321861,
 31267939,
 35826272,
 32205988,
 35392803,
 29892314,
 15833209,
 30352907,
 25755526,
 1416538,
 30825579,
 22569655,
 33611834,
 23045128,
 20730698,
 14548973,
 36393022,
 15848313,
 29466814,
 16751671,
 25757579,
 29417504,
 19372251,
 17614435,
 27635071,
 31735057,
 2268481,
 29002941,
 4240356,
 2178759,
 24393749,
 21639315,
 18489895,
 22096868,
 1227925,
 2551318,
 19380308,
 32663841,
 26739201,
 34040283,
 34933749,
 2508582,
 11330039,
 27674759,
 7911940,
 28978390,
 34032626,
 31300733,
 16740194,
 33153301,
 15367137,
 2885715,
 15406889,
 912245,
 1079320,
 21211616,
 17161344,
 22549772,
 33574221,
 18069263,
 19839727,
 33111294,
 5184666,
 48592

In [34]:
%%time
labeler = LabelAnnotator(lfs=[])

# Only grab candidates that have human labels
cids = session.query(Candidate.id).filter(Candidate.id.in_(train_candidate_ids))
L_train = labeler.load_matrix(session, split=0) #

cids = session.query(Candidate.id).filter(Candidate.id.in_(dev_candidate_ids))
L_dev = labeler.load_matrix(session,cids_query=cids)

CPU times: user 6.33 s, sys: 235 ms, total: 6.56 s
Wall time: 8.54 s


In [35]:
print("Total Data Shape:")
print(L_train.shape)

Total Data Shape:
(2667604, 12)


In [36]:
L_train = L_train[np.unique(L_train.nonzero()[0]), :]
print("Total Data Shape:")
print(L_train.shape)

Total Data Shape:
(60713, 12)


In [37]:
L_dev.shape

(10000, 12)

# Train the Generative Model

Here is the first step of classification step of this project, where we train a gnerative model to discriminate the correct label each candidate will receive. Snorkel's generative model uses a Gibbs Sampling on a [factor graph](http://deepdive.stanford.edu/assets/factor_graph.pdf), to generate the probability of a potential candidate being a true candidate (label of 1).

In [38]:
%%time
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(
    L_train,
    epochs=30,
    decay=0.95,
    step_size=0.1 / L_train.shape[0],
    reg_param=1e-6,
    threads=50,
    verbose=True
)

Inferred cardinality: 2
FACTOR 0: STARTED BURN-IN...
FACTOR 0: DONE WITH BURN-IN
FACTOR 0: STARTED LEARNING
FACTOR 0: EPOCH #0
Current stepsize = 1.647093703160773e-06
Learning epoch took 0.000 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  0.0

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  0.0

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  0.0

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  0.0

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  0.0

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  0.0

    weightId: 12
        isFixed: True
        weight:  1.0

   

FACTOR 0: EPOCH #5
Current stepsize = 1.2744897097820895e-06
Learning epoch took 0.561 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.207619853788

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.328775674686

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.229950323347

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.400937297833

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.232874897344

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.397515748921

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: Fals

FACTOR 0: EPOCH #10
Current stepsize = 9.86175842469288e-07
Learning epoch took 0.566 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.304086121386

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.491463653317

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.33708415281

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.596765009296

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.355910736343

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.602086883264

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: False


FACTOR 0: EPOCH #15
Current stepsize = 7.630840679257377e-07
Learning epoch took 0.567 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.359963924932

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.580056237959

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.398430096947

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.699382173428

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.431782138958

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.717578379896

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: Fals

FACTOR 0: EPOCH #20
Current stepsize = 5.904599054708908e-07
Learning epoch took 0.589 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.395400868764

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.632254520835

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.436901850296

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.758131755152

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.482721755809

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.78612076049

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: False

FACTOR 0: EPOCH #25
Current stepsize = 4.568866192114272e-07
Learning epoch took 0.594 sec.
Weights:
    weightId: 0
        isFixed: True
        weight:  1.0

    weightId: 1
        isFixed: False
        weight:  -0.418209800574

    weightId: 2
        isFixed: True
        weight:  1.0

    weightId: 3
        isFixed: False
        weight:  -0.66503597585

    weightId: 4
        isFixed: True
        weight:  1.0

    weightId: 5
        isFixed: False
        weight:  -0.462593567438

    weightId: 6
        isFixed: True
        weight:  1.0

    weightId: 7
        isFixed: False
        weight:  -0.794257864933

    weightId: 8
        isFixed: True
        weight:  1.0

    weightId: 9
        isFixed: False
        weight:  -0.518388726075

    weightId: 10
        isFixed: True
        weight:  1.0

    weightId: 11
        isFixed: False
        weight:  -0.829766099942

    weightId: 12
        isFixed: True
        weight:  1.0

    weightId: 13
        isFixed: False

FACTOR 0: DONE WITH LEARNING
CPU times: user 1min 20s, sys: 73.1 ms, total: 1min 20s
Wall time: 31.7 s


In [39]:
gen_model.weights.lf_accuracy

array([ 0.56592394,  0.31341156,  0.51959085,  0.18237051,  0.45504092,
        0.14102511,  0.19244823,  0.21139602,  0.1286075 ,  0.13312386,
        0.13173062,  0.13265822])

In [40]:
from utils.disease_gene_lf import LFS
learned_stats_df = gen_model.learned_lf_stats()
learned_stats_df.index = list(LFS)
learned_stats_df

Unnamed: 0,Accuracy,Coverage,Precision,Recall
LF_HETNET_DISEASES,0.754964,0.695,0.745596,0.524584
LF_HETNET_DOAF,0.645704,0.6774,0.637664,0.434783
LF_HETNET_DisGeNET,0.733629,0.6979,0.732708,0.507924
LF_HETNET_GWAS,0.596017,0.6629,0.592514,0.398822
LF_HETNET_ABSENT,0.717085,0.6889,0.707895,0.491873
LF_CHECK_GENE_TAG,0.56829,0.6692,0.561961,0.384193
LF_IS_BIOMARKER,0.592631,0.6704,0.584324,0.389273
LF_ASSOCIATION,0.601773,0.6657,0.590622,0.399228
LF_NO_ASSOCIATION,0.567875,0.6674,0.556822,0.387241
LF_NO_CONCLUSION,0.567699,0.6588,0.561908,0.361438


In [41]:
%time train_marginals = gen_model.marginals(L_train)

CPU times: user 22.5 s, sys: 0 ns, total: 22.5 s
Wall time: 22.5 s


In [None]:
print(len(train_marginals[train_marginals > 0.5]))

In [None]:
plt.hist(train_marginals, bins=20)
plt.title("Training Marginals for Gibbs Sampler")
plt.show()

## ROC of Generative Model

In [42]:
dev_marginals = gen_model.marginals(L_dev)

In [None]:
fpr, tpr, threshold = roc_curve(L_gold_dev.todense(), dev_marginals)
plt.plot([0,1], [0,1])
plt.plot(fpr, tpr, label='AUC {:.2f}'.format(auc(fpr, tpr)))
plt.legend()

In [None]:
L_dev.lf_stats(session, L_gold_dev[L_gold_dev!=0].T, gen_model.learned_lf_stats()['Accuracy'])

## Individual Candidate Error Analysis

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, L_gold_dev)

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(fp, session)
else:
    sv = None

In [None]:
sv

In [None]:
c = sv.get_selected() if sv else list(fp.union(fn))[0]
c

In [None]:
c.labels

## Generate Excel File of Train Data

In [43]:
pair_df = pd.read_csv("data/disease-gene-pairs-association.csv.xz", compression='xz')
pair_df.head(2)

  interactivity=interactivity, compiler=compiler, result=result)


Unnamed: 0,entrez_gene_id,gene_symbol,doid_id,doid_name,sources,hetionet,n_sentences,has_sentence,partition_rank,split
0,1,A1BG,DOID:2531,hematologic cancer,,0,8,1,0.8586,1
1,1,A1BG,DOID:1319,brain cancer,,0,0,0,0.36785,0


In [44]:
rows = list()
for i in tqdm.tqdm(range(L_dev.shape[0])):
    row = OrderedDict()
    candidate = L_dev.get_candidate(session, i)
    row['candidate_id'] = candidate.id
    row['disease'] = candidate[0].get_span()
    row['gene'] = candidate[1].get_span()
    row['doid_id'] = candidate.Disease_cid
    row['entrez_gene_id'] = candidate.Gene_cid
    row['sentence'] = candidate.get_parent().text
    row['label'] = train_marginals[i]
    rows.append(row)
sentence_df = pd.DataFrame(rows)
sentence_df['entrez_gene_id'] = sentence_df.entrez_gene_id.astype(int)
sentence_df.head(2)

100%|██████████| 10000/10000 [00:49<00:00, 200.47it/s]


Unnamed: 0,candidate_id,disease,gene,doid_id,entrez_gene_id,sentence,label
0,21619,schizophrenic,GMP,DOID:5419,22978,"However, Parkinson patients had a 40-50% reduc...",0.191722
1,22186,erythroid colonies,EPO),DOID:2355,2056,Addition of 4 units of purified erythropoietin...,0.897618


In [45]:
sentence_df = pd.merge(
    sentence_df,
    pair_df[["doid_id", "entrez_gene_id", "doid_name", "gene_symbol"]],
    on=["doid_id", "entrez_gene_id"],
    how="left"
)
sentence_df.head(2)

Unnamed: 0,candidate_id,disease,gene,doid_id,entrez_gene_id,sentence,label,doid_name,gene_symbol
0,21619,schizophrenic,GMP,DOID:5419,22978,"However, Parkinson patients had a 40-50% reduc...",0.191722,schizophrenia,NT5C2
1,22186,erythroid colonies,EPO),DOID:2355,2056,Addition of 4 units of purified erythropoietin...,0.897618,anemia,EPO


In [46]:
sentence_df = pd.concat([
    sentence_df,
    pd.DataFrame(L_dev.todense(), columns=list(LFS))
], axis='columns')

sentence_df.tail()

Unnamed: 0,candidate_id,disease,gene,doid_id,entrez_gene_id,sentence,label,doid_name,gene_symbol,LF_HETNET_DISEASES,...,LF_HETNET_DisGeNET,LF_HETNET_GWAS,LF_HETNET_ABSENT,LF_CHECK_GENE_TAG,LF_IS_BIOMARKER,LF_ASSOCIATION,LF_NO_ASSOCIATION,LF_NO_CONCLUSION,LF_DG_DISTANCE,LF_NO_VERB
9995,36402966,TNBC),urokinase-type plasminogen activator receptor,DOID:1612,5329,UNASSIGNED: 150 Background: Triple-negative br...,0.589854,breast cancer,PLAUR,1,...,0,0,0,0,1,0,0,0,0,0
9996,36403056,myeloid leukemia,Smac),DOID:2531,56616,The antiproliferative activity of isoimperator...,0.232878,hematologic cancer,DIABLO,0,...,0,0,-1,0,1,0,0,0,0,0
9997,36403200,osteoarthritis,NLRC3,DOID:8398,197358,METHODS: Gene expression and protein levels of...,0.236215,osteoarthritis,NLRC3,0,...,0,0,-1,0,1,0,0,0,0,0
9998,36403765,MLL,THP-1,DOID:2531,2736,In order to define a core set of MLL rearrange...,0.948656,hematologic cancer,GLI2,0,...,0,0,-1,-1,0,0,0,0,0,0
9999,36404198,OA,COMP,DOID:8398,1311,RESULTS: From 44 pairs of samples which divide...,0.286983,osteoarthritis,COMP,1,...,0,0,0,0,1,1,0,0,0,0


In [47]:
writer = pd.ExcelWriter('data/sentence-labels-dev.xlsx')
(sentence_df
    .sample(frac=1, random_state=100)
    .to_excel(writer, sheet_name='sentences', index=False)
)
if writer.engine == 'xlsxwriter':
    for sheet in writer.sheets.values():
        sheet.freeze_panes(1, 0)
writer.close()

# Save Training Marginals

Save the training marginals for [Notebook 4](4.data-disc-model).

In [None]:
np.savetxt("vanilla_lstm/lstm_disease_gene_holdout/subsampled/train_marginals_subsampled.txt", train_marginals)

In [None]:
#%time save_marginals(session, L_train, train_marginals)