# MUST RUN AT THE START OF EVERYTHING

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import csv
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_curve, auc, f1_score
import tqdm

In [None]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [None]:
from snorkel.annotations import FeatureAnnotator, LabelAnnotator, load_marginals
from snorkel.learning import SparseLogisticRegression
from snorkel.learning.disc_models.rnn import reRNN
from snorkel.learning.utils import RandomSearch
from snorkel.models import Candidate, FeatureKey, candidate_subclass
from snorkel.utils import get_as_dict
from tree_structs import corenlp_to_xmltree
from treedlib import compile_relation_feature_generator

In [None]:
edge_type = "dg"

In [None]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
else:
    print("Please pick a valid edge type")

# Load the data

In [None]:
featurizer = FeatureAnnotator()
labeler = LabelAnnotator(lfs=[])

In [None]:
%%time
L_test = labeler.load_matrix(session,split=2)
F_test = featurizer.load_matrix(session, split=2)

In [None]:
lr_model = SparseLogisticRegression()
#lstm_model = reRNN(seed=100, n_threads=4)

In [None]:
lr_model.load(save_dir='checkpoints/grid_search/', model_name="SparseLogisticRegression_1")
#lstm_model.load(save_dir='checkpoints/rnn', model_name="RNN")

In [None]:
lr_marginals

In [None]:
lr_marginals = lr_model.marginals(F_test)
#rnn_marginals = lstm.marginals(F_test)
#marginal_df = pd.DataFrame([lr_marginals, rnn_marginals], columns=["LR_Marginals", "RNN_marginals"])
marginal_df = pd.DataFrame(lr_marginals.T, columns=["LR_Marginals"])
marginal_df.to_csv("disc_marginals.csv", index=False)

In [None]:
model_marginals = pd.read_csv("disc_marginals.csv")
top_pos_predict_model_marginals = model_marginals.sort_values("LR_Marginals", ascending=False).head(10)
top_neg_predict_model_marginals = model_marginals.sort_values("LR_Marginals", ascending=True).head(10)

In [None]:
F_test[index,:]
#F_test.getrow(index)

In [None]:
from collections import defaultdict
pos_feature_freq = defaultdict(int)
for index in tqdm.tqdm(top_pos_predict_model_marginals.index):
    top_match_feat = F_test[index,:].nonzero()[1]
    for feature in lr_df["Feature"][top_match_feat]:
        pos_feature_freq[feature] += 1
pos_features_df = pd.DataFrame(pos_feature_freq.items(), columns=["Feature", "Frequency"])

In [None]:
from collections import defaultdict
neg_feature_freq = defaultdict(int)
for index in tqdm.tqdm(top_neg_predict_model_marginals.index):
    top_match_feat = F_test[index,:].nonzero()[1]
    for feature in lr_df["Feature"][top_match_feat]:
        neg_feature_freq[feature] += 1
neg_features_df = pd.DataFrame(neg_feature_freq.items(), columns=["Feature", "Frequency"])

In [None]:
pos_features_df.sort_values("Frequency", ascending=False).to_csv('POS_LR_Feat.csv', index=False)

In [None]:
neg_features_df.sort_values("Frequency", ascending=False).to_csv("NEG_LR_Feat.csv", index=False)

In [None]:
top_pos_predict_model_marginals

In [None]:
F_test.get_candidate(session, 192190).id

In [None]:
cand = session.query(Candidate).filter(Candidate.id == 19841894).one()
print cand
xmltree = corenlp_to_xmltree(get_as_dict(cand.get_parent()))
xmltree.render_tree(highlight=[range(cand[0].get_word_start(), cand[0].get_word_end() + 1), range(cand[1].get_word_start(), cand[1].get_word_end()+1)])

# Error Analysis

In [None]:
_, _, _, _ = lr_model.error_analysis(session, F_test, L_test)

In [None]:
_, _, _, _ = lstm_model.error_analysis(session, F_test, L_test)

# Accuracy ROC

In [None]:
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')

for model_marginals, color in zip(["LR_Marginals", "RNN_marginals"], ["darkorange", "red"]):
    fpr, tpr, _= roc_curve(L_test[0:].todense(), marginal_df[model_marginals])
    model_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, color=color, label="{} curve (area = {0.2f})".format(model_auc))

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Accuracy ROC')
plt.legend(loc="lower right")

# Precision vs Recall Curve

In [None]:
#plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')

for model_marginals, color in zip(["LR_Marginals", "RNN_marginals"], ["darkorange", "red"]):
    precision, recall, _=  precision_recall_curve(L_test[0:].todense(), marginal_df[model_marginals])
    model_f1 = f1_score(L_test[0:].todense(), marginal_df[model_marginals])
    plt.plot(fpr, tpr, color=color, label="{} curve (area = {0.2f})".format(model_f1))

plt.xlabel('Precision')
plt.ylabel('Recall')
plt.title('Precision vs Recall')
plt.legend(loc="lower right")

# LR Model Details

In [None]:
lr_df = pd.read_csv("LR_model.csv")

In [None]:
weight_df = lr_df.sort_values("Weight", ascending=False, kind='mergesort')
weight_df.head(15)

In [None]:
n, bins, patches = plt.hist(weight_df["Weight"])
plt.xlabel('Weight')
plt.ylabel('Count')
plt.title('Distribution of LR Weights')

In [None]:
cand = session.query(Candidate).filter(Candidate.id==674118).all()
print cand
print cand[0].get_parent()

# RNN Model Details

TBD