In [46]:
import warnings
warnings.filterwarnings("ignore")

from sklearn.preprocessing import LabelEncoder, LabelBinarizer
import tensorflow as tf
import pandas as pd
import numpy as np
import time
import datetime

import load_data

### One hidden layer

In [63]:
def weight_variable(shape):
    initial=tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def fc_layer(input, size_in, size_out, name="fc"):
    with tf.name_scope(name):
        w = weight_variable([size_in, size_out])
        b = bias_variable([size_out])
        act = tf.matmul(input, w) + b
#         tf.summary.histogram("weights", w)
#         tf.summary.histogram("biases", b)
#         tf.summary.histogram("activations", act)
        return act
    

def train_model(data, label="tissue", learning_rate=1e-2):
    tf.reset_default_graph()
    LOGDIR = "/tmp/tcga_{0}/".format(str(datetime.datetime.today().date()))
    N_IN = data.train.X.shape[1]
    N_OUT = data.train.y[label].shape[1]
    N_HIDDEN = int(np.mean(N_IN + N_OUT))
    
    x = tf.placeholder(tf.float32, [None, N_IN], name="x")
    y_true = tf.placeholder(tf.float32, [None, N_OUT], name="labels")
    hidden = tf.nn.relu(fc_layer(x, N_IN, N_HIDDEN), name="hidden")
    y_pred = fc_layer(hidden, N_HIDDEN, N_OUT, name="softmax")
    
    xent = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            labels=y_true, logits=y_pred), name="xent")
    tf.summary.scalar("xent", xent)

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)

    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_true, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    train_accu_summ = tf.summary.scalar("train_accuracy", accuracy)
    test_accu_summ = tf.summary.scalar("test_accuracy", accuracy)

    sess = tf.Session()
    summ = tf.summary.merge_all()
    writer = tf.summary.FileWriter(LOGDIR + label)
    writer.add_graph(sess.graph)
    
    # training
    t0 = time.time()
    sess.run(tf.global_variables_initializer())

    for i in range(10001):
        batch_x, batch_y = data.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_x, y_true: batch_y[label]})
        if i % 5 == 0:
            [train_accuracy, train_s] = sess.run([accuracy, train_accu_summ],
                feed_dict={x: tcga.train.X, y_true: tcga.train.y[label]})
            [test_accuracy, test_s] = sess.run([accuracy, test_accu_summ],
                feed_dict={x: tcga.test.X, y_true: tcga.test.y[label]})
            writer.add_summary(train_s, i)
            writer.add_summary(test_s, i)
            if i % 100 == 0:
                print("step", i, "training accuracy", train_accuracy, 
                     "test_accuracy", test_accuracy)                
    print("training time:", time.time() - t0)

In [64]:
tcga = load_data.read_data_sets("./data/mRNA_PCA_0.6_variance_MinMaxScaled.csv")
for label_name in ["tissue", "gender", "tumor"]:
    print(label_name)
    tcga.train.reset_epoch()
    train_model(tcga, label=label_name)

tissue
step 0 training accuracy 0.0715685 test_accuracy 0.0719557
step 100 training accuracy 0.746802 test_accuracy 0.76476
step 200 training accuracy 0.823441 test_accuracy 0.819188
step 300 training accuracy 0.866083 test_accuracy 0.878229
step 400 training accuracy 0.883485 test_accuracy 0.885609
step 500 training accuracy 0.89063 test_accuracy 0.898524
step 600 training accuracy 0.886597 test_accuracy 0.885609
step 700 training accuracy 0.893166 test_accuracy 0.894834
step 800 training accuracy 0.898467 test_accuracy 0.904059
step 900 training accuracy 0.899043 test_accuracy 0.895756
step 1000 training accuracy 0.899735 test_accuracy 0.903137
step 1100 training accuracy 0.901925 test_accuracy 0.895756
step 1200 training accuracy 0.905613 test_accuracy 0.911439
step 1300 training accuracy 0.905958 test_accuracy 0.909594
step 1400 training accuracy 0.903769 test_accuracy 0.904059
step 1500 training accuracy 0.908724 test_accuracy 0.911439
step 1600 training accuracy 0.899274 test_acc

step 3700 training accuracy 0.708425 test_accuracy 0.694649
step 3800 training accuracy 0.712689 test_accuracy 0.704797
step 3900 training accuracy 0.713495 test_accuracy 0.714945
step 4000 training accuracy 0.719488 test_accuracy 0.72786
step 4100 training accuracy 0.710499 test_accuracy 0.723247
step 4200 training accuracy 0.696669 test_accuracy 0.684502
step 4300 training accuracy 0.704852 test_accuracy 0.688192
step 4400 training accuracy 0.718566 test_accuracy 0.726937
step 4500 training accuracy 0.719949 test_accuracy 0.729705
step 4600 training accuracy 0.719719 test_accuracy 0.72786
step 4700 training accuracy 0.716261 test_accuracy 0.72417
step 4800 training accuracy 0.718682 test_accuracy 0.726937
step 4900 training accuracy 0.719949 test_accuracy 0.722325
step 5000 training accuracy 0.718797 test_accuracy 0.723247
step 5100 training accuracy 0.719373 test_accuracy 0.726937
step 5200 training accuracy 0.719719 test_accuracy 0.714022
step 5300 training accuracy 0.721448 test_a

step 7400 training accuracy 0.975222 test_accuracy 0.97786
step 7500 training accuracy 0.973032 test_accuracy 0.976937
step 7600 training accuracy 0.974991 test_accuracy 0.976937
step 7700 training accuracy 0.975568 test_accuracy 0.975092
step 7800 training accuracy 0.975913 test_accuracy 0.978782
step 7900 training accuracy 0.972686 test_accuracy 0.969557
step 8000 training accuracy 0.975452 test_accuracy 0.976015
step 8100 training accuracy 0.97672 test_accuracy 0.976937
step 8200 training accuracy 0.973032 test_accuracy 0.976015
step 8300 training accuracy 0.974415 test_accuracy 0.97417
step 8400 training accuracy 0.976259 test_accuracy 0.971402
step 8500 training accuracy 0.976029 test_accuracy 0.97786
step 8600 training accuracy 0.960585 test_accuracy 0.9631
step 8700 training accuracy 0.975452 test_accuracy 0.978782
step 8800 training accuracy 0.973724 test_accuracy 0.976937
step 8900 training accuracy 0.97649 test_accuracy 0.976937
step 9000 training accuracy 0.975107 test_accur