In [None]:
import tensorflow as tf
import sys
import random
import time

import labelreg.helpers as helper
import labelreg.networks as network
import labelreg.utils as util
import labelreg.losses as loss

NB. The ini config file is not used in this tutorial for clarity.

In [None]:
# config = helper.ConfigParser(sys.argv, 'training')

Set up the data feeders that will be used in the training, by specifying the folders containing the images and labels.

In [None]:
reader_moving_image, reader_fixed_image, reader_moving_label, reader_fixed_label = helper.get_data_readers(
    './data/train/mr_images',
    './data/train/us_images',
    './data/train/mr_labels',
    './data/train/us_labels')

The placeholders for moving images, fixed images and their associated labels; <br/>
The on-the-fly data augmentation use random affine transformation, independently drawn for moving and fixed data; <br/>
The minibatch size is 4 here; <br/>
The parameters of the 12 degrees-of-freedom affine transformationwill be used for augmenting the input data.

In [None]:
ph_moving_image = tf.placeholder(tf.float32, [4]+reader_moving_image.data_shape+[1])
ph_fixed_image = tf.placeholder(tf.float32, [4]+reader_fixed_image.data_shape+[1])
ph_moving_affine = tf.placeholder(tf.float32, [4]+[1, 12])
ph_fixed_affine = tf.placeholder(tf.float32, [4]+[1, 12])
input_moving_image = util.warp_image_affine(ph_moving_image, ph_moving_affine)  # data augmentation
input_fixed_image = util.warp_image_affine(ph_fixed_image, ph_fixed_affine)  # data augmentation

Now load the instance of the "local" network, a single U-Net-like encoder-decoder network.

In [None]:
reg_net = network.build_network(network_type='local',
                                minibatch_size=4,
                                image_moving=input_moving_image,
                                image_fixed=input_fixed_image)

In [None]:
ph_moving_label = tf.placeholder(tf.float32, [4]+reader_moving_image.data_shape+[1])
ph_fixed_label = tf.placeholder(tf.float32, [4]+reader_fixed_image.data_shape+[1])
input_moving_label = util.warp_image_affine(ph_moving_label, ph_moving_affine)  # data augmentation
input_fixed_label = util.warp_image_affine(ph_fixed_label, ph_fixed_affine)  # data augmentation

Warp the moving label with the predicted ddf.

In [None]:
warped_moving_label = reg_net.warp_image(input_moving_label)  

Compute the loss: <br/>
* The label similarity between the warped moving labels and fixed labels; <br/>
* The weighted (here, 0.5) defomation regularisation on the predicted DDFs.

In [None]:
loss_similarity = tf.reduce_mean(loss.multi_scale_loss(input_fixed_label, warped_moving_label, 'dice', [0, 1, 2, 4, 8]))
loss_regulariser = tf.reduce_mean(loss.local_displacement_energy(reg_net.ddf, 'bending', 0.5))

Build the training op with a learning rate of 1e-04.

In [None]:
train_op = tf.train.AdamOptimizer(learning_rate=1e-04).minimize(loss_similarity+loss_regulariser)

These are the two utility nodes for information during the training iterations.

In [None]:
dice = util.compute_binary_dice(warped_moving_label, input_fixed_label)
dist = util.compute_centroid_distance(warped_moving_label, input_fixed_label)

Set up the training session.

In [None]:
num_minibatch = int(reader_moving_label.num_data/4)
train_indices = [i for i in range(reader_moving_label.num_data)]

saver = tf.train.Saver(max_to_keep=1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Start training! <br/>
Save model to ./data/model.ckpt

In [None]:
for step in range(10000):

    if step in range(0, 10000, num_minibatch):
        random.shuffle(train_indices)

    minibatch_idx = step % num_minibatch
    case_indices = train_indices[
                    minibatch_idx*4:(minibatch_idx+1)*4]
    label_indices = [random.randrange(reader_moving_label.num_labels[i]) for i in case_indices]

    trainFeed = {ph_moving_image: reader_moving_image.get_data(case_indices),
                 ph_fixed_image: reader_fixed_image.get_data(case_indices),
                 ph_moving_label: reader_moving_label.get_data(case_indices, label_indices),
                 ph_fixed_label: reader_fixed_label.get_data(case_indices, label_indices),
                 ph_moving_affine: helper.random_transform_generator(4),
                 ph_fixed_affine: helper.random_transform_generator(4)}

    sess.run(train_op, feed_dict=trainFeed)

    if step in range(0, 10000, 10):  # print info every 10 iterations
        current_time = time.asctime(time.gmtime())
        loss_similarity_train, loss_regulariser_train, dice_train, dist_train = sess.run(
            [loss_similarity,
             loss_regulariser,
             dice,
             dist],
            feed_dict=trainFeed)

        # print('----- Training -----')
        print('Step %d [%s]: Loss=%f (similarity=%f, regulariser=%f)' %
              (step,
               current_time,
               loss_similarity_train+loss_regulariser_train,
               1-loss_similarity_train,
               loss_regulariser_train))
        print('  Dice: %s' % dice_train)
        print('  Distance: %s' % dist_train)
        print('  Image-label indices: %s - %s' % (case_indices, label_indices))

    if step in range(0, 10000, 100):  # save the model every 100 iterations
        save_path = saver.save(sess, './data/model.ckpt', write_meta_graph=False)
        print("Model saved in: %s" % save_path)