In [1]:
import tensorflow as tf
import sys
import random
import time

import labelreg.helpers as helper
import labelreg.networks as network
import labelreg.utils as util
import labelreg.losses as loss

The ini file contains the configuration options

In [2]:
config = helper.ConfigParser(sys.argv, 'training')

Reading default config file in: C:\Users\yhu\Git\tutorials2019\weakly\config_demo.ini.

[Data]: dir_moving_image: ./data/train/mr_images
[Data]: dir_fixed_image: ./data/train/us_images
[Data]: dir_moving_label: ./data/train/mr_labels
[Data]: dir_fixed_label: ./data/train/us_labels
[Network]: network_type: local
[Loss]: similarity_type: dice
[Loss]: similarity_scales: [0, 1, 2, 4, 8]
[Loss]: regulariser_type: bending
[Loss]: regulariser_weight: 0.5
[Train]: total_iterations: 10000
[Train]: learning_rate: 1e-05
[Train]: minibatch_size: 4
[Train]: freq_info_print: 10
[Train]: freq_model_save: 50
[Train]: file_model_save: ./data/model.ckpt



Set up the data feeders that will be used in the training.

In [3]:
reader_moving_image, reader_fixed_image, reader_moving_label, reader_fixed_label = helper.get_data_readers(
    config['Data']['dir_moving_image'],
    config['Data']['dir_fixed_image'],
    config['Data']['dir_moving_label'],
    config['Data']['dir_fixed_label'])

The placeholders for moving images, fixed images and their associated labels
The on-the-fly data augmentation use random affine transformation, independently drawn for moving and fixed data.

In [4]:
ph_moving_image = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+reader_moving_image.data_shape+[1])
ph_fixed_image = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+reader_fixed_image.data_shape+[1])
ph_moving_affine = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+[1, 12])
ph_fixed_affine = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+[1, 12])
input_moving_image = util.warp_image_affine(ph_moving_image, ph_moving_affine)  # data augmentation
input_fixed_image = util.warp_image_affine(ph_fixed_image, ph_fixed_affine)  # data augmentation

Instructions for updating:
Use tf.cast instead.


Now load the instance of the network

In [5]:
reg_net = network.build_network(network_type=config['Network']['network_type'],
                                minibatch_size=config['Train']['minibatch_size'],
                                image_moving=input_moving_image,
                                image_fixed=input_fixed_image)


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Colocations handled automatically by placer.


In [6]:
ph_moving_label = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+reader_moving_image.data_shape+[1])
ph_fixed_label = tf.placeholder(tf.float32, [config['Train']['minibatch_size']]+reader_fixed_image.data_shape+[1])
input_moving_label = util.warp_image_affine(ph_moving_label, ph_moving_affine)  # data augmentation
input_fixed_label = util.warp_image_affine(ph_fixed_label, ph_fixed_affine)  # data augmentation

In [7]:
warped_moving_label = reg_net.warp_image(input_moving_label)  # warp the moving label with the predicted ddf

In [8]:
loss_similarity, loss_regulariser = loss.build_loss(similarity_type=config['Loss']['similarity_type'],
                                                    similarity_scales=config['Loss']['similarity_scales'],
                                                    regulariser_type=config['Loss']['regulariser_type'],
                                                    regulariser_weight=config['Loss']['regulariser_weight'],
                                                    label_moving=warped_moving_label,
                                                    label_fixed=input_fixed_label,
                                                    network_type=config['Network']['network_type'],
                                                    ddf=reg_net.ddf)

In [9]:
train_op = tf.train.AdamOptimizer(config['Train']['learning_rate']).minimize(loss_similarity+loss_regulariser)

Instructions for updating:
Use tf.cast instead.


These are the two utility nodes for information during the training iterations

In [10]:
dice = util.compute_binary_dice(warped_moving_label, input_fixed_label)
dist = util.compute_centroid_distance(warped_moving_label, input_fixed_label)

Set up the training session

In [11]:
num_minibatch = int(reader_moving_label.num_data/config['Train']['minibatch_size'])
train_indices = [i for i in range(reader_moving_label.num_data)]

saver = tf.train.Saver(max_to_keep=1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Start the training

In [None]:
for step in range(config['Train']['total_iterations']):

    if step in range(0, config['Train']['total_iterations'], num_minibatch):
        random.shuffle(train_indices)

    minibatch_idx = step % num_minibatch
    case_indices = train_indices[
                    minibatch_idx*config['Train']['minibatch_size']:(minibatch_idx+1)*config['Train']['minibatch_size']]
    label_indices = [random.randrange(reader_moving_label.num_labels[i]) for i in case_indices]

    trainFeed = {ph_moving_image: reader_moving_image.get_data(case_indices),
                 ph_fixed_image: reader_fixed_image.get_data(case_indices),
                 ph_moving_label: reader_moving_label.get_data(case_indices, label_indices),
                 ph_fixed_label: reader_fixed_label.get_data(case_indices, label_indices),
                 ph_moving_affine: helper.random_transform_generator(config['Train']['minibatch_size']),
                 ph_fixed_affine: helper.random_transform_generator(config['Train']['minibatch_size'])}

    sess.run(train_op, feed_dict=trainFeed)

    if step in range(0, config['Train']['total_iterations'], config['Train']['freq_info_print']):
        current_time = time.asctime(time.gmtime())
        loss_similarity_train, loss_regulariser_train, dice_train, dist_train = sess.run(
            [loss_similarity,
             loss_regulariser,
             dice,
             dist],
            feed_dict=trainFeed)

        # print('----- Training -----')
        print('Step %d [%s]: Loss=%f (similarity=%f, regulariser=%f)' %
              (step,
               current_time,
               loss_similarity_train+loss_regulariser_train,
               1-loss_similarity_train,
               loss_regulariser_train))
        print('  Dice: %s' % dice_train)
        print('  Distance: %s' % dist_train)
        print('  Image-label indices: %s - %s' % (case_indices, label_indices))

    if step in range(0, config['Train']['total_iterations'], config['Train']['freq_model_save']):
        save_path = saver.save(sess, config['Train']['file_model_save'], write_meta_graph=False)
        print("Model saved in: %s" % save_path)

Step 0 [Tue Oct 15 08:54:09 2019]: Loss=0.772278 (similarity=0.227760, regulariser=0.000038)
  Dice: [0.         0.46399537 0.         0.5488827 ]
  Distance: [ 9.623273   8.272579  12.3605175 10.924269 ]
  Image-label indices: [1, 4, 0, 6] - [2, 0, 1, 0]
Model saved in: ./data/model.ckpt
Step 10 [Tue Oct 15 09:01:45 2019]: Loss=0.589504 (similarity=0.412633, regulariser=0.002137)
  Dice: [0.516828   0.         0.73716086 0.64209586]
  Distance: [8.238906 9.745856 8.24459  8.262364]
  Image-label indices: [5, 1, 2, 6] - [0, 2, 0, 0]
