In [None]:
import time
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
learning_rate = 0.01
num_epochs = 10
batch_size = 1
num_trainimgs = 234
num_valimgs = 26
decay_step = int(num_trainimgs / batch_size * 10)
decay_rate = 0.9
seed = 777
tf.set_random_seed(seed)

In [None]:
cur_dir = os.getcwd()
data_dir = cur_dir
if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')
checkpoint_dir = os.path.join(cur_dir, 'checkpoints')
TRAIN_FILE = 'train_images.tfrecords'
VALIDATION_FILE = 'val_images.tfrecords'

In [None]:
def print_tensor_shape(tensor, string):
    print(string, tensor.get_shape())

In [None]:
def read_and_decode(tfrecord_serialized):
    features={'img_raw': tf.FixedLenFeature([], tf.string),
             'label_raw': tf.FixedLenFeature([], tf.string)}
    parsed_features = tf.parse_single_example(tfrecord_serialized, features)
    
    image = tf.decode_raw(parsed_features['img_raw'], tf.int64)
    image.set_shape([65536])
    image_re = tf.reshape(image, [256, 256])
    image_re = tf.cast(image_re, tf.float32) * (1. / 1024)
    label = tf.decode_raw(parsed_features['label_raw'], tf.uint8)
    label.set_shape([65536])
    label_re = tf.reshape(label, [256, 256])
    #label_onehot = tf.one_hot(label, depth=n_class)
    
    #print(image_re.shape)
    #print(label_re.shape)   
    
    return image_re, label_re

In [None]:
def make_dataset(batch_size, tfrecord_path):    
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(read_and_decode, num_parallel_calls=8)
    dataset = dataset.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
    return dataset    

In [None]:
def network(images):
    print_tensor_shape(images, 'input images shape')
    images_re = tf.reshape(images, [-1, 256, 256, 1])
    print_tensor_shape(images, 'input images shape after reshaping')
    
    # number of units in the hidden layer
    hidden = 512
    
    with tf.name_scope('Hidden'):
        w_fc1 = tf.Variable(tf.truncated_normal([256*256, hidden], stddev=0.1, dtype=tf.float32), name='w_fc1')
        print_tensor_shape(w_fc1, 'w_fc1 shape')
        
        flatten_input = tf.reshape(images_re, [-1, 256*256])
        print_tensor_shape(flatten_input, 'flattened input shape')
        
        net = tf.matmul(flatten_input, w_fc1)
        print_tensor_shape(net, 'hidden layer shape')
    
    with tf.name_scope('Final'):
        w_fc2 = tf.Variable(tf.truncated_normal([hidden, 256*256*2], stddev=0.1, dtype=tf.float32, name='w_fc2'))
        print_tensor_shape(w_fc2, 'w_fc2 shape')
        
        net = tf.matmul(net, w_fc2)
        print_tensor_shape(net, 'final layer shape')
        
        net = tf.reshape(net, [-1, 256, 256, 2])
        print_tensor_shape(net, 'output shape')
        
    return net

In [None]:
def loss(logits, labels):
    labels = tf.to_int64(labels)
    print_tensor_shape(logits, 'logits shape before')
    print_tensor_shape(labels, 'labels shape before')
    
    logits_re = tf.reshape(logits, [-1, 2])
    labels_re = tf.reshape(labels, [-1])
    print_tensor_shape(logits_re, 'logits shape after')
    print_tensor_shape(labels_re, 'labels shale after')
    
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits, name='cross_entropy')
    print_tensor_shape(cross_entropy, 'cross_entropy shape')
    
    loss = tf.reduce_mean(cross_entropy, name='simple_cross_entropy_mean')
    
    return loss

In [None]:
def training(loss, learning_rate, decay_steps, decay_rate):
    global_step = tf.Variable(0, name='global_step', trainable=False)    
    lr = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=True)
    optimizer = tf.train.GradientDescentOptimizer(lr)
    train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op

In [None]:
def evaluation(logits, labels):
    with tf.name_scope('eval'):
        labels = tf.to_int64(labels)
        print_tensor_shape( logits, 'logits eval shape before')
        print_tensor_shape( labels, 'labels eval shape before')

        logits_re = tf.reshape( logits, [-1, 2] )
        labels_re = tf.reshape( labels, [-1] )
        print_tensor_shape( logits, 'logits eval shape after')
        print_tensor_shape( labels, 'labels eval shape after')

        correct = tf.nn.in_top_k(logits_re, labels_re, 1)
        print_tensor_shape( correct, 'correct shape')

        return tf.reduce_sum(tf.cast(correct, tf.int32))

In [None]:
train_tfr_path = os.path.join(data_dir, TRAIN_FILE)
val_tfr_path = os.path.join(data_dir, VALIDATION_FILE)

train_dataset = make_dataset(batch_size, train_tfr_path)
val_dataset = make_dataset(batch_size, val_tfr_path)

In [None]:
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
images, labels = iterator.get_next()

In [None]:
train_init = iterator.make_initializer(train_dataset)
val_init = iterator.make_initializer(val_dataset)

In [None]:
logits = network(images)

In [None]:
loss = loss(logits, labels)

In [None]:
train_op = training(loss, learning_rate, decay_step, decay_rate)

In [None]:
eval_op = evaluation(logits, labels)

In [None]:
init_op = tf.global_variables_initializer()

In [None]:
saver = tf.train.Saver()
checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')

In [None]:
sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth =True)))
sess.run(init_op)

In [None]:
print('Learning started. It takes sometime.')
for epoch in range(num_epochs):
    avg_loss = 0.
    avg_prec_train = 0.
    avg_prec_val = 0.
    n_iter_train = int(num_trainimgs / batch_size)
    n_iter_val = int(num_valimgs / batch_size)
    
    sess.run(train_init)    
    for i in range(n_iter_train):
        sess.run([images, labels])
        _, loss_val = sess.run([train_op, loss])
        prec = sess.run(eval_op)
        avg_loss += loss_val / n_iter_train
        avg_prec_train += prec / (n_iter_train * 256.0 * 256)
    
    sess.run(val_init)
    for i in range(n_iter_val):
        val_images, val_labels = sess.run([images, labels])
        val_logits, prec = sess.run([logits, eval_op])
        avg_prec_val += prec / (n_iter_val * 256.0 * 256)
        '''if (epoch == num_epochs-1):            
            val_images = np.reshape(val_images, (256, 256))
            val_labels = np.reshape(val_labels, (256, 256))
            val_logits = np.reshape(val_logits[:,:,:,1], (256, 256))
            
            plt.subplot(131)
            plt.imshow(val_images, cmap='gray', vmin=0, vmax=1)
            plt.subplot(132)
            plt.imshow(val_labels, cmap='gray', vmin=0, vmax=1)
            plt.subplot(133)
            plt.imshow(val_logits, cmap='gray', vmin=0, vmax=1)
            plt.show()'''
    
    print('OUTPUT: epoch {}: loss = {:.5f}, train_precision = {:.3f}, val_precision = {:.3f}'.format(
        epoch+1, avg_loss, avg_prec_train, avg_prec_val ))
    saver.save(sess, checkpoint_path)
print('Done Training for {} epochs'.format(num_epochs))