In [3]:
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import datetime

import load_data

In [27]:
def fc_layer(A_prev, size_in, size_out, name="fully-connected"):
    with tf.name_scope(name):
        w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[size_out]))
        act = tf.matmul(A_prev, w) + b
        tf.summary.histogram("weights", w)
        tf.summary.histogram("biases", b)
        tf.summary.histogram("activations", act)
        return act, w, b

    
def build_model(x, N_IN, N_HIDDEN):
    parameters = {}
    a1, w1, b1 = fc_layer(x, N_IN, N_HIDDEN, name="fc1")
    parameters.update({"a1":a1, "w1": w1, "b1": b1})
    hidden = tf.nn.relu(a1, name="hidden")
    x_recon, w2, b2 = fc_layer(hidden, N_HIDDEN, N_IN, name="fc2")
    parameters.update({"w2": w2, "b2": b2})
    return x_recon, parameters


def back_prop(x, x_recon, learning_rate):
    loss = tf.reduce_mean(tf.square(x_recon - x))
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    return train_step


def feed_forward(x, parameters):
    w1, b1 = parameters["w1"], parameters["b1"]
    hidden = tf.nn.relu(tf.matmul(x, w1) + b1)
    w2, b2 = parameters["w2"], parameters["b2"]
    x_recon = tf.matmul(hidden, w2) + b2
    return x_recon


def mse(x, x_recon, name=""):
    mse = tf.reduce_mean(tf.square(x_recon-x))
    mse_summary = tf.summary.scalar(name + "_mse", mse)
    return mse, mse_summary

In [43]:
def train_model(data, batch_size=128, num_epoch=1000, learning_rate=1e-3, extra=""):
    tf.reset_default_graph()
    LOGDIR = "/tmp/tcga_{0}".format(datetime.datetime.today().date())
    N_IN = data.train.num_features
    N_OUT = data.train.num_features
    N_HIDDEN = 1000
    
    # train step
    (train_batch, train_iter, val_all, val_iter, 
        train_all, train_iter_all) = data.prep_train_batch(batch_size=batch_size)
    x = train_batch["X"]
    x_recon, parameters = build_model(x, N_IN, N_HIDDEN)
    train_step = back_prop(x, x_recon, learning_rate)
    
    # mse
    x_train, x_val = train_all["X"], val_all["X"]
    x_train_recon = feed_forward(x_train, parameters)
    x_val_recon = feed_foward(x_val, parameters)
    train_mse, train_summ = mse(x_train, x_train_recon, name="train")
    val_mse, val_summ = mse(x_val, x_val_recon, name="valiation")
    
    # run
    sess = tf.Session()
    summ = tf.summary.merge_all()
    writer = tf.summary.FileWriter(LOGDIR + "ae_{0}".format(extra))
    writer.add_graph(sess.graph)                            
    sess.run(tf.global_variables_initializer())

    for epoch in range(num_epoch):
        sess.run([train_iter.initializer])
        t0 = time.time()
        try:
            while True:
                sess.run(train_step)
        except tf.errors.OutOfRangeError:
            sess.run([train_iter_all.initializer, val_iter.initializer])
            [train_error, train_s, val_error, val_s] = sess.run(
                [train_mse, train_summ, val_mse, val_summ])
            writer.add_summary(train_s, epoch)
            writer.add_summary(val_s, epoch)
            if epoch % 10 == 0:
                print("epoch", epoch)
                print("training mse:", train_error, "validation mse", val_error)
                print("epoch time:", time.time()-t0)                             
    sess.close()

In [7]:
tcga = load_data.read_data_sets("./data/mRNA_lognorm_MinMaxScaled.csv")

In [44]:
train_model(tcga, extra="first")

epoch 0
training mse: 0.18025485 validation mse 0.18044284
epoch time: 6.865894079208374
epoch 10
training mse: 0.0397139 validation mse 0.039752603
epoch time: 1.537994384765625
epoch 20
training mse: 0.016892394 validation mse 0.016865209
epoch time: 1.5532119274139404
epoch 30
training mse: 0.015188688 validation mse 0.015145953
epoch time: 1.5515069961547852
epoch 40
training mse: 0.015113662 validation mse 0.015068741
epoch time: 1.5514235496520996
epoch 50
training mse: 0.015111879 validation mse 0.015066693
epoch time: 1.5401272773742676
epoch 60
training mse: 0.0151119325 validation mse 0.015066264
epoch time: 1.555967092514038
epoch 70
training mse: 0.015111965 validation mse 0.015066518
epoch time: 1.5533232688903809
epoch 80
training mse: 0.015112203 validation mse 0.015067463
epoch time: 1.5356624126434326
epoch 90
training mse: 0.015112378 validation mse 0.015066416
epoch time: 1.5592448711395264
epoch 100
training mse: 0.015112734 validation mse 0.015066557
epoch time: 1.

KeyboardInterrupt: 