In [None]:
from glob import glob
import os 
import tensorflow as tf
from unet_model_tf import unet_inference
from sklearn.model_selection import train_test_split
import itertools
from skimage import io, transform
import numpy as np


In [None]:
IMAGE_FOLDER = './images'
MASK_FOLDER = './masks'
IMAGE_SIZE = 128 # For this example we resize the images, usually we would want to train on image patches
BATCH_SIZE = 6
OUTPUT_DIR = './output'
KERNEL_NUM = 12
DOUBLE_CONV = True
SEED = 42

NORMALIZER_PARAMS = {
    # Decay for the moving averages.
    'decay': 0.9,
    # epsilon to prevent 0s in variance.
    'epsilon': 0.001,
    # scale
    'scale': True,
    # center
    'center': True,
    # renorm:
    'renorm': False,
}
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
all_images = glob(os.path.join(IMAGE_FOLDER, '*.png'))
ids = [os.path.splitext(os.path.basename(x))[0] for x in all_images]
# show some ids
ids[:4]

In [None]:
def augment(image, mask):
    """Augments image and mask
    usually we want more functions here,
    for example: flips, zooms, rotations, color distortions
    """
    if np.random.binomial(1, 0.5):
        image = np.fliplr(image)
        mask = np.fliplr(mask)
    if np.random.binomial(1, 0.5):
        image = np.flipud(image)
        mask = np.flipud(mask)
    if np.random.binomial(1, 0.5):
        degree = np.random.randint(-180, 180)
        image = transform.rotate(image, degree)
        mask = transform.rotate(mask, degree)

    return image, mask

def get_data_generator(ids, image_folder, mask_folder, image_size, batch_size, is_training):
    # function that returns data generator
    def data_generator():
        # the generator required by tensorflow just loops forever
        for i in itertools.count(1):
            idx = ids[i % len(ids)]
            image_path = os.path.join(image_folder, idx + '.png')
            mask_path = os.path.join(mask_folder, idx+ '_mask.png')
            
            # read image and mask
            image = io.imread(image_path, as_gray=True) # read image as grayscale
            image = transform.resize(image, (image_size,image_size)).astype(np.float)
            mask = io.imread(mask_path, as_gray=True)
            mask = transform.resize(mask, (image_size,image_size)).astype(np.float)

            if is_training:
                image, mask = augment(image, mask)
                
            image -= 0.5
            # we binarize because of the resizing which interpolates
            mask[mask > 0.5] = 1
            mask[mask <= 0.5] = 0
            mask = np.expand_dims(mask, -1)
            image = np.expand_dims(image, -1)
            yield image, mask
    return data_generator



In [None]:
train_ids, val_ids = train_test_split(ids, random_state=SEED)
print('Trainids length: ', len(train_ids))
print('Eval length: ', len(val_ids))
train_generator = get_data_generator(train_ids, IMAGE_FOLDER, MASK_FOLDER, IMAGE_SIZE, BATCH_SIZE, True)
val_generator = get_data_generator(val_ids, IMAGE_FOLDER, MASK_FOLDER, IMAGE_SIZE, 3, False)

In [None]:
# metric for evaluation
# dice coefficient
def dice_tf(prediction, label):
    prediction = tf.layers.flatten(tf.round(prediction))
    label = tf.layers.flatten(tf.round(label))
    intersection = prediction*label
    dices = 2*tf.reduce_sum(intersection, -1)/(tf.reduce_sum(prediction, -1) + tf.reduce_sum(label, -1))
    return tf.reduce_mean(dices)

In [None]:


with tf.Graph().as_default():
    # Create tensorflow dataset
    ds = tf.data.Dataset().from_generator(train_generator, 
                                          (tf.float32, tf.float32), 
                                          ((IMAGE_SIZE, IMAGE_SIZE, 1), (IMAGE_SIZE, IMAGE_SIZE, 1)))
    ds_val = tf.data.Dataset().from_generator(val_generator, 
                                          (tf.float32, tf.float32), 
                                          ((IMAGE_SIZE, IMAGE_SIZE, 1), (IMAGE_SIZE, IMAGE_SIZE, 1)))

    # global step counts number of batches fed through network
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 


    # Load unet definition
    image, label = ds.prefetch(100).batch(BATCH_SIZE).make_one_shot_iterator().get_next() 
    model, _ = unet_inference(image, kernel_num=KERNEL_NUM,
                                      is_training=True, 
                                      reuse=None, style='default', normalizer_params=NORMALIZER_PARAMS,
                                     double_conv=DOUBLE_CONV, normalizer_fn=tf.contrib.layers.batch_norm)
    
    # Load unet definition for validation
    image_val, label_val = ds_val.prefetch(20).batch(BATCH_SIZE).make_one_shot_iterator().get_next() 
    model_val, _ = unet_inference(image_val, kernel_num=KERNEL_NUM,
                                      is_training=False, 
                                      reuse=True, style='default', normalizer_params=NORMALIZER_PARAMS,
                                     double_conv=DOUBLE_CONV, normalizer_fn=tf.contrib.layers.batch_norm)
    
    

    logits = model['logits']
    mask_pred = model['probs']
    logits_val = model_val['logits']
    mask_pred_val = model_val['probs']
    
    # define loss
    loss_op = tf.losses.sigmoid_cross_entropy(logits=logits, multi_class_labels=label)
    loss_op_val = tf.losses.sigmoid_cross_entropy(logits=logits_val, multi_class_labels=label_val)
    
    # metrics
    dice = dice_tf(mask_pred, label)
    dice_val =  dice_tf(mask_pred_val, label_val)
    
    # needed for batch norm, see tf documentation of batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(loss_op, global_step=global_step)
        
    # define some summaries
    with tf.name_scope('train'):
        tf.summary.image('prediction', mask_pred, 1)
        tf.summary.image('label', label, 1)
        tf.summary.image('image', image, 1)
        tf.summary.scalar('loss', loss_op)
        tf.summary.scalar('dice', dice)
    
    with tf.name_scope('val'):
        tf.summary.image('prediction_val', mask_pred_val, 1)
        tf.summary.image('label_val', label_val, 1)
        tf.summary.image('image_val', image_val, 1)
        tf.summary.scalar('loss_val', loss_op_val)
        tf.summary.scalar('dice_val', dice_val)
    
    sv = tf.train.Supervisor(logdir=OUTPUT_DIR,
                             global_step=global_step,
                             save_model_secs=300,
                             save_summaries_secs=20)
    
    with sv.managed_session() as sess:
        for step in range(10000):
            _, im = sess.run([train_op, mask_pred])
            if sv.should_stop():
                break
                                



# Visualization
use tensorboard to track training progess and visualize computation graph (expand unet block)

```tensorboard --logdir=output```