In [45]:
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import datetime
from pprint import pprint as pp

sys.path.insert(0, '/home/molly/Desktop/DeepTCGA/')
import load_data

In [41]:
def fc_layer(A_prev, size_in, size_out, name="fully-connected"):
    with tf.name_scope(name):
        w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[size_out]))
        act = tf.matmul(A_prev, w) + b
        tf.summary.histogram("weights", w)
        tf.summary.histogram("biases", b)
        tf.summary.histogram("activations", act)
        return act, w, b

    
def build_model(x, N_IN, N_HIDDEN):
    parameters = {}
    z1, w1, b1 = fc_layer(x, N_IN, N_HIDDEN, name="fc1")
    parameters.update({"w1": w1, "b1": b1})
    hidden = tf.nn.tanh(z1)
    x_recon, w2, b2 = fc_layer(hidden, N_HIDDEN, N_IN, name="fc2")
    parameters.update({"w2": w2, "b2": b2})
    return x_recon, parameters


def back_prop(x, x_recon, learning_rate):
    loss = tf.reduce_mean(tf.square(x_recon - x))
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    return train_step


def feed_forward(x, parameters):
    w1, b1 = parameters["w1"], parameters["b1"]
    hidden = tf.nn.tanh(tf.matmul(x, w1) + b1)
    w2, b2 = parameters["w2"], parameters["b2"]
    x_recon = tf.matmul(hidden, w2) + b2
    return x_recon, hidden


def mse(x, x_recon, name=""):
    mse = tf.reduce_mean(tf.square(x_recon-x))
    mse_summary = tf.summary.scalar(name + "mse", mse)
    return mse, mse_summary

In [42]:
def train_model(data, batch_size=128, num_epoch=1000, learning_rate=1e-3, extra=""):
    tf.reset_default_graph()
    LOGDIR = "/tmp/tcga_{0}".format(datetime.datetime.today().date())
    N_IN = data.train.num_features
    N_OUT = data.train.num_features
    N_HIDDEN = int(N_IN/2)
    
    # train step
    (train_batch, train_iter, val_all, val_iter, 
        train_all, train_iter_all) = data.prep_batch(batch_size=batch_size)
    x = train_batch["X"]
    x_recon, parameters = build_model(x, N_IN, N_HIDDEN)
    train_step = back_prop(x, x_recon, learning_rate)
    
    # mse
    x_train, x_val = train_all["X"], val_all["X"]
    x_train_recon, _ = feed_forward(x_train, parameters)
    x_val_recon, _ = feed_forward(x_val, parameters)
    train_mse, train_summ = mse(x_train, x_train_recon, name="train")
    val_mse, val_summ = mse(x_val, x_val_recon, name="valiation")
    
    # run
    sess = tf.Session()
    summ = tf.summary.merge_all()
    writer = tf.summary.FileWriter(LOGDIR + "ae_{0}".format(extra))
    writer.add_graph(sess.graph)                            
    sess.run(tf.global_variables_initializer())
    sess.run([train_iter.initializer, val_iter.initializer, 
              train_iter_all.initializer])
    for i in range(10000):
        sess.run(train_step)
        if i % 1000 == 0:
            [train_error, train_s, val_error, val_s] = sess.run(
                [train_mse, train_summ, val_mse, val_summ])
            print("step", i)
            print("training mse:", train_error, "validation mse", val_error)
    x_batch = sess.run(x)
    x_recon, hidden = sess.run(feed_forward(x_batch, parameters))
    print(x_batch)
    print(hidden)
    print(x_recon)
    pp(sess.run(parameters))
    sess.close()

In [43]:
filename = "./data/mRNA_PCA_0.9_variance_StandardScaled.csv"
tcga = load_data.read_data_sets(filename)

In [44]:
train_model(tcga, batch_size=128)

step 0
training mse: 3.6210618 validation mse 3.6267076
step 1000
training mse: 0.5373182 validation mse 0.6909121
step 2000
training mse: 0.46958587 validation mse 0.67098707
step 3000
training mse: 0.41388822 validation mse 0.6599569
step 4000
training mse: 0.40571782 validation mse 0.6788304
step 5000
training mse: 0.40113533 validation mse 0.6861478
step 6000
training mse: 0.39863923 validation mse 0.68911123
step 7000
training mse: 0.39800736 validation mse 0.69227225
step 8000
training mse: 0.396873 validation mse 0.69318444
step 9000
training mse: 0.39615357 validation mse 0.6944781
[[ 0.05460615  2.0871959  -1.4225044  ... -0.3922105  -1.3891143
  -0.9487117 ]
 [-1.8024722   2.3156471  -0.6538204  ... -0.4953093  -1.1372011
  -0.20151384]
 [ 0.625159    1.4412061  -0.36655846 ... -0.7830135  -0.06554774
  -1.852687  ]
 ...
 [ 0.83134836 -0.63173866  0.08419194 ... -1.4867461  -1.1253355
   0.37700656]
 [ 1.8663054   1.0647461  -1.6047138  ...  0.8292623   0.29505697
   0.551434