In [None]:
#3D Unet model:
simpleUNet = False

class UNetwork():
    
    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv) 
        return conv

    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        # Upconvolution - two different implementations: the first is as suggested in the original Unet paper and the second is a more recent version
        # Needs to be determined if these do the same thing
        padding = 'valid'
        if self.should_pad: padding = 'same'
        # upsample_routine = tf.keras.layers.UpSampling3D(size = (scale,scale,scale)) # Uses tf.resize_images
        # tensor = upsample_routine(tensor)
        # conv = tf.layers.conv3d(tensor, filters, kernel, stride, padding = 'same',
        #                                 kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        # use_bias = False is a tensorflow bug
        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias=False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv

    def centre_crop_and_concat(self, prev_conv, up_conv):
        # If concatenating two different sized Tensors, centre crop the first Tensor to the right size and concat
        # Needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        up_concat = tf.concat((prev_conv_crop, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 8, in_depth = INPUT_DEPTH, out_depth = OUTPUT_DEPTH,
                 in_size = INPUT_SIZE, out_size = OUTPUT_SIZE, num_classes = OUTPUT_CLASSES,
                 learning_rate = 0.001, print_shapes = True, drop = 0.2, should_pad = False):
        # Initialise your model with the parameters defined above
        # Print-shape is a debug shape printer for convenience
        # Should_pad controls whether the model has padding or not
        # Base_filt controls the number of base conv filters the model has. Note deeper analysis paths have filters that are scaled by this value
        # Drop specifies the proportion of dropped activations
        
        self.base_init = tf.truncated_normal_initializer(stddev=0.1) # Initialise weights
        self.reg_init = tf.contrib.layers.l2_regularizer(scale=0.1) # Initialise regularisation (was useful)
        
        self.should_pad = should_pad # To pad or not to pad, that is the question
        self.drop = drop # Set dropout rate
        
        with tf.variable_scope('3DuNet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # Define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, out_depth, out_size, out_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('Input features shape', self.model_input.get_shape())
                print('Labels shape', labels_one_hot.get_shape())
                
            # Level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            # Level one
            max_1_1 = tf.layers.max_pooling3d(conv_0_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            conv_1_2 = tf.layers.dropout(conv_1_2, rate = self.drop, training = self.training)
            # Level two
            max_2_1 = tf.layers.max_pooling3d(conv_1_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            conv_2_2 = tf.layers.dropout(conv_2_2, rate = self.drop, training = self.training)
            
            if simpleUNet:
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_2, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            else:
                # Level three
                max_3_1 = tf.layers.max_pooling3d(conv_2_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
                conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
                conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
                conv_3_2 = tf.layers.dropout(conv_3_2, rate = self.drop, training = self.training)
                # Level two
                up_conv_3_2 = self.upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2] 
                concat_2_1 = self.centre_crop_and_concat(conv_2_2, up_conv_3_2)
                conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
                conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
                conv_2_4 = tf.layers.dropout(conv_2_4, rate = self.drop, training = self.training)
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            
            concat_1_1 = self.centre_crop_and_concat(conv_1_2, up_conv_2_1)
            conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            conv_1_4 = tf.layers.dropout(conv_1_4, rate = self.drop, training = self.training)
            # Level zero
            up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [1,2,2])  # Stride previously [2,2,2]
            concat_0_1 = self.centre_crop_and_concat(conv_0_2, up_conv_1_0)
            conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            conv_0_4 = tf.layers.dropout(conv_0_4, rate = self.drop, training = self.training)
            conv_out = tf.layers.conv3d(conv_0_4, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            # Note, this can be more easily visualised in a tool like tensorboard; Follows exact same format as in Paper.
            
            if self.do_print: 
                print('Model Convolution output shape', conv_out.get_shape())
                print('Model Argmax output shape', self.predictions.get_shape())
            
            do_weight = True
            loss_weights = [0.00439314, 0.68209101, 0.31351585] # see section 1.4 # instead of [1, 150, 100, 1.0] 
            # Weighted cross entropy: approach adapts following code: https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy
            ce_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=conv_out, labels=labels_one_hot)
            if do_weight:
                weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, 1, 1, num_classes]) # Format to the right size
                weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
                ce_loss = ce_loss * weighted_one_hot
            self.loss = tf.reduce_mean(ce_loss) # Get loss
            
            self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)

In [None]:
#model with attention module and dense layers:

simpleUNet = False

class UNetwork():
    
    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.leaky_relu(conv)
        #conv = tf.nn.relu(conv) 
        return conv
    # densenet model    
    def dense_conv_1(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        conv_4_1 = self.conv_batch_relu(tensor, filters, is_training = self.training)
        conv_4_2 = self.conv_batch_relu(conv_4_1, filters, is_training = self.training)
        conv_4_2 = tf.layers.dropout(conv_4_2, rate = self.drop, training = self.training)

        conv_5_1 = self.conv_batch_relu(conv_4_2, filters, is_training = self.training)
        conv_5_2 = self.conv_batch_relu(conv_5_1, filters, is_training = self.training)
        conv_5_2 = tf.layers.dropout(conv_5_2, rate = self.drop, training = self.training)

        concate_merge = tf.concat((conv_4_2,conv_5_2),4)
        conv_6_1 = self.conv_batch_relu(concate_merge, filters, is_training = self.training)
        conv_6_2 = self.conv_batch_relu(conv_6_1, filters, is_training = self.training)
        conv_6_2 = tf.layers.dropout(conv_6_2, rate = self.drop, training = self.training)

        concate_merge1 = tf.concat((conv_4_2,conv_6_2),4)
        concate_merge2 = tf.concat((conv_5_2,concate_merge1),4)
        conv_7_1 = self.conv_batch_relu(concate_merge2, filters, is_training = self.training)
        conv_7_2 = self.conv_batch_relu(conv_7_1, filters, is_training = self.training)
        conv_7_2 = tf.layers.dropout(conv_7_2, rate = self.drop, training = self.training)
        
        return conv_7_2

    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        # Upconvolution - two different implementations: the first is as suggested in the original Unet paper and the second is a more recent version
        # Needs to be determined if these do the same thing
        padding = 'valid'
        if self.should_pad: padding = 'same'
        # upsample_routine = tf.keras.layers.UpSampling3D(size = (scale,scale,scale)) # Uses tf.resize_images
        # tensor = upsample_routine(tensor)
        # conv = tf.layers.conv3d(tensor, filters, kernel, stride, padding = 'same',
        #                                 kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        # use_bias = False is a tensorflow bug
        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias=False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv
    #attention model on skip connection
    def centre_crop_and_concat(self, prev_conv, up_conv,filters):
        # If concatenating two different sized Tensors, centre crop the first Tensor to the right size and concat
        # Needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        conv_out1 = tf.layers.conv3d(prev_conv_crop,filters, [1,1,1], [1,1,1], padding = 'same')
        conv_out2 = tf.layers.conv3d(up_conv,filters, [1,1,1], [1,1,1], padding = 'same')
        add = tf.add(conv_out1,conv_out2)
        conv = tf.nn.leaky_relu(add)
        conv_out3 = tf.layers.conv3d(conv,filters, [1,1,1], [1,1,1], padding = 'same')
        final_atten = tf.nn.sigmoid(conv_out3)
        up_concat = tf.maximum(final_atten, up_conv)
        #up_concat = tf.matmul(final_atten, up_conv)

        #up_concat = tf.maximum((final_atten, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 16, in_depth = INPUT_DEPTH, out_depth = OUTPUT_DEPTH,
                 in_size = INPUT_SIZE, out_size = OUTPUT_SIZE, num_classes = OUTPUT_CLASSES,
                 learning_rate = 0.001, print_shapes = True, drop = 0.2, should_pad = False):
        # Initialise your model with the parameters defined above
        # Print-shape is a debug shape printer for convenience
        # Should_pad controls whether the model has padding or not
        # Base_filt controls the number of base conv filters the model has. Note deeper analysis paths have filters that are scaled by this value
        # Drop specifies the proportion of dropped activations
        
        self.base_init = tf.truncated_normal_initializer(stddev=0.1) # Initialise weights
        self.reg_init = tf.contrib.layers.l2_regularizer(scale=0.1) # Initialise regularisation (was useful)
        
        self.should_pad = should_pad # To pad or not to pad, that is the question
        self.drop = drop # Set dropout rate
        
        with tf.variable_scope('3DuNet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # Define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, out_depth, out_size, out_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('Input features shape', self.model_input.get_shape())
                print('Labels shape', labels_one_hot.get_shape())
                
            # Level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            dense_1 = self.dense_conv_1(conv_0_2, base_filt*2,is_training = self.training)
           
            # Level one
            max_1_1 = tf.layers.max_pooling3d(dense_1, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            dense_2 = self.dense_conv_1(conv_1_2, base_filt*4,is_training = self.training)
            conv_1_2 = tf.layers.dropout(dense_2, rate = self.drop, training = self.training)
            # Level two
            max_2_1 = tf.layers.max_pooling3d(conv_1_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            dense_3 = self.dense_conv_1(conv_2_2, base_filt*8,is_training = self.training)
            conv_2_2 = tf.layers.dropout(dense_3, rate = self.drop, training = self.training)
            
            if simpleUNet:
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_2, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            else:
                # Level three
                max_3_1 = tf.layers.max_pooling3d(conv_2_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
                conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
                conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
                dense_4 = self.dense_conv_1(conv_3_2, base_filt*16,is_training = self.training)
                conv_3_2 = tf.layers.dropout(dense_4, rate = self.drop, training = self.training)
                conv_atten = tf.layers.conv3d(conv_3_2,base_filt*16, [1,1,1], [1,1,1], padding = 'same')
                conv_atten = tf.nn.leaky_relu(conv_atten)
                conv_atten = tf.layers.conv3d(conv_atten,base_filt*16, [1,1,1], [1,1,1], padding = 'same')
                final_atten = tf.nn.sigmoid(conv_atten)
                up_concat = tf.multiply(final_atten, conv_3_2)
                # Level two
                up_conv_3_2 = self.upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2] 
                concat_2_1 = self.centre_crop_and_concat(conv_2_2, up_conv_3_2,base_filt*16)
                conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
                conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
                dense_5 = self.dense_conv_1(conv_2_4, base_filt*8,is_training = self.training)
                conv_2_4 = tf.layers.dropout(dense_5, rate = self.drop, training = self.training)
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            
            concat_1_1 = self.centre_crop_and_concat(conv_1_2, up_conv_2_1,base_filt*8)
            conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            dense_6 = self.dense_conv_1(conv_1_4, base_filt*4,is_training = self.training)
            conv_1_4 = tf.layers.dropout(dense_6, rate = self.drop, training = self.training)
            # Level zero
            up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [1,2,2])  # Stride previously [2,2,2]
            concat_0_1 = self.centre_crop_and_concat(conv_0_2, up_conv_1_0,base_filt*4)
            conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            dense_7 = self.dense_conv_1(conv_0_4, base_filt*2,is_training = self.training)
            conv_0_4 = tf.layers.dropout(dense_7, rate = self.drop, training = self.training)
            conv_out = tf.layers.conv3d(conv_0_4, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            # Note, this can be more easily visualised in a tool like tensorboard; Follows exact same format as in Paper.
            
            if self.do_print: 
                print('Model Convolution output shape', conv_out.get_shape())
                print('Model Argmax output shape', self.predictions.get_shape())
            
            do_weight = True
            loss_weights = [0.00439314, 0.68209101, 0.31351585] # see section 1.4 # instead of [1, 150, 100, 1.0] 
            # Weighted cross entropy: approach adapts following code: https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy
            ce_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=conv_out, labels=labels_one_hot)
            if do_weight:
                weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, 1, 1, num_classes]) # Format to the right size
                weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
                ce_loss = ce_loss * weighted_one_hot
            self.loss = tf.reduce_mean(ce_loss) # Get loss
            
            self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)

In [None]:
#model with attention module:

simpleUNet = False

class UNetwork():
    
    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.leaky_relu(conv)
        #conv = tf.nn.relu(conv) 
        return conv

    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        # Upconvolution - two different implementations: the first is as suggested in the original Unet paper and the second is a more recent version
        # Needs to be determined if these do the same thing
        padding = 'valid'
        if self.should_pad: padding = 'same'
        # upsample_routine = tf.keras.layers.UpSampling3D(size = (scale,scale,scale)) # Uses tf.resize_images
        # tensor = upsample_routine(tensor)
        # conv = tf.layers.conv3d(tensor, filters, kernel, stride, padding = 'same',
        #                                 kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        # use_bias = False is a tensorflow bug
        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias=False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv

    def centre_crop_and_concat(self, prev_conv, up_conv,filters):
        # If concatenating two different sized Tensors, centre crop the first Tensor to the right size and concat
        # Needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        conv_out1 = tf.layers.conv3d(prev_conv_crop,filters, [1,1,1], [1,1,1], padding = 'same')
        conv_out2 = tf.layers.conv3d(up_conv,filters, [1,1,1], [1,1,1], padding = 'same')
        add = tf.add(conv_out1,conv_out2)
        conv = tf.nn.leaky_relu(add)
        conv_out3 = tf.layers.conv3d(conv,filters, [1,1,1], [1,1,1], padding = 'same')
        final_atten = tf.nn.sigmoid(conv_out3)
        up_concat = tf.matmul(final_atten, up_conv)

        #up_concat = tf.concat((final_atten, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 16, in_depth = INPUT_DEPTH, out_depth = OUTPUT_DEPTH,
                 in_size = INPUT_SIZE, out_size = OUTPUT_SIZE, num_classes = OUTPUT_CLASSES,
                 learning_rate = 0.001, print_shapes = True, drop = 0.2, should_pad = False):
        # Initialise your model with the parameters defined above
        # Print-shape is a debug shape printer for convenience
        # Should_pad controls whether the model has padding or not
        # Base_filt controls the number of base conv filters the model has. Note deeper analysis paths have filters that are scaled by this value
        # Drop specifies the proportion of dropped activations
        
        self.base_init = tf.truncated_normal_initializer(stddev=0.1) # Initialise weights
        self.reg_init = tf.contrib.layers.l2_regularizer(scale=0.1) # Initialise regularisation (was useful)
        
        self.should_pad = should_pad # To pad or not to pad, that is the question
        self.drop = drop # Set dropout rate
        
        with tf.variable_scope('3DuNet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # Define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, out_depth, out_size, out_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('Input features shape', self.model_input.get_shape())
                print('Labels shape', labels_one_hot.get_shape())
                
            # Level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            # Level one
            max_1_1 = tf.layers.max_pooling3d(conv_0_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            conv_1_2 = tf.layers.dropout(conv_1_2, rate = self.drop, training = self.training)
            # Level two
            max_2_1 = tf.layers.max_pooling3d(conv_1_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            conv_2_2 = tf.layers.dropout(conv_2_2, rate = self.drop, training = self.training)
            
            if simpleUNet:
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_2, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            else:
                # Level three
                max_3_1 = tf.layers.max_pooling3d(conv_2_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
                conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
                conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
                conv_3_2 = tf.layers.dropout(conv_3_2, rate = self.drop, training = self.training)
                # Level two
                up_conv_3_2 = self.upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2] 
                concat_2_1 = self.centre_crop_and_concat(conv_2_2, up_conv_3_2,base_filt*16)
                conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
                conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
                conv_2_4 = tf.layers.dropout(conv_2_4, rate = self.drop, training = self.training)
                # Level one
                up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            
            concat_1_1 = self.centre_crop_and_concat(conv_1_2, up_conv_2_1,base_filt*8)
            conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            conv_1_4 = tf.layers.dropout(conv_1_4, rate = self.drop, training = self.training)
            # Level zero
            up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [1,2,2])  # Stride previously [2,2,2]
            concat_0_1 = self.centre_crop_and_concat(conv_0_2, up_conv_1_0,base_filt*4)
            conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            conv_0_4 = tf.layers.dropout(conv_0_4, rate = self.drop, training = self.training)
            conv_out = tf.layers.conv3d(conv_0_4, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            # Note, this can be more easily visualised in a tool like tensorboard; Follows exact same format as in Paper.
            
            if self.do_print: 
                print('Model Convolution output shape', conv_out.get_shape())
                print('Model Argmax output shape', self.predictions.get_shape())
            
            do_weight = True
            loss_weights = [0.00439314, 0.68209101, 0.31351585] # see section 1.4 # instead of [1, 150, 100, 1.0] 
            # Weighted cross entropy: approach adapts following code: https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy
            ce_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=conv_out, labels=labels_one_hot)
            if do_weight:
                weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, 1, 1, num_classes]) # Format to the right size
                weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
                ce_loss = ce_loss * weighted_one_hot
            self.loss = tf.reduce_mean(ce_loss) # Get loss
            
            self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)

In [None]:
# 3DUnet with the resnet34 on encoder side

simpleUNet = False

class UNetwork():

    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv) 
        return conv
    def conv_pool(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        #conv = tf.layers.batch_normalization(conv, training = is_training)
        #conv = tf.nn.relu(conv) 
        max = tf.layers.max_pooling3d(conv, [1,2,2], [1,2,2])
        return max

    def conv_batch_relu_conv(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
        conv = tf.layers.batch_normalization(tensor, training = is_training)
        conv = tf.nn.relu(conv)
    
        conv = tf.layers.conv3d(conv, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv)
        conv = tf.layers.conv3d(conv, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init) 
        return conv
    def conv_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        # Produces the conv_batch_relu combination as in the paper
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        #conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv) 

        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        #conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv)
        return conv    


    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        # Upconvolution - two different implementations: the first is as suggested in the original Unet paper and the second is a more recent version
        # Needs to be determined if these do the same thing
        padding = 'valid'
        if self.should_pad: padding = 'same'
        # upsample_routine = tf.keras.layers.UpSampling3D(size = (scale,scale,scale)) # Uses tf.resize_images
        # tensor = upsample_routine(tensor)
        # conv = tf.layers.conv3d(tensor, filters, kernel, stride, padding = 'same',
        #                                 kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        # use_bias = False is a tensorflow bug
        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias=False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv

    def centre_crop_and_concat(self, prev_conv, up_conv):
        # If concatenating two different sized Tensors, centre crop the first Tensor to the right size and concat
        # Needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        up_concat = tf.concat((prev_conv_crop, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 8, in_depth = INPUT_DEPTH, out_depth = OUTPUT_DEPTH,
                 in_size = INPUT_SIZE, out_size = OUTPUT_SIZE, num_classes = OUTPUT_CLASSES,
                 learning_rate = 0.001, print_shapes = True, drop = 0.2, should_pad = True):
        # Initialise your model with the parameters defined above
        # Print-shape is a debug shape printer for convenience
        # Should_pad controls whether the model has padding or not
        # Base_filt controls the number of base conv filters the model has. Note deeper analysis paths have filters that are scaled by this value
        # Drop specifies the proportion of dropped activations
        
        self.base_init = tf.truncated_normal_initializer(stddev=0.1) # Initialise weights
        self.reg_init = tf.contrib.layers.l2_regularizer(scale=0.1) # Initialise regularisation (was useful)
        
        self.should_pad = should_pad # To pad or not to pad, that is the question
        self.drop = drop # Set dropout rate
        
        with tf.variable_scope('3DuNet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # Define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, out_depth, out_size, out_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('Input features shape', self.model_input.get_shape())
                print('Labels shape', labels_one_hot.get_shape())
                
            # Level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            # Level one
            max_1_1 = tf.layers.max_pooling3d(conv_0_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]

            conv_1_1 = self.conv_batch_relu_conv(max_1_1, base_filt*2, is_training = self.training)
            add_1 = tf.add(max_1_1,conv_1_1)

            conv_2_1 = self.conv_batch_relu_conv(add_1, base_filt*2, is_training = self.training)
            add_2 = tf.add(add_1,conv_2_1)

            conv_3_1 = self.conv_batch_relu_conv(add_2, base_filt*2, is_training = self.training)
            add_3 = tf.add(add_2,conv_3_1)

            conv_4_1 = self.conv_batch_relu_conv(add_3, base_filt*4, is_training = self.training)
            #here 1st concate
            max_2_1 = tf.layers.max_pooling3d(conv_4_1, [1,2,2], [1,2,2])

            skip_1 = self.conv_pool(add_3, base_filt*4, is_training = self.training)

            add_4 = tf.add(max_2_1,skip_1)

            conv_5_1 = self.conv_batch_relu_conv(add_4, base_filt*4, is_training = self.training)
            add_5 = tf.add(add_4,conv_5_1)

            conv_6_1 = self.conv_batch_relu_conv(add_5, base_filt*4, is_training = self.training)
            add_6 = tf.add(add_5,conv_6_1)

            conv_7_1 = self.conv_batch_relu_conv(add_6, base_filt*4, is_training = self.training)
            add_7 = tf.add(add_6,conv_7_1)

            conv_8_1 = self.conv_batch_relu_conv(add_7, base_filt*8, is_training = self.training)
            #here 2st concate
            max_3_1 = tf.layers.max_pooling3d(conv_8_1, [1,2,2], [1,2,2])

            skip_2 = self.conv_pool(add_7, base_filt*8, is_training = self.training)

            add_8 = tf.add(max_3_1,skip_2)

            conv_9_1 = self.conv_batch_relu_conv(add_8, base_filt*8, is_training = self.training)
            add_9 = tf.add(add_8,conv_9_1)

            conv_10_1 = self.conv_batch_relu_conv(add_9, base_filt*8, is_training = self.training)
            add_10 = tf.add(add_9,conv_10_1)

            conv_11_1 = self.conv_batch_relu_conv(add_10, base_filt*8, is_training = self.training)
            add_11 = tf.add(add_10,conv_11_1)

            conv_12_1 = self.conv_batch_relu_conv(add_11, base_filt*8, is_training = self.training)
            add_12 = tf.add(add_11,conv_12_1)
            print(add_12.shape)

            conv_13_1 = self.conv_batch_relu_conv(add_12, base_filt*8, is_training = self.training)
            add_13 = tf.add(add_12,conv_13_1)

            conv_14_1 = self.conv_batch_relu_conv(add_13, base_filt*16, is_training = self.training)
            #here 2st concate
            max_4_1 = tf.layers.max_pooling3d(conv_14_1, [1,2,2], [1,2,2])

            skip_3 = self.conv_pool(add_13, base_filt*16, is_training = self.training)

            add_14 = tf.add(max_4_1,skip_3)

            conv_15_1 = self.conv_batch_relu_conv(add_14, base_filt*16, is_training = self.training)
            add_15 = tf.add(add_14,conv_15_1)

            conv_16_1 = self.conv_batch_relu_conv(add_15, base_filt*16, is_training = self.training)
            add_16 = tf.add(add_15,conv_16_1)

            batch_16 = tf.layers.batch_normalization(add_16, training= self.training)
            last = tf.nn.relu(batch_16)

            up_conv_3_2 = self.upconvolve(last, base_filt*16, kernel = 2, stride = [1,2,2])
            concate_3 = tf.concat((up_conv_3_2,conv_14_1),4)
            conv_17_1 = self.conv_relu(concate_3,base_filt*16, is_training = self.training)

            up_conv_2_2 = self.upconvolve(conv_17_1, base_filt*8, kernel = 2, stride = [1,2,2])
            concate_2 = tf.concat((up_conv_2_2,conv_8_1),4)
            conv_18_1 = self.conv_relu(concate_2,base_filt*8, is_training = self.training)

            up_conv_1_2 = self.upconvolve(conv_18_1, base_filt*4, kernel = 2, stride = [1,2,2])
            concate_1 = tf.concat((up_conv_1_2,conv_4_1),4)
            conv_19_1 = self.conv_relu(concate_1,base_filt*4, is_training = self.training)

            
            up_conv_0_2 = self.upconvolve(conv_19_1, base_filt*2, kernel = 2, stride = [1,2,2])
            conv_20_1 = self.conv_relu(up_conv_0_2,base_filt*2, is_training = self.training)

            conv_out = tf.layers.conv3d(conv_20_1, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)





























































            #conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            #conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            #conv_1_2 = tf.layers.dropout(conv_1_2, rate = self.drop, training = self.training)
            # Level two
            #max_2_1 = tf.layers.max_pooling3d(conv_1_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
            #conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            #conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            #conv_2_2 = tf.layers.dropout(conv_2_2, rate = self.drop, training = self.training)
            
            #if simpleUNet:
                # Level one
             #   up_conv_2_1 = self.upconvolve(conv_2_2, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            #else:
                # Level three
            #    max_3_1 = tf.layers.max_pooling3d(conv_2_2, [1,2,2], [1,2,2]) # Stride, Kernel previously [2,2,2]
             #   conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
              #  conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
               # conv_3_2 = tf.layers.dropout(conv_3_2, rate = self.drop, training = self.training)

                #conv_4_1 = self.conv_batch_relu(conv_3_2, base_filt*16, is_training = self.training)
                #conv_4_2 = self.conv_batch_relu(conv_4_1, base_filt*16, is_training = self.training)
                #conv_4_2 = tf.layers.dropout(conv_4_2, rate = self.drop, training = self.training)

                #conv_5_1 = self.conv_batch_relu(conv_4_2, base_filt*16, is_training = self.training)
                #conv_5_2 = self.conv_batch_relu(conv_5_1, base_filt*16, is_training = self.training)
                #conv_5_2 = tf.layers.dropout(conv_5_2, rate = self.drop, training = self.training)

                #concate_merge = tf.concat((conv_4_2,conv_5_2),4)
                #conv_6_1 = self.conv_batch_relu(concate_merge, base_filt*16, is_training = self.training)
                #conv_6_2 = self.conv_batch_relu(conv_6_1, base_filt*16, is_training = self.training)
                #conv_6_2 = tf.layers.dropout(conv_6_2, rate = self.drop, training = self.training)

                #concate_merge1 = tf.concat((conv_4_2,conv_6_2),4)
                #concate_merge2 = tf.concate((conv_5_2,concate_merge1),4)
                #conv_7_1 = self.conv_batch_relu(concate_merge2, base_filt*16, is_training = self.training)
                #conv_7_2 = self.conv_batch_relu(conv_7_1, base_filt*16, is_training = self.training)
                #conv_7_2 = tf.layers.dropout(conv_7_2, rate = self.drop, training = self.training)


                

                # Level two
                #up_conv_3_2 = self.upconvolve(conv_7_2, base_filt*16, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2] 
                #concat_2_1 = self.centre_crop_and_concat(conv_2_2, up_conv_3_2)
                #conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
                #conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
                #conv_2_4 = tf.layers.dropout(conv_2_4, rate = self.drop, training = self.training)
                # Level one
                #up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [1,2,2]) # Stride previously [2,2,2]
            
            #concat_1_1 = self.centre_crop_and_concat(conv_1_2, up_conv_2_1)
            #conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            #conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            #conv_1_4 = tf.layers.dropout(conv_1_4, rate = self.drop, training = self.training)
            # Level zero
            #up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [1,2,2])  # Stride previously [2,2,2]
            #concat_0_1 = self.centre_crop_and_concat(conv_0_2, up_conv_1_0)
            #conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            #conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            #conv_0_4 = tf.layers.dropout(conv_0_4, rate = self.drop, training = self.training)
            #conv_out = tf.layers.conv3d(conv_0_4, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            #self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            # Note, this can be more easily visualised in a tool like tensorboard; Follows exact same format as in Paper.
            
            if self.do_print: 
                print('Model Convolution output shape', conv_out.get_shape())
                print('Model Argmax output shape', self.predictions.get_shape())
            
            do_weight = True
            loss_weights = [0.00439314, 0.68209101, 0.31351585] # see section 1.4 # instead of [1, 150, 100, 1.0] 
            # Weighted cross entropy: approach adapts following code: https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy
            ce_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=conv_out, labels=labels_one_hot)
            if do_weight:
                weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, 1, 1, num_classes]) # Format to the right size
                weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
                ce_loss = ce_loss * weighted_one_hot
            self.loss = tf.reduce_mean(ce_loss) # Get loss
            
            self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)