# Train the Discriminator for Candidate Classification on the Sentence Level

This notebook is designed to train ML algorithms: Long Short Term Memory Neural Net (LSTM) and SparseLogisticRegression (SLR) for candidate classification. 

## MUST RUN AT THE START OF EVERYTHING

Set up the database for data extraction and load the Candidate subclass for the algorithms below

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import csv
import os

import numpy as np
import pandas as pd
import tqdm

In [None]:
#Set up the environment
username = "danich1"
password = "snorkel"
dbname = "pubmeddb"

#Path subject to change for different os
database_str = "postgresql+psycopg2://{}:{}@/{}?host=/var/run/postgresql".format(username, password, dbname)
os.environ['SNORKELDB'] = database_str

from snorkel import SnorkelSession
session = SnorkelSession()

In [None]:
from snorkel.annotations import FeatureAnnotator, LabelAnnotator, load_marginals
from snorkel.learning import SparseLogisticRegression
from snorkel.learning.disc_models.rnn import reRNN
from snorkel.learning.utils import RandomSearch
from snorkel.models import Candidate, FeatureKey, candidate_subclass

In [None]:
edge_type = "dg"

In [None]:
if edge_type == "dg":
    DiseaseGene = candidate_subclass('DiseaseGene', ['Disease', 'Gene'])
elif edge_type == "gg":
    GeneGene = candidate_subclass('GeneGene', ['Gene1', 'Gene2'])
elif edge_type == "cg":
    CompoundGene = candidate_subclass('CompoundGene', ['Compound', 'Gene'])
elif edge_type == "cd":
    CompoundDisease = candidate_subclass('CompoundDisease', ['Compound', 'Disease'])
else:
    print("Please pick a valid edge type")

# Load preprocessed data 

This code will automatically load our labels and features that were generated in the [previous notebook](2.data-labeler.ipynb). 

In [None]:
%%time
labeler = LabelAnnotator(lfs=[])

L_train = labeler.load_matrix(session, split=0)
L_dev = labeler.load_matrix(session, split=1)
L_test = labeler.load_matrix(session, split=2)

In [None]:
print "Total Data Shape:"
print L_train.shape
print L_dev.shape
print L_test.shape
print

In [None]:
%%time
featurizer = FeatureAnnotator()

F_train = featurizer.load_matrix(session, split=0)
F_dev = featurizer.load_matrix(session, split=1)
F_test = featurizer.load_matrix(session, split=2)

In [None]:
print "Total Data Shape:"
print F_train.shape
print F_dev.shape
print F_test.shape
print

# Train Sparse Logistic Regression Disc Model

Here we train an SLR. To find the optimal hyperparameter settings this code uses a [random search](http://scikit-learn.org/stable/modules/grid_search.html) instead of iterating over all possible combinations of parameters. After the final model has been found, it is saved in the checkpoints folder to be loaded in the [next notebook](5.data-analysis.ipynb). Furthermore, the weights for the final model are output into a text file to be analyzed as well.

In [None]:
%time train_marginals = load_marginals(session, split=0)

In [None]:
# Searching over learning rates
""" 
old code
param_ranges = {
    'lr' : [1e-2, 1e-3, 1e-4, 1e-5, 1e-6],
    'l1_penalty' : [1e-2, 1e-3, 1e-4, 1e-5, 1e-6],
    'l2_penalty' : [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]
}
"""

rate_parameters = [
        RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10), 
        RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10), 
        RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)]

searcher = RandomSearch(SparseLogisticRegression, rate_parameters, F_train,
                        Y_train=train_marginals, n=5)

In [None]:
%%time
np.random.seed(100)
disc_model, run_stats = searcher.fit(F_dev, L_dev, n_threads=4, n_epochs=50, rebalance=0.5, print_freq=25)

In [None]:
LR_marginals = disc_model.marginals(F_test)
LR_marginals

In [None]:
filename = "stratified_data/lstm_disease_gene_holdout/LR_data/LR_test_marginals.csv"
pd.DataFrame(LR_marginals, columns=["LR_Marginals"]).to_csv(filename, index=False)

## Train a LSTM Disc Model

This block of code trains an LSTM. An LSTM is a special type of recurrent nerual network that retains a memory of past values over period of time. ([Further explaination here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)). The problem with the code below is that sqlalchemy runs into an out of memory error on my computer during the preprocessing step. As a consequence we have to resort loading this data onto University of Pennsylvania's Performance Computing Cluster. The data that gets preprocessed is exported to a text file and then get shipped towards the cluster.

In [None]:
directory = 'stratified_data/lstm_disease_gene_holdout/'

In [None]:
%time train_marginals = load_marginals(session, split=0)
np.savetxt("{}/train_marginals".format(directory), train_marginals)

In [None]:
%%time
"""
train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.5,
    'print_freq': 1,
    'max_sentence_length': 1000,
}
"""
lstm = reRNN(seed=100, n_threads=4)
#lstm.train(train_cands, train_marginals[0:10], X_dev=dev_cands, Y_dev=L_dev[0:10], **train_kwargs)

### Write the Training data to an External File

In [None]:
%%time
field_names = ["disease_id", "disease_char_start", "disease_char_end", "gene_id", "gene_char_start", "gene_char_end", "sentence", "pubmed"]
chunksize = 100000
start = 0

with open('{}/train_candidates_ends.csv'.format(directory), 'wb') as g:
    with open("{}/train_candidates_offsets.csv".format(directory), "wb") as f:
        with open("{}/train_candidates_sentences.csv".format(directory), "wb") as h:
            output = csv.writer(f)
            writer = csv.DictWriter(h, fieldnames=field_names)
            writer.writeheader()

            while True:
                train_cands = (
                        session
                        .query(DiseaseGene)
                        .filter(DiseaseGene.split == 0)
                        .order_by(DiseaseGene.id)
                        .limit(chunksize)
                        .offset(start)
                        .all()
                )

                if not train_cands:
                    break

                
                for c in tqdm.tqdm(train_cands):
                    data, ends = lstm._preprocess_data([c], extend=True)
                    output.writerow(data[0])
                    g.write("{}\n".format(ends[0]))
                    
                    row = {
                    "disease_id": c.Disease_cid,"disease_name":c[0].get_span(),
                    "disease_char_start":c[0].char_start, "disease_char_end": c[0].char_end, 
                    "gene_id": c.Gene_cid, "gene_name":c[1].get_span(), 
                    "gene_char_start":c[1].char_start, "gene_char_end":c[1].char_end, 
                    "sentence": c.get_parent().text, "pubmed", c.get_parent().get_parent().name
                    }
                
                    writer.writerow(row)

                start += chunksize

### Save the word dictionary to an External File

In [None]:
%%time
with open("{}/train_word_dict.csv".format(directory), 'w') as f:
    output = csv.DictWriter(f, fieldnames=["Key", "Value"])
    output.writeheader()
    for key in tqdm.tqdm(lstm.word_dict.d):
        output.writerow({'Key':key, 'Value': lstm.word_dict.d[key]})

### Save the Development Candidates to an External File

In [None]:
dev_cands = (
        session
        .query(DiseaseGene)
        .filter(DiseaseGene.split == 1)
        .order_by(DiseaseGene.id)
        .all()
)

dev_cand_labels = pd.read_csv("stratified_data/dev_set.csv")
hetnet_set = set(map(tuple,dev_cand_labels[dev_cand_labels["hetnet"] == 1][["disease_ontology", "gene_id"]].values))

In [None]:
%%time
field_names = [
    "disease_id", "disease_char_start", 
    "disease_char_end", "gene_id", 
    "gene_char_start", "gene_char_end", 
    "sentence", "pubmed"
]

with open('{}/dev_candidates_offset.csv'.format(directory), 'wb') as g:
    with open('{}/dev_candidates_labels.csv'.format(directory), 'wb') as f:
        with open('{}/dev_candidates_sentences.csv'.format(directory), 'wb') as h:
            
            output = csv.writer(g)
            label_output = csv.writer(f)
            writer = csv.DictWriter(h, fieldnames=field_names)
            writer.writeheader()
            
            for c in tqdm.tqdm(dev_cands):
                data, ends = lstm._preprocess_data([c])
                output.writerow(data[0])
                label_output.writerow([1 if (c.Disease_cid, int(c.Gene_cid)) in hetnet_set else -1])
                
                row = {
                "disease_id": c.Disease_cid,"disease_name":c[0].get_span(),
                "disease_char_start":c[0].char_start, "disease_char_end": c[0].char_end, 
                "gene_id": c.Gene_cid, "gene_name":c[1].get_span(), 
                "gene_char_start":c[1].char_start, "gene_char_end":c[1].char_end, 
                "sentence": c.get_parent().text, "pubmed", c.get_parent().get_parent().name
                }
                
                writer.writerow(row) 

### Save the Test Candidates to an External File

In [None]:
test_cands = (
        session
        .query(DiseaseGene)
        .filter(DiseaseGene.split == 2)
        .order_by(DiseaseGene.id)
        .all()
)

dev_cand_labels = pd.read_csv("stratified_data/test_set.csv")
hetnet_set = set(map(tuple,dev_cand_labels[dev_cand_labels["hetnet"] == 1][["disease_ontology", "gene_id"]].values))

In [None]:
%%time
field_names = ["disease_id", "disease_char_start", "disease_char_end", "gene_id", "gene_char_start", "gene_char_end", "sentence", "pubmed"]
with open('{}/test_candidates_offset.csv'.format(directory), 'wb') as g:
    with open('{}/test_candidates_labels.csv'.format(directory), 'wb') as f:
        with open('{}/test_candidates_sentences.csv'.format(directory), 'wb') as h:
            
            output = csv.writer(g)
            label_output = csv.writer(f)
            writer = csv.DictWriter(h, fieldnames=field_names)
            writer.writeheader()
            
            for c in tqdm.tqdm(test_cands):
                data, ends = lstm._preprocess_data([c])
                output.writerow(data[0])
                label_output.writerow([1 if (c.Disease_cid, int(c.Gene_cid)) in hetnet_set else -1])
                
                row = {
               "disease_id": c.Disease_cid,"disease_name":c[0].get_span(),
                "disease_char_start":c[0].char_start, "disease_char_end": c[0].char_end, 
                "gene_id": c.Gene_cid, "gene_name":c[1].get_span(), 
                "gene_char_start":c[1].char_start, "gene_char_end":c[1].char_end, 
                "sentence": c.get_parent().text, "pubmed", c.get_parent().get_parent().name
                }
                
                writer.writerow(row) 