# Lets See How The Disc Models Preformed

This notebook is designed to analyze the disc models performance and to answer the question does Long Short Term Memory Neural Net (LSTM) outperform SparseLogisticRegression (SLR).

## MUST RUN AT THE START OF EVERYTHING

Load the database and other helpful functions for analysis.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import csv
import os

from IPython.core.display import display, HTML
import matplotlib.pyplot as plt
import numpy as np
import re
import pandas as pd
import seaborn as sns
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_curve, auc, f1_score
import tqdm

In [None]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [None]:
from snorkel.annotations import FeatureAnnotator, LabelAnnotator, load_marginals
from snorkel.learning import SparseLogisticRegression
from snorkel.learning.disc_models.rnn import reRNN
from snorkel.learning.utils import RandomSearch
from snorkel.models import Candidate, FeatureKey, candidate_subclass
from snorkel.utils import get_as_dict
from snorkel.viewer import SentenceNgramViewer
from tree_structs import corenlp_to_xmltree
from treedlib import compile_relation_feature_generator

In [None]:
edge_type = "dg"

In [None]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
else:
    print("Please pick a valid edge type")

# Load the data

Here is where we load the test dataset in conjunction with the previously trained disc models. Each algorithm will output a probability of a candidate being a true candidate.

In [None]:
featurizer = FeatureAnnotator()
labeler = LabelAnnotator(lfs=[])

In [None]:
%%time
L_test = labeler.load_matrix(session,split=2)
#F_test = featurizer.load_matrix(session, split=2)

In [None]:
L_test.shape

In [None]:
model_marginals = pd.read_csv("Experiment_2/experiment_2.csv")

# Grab the features of the Logistic Regression Model
#lr_df = pd.read_csv("Experiment 1/LR_model.csv")

# Accuracy ROC

From the probabilities calculated above, we can create a [Receiver Operator Curve](http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html) (ROC) graph to measure the false positive rate and the true positive rate at each calculated threshold.

In [None]:
#models = ["LR_Marginals", "RNN_1_Marginals", "RNN_10_Marginals", "RNN_Full_Marginals"]
#model_colors = ["darkorange", "red", "green", "magenta"]
#model_labels = ["LogReg", "RNN_1%", "RNN_10%", "RNN_100%"]
models = ["RNN_1_Marginals", "RNN_10_Marginals"]
model_colors = ["green", "magenta"]
model_labels = ["RNN_1%", "RNN_10%"]
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')

for model_label, marginal_label, color in zip(model_labels, models, model_colors):
    fpr, tpr, _= roc_curve(model_marginals["True Labels"], model_marginals[marginal_label])
    model_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, color=color, label="{} (area = {:0.2f})".format(model_label, model_auc))

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Accuracy ROC')
plt.legend(loc="lower right")

# Precision vs Recall Curve

This code produces a [Precision-Recall](http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html) graph, which shows the trade off between [precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) at each given probability threshold.

In [None]:
#models = ["LR_Marginals", "RNN_1_Marginals", "RNN_10_Marginals", "RNN_Full_Marginals"]
#model_colors = ["darkorange", "red", "green", "magenta"]
#model_labels = ["LogReg", "RNN_1%", "RNN_10%", "RNN_100%"]
models = ["RNN_1_Marginals", "RNN_10_Marginals"]
model_colors = ["green", "magenta"]
model_labels = ["RNN_1%", "RNN_10%"]


for model_label, marginal_label, color in zip(model_labels, models, model_colors):
    precision, recall, _ = precision_recall_curve(model_marginals["True Labels"], model_marginals[marginal_label])
    model_precision = average_precision_score(model_marginals["True Labels"], model_marginals[marginal_label])
    plt.plot(recall, precision, color=color, label="{} curve (area = {:0.2f})".format(model_label, model_precision))

plt.ylabel('Precision')
plt.xlabel('Recall')
plt.title('Precision vs Recall')
plt.xlim([0, 1.01])
plt.ylim([0, 1.05])
plt.legend(loc="lower right")

# Error Analysis

This code shows the amount of true positives, false positives, true negatives and false negatives.

In [None]:
result_category = "tp"
if result_category == "tp":
    #lr_cond = (model_marginals["LR_Predictions"] == 1)&(model_marginals["True Labels"] == 1)
    rnn1_cond = (model_marginals["RNN_1_Predictions"] == 1)&(model_marginals["True Labels"] == 1)
    rnn10_cond = (model_marginals["RNN_10_Predictions"] == 1)&(model_marginals["True Labels"] == 1)
    #rnn100_cond = (model_marginals["RNN_Full_Predictions"] == 1)&(model_marginals["True Labels"] == 1)
elif result_category == "fp":
    #lr_cond = (model_marginals["LR_Predictions"] == 1)&(model_marginals["True Labels"] == -1)
    rnn1_cond = (model_marginals["RNN_1_Predictions"] == 1)&(model_marginals["True Labels"] == -1)
    rnn10_cond = (model_marginals["RNN_10_Predictions"] == 1)&(model_marginals["True Labels"] == -1)
    #rnn100_cond = (model_marginals["RNN_Full_Predictions"] == 1)&(model_marginals["True Labels"] == -1)
elif result_category == "tn":
    #lr_cond = (model_marginals["LR_Predictions"] == -1)&(model_marginals["True Labels"] == -1)
    rnn1_cond = (model_marginals["RNN_1_Predictions"] == -1)&(model_marginals["True Labels"] == -1)
    rnn10_cond = (model_marginals["RNN_10_Predictions"] == -1)&(model_marginals["True Labels"] == -1)
    #rnn100_cond = (model_marginals["RNN_Full_Predictions"] == -1)&(model_marginals["True Labels"] == -1)
elif result_category == "fn":
    #lr_cond = (model_marginals["LR_Predictions"] == -1)&(model_marginals["True Labels"] == 1)
    rnn1_cond = (model_marginals["RNN_1_Predictions"] == -1)&(model_marginals["True Labels"] == 1)
    rnn10_cond = (model_marginals["RNN_10_Predictions"] == -1)&(model_marginals["True Labels"] == 1)
    #rnn100_cond = (model_marginals["RNN_Full_Predictions"] == -1)&(model_marginals["True Labels"] == 1)
else:
    print ("Please re-run cell with correct options")

In [None]:
#display_columns = ["LR_Marginals", "RNN_1_Marginals", "RNN_10_Marginals", "RNN_Full_Marginals", "True Labels"]
display_columns = ["RNN_1_Marginals", "RNN_10_Marginals","True Labels"]

## LR

In [None]:
model_marginals[lr_cond].sort_values("LR_Marginals", ascending=False).head(10)[display_columns]

In [None]:
cand_index = list(model_marginals[lr_cond].sort_values("LR_Marginals", ascending=False).head(10).index)
lr_cands = [L_test.get_candidate(session, i) for i in cand_index]

In [None]:
print "Category: {}".format(result_category)
print 
for cand, cand_ind in zip(lr_cands, cand_index):
    text = cand[0].get_parent().text
    text = re.sub(cand[0].get_span().replace(")", "\)"), "--[[{}]]D--".format(cand[0].get_span()), text)
    text = re.sub(cand[1].get_span().replace(")", "\)"), "--[[{}]]G--".format(cand[1].get_span()), text)
    print cand_ind
    print "Candidate: ", cand
    print
    print "Text: \"{}\"".format(text)
    print
    print "--------------------------------------------------------------------------------------------"
    print

In [None]:
F_cand_index = 137865
print "Confidence Level: ", model_marginals["LR_Marginals"][F_cand_index]

In [None]:
F_cand_index = 137865
lr_df.iloc[F_test[F_cand_index, :].nonzero()[1]].sort_values("Weight", ascending=False)

In [None]:
cand = session.query(Candidate).filter(Candidate.id == L_test.get_candidate(session, 137865).id).one()
print cand
xmltree = corenlp_to_xmltree(get_as_dict(cand.get_parent()))
xmltree.render_tree(highlight=[range(cand[0].get_word_start(), cand[0].get_word_end() + 1), range(cand[1].get_word_start(), cand[1].get_word_end()+1)])

## LSTM 1% Sub-Sampling

In [None]:
model_marginals[rnn1_cond].sort_values("RNN_1_Marginals", ascending=False).head(10)[display_columns]

In [None]:
cand_index = list(model_marginals[rnn1_cond].sort_values("RNN_1_Marginals", ascending=False).head(10).index)
lr_cands = [L_test.get_candidate(session, i) for i in cand_index]

In [None]:
print "Category: {}".format(result_category)
print 
for cand in lr_cands:
    text = cand[0].get_parent().text
    text = re.sub(cand[0].get_span().replace(")", "\)"), "--[[{}]]D--".format(cand[0].get_span()), text)
    text = re.sub(cand[1].get_span().replace(")", "\)"), "--[[{}]]G--".format(cand[1].get_span()), text)
    print "Candidate: ", cand
    print
    print "Text: \"{}\"".format(text)
    print
    print "--------------------------------------------------------------------------------------------"
    print

## LSTM 10% Sub-Sampling

In [None]:
model_marginals[rnn10_cond].sort_values("RNN_10_Marginals", ascending=False).head(10)[display_columns]

In [None]:
cand_index = list(model_marginals[rnn10_cond].sort_values("RNN_10_Marginals", ascending=False).head(10).index)
lr_cands = [L_test.get_candidate(session, i) for i in cand_index]

In [None]:
print "Category: {}".format(result_category)
print 
for cand in lr_cands:
    text = cand[0].get_parent().text
    text = re.sub(cand[0].get_span().replace(")", "\)"), "--[[{}]]D--".format(cand[0].get_span()), text)
    text = re.sub(cand[1].get_span().replace(")", "\)"), "--[[{}]]G--".format(cand[1].get_span()), text)
    print "Candidate: ", cand
    print
    print "Text: \"{}\"".format(text)
    print
    print "--------------------------------------------------------------------------------------------"
    print

In [None]:
def insert(x, g_start, g_end, d_start, d_end, proba, d_cid, g_cid):
    if d_start == x[0] or g_start == x[0]:
        pos_str = "<span title=\"{}\" style=\"background-color: rgba(0,255,0,{})\">{}"
        neg_str = "<span title=\"{}\" style=\"background-color: rgba(255,0,0,{})\">{}"
        if proba > 0.5:
            return pos_str.format(d_cid, proba, x[1]) if d_start == x[0] else pos_str.format(g_cid, proba, x[1])
        else:
            return neg_str.format(d_cid, 1-proba, x[1]) if d_start == x[0] else neg_str.format(g_cid, 1-proba, x[1])
    elif d_end == x[0] or g_end == x[0]:
            return "{}</span>".format(x[1])
    else:
        return x[1]

In [None]:
html_string = ""
for cand, proba_index in zip(lr_cands, cand_index):
    gene_start = cand[1].char_start
    gene_end = cand[1].char_end
    disease_start = cand[0].char_start
    disease_end = cand[0].char_end
    proba = model_marginals["RNN_10_Marginals"].iloc[proba_index]
    letters = []
    
    for x in enumerate(cand[0].get_parent().text):
        letters.append(insert(x, gene_start, gene_end, disease_start, disease_end, proba, cand.Disease_cid, cand.Gene_cid))
    
    html_string += "<div title=\"{}\">{}</div><br />".format(proba, ''.join(letters))

In [None]:
with open("html/candidate_viewer.html", 'r') as f:
    display(HTML(f.read().format(html_string)))

# FULL LSTM

In [None]:
model_marginals[rnn100_cond].sort_values("RNN_Full_Marginals", ascending=False).head(10)[display_columns]

In [None]:
cand_index = list(model_marginals[rnn100_cond].sort_values("RNN_Full_Marginals", ascending=False).head(10).index)
lr_cands = [L_test.get_candidate(session, i) for i in cand_index]

In [None]:
print "Category: {}".format(result_category)
print 
for cand in lr_cands:
    text = cand[0].get_parent().text
    text = re.sub(cand[0].get_span().replace(")", "\)"), "--[[{}]]D--".format(cand[0].get_span()), text)
    text = re.sub(cand[1].get_span().replace(")", "\)"), "--[[{}]]G--".format(cand[1].get_span()), text)
    print "Candidate: ", cand
    print
    print "Text: \"{}\"".format(text)
    print
    print "--------------------------------------------------------------------------------------------"
    print

# Write Results to TSV

In [None]:
field_names = ["Disease ID", "Disease Char Start", "Disease Char End", "Gene ID", "Gene Char Start", "Gene Char End", "Sentence", "Prediction"]
with open("Experiment_2/LSTM_10_results.tsv", "w") as f:
    writer = csv.DictWriter(f, fieldnames=field_names)
    writer.writeheader()
    for i in tqdm.tqdm(model_marginals.index):
        cand = L_test.get_candidate(session, i)
        row = {
                "Disease ID": cand.Disease_cid, "Disease Char Start":cand[0].char_start, 
                "Disease Char End": cand[0].char_end, "Gene ID": cand.Gene_cid, 
                "Gene Char Start":cand[1].char_start, "Gene Char End":cand[1].char_end, 
                "Sentence": cand.get_parent().text, "Prediction": model_marginals.iloc[i]["RNN_10_Marginals"]}
        writer.writerow(row)