In [1]:
import tensorflow as tf
import numpy as np
from DL_app import tf_util

  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters


In [2]:
flags = tf.flags
flags.DEFINE_integer('num_classes', 10, 'the number of classes')
flags.DEFINE_float('dropout_rate', .95, 'the dropout rate of the CNN')
FLAGS = tf.flags.FLAGS

In [3]:
# CNN layer definition

def xavier_normal_dist_conv3d(shape):
    return tf.truncated_normal(shape, mean=0,
                               stddev=tf.sqrt(3. / (tf.reduce_prod(shape[:3]) * tf.reduce_sum(shape[3:]))))

def xavier_uniform_dist_conv3d(shape):
    with tf.variable_scope('xavier_glorot_initializer'):
        denominator = tf.cast((tf.reduce_prod(shape[:3]) * tf.reduce_sum(shape[3:])), tf.float32)
        lim = tf.sqrt(6. / denominator)
        return tf.random_uniform(shape, minval=-lim, maxval=lim)

def convolution_3d(layer_input, filter, strides, padding='SAME'):
    assert len(filter) == 5  # [filter_depth, filter_height, filter_width, in_channels, out_channels]
    assert len(strides) == 5  # must match input dimensions [batch, in_depth, in_height, in_width, in_channels]
    assert padding in ['VALID', 'SAME']

    w = tf.Variable(initial_value=xavier_uniform_dist_conv3d(shape=filter), name='weights')
    b = tf.Variable(tf.constant(1.0, shape=[filter[-1]]), name='biases')

    return tf.nn.conv3d(layer_input, w, strides, padding) + b


def deconvolution_3d(layer_input, filter, output_shape, strides, padding='SAME'):
    assert len(filter) == 5  # [depth, height, width, output_channels, in_channels]
    assert len(strides) == 5  # must match input dimensions [batch, depth, height, width, in_channels]
    assert padding in ['VALID', 'SAME']

    w = tf.Variable(initial_value=xavier_uniform_dist_conv3d(shape=filter), name='weights')
    b = tf.Variable(tf.constant(1.0, shape=[filter[-2]]), name='biases')

    return tf.nn.conv3d_transpose(layer_input, w, output_shape, strides, padding) + b



def CNN3d_layer(input, output_size, scope, activation_function=tf.tanh):
    with tf.variable_scope(scope):
        W = tf.get_variable('weight', shape=(1, input.shape[1].value, FLAGS.num_classes, input.shape[-2].value, 16),
                        dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=FLAGS.stddev))
        
        

In [4]:
def convolution_block(layer_input, n_channels, num_convolutions):
    x = layer_input
    for i in range(num_convolutions - 1):
        with tf.variable_scope('conv_' + str(i+1)):
            x = convolution_3d(x, [5, 5, 5, n_channels, n_channels], [1, 1, 1, 1, 1])
            x = prelu(x)
    x = convolution_3d(x, [5, 5, 5, n_channels, n_channels], [1, 1, 1, 1, 1])
    x = x + layer_input
    return prelu(x)

def convolution_block_2(layer_input, fine_grained_features, n_channels, num_convolutions):

    x = tf.concat((layer_input, fine_grained_features), axis=-1)

    with tf.variable_scope('conv_' + str(1)):
        x = convolution_3d(x, [5, 5, 5, n_channels * 2, n_channels], [1, 1, 1, 1, 1])

    for i in range(1, num_convolutions - 1):
        with tf.variable_scope('conv_' + str(i+1)):
            x = convolution_3d(x, [5, 5, 5, n_channels, n_channels], [1, 1, 1, 1, 1])
            x = prelu(x)

    x = convolution_3d(x, [5, 5, 5, n_channels, n_channels], [1, 1, 1, 1, 1])
    x = x + layer_input
    return prelu(x)


def down_convolution(layer_input, in_channels):
    with tf.variable_scope('down_convolution'):
        x = convolution_3d(layer_input, [2, 2, 2, in_channels, in_channels * 2], [1, 2, 2, 2, 1])
        return prelu(x)


def up_convolution(layer_input, output_shape, in_channels):
    with tf.variable_scope('up_convolution'):
        x = deconvolution_3d(layer_input, [2, 2, 2, in_channels // 2, in_channels], output_shape, [1, 2, 2, 2, 1])
        return prelu(x)
    
def v_net(tf_input, input_channels, output_channels=1, n_channels=16):

    with tf.variable_scope('contracting_path'):

        # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input channel
        if input_channels == 1:
            c0 = tf.tile(tf_input, [1, 1, 1, 1, n_channels])
        else:
            with tf.variable_scope('level_0'):
                c0 = prelu(convolution_3d(tf_input, [5, 5, 5, input_channels, n_channels], [1, 1, 1, 1, 1]))

        with tf.variable_scope('level_1'):
            c1 = convolution_block(c0, n_channels, 1)
            c12 = down_convolution(c1, n_channels)

        with tf.variable_scope('level_2'):
            c2 = convolution_block(c12, n_channels * 2, 2)
            c22 = down_convolution(c2, n_channels * 2)

        with tf.variable_scope('level_3'):
            c3 = convolution_block(c22, n_channels * 4, 3)
            c32 = down_convolution(c3, n_channels * 4)

        with tf.variable_scope('level_4'):
            c4 = convolution_block(c32, n_channels * 8, 3)
            c42 = down_convolution(c4, n_channels * 8)

        with tf.variable_scope('level_5'):
            c5 = convolution_block(c42, n_channels * 16, 3)
            c52 = up_convolution(c5, tf.shape(c4), n_channels * 16)

    with tf.variable_scope('expanding_path'):

        with tf.variable_scope('level_4'):
            e4 = concat_layer(c52, c4, 3)
            e42 = up_convolution(e4, tf.shape(c3), n_channels * 8)

        with tf.variable_scope('level_3'):
            e3 = convolution_block_2(e42, c3, n_channels * 4, 3)
            e32 = up_convolution(e3, tf.shape(c2), n_channels * 4)

        with tf.variable_scope('level_2'):
            e2 = convolution_block_2(e32, c2, n_channels * 2, 2)
            e22 = up_convolution(e2, tf.shape(c1), n_channels * 2)

        with tf.variable_scope('level_1'):
            e1 = convolution_block_2(e22, c1, n_channels, 1)
            with tf.variable_scope('output_layer'):
                logits = convolution_3d(e1, [1, 1, 1, n_channels, output_channels], [1, 1, 1, 1, 1])

    return logits

In [5]:
def placeholder_inputs(input_batch_shape, output_batch_shape):
    """Generate placeholder variables to represent the the input tensors.
    These placeholders are used as inputs by the rest of the model building
    code and will be fed from the downloaded ckpt in the .run() loop, below.
    Args:
        patch_shape: The patch_shape will be baked into both placeholders.
    Returns:
        images_placeholder: Images placeholder.
        labels_placeholder: Labels placeholder.
    """
    # Note that the shapes of the placeholders match the shapes of the full
    # image and label tensors, except the first dimension is now batch_size
    # rather than the full size of the train or test ckpt sets.
    # batch_size = -1

    images_placeholder = tf.placeholder(tf.float32, shape=input_batch_shape, name="images_placeholder")
    labels_placeholder = tf.placeholder(tf.float32, shape=output_batch_shape, name="labels_placeholder")   
   
    return images_placeholder, labels_placeholder



In [5]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import os
import math
import datetime

# tensorflow app flags
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('data_dir', './data',
    """Directory of stored data.""")
tf.app.flags.DEFINE_integer('batch_size',1,
    """Size of batch""")               
tf.app.flags.DEFINE_integer('patch_size',128,
    """Size of a data patch""")
tf.app.flags.DEFINE_integer('patch_layer',128,
    """Number of layers in data patch""")
tf.app.flags.DEFINE_integer('epochs',2000,
    """Number of epochs for training""")
tf.app.flags.DEFINE_string('log_dir', './tmp/log',
    """Directory where to write training and testing event logs """)
tf.app.flags.DEFINE_float('init_learning_rate',0.0001,
    """Initial learning rate""")
tf.app.flags.DEFINE_float('decay_factor',0.01,
    """Exponential decay learning rate factor""")
tf.app.flags.DEFINE_integer('decay_steps',100,
    """Number of epoch before applying one learning rate decay""")
tf.app.flags.DEFINE_integer('display_step',10,
    """Display and logging interval (train steps)""")
tf.app.flags.DEFINE_integer('save_interval',1,
    """Checkpoint save interval (epochs)""")
tf.app.flags.DEFINE_string('checkpoint_dir', './tmp/ckpt',
    """Directory where to write checkpoint""")
tf.app.flags.DEFINE_string('model_dir','./tmp/model',
    """Directory to save model""")
tf.app.flags.DEFINE_bool('restore_training',True,
    """Restore training from last checkpoint""")
tf.app.flags.DEFINE_float('drop_ratio',0.5,
    """Probability to drop a cropped area if the label is empty. All empty patches will be droped for 0 and accept all cropped patches if set to 1""")
tf.app.flags.DEFINE_integer('min_pixel',10,
    """Minimum non-zero pixels in the cropped label""")
tf.app.flags.DEFINE_integer('shuffle_buffer_size',5,
    """Number of elements used in shuffle buffer""")

In [7]:
def train():
    """Train the Vnet model"""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # patch_shape(batch_size, height, width, depth, channels)
        input_batch_shape = (FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer, 1) 
        output_batch_shape = (FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer, 1) 
        
        images_placeholder, labels_placeholder = placeholder_inputs(input_batch_shape,output_batch_shape)

        images_log = tf.cast(images_placeholder[:,:,:,int(FLAGS.patch_layer/2),:], dtype=tf.uint8)
        labels_log = tf.cast(tf.scalar_mul(255,labels_placeholder[:,:,:,int(FLAGS.patch_layer/2),:]), dtype=tf.uint8)

        tf.summary.image("image", images_log,max_outputs=FLAGS.batch_size)
        tf.summary.image("label", labels_log,max_outputs=FLAGS.batch_size)

        # Get images and labels
        train_data_dir = os.path.join(FLAGS.data_dir,'training')
        test_data_dir = os.path.join(FLAGS.data_dir,'testing')
        # support multiple image input, but here only use single channel, label file should be a single file with different classes
        image_filename = 'img.nii.gz'
        label_filename = 'label.nii.gz'

        # Force input pipepline to CPU:0 to avoid operations sometimes ended up at GPU and resulting a slow down
        with tf.device('/cpu:0'):
            # create transformations to image and labels
            trainTransforms = [
                NiftiDataset.Normalization(),
                NiftiDataset.Resample(0.4356),
                NiftiDataset.Padding((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer)),
                NiftiDataset.RandomCrop((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer),FLAGS.drop_ratio,FLAGS.min_pixel),
                NiftiDataset.RandomNoise()
                ]

            TrainDataset = NiftiDataset.NiftiDataset(
                data_dir=train_data_dir,
                image_filename=image_filename,
                label_filename=label_filename,
                transforms=trainTransforms,
                train=True
                )
            
            trainDataset = TrainDataset.get_dataset()
            trainDataset = trainDataset.shuffle(buffer_size=5)
            trainDataset = trainDataset.batch(FLAGS.batch_size)

            testTransforms = [
                NiftiDataset.Normalization(),
                NiftiDataset.Resample(0.4356),
                NiftiDataset.Padding((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer)),
                NiftiDataset.RandomCrop((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer),FLAGS.drop_ratio,FLAGS.min_pixel)
                ]

            TestDataset = NiftiDataset.NiftiDataset(
                data_dir=train_data_dir,
                image_filename=image_filename,
                label_filename=label_filename,
                transforms=testTransforms,
                train=True
            )

            testDataset = TestDataset.get_dataset()
            testDataset = testDataset.shuffle(buffer_size=5)
            testDataset = testDataset.batch(FLAGS.batch_size)
            
        train_iterator = trainDataset.make_initializable_iterator()
        next_element_train = train_iterator.get_next()

        test_iterator = testDataset.make_initializable_iterator()
        next_element_test = test_iterator.get_next()

        # Initialize the model
        with tf.name_scope("vnet"):
            logits = VNet.v_net(images_placeholder,input_channels = input_batch_shape[4], output_channels =2)

        logits_log_0 = tf.cast(logits[:,:,:,int(FLAGS.patch_layer/2):int(FLAGS.patch_layer/2)+1,0], dtype=tf.uint8)
        logits_log_1 = tf.cast(logits[:,:,:,int(FLAGS.patch_layer/2):int(FLAGS.patch_layer/2)+1,1], dtype=tf.uint8)
        tf.summary.image("logits_0", logits_log_0,max_outputs=FLAGS.batch_size)
        tf.summary.image("logits_1", logits_log_1,max_outputs=FLAGS.batch_size)

        # # Exponential decay learning rate
        # train_batches_per_epoch = math.ceil(TrainDataset.data_size/FLAGS.batch_size)
        # decay_steps = train_batches_per_epoch*FLAGS.decay_steps

        with tf.name_scope("learning_rate"):
            learning_rate = FLAGS.init_learning_rate
        #     learning_rate = tf.train.exponential_decay(FLAGS.init_learning_rate,
        #         global_step,
        #         decay_steps,
        #         FLAGS.decay_factor,
        #         staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        # softmax op for probability layer
        with tf.name_scope("softmax"):
            softmax_op = tf.nn.softmax(logits,name="softmax")
        softmax_log_0 = tf.cast(tf.scalar_mul(255,softmax_op[:,:,:,int(FLAGS.patch_layer/2):int(FLAGS.patch_layer/2)+1,0]), dtype=tf.uint8)
        softmax_log_1 = tf.cast(tf.scalar_mul(255,softmax_op[:,:,:,int(FLAGS.patch_layer/2):int(FLAGS.patch_layer/2)+1,1]), dtype=tf.uint8)

        tf.summary.image("softmax_0", softmax_log_0,max_outputs=FLAGS.batch_size)
        tf.summary.image("softmax_1", softmax_log_1,max_outputs=FLAGS.batch_size)

        # Op for calculating loss
        with tf.name_scope("cross_entropy"):
            loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits,
                labels=tf.squeeze(labels_placeholder, 
                squeeze_dims=[4])))
        tf.summary.scalar('loss',loss_op)

        # Argmax Op to generate label from logits
        with tf.name_scope("predicted_label"):
            pred = tf.argmax(logits, axis=4 , name="prediction")
        pred_log = tf.cast(tf.scalar_mul(255,pred[:,:,:,int(FLAGS.patch_layer/2):int(FLAGS.patch_layer/2)+1]), dtype=tf.uint8)
        tf.summary.image("pred", pred_log,max_outputs=FLAGS.batch_size)

        # Training Op
        with tf.name_scope("training"):
            # optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.init_learning_rate)
            train_op = optimizer.minimize(
                loss=loss_op,
                global_step=global_step)

        # Accuracy of model
        with tf.name_scope("accuracy"):
            correct_pred = tf.equal(tf.expand_dims(pred,-1), tf.cast(labels_placeholder,dtype=tf.int64))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        # Dice Similarity
        with tf.name_scope("dice"):
            sorensen = dice_coe(tf.expand_dims(pred,-1),tf.cast(labels_placeholder,dtype=tf.int64), loss_type='sorensen')
            jaccard = dice_coe(tf.expand_dims(pred,-1),tf.cast(labels_placeholder,dtype=tf.int64), loss_type='jaccard')
        tf.summary.scalar('sorensen', sorensen)
        tf.summary.scalar('jaccard', jaccard)

        # # epoch checkpoint manipulation
        start_epoch = tf.get_variable("start_epoch", shape=[1], initializer= tf.zeros_initializer,dtype=tf.int32)
        start_epoch_inc = start_epoch.assign(start_epoch+1)

        # # save model builder
        # builder = tf.saved_model.builder.SavedModelBuilder(FLAGS.checkpoint_dir)


        # saver
        summary_op = tf.summary.merge_all()
        checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir ,"checkpoint")
        print("Setting up Saver...")
        saver = tf.train.Saver()

        # training cycle
        with tf.Session() as sess:
            # Initialize all variables
            sess.run(tf.global_variables_initializer())
            print("{}: Start training...".format(datetime.datetime.now()))

            # summary writer for tensorboard
            train_summary_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            test_summary_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test', sess.graph)

            # restore from checkpoint
            if FLAGS.restore_training:
                # check if checkpoint exists
                if os.path.exists(checkpoint_prefix+"-latest"):
                    print("{}: Last checkpoint found at {}, loading...".format(datetime.datetime.now(),FLAGS.checkpoint_dir))
                    latest_checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir,latest_filename="checkpoint-latest")
                    saver.restore(sess, latest_checkpoint_path)
            
            print("{}: Last checkpoint epoch: {}".format(datetime.datetime.now(),start_epoch.eval()[0]))
            print("{}: Last checkpoint global step: {}".format(datetime.datetime.now(),tf.train.global_step(sess, global_step)))

            # loop over epochs
            for epoch in np.arange(start_epoch.eval(), FLAGS.epochs):
                # initialize iterator in each new epoch
                sess.run(train_iterator.initializer)
                sess.run(test_iterator.initializer)
                print("{}: Epoch {} starts".format(datetime.datetime.now(),epoch+1))

                # training phase
                while True:
                    try:
                        [image, label] = sess.run(next_element_train)

                        image = image[:,:,:,:,np.newaxis]
                        label = label[:,:,:,:,np.newaxis]
                        
                        train, summary = sess.run([train_op, summary_op], feed_dict={images_placeholder: image, labels_placeholder: label})
                        train_summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step))

                    except tf.errors.OutOfRangeError:
                        start_epoch_inc.op.run()
                        # print(start_epoch.eval())
                        # save the model at end of each epoch training
                        print("{}: Saving checkpoint of epoch {} at {}...".format(datetime.datetime.now(),epoch+1,FLAGS.checkpoint_dir))
                        saver.save(sess, checkpoint_prefix, 
                            global_step=tf.train.global_step(sess, global_step),
                            latest_filename="checkpoint-latest")
                        print("{}: Saving checkpoint succeed".format(datetime.datetime.now()))
                        break
                
                # testing phase
                print("{}: Training of epoch {} finishes, testing start".format(datetime.datetime.now(),epoch+1))
                while True:
                    try:
                        [image, label] = sess.run(next_element_test)

                        image = image[:,:,:,:,np.newaxis]
                        label = label[:,:,:,:,np.newaxis]
                        
                        loss, summary = sess.run([loss_op, summary_op], feed_dict={images_placeholder: image, labels_placeholder: label})
                        test_summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step))

                    except tf.errors.OutOfRangeError:
                        break

        # close tensorboard summary writer
        train_summary_writer.close()
        test_summary_writer.close()