Trains a simple fully-connected neural network on electronic 
health records using treatment assignment as target, and
extracts activations from the last layer.

This notebook also does some hyperparameter tuning on the
network.

In [1]:
import tensorflow as tf
import numpy as np

from utils.data import load_data
from utils.nn import add_fully_connected

In [2]:
# Load data
datasets = load_data()
dim = datasets.dimension

In [4]:
# Instantiate session and create base variables
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=[None, dim])
y_ = tf.placeholder(tf.float32, shape=[None, 2])

In [5]:
# Hidden layers dimensions
hidden1 = 100
hidden2 = 100

In [6]:
# Actual network creation
h1 = add_fully_connected(x, dim, hidden1)
h2 = add_fully_connected(h1, hidden1, hidden2)
y_scores = add_fully_connected(h2, hidden2, 1)

In [7]:
# Define softmax and accuracy
cross_entropy = tf.reduce_mean(
    - tf.reduce_sum(y_ * tf.log(y_scores), reduction_indices=[1])
)
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_scores, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [9]:
# Actually train data
sess.run(tf.initialize_all_variables())

for i in range(10000):
    batch = datasets.train.next_batch(50)
    if i % 1000 == 0:
        train_accuracy = accuracy.eval(
            feed_dict={x: batch[0], y_: batch[1]}
        )
        print("Step %d, training accuracy %g" % (i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1]})

print("Final test accuracy %g"%accuracy.eval(
    feed_dict={x: datasets.val1._patients, y_: datasets.val1._labels}
))

Step 0, training accuracy 0.7
Step 1000, training accuracy 0.74
Step 2000, training accuracy 0.68
Step 3000, training accuracy 0.66
Step 4000, training accuracy 0.92
Step 5000, training accuracy 0.74
Step 6000, training accuracy 0.78
Step 7000, training accuracy 0.8
Step 8000, training accuracy 0.82
Step 9000, training accuracy 0.78
Final test accuracy 0.76691
