In [2]:
import csv
import tensorflow as tf
from tflearn.layers.conv import global_avg_pool
from tensorflow.contrib.layers import batch_norm, flatten
from tensorflow.contrib.layers import xavier_initializer
from tensorflow.contrib.framework import arg_scope
import numpy as np
import os
from dataset_pipeline import Dataset

In [3]:
feature = "Age"


config = {"tf_train":"....tfrecord","tf_val":"....tfrecord","model_dir":"model"}

init_learning_rate = 1e-4
epsilon = 1e-4
batch_size = 25

height = 256
width = 256
total_epochs = 200
iteration = 50000//batch_size

In [4]:
def conv_layer(input, filter, kernel, stride=1, layer_name="conv"):
    with tf.name_scope(layer_name):
        network = tf.layers.conv2d(inputs=input, use_bias=False, filters=filter, kernel_size=kernel, strides=stride, padding='SAME',kernel_initializer=xavier_initializer())
        return network

def Global_Average_Pooling(x, stride=1):
    return global_avg_pool(x, name='Global_avg_pooling')

def Batch_Normalization(x, training, scope):
    with arg_scope([batch_norm],
                   scope=scope,
                   updates_collections=None,
                   decay=0.9,
                   center=True,
                   scale=True,
                   zero_debias_moving_mean=True) :
        return tf.cond(training,
                       lambda : batch_norm(inputs=x, is_training=training, reuse=None),
                       lambda : batch_norm(inputs=x, is_training=training, reuse=True))
    
def Average_pooling(x, pool_size=[2,2], stride=2, padding='VALID'):
    return tf.layers.average_pooling2d(inputs=x, pool_size=pool_size, strides=stride, padding=padding)


def Max_Pooling(x, pool_size=[3,3], stride=2, padding='VALID'):
    return tf.layers.max_pooling2d(inputs=x, pool_size=pool_size, strides=stride, padding=padding)

def Concatenation(layers) :
    return tf.concat(layers, axis=3)

def Drop_out(x, rate, training) :
    return tf.layers.dropout(inputs=x, rate=rate, training=training)

def Relu(x):
    return tf.nn.relu(x)

def Linear(x,units,name) :
    return tf.layers.dense(inputs=x, units=units, name=name)

def Evaluate(sess):
    test_acc = 0.0
    test_pre_index = 0
    logit = []
    test_loss = []
    i = 0
    testdata.initialize()
    test_batch_x_,test_batch_y_ = testdata.get_next()
    while True :
        try :
            
            
            test_batch_x,test_batch_y = sess.run([test_batch_x_["fundus"],test_batch_y_[feature]])
            
            test_batch_y = test_batch_y / 100.
            
            test_feed_dict = {
                x: test_batch_x,
                label: test_batch_y,
                learning_rate: epoch_learning_rate,
                training_flag: False
            }
            
            

            loss_ = sess.run(cost, feed_dict=test_feed_dict)

            test_loss.append(loss_)
        except tf.errors.OutOfRangeError:
            break
            
    
    return np.array(test_loss).mean()



In [5]:
class VGG():
    def __init__(self, x,training):
        self.training = training
        self.model = self.VGG(x)
            
    def VGG(self, x):
        x1 = conv_layer(x, filter=64, kernel=[3,3],layer_name='conv1',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch1')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=64, kernel=[3,3],layer_name='conv2',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch2')
        x1 = Relu(x1)
        x1 = Max_Pooling(x1,pool_size=[2,2],stride=2)
        
        x1 = conv_layer(x1, filter=128, kernel=[3,3],layer_name='conv3',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch3')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=128, kernel=[3,3],layer_name='conv4',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch4')
        x1 = Relu(x1)
        x1 = Max_Pooling(x1,pool_size=[2,2],stride=2)
        
        x1 = conv_layer(x1, filter=256, kernel=[3,3],layer_name='conv5',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch5')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=256, kernel=[3,3],layer_name='conv6',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch6')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=256, kernel=[3,3],layer_name='conv7',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch7')
        x1 = Relu(x1)
        x1 = Max_Pooling(x1,pool_size=[2,2],stride=2)
        
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv8',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch8')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv9',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch9')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv10',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch10')
        x1 = Relu(x1)
        x1 = Max_Pooling(x1,pool_size=[2,2],stride=2)
        
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv11',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch11')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv12',stride=1)
        x1 = Batch_Normalization(x1, training=self.training, scope='batch12')
        x1 = Relu(x1)
        x1 = conv_layer(x1, filter=512, kernel=[3,3],layer_name='conv13',stride=1)
        
        x1 = Global_Average_Pooling(x1)
        x1 = flatten(x1)
        regression = Linear(x1,1,"regr")

        
        return tf.reshape(regression,[-1])

In [6]:
# image_size = 32, img_channels = 3, class_num = 10 in cifar10

img_channels = 3
x = tf.placeholder(tf.float32, shape=[None, height, width, img_channels])
label = tf.placeholder(tf.float32, shape=[None]) 

training_flag = tf.placeholder(tf.bool)


learning_rate = tf.placeholder(tf.float32, name='learning_rate')

regr = VGG(x=x,training=training_flag).model
cost = tf.losses.mean_squared_error(label, regr)

extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=epsilon)
train = optimizer.minimize(cost)

saver = tf.train.Saver(tf.global_variables())

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Use keras.layers.max_pooling2d instead.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use tf.cast instead.


In [7]:
traindata = Dataset(config["tf_train"],batch_size,width,True,True,True,True,True,True,feature)
testdata = Dataset(config["tf_val"],batch_size,width,True,False,False,False,False,False,feature)

In [8]:
train_losses = []
val_losses = []


min_loss = 1.8
log_device_placement = True
batch_x_,batch_y_ = traindata.get_next()
with tf.Session(config=tf.ConfigProto(log_device_placement=log_device_placement)) as sess:
    ckpt = tf.train.get_checkpoint_state('../'+config["model_dir"]+'/fundus_'+feature+"_"+str(batch_size))
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())
    
    
    for epoch in range(1, total_epochs):
        
        if epoch % 100 == 0 :
            init_learning_rate = init_learning_rate / 2.
        epoch_learning_rate = init_learning_rate
        
        pre_index = 0
        train_acc = 0.0
        train_loss = 0.0
        
    
        for step in range(1, iteration + 1):

            try :
                batch_x,batch_y = sess.run([batch_x_["fundus"],batch_y_[feature]])
                batch_y = batch_y / 100.


                train_feed_dict = {
                    x: batch_x,
                    label: batch_y,
                    learning_rate: epoch_learning_rate,
                    training_flag : True
                }

                _, batch_loss,_ = sess.run([train, cost,extra_update_ops], feed_dict=train_feed_dict)

                train_loss += batch_loss
                pre_index += batch_size
                if step == iteration :
                    train_loss /= iteration # average loss


                    test_loss = Evaluate(sess)

                    if test_loss < min_loss :
                        min_loss = test_loss
                        saver.save(sess=sess, save_path='../'+config["model_dir"]+'/fundus_'+feature+'_min/model.ckpt')


                    line = "epoch: %d/%d, train_loss: %.4f, test_loss: %.4f \n" % (
                        epoch, total_epochs, train_loss, test_loss)
                    
                    train_losses.append(train_loss)
                    val_losses.append(test_loss)
                    
                    print(line)
                    
            except tf.errors.OutOfRangeError:
                print("outofRange")
                pass

        saver.save(sess=sess, save_path='../'+config["model_dir"]+'/fundus_'+feature+"_"+'/model.ckpt')
            