In [None]:
#TODO:
# Data pipeline
# Training graph
# Loss function
# Kmeans - structure properly 
# Mask generation
# Metric

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import tensorflow as tf
import sys
import math

PASCAL VOC colour map function which returns a colour map that enables us to match colours on the segmentation mask images with class labels.

In [None]:
def color_map(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap

A Config object that will store parameters that need to be passed around.

A function to read an image, possibly downsample it and resize to specified dimensions with crop or pad (so as not to affect aspect ratio).
TODO: add augmentations

A function to read a mask. PASCAL VOC masks include a border around the segmentations for each object and pixels of predicted masks falling within the border might be excluded from calculations of losses and metrics. We

We will dynamically obtain bounding boxes from the masks for now. Later we can run this once and save the results. It is necessary to match each instance to its class. The class colours are known beforehand. The documentation claims that the instance colours can be used to identify the class but does not specify how.

In [None]:
def get_class_mask(mask_rgb):
    #Get one-hot masks for each class (excluding background and border)
    masks = tf.reduce_all(tf.equal(mask_rgb[..., None, :], config.clrs_fgd[None, None]), 
                axis=-1) #(h, w, 1, 3)==(1, 1, n_classes, 3) -> (h, w, n_classes, 3) -> (h, w, n_classes)
    return tf.to_float(masks)

def rgb2val(img):
    img = tf.reduce_sum(img*tf.pow(256., [2,1,0])[None, None], axis=-1)
    return img

def process_bboxes(bbox, config):
    # bbox: batch_size x n_inst_max x 4
    centre = (bbox[...,2:]+bbox[...,:2])/2
    dims = (bbox[...,2:]-bbox[...,:2])
    bbox = tf.concat([coords, dims], axis=-1)
    return bbox

def get_bboxes_from_masks(masks):
    # mask: h x w x n_inst
    is_fgd = tf.greater(masks, 0)
    
    shape = tf.to_int64(tf.shape(masks))
    
    height = shape[0]
    width = shape[1]
    
    horz_any = tf.to_float(tf.reduce_any(is_fgd, axis=0)) # h x n_inst
    vert_any = tf.to_float(tf.reduce_any(is_fgd, axis=1)) # w x n_inst
    
    # Will find first non-zero column or row for each mask which will be 
    # coordinates top left corner of bbox
    x1 = tf.argmax(horz_any, axis=0) # n_inst
    y1 = tf.argmax(vert_any, axis=0) # n_inst
    
    # Reverse to find last non-zero column or row and subtract from the 
    # the corresponding mask dimensions - note that these are outside the bbox 
    x2 = width - tf.argmax(horz_any[::-1], axis=0) # n_inst
    y2 = height - tf.argmax(vert_any[::-1], axis=0) # n_inst
    
    y2 = tf.where(tf.reduce_any(tf.greater(horz_any, 0), axis=0), y2, tf.zeros_like(y2)) # n_inst
    x2 = tf.where(tf.reduce_any(tf.greater(vert_any, 0), axis=0), x2, tf.zeros_like(x2)) # n_inst
    
    boxes = tf.stack([y1, x1, y2, x2], axis=-1) # n_inst x 4
    return tf.to_int32(boxes)
    
              
def get_masks_and_bboxes(class_mask_rgb, 
                         inst_mask_rgb, 
                         config):
    #One-hot foreground class mask
    class_mask_one_hot = get_class_mask(class_mask_rgb) #(h, w, n_classes)
    
    #Convert rgb to scalar values in order to identify unique instances,
    #excluding background and border 
    inst_mask = rgb2val(inst_mask_rgb) # (h, w)
    inst_vals_all = tf.unique(tf.reshape(inst_mask, [-1]))[0] # (n_inst,)
    #Sort the unique values, exclude background and border (0, 255) so always first and last
    #since these two are always present
    inst_vals = tf.nn.top_k(inst_vals_all, k=tf.size(inst_vals_all))[0][1:-1]
    #Select upto max_inst instances
    inst_vals = tf.random_shuffle(inst_vals)[:config.max_inst] # n_inst_max = max(n_inst, max_inst)
    
    #One-hot instance mask which is then used to match each instance with its class
    inst_mask_one_hot = tf.to_float(tf.equal(inst_mask[..., None], 
                                inst_vals[None, None]))# (h, w, 1)==(1,1,n_inst_max) -> (h, w, n_inst_max)
    
    if config.detection_mode == 'bbox':
        paddings = [(0, config.max_inst - tf.shape(inst_vals)[0])]
    
        #Class mask for each instance
        inst_class_mask = tf.argmax(class_mask_one_hot[...,None,:]*inst_mask_one_hot[...,None], axis=-1)
                        # (h, w, 1, n_classes)*(h, w, n_inst_max, 1) -> (h, w, n_inst_max, n_classes) -> (h, w, n_inst_max)

        #Get the target labels by taking the maximum of the class masks for each instance
        #(which should have single non-zero value for each instance which is returned)
        class_labels = tf.reduce_max(inst_class_mask, axis=(0, 1)) + 1 # (n_inst_max,) 
        class_labels = tf.pad(class_labels, paddings=paddings)

        bboxes = get_bboxes_from_masks(inst_mask_one_hot)
        bboxes = tf.pad(bboxes, paddings=paddings + [(0,0)])
        
        return tf.pad(inst_mask_one_hot, paddings=[(0,0), (0,0)] + paddings), class_labels, bboxes
    
    
    elif config.detection_mode == 'semi_conv':
        # Add a background channels as the first channel and then take argmax 
        class_mask_sparse = tf.argmax(tf.pad(class_mask_one_hot, [(0, 0), (0, 0), (1, 0)]), axis=-1)
        inst_mask_sparse = tf.argmax(tf.pad(inst_mask_one_hot, [(0, 0), (0, 0), (1, 0)]), axis=-1)
                      
        #Where we restrict the number of instances, want to ensure that other instance are also masked
        #from the class mask
        class_mask_sparse = class_mask_sparse*tf.to_int64(tf.greater(inst_mask_sparse, 0))
        
        counts = tf.shape(inst_vals)[0] - 1 #should be in [0, max_inst)
        return class_mask_sparse, inst_mask_sparse, counts
    
    
    
        
    

In [None]:
def read_img(img_file, img_type, config):
    downsample = config.downsample
    height = config.height
    width = config.width
    assert img_type in ['jpeg', 'png']
    img_string = tf.read_file(img_file)
    if img_type == 'jpeg':
        img = tf.image.decode_jpeg(img_string)
    if img_type == 'png':
        img = tf.image.decode_png(img_string)
    img = img[::downsample, ::downsample]
    if height is not None and width is not None:
        img = tf.image.resize_image_with_crop_or_pad(img, height, width)
    img = tf.to_float(img)
    img.set_shape([256, 256, 3])
    return img

def read_masks(class_mask_file, inst_mask_file, config):
    class_mask_rgb = read_img(class_mask_file, config.mask_type, config)
    inst_mask_rgb = read_img(inst_mask_file, config.mask_type, config)
    return class_mask_rgb, inst_mask_rgb

def read_data(img_file, class_mask_file, inst_mask_file, config):
    img = read_img(img_file, config.img_type, config)
    class_mask_rgb, inst_mask_rgb = read_masks(class_mask_file, inst_mask_file, config)
    
    #Normalize image to lie in [0,1]
    img = img/255 
    
    #TODO: concatenate img and mask and apply augmentation, then split
    
    masks = get_masks_and_bboxes(class_mask_rgb, inst_mask_rgb, config)
    return (img,) + masks 

In [None]:
class Config(object):
    def __init__(self):
        self.height = 256
        self.width = 256
        self.downsample = 2
        self.filespath = '../input/voctrainval_11-may-2012/VOCdevkit/VOC2012/ImageSets/Segmentation'
        self.train = open(os.path.join(self.filespath, 'train.txt')).read().split('\n')[:-1]
        self.val = open(os.path.join(self.filespath, 'val.txt')).read().split('\n')[:-1]
        mid = len(self.val)//2
        self.train =self.train + self.val[mid:]
        self.val = self.val[:mid]
        assert(len(set(self.train).intersection(set(self.val))) == 0)
        self.imgs_path = '../input/voctrainval_11-may-2012/VOCdevkit/VOC2012/JPEGImages/{}.jpg'
        self.class_segs_path = '../input/voctrainval_11-may-2012/VOCdevkit/VOC2012/SegmentationClass/{}.png'
        self.inst_segs_path = '../input/voctrainval_11-may-2012/VOCdevkit/VOC2012/SegmentationObject/{}.png'
        self.img_files, self.class_mask_files, self.inst_mask_files = \
            [[path.format(f) for f in self.train] for path in [self.imgs_path, self.class_segs_path, self.inst_segs_path]]
        self.valid_img_files, self.valid_class_mask_files, self.valid_inst_mask_files = \
            [[path.format(f) for f in self.val] for path in [self.imgs_path, self.class_segs_path, self.inst_segs_path]]
        self.n_epochs = 20
        self.batch_size = 4
        self.num_ex = len(self.img_files)
        self.num_valid = len(self.valid_img_files)
        self.batches_per_epoch = math.ceil(self.num_ex/self.batch_size)
        self.num_valid_batches = math.ceil(self.num_valid/self.batch_size)
        
        
        
        self.img_type = 'jpeg'
        self.mask_type = 'png'
        
        self.cmap = color_map()
        self.labels = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 
                      'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 
                      'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 
                      'sofa', 'train', 'tvmonitor', 'void']
        
        self.upsample = 'simple'
        
        self.n_classes = len(self.labels) - 1 #exclude void
        
        self.clrs = np.concatenate([self.cmap[:len(self.labels)-1], self.cmap[-1:]])
        self.clrs_fgd = self.clrs[1:-1]
        
        self.max_inst = 20 #56
        self.detection_mode = 'semi_conv'
        
        self.optimizer = 'AdamOptimizer'
        self.optim_kwargs = {'learning_rate':1e-3}
        
        self.n_features = 8
        
        self.n_units = [1024, 1024]
        self.max_ramp_iter = 80

In [None]:
# def kmeans_fn(x, k, tol=1e-6, max_iters=10):
#     # inputs: H*W x F
#     assert (tol is not None) or (max_iters is not None)
#     #vectors = tf.reshape(inputs, [-1, tf.shape(inputs)[-1]]) #[H*W, F]
#     centres = tf.random.shuffle(x)[:k]#[k, F]
#     iters = tf.constant(0)
    
#     def cluster_fn(clusters, mu, mu_prev, iters):
#         sq_diff = (mu[:, None] - x[None])**2 # (k x 1 x F - 1 x H*W x F)**2 -> k x H*W x F
#         dist = tf.reduce_sum(sq_diff, axis=-1) # k x H*W
#         clusters = tf.argmin(dist, axis=0) # H*W 
#         clusters_one_hot = tf.to_float(tf.one_hot(clusters, depth=k, axis=0)) # k x H*W
#         # k x H*W x 1 * 1 x H*W x F = k x H*W x F - > k x F
#         # k x H*W -> k -> k x 1
#         # k x F / k x 1 -> k x F
#         mu_new = tf.reduce_sum(clusters_one_hot[..., None]*x[None], axis=1)/tf.reduce_sum(clusters_one_hot, axis=1)[:, None]
#         return clusters, mu_new, mu, tf.add(iters, 1)
    
#     def cond(cl, m, mp, i):
#         not_converged = tf.greater(tf.reduce_mean((m-mp)**2), tol) if tol is not None else tf.constant(True)
#         not_converged = tf.logical_or(tf.equal(i, 0), not_converged)
#         below_max_iters = tf.less(i, max_iters) if max_iters is not None else tf.constant(True)
#         return tf.logical_and(not_converged, below_max_iters)
    
        
        
#     clusters, mu, mu_prev, n_iters  = tf.while_loop(cond=cond, body=cluster_fn, 
#                              loop_vars=[tf.zeros(tf.shape(x)[0], dtype=tf.int64), 
#                                        centres,
#                                        centres, iters])
    
#     inertia = tf.reduce_sum((x - tf.gather(mu, clusters)) ** 2)
    
#     return clusters, mu, mu_prev, inertia, n_iters
    
# def kmeans_clustering(inputs, k, tol=1e-6, max_iters=10, n_init=10):
#     clusters, mu, mu_prev, inertia, iters = tf.map_fn(elems=tf.range(n_init), 
#                                      fn=lambda _: kmeans_fn(inputs, k, tol, max_iters),
#                                      dtype=(tf.int64, tf.float32, tf.float32, tf.float32, tf.int32))
#     best = tf.argmin(inertia)
#     return clusters[best], mu[best], mu_prev[best], iters[best]

# def kmeans_layer(features, class_mask, k, tol=1e-6, max_iters=10, n_init=10):
#     where = tf.where(tf.greater(class_mask, 0))
#     vectors = tf.gather(features, where)
#     clusters, mu, _, _ = kmeans_clustering(vectors, km tol, max_iters, n_init)
#     instance_map = tf.scatter(clusters, where, tf.shape(class_mask))
#     return instance_map, mu

In [None]:
# #TODO: get rid on pp_in_layers in appropriate

# from tf.keras import models, layers
# # Build U-Net model
def upsample_conv(inputs, **conv_kwargs):
    return tf.layers.conv2d_transpose(inputs, **conv_kwargs)

def upsample_simple(inputs, **conv_kwargs):
    return tf.image.resize_nearest_neighbor(inputs, 
                                            size=tf.multiply(tf.shape(inputs)[1:-1], conv_kwargs['strides']), 
                                            align_corners=True)

def conv_relu_bn(inputs, training, **conv_kwargs):
    conv = tf.layers.conv2d(inputs, activation=tf.nn.relu, **conv_kwargs)
    bn = tf.layers.batch_normalization(conv)
    return bn

def dense_relu_bn(inputs, training, **dense_kwargs):
    dense = tf.layers.dense(inputs, activation=tf.nn.relu, **dense_kwargs)
    bn = tf.layers.batch_normalization(dense)
    return bn

upsample_fn = {'deconv':upsample_conv, 'simple':upsample_simple}
    
# # input_img = layers.Input(t_x.shape[1:], name = 'RGB_Input')
# # pp_in_layer = input_img
# # if NET_SCALING is not None:
# #     pp_in_layer = layers.AvgPool2D(NET_SCALING)(pp_in_layer)
    
# # pp_in_layer = layers.GaussianNoise(GAUSSIAN_NOISE)(pp_in_layer)
# # pp_in_layer = layers.BatchNormalization()(pp_in_layer)

def unet(inputs, training, n_classes, upsample):
    conv1 = conv_relu_bn(inputs, training, filters=8, kernel_size=(3, 3), padding='same')
    conv1 = conv_relu_bn(conv1, training, filters=8, kernel_size=(3, 3), padding='same')
    pool1 = tf.layers.max_pooling2d(conv1, pool_size=2, strides=2)

    conv2 = conv_relu_bn(pool1, training, filters=16, kernel_size=(3, 3), padding='same')
    conv2 = conv_relu_bn(conv2, training, filters=16, kernel_size=(3, 3), padding='same')
    pool2 = tf.layers.max_pooling2d(conv2, pool_size=2, strides=2)

    conv3 = conv_relu_bn(pool2, training, filters=32, kernel_size=(3, 3), padding='same')
    conv3 = conv_relu_bn(conv3, training, filters=32, kernel_size=(3, 3), padding='same')
    pool3 = tf.layers.max_pooling2d(conv3, pool_size=2, strides=2)

    conv4 = conv_relu_bn(pool3, training, filters=64, kernel_size=(3, 3), padding='same')
    conv4 = conv_relu_bn(conv4, training, filters=64, kernel_size=(3, 3), padding='same')
    pool4 = tf.layers.max_pooling2d(conv4, pool_size=2, strides=2)

    conv5 = conv_relu_bn(pool4, training, filters=128, kernel_size=(3, 3), padding='same')
    conv5 = conv_relu_bn(conv5, training, filters=128, kernel_size=(3, 3), padding='same')

    up6 = upsample(conv5, filters=64, kernel_size=(2, 2), strides=(2, 2), padding='same')
    up6 = tf.concat([up6, conv4], axis=-1)
    conv6 = conv_relu_bn(up6, training, filters=64, kernel_size=(3, 3), padding='same')
    conv6 = conv_relu_bn(conv6, training, filters=64, kernel_size=(3, 3), padding='same')

    up7 = upsample(conv6, filters=32, kernel_size=(2, 2), strides=(2, 2), padding='same')
    up7 = tf.concat([up7, conv3], axis=-1)
    conv7 = conv_relu_bn(up7, training, filters=32, kernel_size=(3, 3), padding='same')
    conv7 = conv_relu_bn(conv7, training, filters=32, kernel_size=(3, 3), padding='same')

    up8 = upsample(conv7, filters=16, kernel_size=(2, 2), strides=(2, 2), padding='same') 
    up8 = tf.concat([up8, conv2], axis=-1)
    conv8 = conv_relu_bn(up8, training, filters=16, kernel_size=(3, 3), padding='same')
    conv8 = conv_relu_bn(conv8, training, filters=16, kernel_size=(3, 3), padding='same')

    up9 = upsample(conv8, filters=8, kernel_size=(2, 2), strides=(2, 2), padding='same')
    up9 = tf.concat([up9, conv1], axis=-1)
    conv9 = conv_relu_bn(up9, training, filters=8, kernel_size=(3, 3), padding='same')
    conv9 = conv_relu_bn(conv9, training, filters=8, kernel_size=(3, 3), padding='same')

    class_map = tf.layers.conv2d(conv9, filters=n_classes, kernel_size=(1, 1))
    
    return conv9, class_map

def count_net(inputs, training, n_units, n_counts):
    #conv = tf.layers.conv2d(inputs, kernel_size=(3, 3), filters=64)
    flat = tf.layers.flatten(inputs)
    dense1 = dense_relu_bn(flat, training, units=n_units[0])
    dense2 = dense_relu_bn(dense1, training, units=n_units[1])
    counts = dense_relu_bn(dense2, training, units=n_counts)
    
    return counts

def add_position_info_2d(inputs):
    #inputs: (n, h, w, f)
    shape = tf.shape(inputs)[1:-1]
    delta = tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]), indexing='ij') #[(h, w), (h, w)]
    delta = tf.to_float(tf.stack(delta, axis=-1)/shape) #(h, w, 2)
    inputs = tf.concat([inputs[...,:2] + delta[None], inputs[...,2:]], axis=-1)
    return inputs

def semi_conv_layer(inputs, training, n_features):
    semi_conv = tf.layers.conv2d(inputs, filters=n_features, kernel_size=(1, 1))
    semi_conv = add_position_info_2d(semi_conv)
    return semi_conv

def model(inputs, training, config):
    feature_maps, class_map = unet(inputs, training, config.n_classes, upsample_fn[config.upsample])
    semi_conv_inputs = tf.concat([feature_maps, class_map], axis=-1)
    semi_conv = semi_conv_layer(semi_conv_inputs, training, config.n_features)
    counts_pred = count_net(feature_maps, training, config.n_units, config.max_inst) 
    return class_map, semi_conv, counts_pred

def semi_conv_loss(y_true, y_pred):
    """
    Implements equation 5 from https://arxiv.org/abs/1807.10712 for a mini-batch of images.
    
    Args:
        y_true (Tensor): sparse label tensor of shape (batch_size x height x width), 
                         with a separate number for each instance present in the image.
                         Requires that the values are consecutive integers starting from 0.
        y_pred (Tensor): sparse prediction tensor of shape (batch_size x height x width x channels)
        
    Returns:
        semi-convolutional loss 
    """
        #find the maximum number of instances in any image in this batch
    n_inst_max = tf.to_int32(tf.reduce_max(y_true))
    
    #batch_size x height x width -> batch_size x n_inst_max x height x width
    y_true_one_hot = tf.one_hot(y_true, depth=n_inst_max, axis=1)
    
    #results in tensor of shape batch_size x n_inst_max x height x width x channels
    y_pred_dense = y_true_one_hot[...,tf.newaxis]*y_pred[:,tf.newaxis]
    
    #reshape to (batch_size*n_inst_max) x height x width x channels
    y_pred_dense = tf.reshape(y_pred_dense, 
                              tf.concat([[-1], tf.shape(y_pred_dense)[2:]], axis=0))
    
    #batch_size x n_inst_max x height x width -> (batch_size*n_inst_max) x height x width
    y_true_one_hot = tf.reshape(y_true_one_hot,
                                tf.concat([[-1], tf.shape(y_true_one_hot)[2:]], axis=0))    
        #find number of pixels in each instance
    #(batch_size*n_inst_max) x height x width -> (batch_size*n_inst_max)
    n_inst_pixels = tf.reduce_sum(y_true_one_hot, axis=[1, 2])
    has_inst_mask = tf.greater(n_inst_pixels, 0)
    
    #num_unpadded x height x width x channels
    y_pred_dense = tf.boolean_mask(y_pred_dense, has_inst_mask)
    
    #num_unpadded x height x width
    y_true_one_hot = tf.boolean_mask(y_true_one_hot, has_inst_mask)
    
    #num_unpadded
    n_inst_pixels = tf.boolean_mask(n_inst_pixels, has_inst_mask)
    
    #num_unpadded x height x width x channels -> num_unpadded x 1 x 1 x channels
    embeds_sum = tf.reduce_sum(y_pred_dense, axis=[1,2], keep_dims=True)
    #num_unpadded x height x width x channels -> num_unpadded x height x width
    dist = tf.norm(y_pred_dense - embeds_sum/n_inst_pixels[:, None, None, None], axis=-1)
    dist_masked = dist*y_true_one_hot
    #num_unpadded x height x width -> num_unpadded
    dist_avg = tf.reduce_sum(dist_masked, axis=[1,2])/n_inst_pixels
    
    #dist_avg = tf.Print(dist_avg, data=[dist_avg])
    
    #num_unpadded -> 1
    loss = tf.reduce_sum(dist_avg)
    
    return loss
    
#     #(batch_size*n_inst_max) x height x width x channels -> (batch_size*n_inst_max) x channels
#     embeds_sum = tf.reduce_sum(y_pred_dense, axis=[1,2], keep_dims=True)
#     #(batch_size*n_inst_max) x height x width x channels -> (batch_size*n_inst_max) x height x width
#     #dist = tf.norm(y_pred_dense*n_inst_pixels[:,None,None,None] - embeds_sum, axis=-1)
    
    
#     #keep only the distances for pixels that belong to the instance
#     dist_masked = dist*y_true_one_hot
    
#     #sum the losses for each instance
#     #(batch_size*n_inst_max) x height x width -> (batch_size*n_inst_max)
#     dist_sum = tf.reduce_sum(dist_masked, axis=[1,2])
#     has_inst_mask = tf.greater(n_inst_pixels, 0)
    
#     #select only the elements of dist that correspond to an instance 
#     losses = (tf.boolean_mask(dist_sum, has_inst_mask)/
#                     tf.boolean_mask(n_inst_pixels, has_inst_mask)**2)
    
#     loss = tf.reduce_sum(losses)
    
#     return loss, dist

def get_losses(class_pred, class_true, inst_maps, inst_true, counts_pred, counts_true, itr=None):
    with tf.variable_scope('class_loss'):
        class_loss = tf.losses.sparse_softmax_cross_entropy(logits=class_pred, labels=class_true)
    with tf.variable_scope('inst_loss'):
        inst_loss = semi_conv_loss(y_pred=inst_maps, y_true=inst_true)
    with tf.variable_scope('count_loss'):
        count_loss = tf.losses.sparse_softmax_cross_entropy(logits=counts_pred, labels=counts_true)
        tf.add_to_collection('counts_pred', counts_pred)
    if itr is not None:
        T = tf.minimum(itr/float(config.max_ramp_iter), 1.)
        count_weight = tf.exp(-5*(1-T)**2)
        count_loss = count_weight*count_loss
    losses = (class_loss, inst_loss, count_loss)
    total_loss = tf.reduce_sum(losses)
    return total_loss, losses

def get_train_op(loss, config):
    optimizer = getattr(tf.train, config.optimizer)(**config.optim_kwargs)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.minimize(loss)
    return train_step

In [None]:
def create_data_pipeline(config, mode='train'):
    file_lists = [config.img_files, config.class_mask_files, config.inst_mask_files]
    file_datasets = tuple(map(tf.data.Dataset.from_tensor_slices, file_lists))
    dataset = tf.data.Dataset.zip(file_datasets)
    if mode=='train':
        dataset = dataset.shuffle(len(config.img_files))
    
    dataset = dataset.map(lambda x, y, z: read_data(x, y, z, config))
    dataset = dataset.batch(config.batch_size).repeat(config.n_epochs)
    tf.add_to_collection('ITERATOR_{}'.format(mode.upper()),
                         dataset.make_initializable_iterator())

def preproc(config):
    dataset_handle = tf.placeholder(tf.string, shape=[])

    for mode in ['train', 'valid']:
        create_data_pipeline(config, mode)
    
    train_itr = tf.get_collection('ITERATOR_TRAIN')[0]
    iterator = tf.data.Iterator.from_string_handle(dataset_handle, 
                                                  output_types=train_itr.output_types,
                                                  output_shapes=train_itr.output_shapes)
    return iterator.get_next(), dataset_handle

In [None]:
def init_iterators(sess):
    for mode in ['train','valid']:
        itr = tf.get_collection('ITERATOR_{}'.format(mode.upper()))[0]
        sess.run(itr.initializer)
        
def get_itr_handles(sess):
    handles = []
    for mode in ['train','valid']:
        itr = tf.get_collection('ITERATOR_{}'.format(mode.upper()))[0]
        handles.append(sess.run(itr.string_handle()))
    return handles

In [None]:
def add_metric_avg_op(metric, step, name):
    avg = tf.Variable(initial_value=0.)
    avg = tf.assign(avg, (avg*(step-1) + metric)/step, name=name)
    avg_reset = tf.assign(avg, 0)
    return avg, avg_reset

def get_ema_var(graph, name):
    return graph.get_tensor_by_name('{}/ExponentialMovingAverage:0'.format(name))

tf.reset_default_graph()
#config = train_utils.Config()
config = Config()
(images, class_masks, instance_masks, counts), dataset_handle = preproc(config)

#model = getattr(import_module(config.model_module), config.model_function)
training = tf.placeholder(shape=[], dtype=tf.bool, name='training')
graph = tf.get_default_graph()
class_pred, inst_maps, counts_pred = model(images, training, config)
#itr = tf.placeholder(dtype=tf.float32, shape=[], name='itr')
total_loss, losses = get_losses(class_pred, class_masks, inst_maps, instance_masks, counts_pred, counts)#, itr)

step = tf.assign_add(tf.Variable(initial_value=0.), 1)
step_reset = tf.assign(step, 0)
loss_avg, loss_avg_reset = add_metric_avg_op(total_loss, step, 'loss_avg')
cl_loss_avg, cl_loss_avg_reset = add_metric_avg_op(losses[0], step, 'cl_loss_avg')
inst_loss_avg, inst_loss_avg_reset = add_metric_avg_op(losses[1], step, 'inst_loss_avg')
ct_loss_avg, ct_loss_avg_reset = add_metric_avg_op(losses[2], step, 'ct_loss_avg')

train_op = get_train_op(total_loss, config)
# ema = tf.train.ExponentialMovingAverage(config.ema_decay)
# ema_op = ema.apply([loss, dice])
# tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

# loss_ema, dice_ema = [get_ema_var(graph, metric) for metric in ['loss', 'dice']]

# Note that we display the running average metrics but tensorboard gets the raw ones
# train_summary_op = add_scalar_summaries({'loss': loss, 'dice': dice})
# valid_summary_op = add_scalar_summaries({'loss': loss_avg, 'dice': dice_avg}, postfix='val')
# img_summary_op = add_img_summary(images, masks, probs)
                                  
# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# with tf.control_dependencies(update_ops):
#     optim = getattr(tf.train, config.optimizer)(config.lr)
#     train_step = optim.minimize(loss)

In [None]:
steps_per_epoch = np.ceil(len(config.train)/config.batch_size).astype('int')
max_iters = steps_per_epoch*config.n_epochs
num_valid_steps = np.ceil(len(config.val)/config.batch_size).astype('int')

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    init_iterators(sess)
    
    train_handle, valid_handle = get_itr_handles(sess)
    
    for it in range(1, max_iters+1):
        #tt = np.float32(it//100 + 1.)
        _, tl, ls, count_val = sess.run([train_op, loss_avg, [cl_loss_avg, inst_loss_avg, ct_loss_avg], counts], {dataset_handle: train_handle, 
                                                                                               training:True})#, itr:tt})
        sys.stdout.write('\riter: {}, total_loss: {:.4f}, cl_loss: {:.4f}, inst_loss: {:.4f}, ct_loss: {:.4f}'.format(it, tl, *ls))
        
        if (it%steps_per_epoch) == 0:
            print('Validation')
            sess.run([step_reset, loss_avg_reset, cl_loss_avg_reset, inst_loss_avg_reset, ct_loss_avg_reset], 
                      {training:True, dataset_handle: train_handle})#, itr:tt})
            for vt in range(1, num_valid_steps):
                tl, ls = sess.run([loss_avg, [cl_loss_avg, inst_loss_avg, ct_loss_avg]], {dataset_handle: valid_handle, training:False})
                                                                                     #itr:tt})
                sys.stdout.write('\rval iter: {}, total_loss: {:.4f}, cl_loss: {:.4f}, inst_loss: {:.4f}, ct_loss: {:.4f}'.format(vt, tl, *ls))
            sess.run([step_reset, loss_avg_reset, cl_loss_avg_reset, inst_loss_avg_reset, ct_loss_avg_reset], 
                      {training:True, dataset_handle: train_handle})#, itr:tt})

In [None]:
saver = tf.train.Saver()
best_dice = 0
with tf.Session() as sess:
    if config.ckpt is not None:
        print('Restoring weights from', config.ckpt)
        saver.restore(sess, config.ckpt)
    else:
        sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    init_iterators(sess)
    train_handle, valid_handle = get_itr_handles(sess)
    
    train_writer = tf.summary.FileWriter(
        os.path.join(config.logs_path, 'train'), graph=sess.graph)
    val_writer = tf.summary.FileWriter(
        os.path.join(config.logs_path, 'val'), graph=sess.graph)
    img_writer = tf.summary.FileWriter(
        os.path.join(config.logs_path, 'img'), graph=sess.graph)
    
    tq_train = tqdm_notebook(range(config.iters_done+1, config.iters_done+config.n_iters+1), 
                             initial=config.iters_done+1)
    
    for it in tq_train:
        fetch = [train_step, loss_avg, dice_avg, train_summary_op, img_summary_op]
        fetch_vals = sess.run(fetch, {dataset_handle: train_handle, training:True})
        _, loss_val, dice_val, train_sum_str, img_sum_str = fetch_vals
        
        tq_train.set_postfix(loss=loss_val, dice=dice_val)
        train_writer.add_summary(train_sum_str, it)
        img_writer.add_summary(img_sum_str, it)
        
        
        if (it%config.valid_every) == 0:
            sess.run([step_reset, loss_avg_reset, dice_avg_reset], {training:True, dataset_handle: train_handle})
            tq_valid = tqdm_notebook(range(1, config.valid_iters+1), initial=1)
            
            for val_iter in tq_valid:
                fetch_valid = [loss_avg, dice_avg, valid_summary_op, img_summary_op]
                fetch_valid_vals = sess.run(fetch_valid, {dataset_handle: valid_handle, training:False})
                loss_valid_val, dice_valid_val, val_sum_str, img_val_sum_str = fetch_valid_vals 
                img_writer.add_summary(img_sum_str, it)
                
                tq_valid.set_postfix(loss=loss_valid_val, dice=dice_valid_val)
            
            sess.run([step_reset, loss_avg_reset, dice_avg_reset], {training:False, dataset_handle: valid_handle})
            val_writer.add_summary(val_sum_str, it)
        
        
            present_dice = dice_valid_val #avg_dict_val['dice']
            if present_dice > best_dice:
                print('Dice increased from {:.4f} to {:.4f}'.format(best_dice, present_dice))
                save_path = saver.save(sess=sess, 
                                       save_path='{}/{}_best'.format(config.save_path, config.model_name))
                print('Saving to {}'.format(save_path))
                best_dice = present_dice
        
    saver.save(sess=sess, 
               save_path='{}/{}_last'.format(config.save_path, config.model_name))
        
            
        

In [None]:
config = Config()