In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import keras.backend as K

In [None]:
def _bernoulli(shape, mean):
    return tf.nn.relu(tf.sign(mean - tf.random.uniform(shape, minval=0, maxval=1, dtype=tf.float32)))
class DropBlock2D(tf.keras.layers.Layer):
    def __init__(self, keep_prob, block_size, scale=True,name=None, **kwargs):
        super(DropBlock2D, self).__init__(name="DropBlock2D")
        self.keep_prob = min(1.0, max(float(keep_prob), 0))
        self.block_size = int(block_size)
        self.names = name
        self.scale = tf.constant(scale, dtype=tf.bool) if isinstance(scale, bool) else scale
        super(DropBlock2D, self).__init__(**kwargs)
        
    def get_config(self):
        config = super().get_config().copy()
        config.update( {"block_size": self.block_size,"keep_prob":  self.keep_prob,"name": self.names})  
        return config

    def compute_output_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        assert len(input_shape) == 4
        _, self.h, self.w, self.channel = input_shape.as_list()
        # pad the mask
        p1 = (self.block_size - 1) // 2
        p0 = (self.block_size - 1) - p1
        self.padding = [[0, 0], [p0, p1], [p0, p1], [0, 0]]
        super(DropBlock2D, self).build(input_shape)

    def call(self, inputs, training=None, **kwargs):
        def drop():
            mask = self._create_mask(tf.shape(inputs))
            output = inputs * mask
            output = tf.cond(self.scale,
                             true_fn=lambda: output *tf.cast(tf.size(mask), dtype=tf.float32)  / tf.reduce_sum(mask),
                             false_fn=lambda: output)
            return output

        if training is None:
            training = K.learning_phase()
            training = tf.cast(training, tf.bool)
        output = tf.cond(tf.logical_or(tf.logical_not(training), tf.equal(self.keep_prob, 1.0)),
                         true_fn=lambda: inputs,
                         false_fn=drop)
        return output
    def _create_mask(self, input_shape):
        self.gamma = (1. - self.keep_prob) * (self.w * self.h) / (self.block_size ** 2) / \
                     ((self.w - self.block_size + 1) * (self.h - self.block_size + 1))
        sampling_mask_shape = tf.stack([input_shape[0],
                                       self.h - self.block_size + 1,
                                       self.w - self.block_size + 1,
                                       self.channel])
        mask = _bernoulli(sampling_mask_shape, self.gamma)
        mask = tf.pad(mask, self.padding)
        mask = tf.nn.max_pool(mask, [1, self.block_size, self.block_size, 1], [1, 1, 1, 1], 'SAME')
        mask = 1 - mask
        return mask

class DropBlock3D(tf.keras.layers.Layer):
    def __init__(self, keep_prob, block_size, scale=True,name=None, **kwargs):
        super(DropBlock3D, self).__init__(name="DropBlock3D")
        self.keep_prob = min(1.0, max(float(keep_prob), 0))
        self.block_size = int(block_size)
        self.names = name
        self.scale = tf.constant(scale, dtype=tf.bool) if isinstance(scale, bool) else scale
        super(DropBlock3D, self).__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update( {"block_size": self.block_size,"keep_prob":  self.keep_prob,"name": self.names})  
        return config
        
    def compute_output_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        assert len(input_shape) == 5
        _, self.d, self.h, self.w, self.channel = input_shape.as_list()
        # pad the mask
        p1 = (self.block_size - 1) // 2
        p0= (self.block_size - 1) - p1
        self.padding = [[0, 0], [p0, p1], [p0, p1], [p0, p1], [0, 0]]
        super(DropBlock3D, self).build(input_shape)
    
    def call(self, inputs, training=None, **kwargs):
        def drop():
            mask = self._create_mask(tf.shape(inputs))
            output = inputs * mask
            output = tf.cond(self.scale,
                             true_fn=lambda: output * tf.cast(tf.size(mask), dtype=tf.float32)  / tf.reduce_sum(mask),
                             false_fn=lambda: output)
            return output
        if training is None:
            training = K.learning_phase()
            training = tf.cast(training, tf.bool)
        output = tf.cond(tf.logical_or(tf.logical_not(training), tf.equal(self.keep_prob, 1.0)),
                         true_fn=lambda: inputs,
                         false_fn=drop)
        return output
    def _create_mask(self, input_shape):
        self.gamma = ((1. - self.keep_prob) * (self.d * self.w * self.h) / (self.block_size ** 3) /
                    ((self.d - self.block_size + 1) * (self.w - self.block_size + 1) * (self.h - self.block_size + 1)))
        sampling_mask_shape = tf.stack([input_shape[0],
                                        self.d - self.block_size + 1,
                                        self.h - self.block_size + 1,
                                        self.w - self.block_size + 1,
                                        self.channel])
        mask = _bernoulli(sampling_mask_shape, self.gamma)
        mask = tf.pad(mask, self.padding)
        mask = tf.nn.max_pool3d(mask, [1, self.block_size, self.block_size, self.block_size, 1], [1, 1, 1, 1, 1], 'SAME')
        mask = 1 - mask
        return mask

In [None]:
# import numpy as np
# import tensorflow as tf
# # only support `channels_last` data format
# a = Input(shape=(4,4,4,10))
# b = DropBlock3D(block_size=3, keep_prob=0)(a)

# model = Model(a,b)

# for layer in model.layers:
#     if isinstance(layer, DropBlock3D):
#         print(layer.gamma)
# output = model(np.ones([1,4,4,4,1]), training = True)
# output

In [None]:
def cbam_block(cbam_feature, ratio=8):
	"""Contains the implementation of Convolutional Block Attention Module(CBAM) block.
	As described in https://arxiv.org/abs/1807.06521.
	"""
	cbam_feature = channel_attention(cbam_feature, ratio)
	cbam_feature = spatial_attention(cbam_feature)
	return cbam_feature

def channel_attention(input_feature, ratio=8):
    channel = input_feature.shape[-1]
    num_reduced_filters= max(1, int(channel // ratio))
    avg_pool = Lambda(lambda a: K.mean(a, axis=[1,2,3], keepdims=True))(input_feature)
    avg_pool = Conv3D(num_reduced_filters, kernel_size=(1, 1, 1), kernel_initializer='he_normal',
                        padding='same',use_bias=True)(avg_pool)
    avg_pool = Swish()(avg_pool)
    avg_pool = Conv3D(channel, kernel_size=(1, 1, 1), kernel_initializer='he_normal',
                        padding='same',use_bias=True)(avg_pool)

    max_pool = Lambda(lambda a: K.mean(a, axis=[1,2,3], keepdims=True))(input_feature)
    max_pool = Conv3D(num_reduced_filters, kernel_size=(1, 1, 1), kernel_initializer='he_normal',
                        padding='same',use_bias=True)(max_pool)
    max_pool = Swish()(max_pool)
    max_pool = Conv3D(channel, kernel_size=(1, 1, 1), kernel_initializer='he_normal',
                        padding='same',use_bias=True)(max_pool)

    cbam_feature = Add()([avg_pool,max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)

    output = Multiply()([input_feature, cbam_feature])
    return output

def spatial_attention(input_feature):
    kernel_size = 7
    channel = input_feature.shape[-1]
    cbam_feature = input_feature

    avg_pool = Lambda(lambda x: K.mean(x, axis=-1, keepdims=True))(cbam_feature)
    max_pool = Lambda(lambda x: K.max(x, axis=-1, keepdims=True))(cbam_feature)

    concat = Concatenate()([avg_pool, max_pool])
    cbam_feature = Conv3D(filters = 1,kernel_size=kernel_size,strides=1,padding='same',activation='sigmoid',
                    kernel_initializer='he_normal',
                    use_bias=False)(concat)	
    return Multiply()([input_feature, cbam_feature])

In [None]:
class Swish(tf.keras.layers.Layer):
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

    def call(self, inputs, **kwargs):
        return tf.nn.swish(inputs)
    def get_config(self):
        config = super().get_config()
        config['name'] = self.name
        return config
def convBn(inputs, filters,kernel_size = (3,3,3), block_size = 1, keep_prob = 1):
    x = Conv3D(filters, kernel_size, padding="same", use_bias=False,kernel_initializer='he_normal')(inputs)
    x = DropBlock3D(block_size = block_size, keep_prob = keep_prob)(x)
    x = BatchNormalization()(x)
    x = Swish()(x)
    x = cbam_block(x)
    return x
def downSampleBlock(inputs, block_size = 1, keep_prob = 1 ):
    indim = inputs.shape[-1] // 2
    x = convBn(inputs, indim, (1,1,1), block_size = block_size, keep_prob = keep_prob)
    x = Conv3D(indim, (2,2,2),strides=(2,2,2), padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    return x
def attention_module(inputs, skip, n_filter):
    x = Conv3D(n_filter, (1, 1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Conv3DTranspose(n_filter, (2, 2, 2), strides=(2, 2, 2), padding='same',kernel_initializer = 'he_normal')(x)

    x1= Conv3D(n_filter, (1, 1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(skip)
    x1 = BatchNormalization()(x1)

    out = Swish()(x1+x)
    out = cbam_block(out)
    out = Conv3D(1, (1, 1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(out)
    out = BatchNormalization()(out)
    out = Activation('sigmoid')(out)
    skip = Multiply()([out, skip])
    skip = cbam_block(skip)
    return skip
def ResidualBlock(inputs, block_size = 1, keep_prob = 1):
    indim = inputs.shape[-1]
    residual = convBn(inputs, indim, kernel_size=(1,1,1), block_size = block_size, keep_prob = keep_prob)
    residual = convBn(residual, indim // 2 , block_size = block_size, keep_prob = keep_prob)
    residual = convBn(residual, indim, kernel_size=(1,1,1), block_size = block_size, keep_prob = keep_prob)
    
    return Add()([inputs, residual])
def decoder_block(inputs, n_filter, skip=None, block_size = 1, keep_prob = 1, types = "encoder"):
    x= Conv3DTranspose(n_filter, (2,2,2), strides=(2, 2, 2), padding='same',kernel_initializer = 'he_normal')(inputs)
    out = x
    if skip is not None :
        attention = attention_module(inputs,skip, n_filter)
        out = Concatenate()([x,attention])
    if  types == "encoder":
        out = convBn(out, n_filter, block_size = block_size, keep_prob = keep_prob)
    else: 
        out = ResidualBlock(out, block_size = block_size, keep_prob = keep_prob)
    return out

In [None]:
def RSU4times(input_feature, mid_ch = 16, out_ch = 64, keep_prob = 1):
    residual = convBn(input_feature, out_ch, block_size=7, keep_prob=keep_prob)

    skip1 = convBn(residual, mid_ch, block_size=7, keep_prob=keep_prob)
    output = downSampleBlock(skip1, block_size = 7, keep_prob = keep_prob)

    skip2 = convBn(output, mid_ch, block_size= 5, keep_prob=keep_prob)
    output = downSampleBlock(skip2,block_size = 5, keep_prob = keep_prob)

    skip3 = convBn(output, mid_ch, block_size=3, keep_prob=keep_prob )
    output = downSampleBlock(skip3, block_size = 3, keep_prob = keep_prob)

    skip4 = convBn(output, mid_ch, block_size= 2, keep_prob=keep_prob)
    output = downSampleBlock(skip4, block_size = 2, keep_prob = keep_prob)

    output = convBn(output, mid_ch, keep_prob= keep_prob)
    output = convBn(output, mid_ch, (1,1,1), keep_prob= keep_prob)
    
    skip_connects = [skip4, skip3, skip2, skip1]
    blockSizes = [2, 3, 5, 7]
    for i in range(4):
        filters = mid_ch if i < 3 else out_ch
        output = decoder_block(output, filters ,skip=skip_connects[i],
                               block_size = blockSizes[i], keep_prob=keep_prob)
    return Add()([output, residual])

def RSU3times(input_feature, mid_ch = 16, out_ch = 64, keep_prob = 1):
    residual = convBn(input_feature, out_ch, block_size=5, keep_prob=keep_prob)

    skip1 = convBn(residual, mid_ch, block_size=5, keep_prob=keep_prob)
    output = downSampleBlock(skip1, block_size = 5, keep_prob = keep_prob)

    skip2 = convBn(output, mid_ch, block_size=3, keep_prob=keep_prob)
    output = downSampleBlock(skip2,block_size = 3, keep_prob = keep_prob)

    skip3 = convBn(output, mid_ch, block_size=2, keep_prob=keep_prob )
    output = downSampleBlock(skip3, block_size = 2, keep_prob = keep_prob)

    output = convBn(output, mid_ch, keep_prob= keep_prob)
    output = convBn(output, mid_ch, (1,1,1), keep_prob= keep_prob)
    
    skip_connects = [skip3, skip2, skip1]
    blockSizes = [2, 3, 5]
    for i in range(3):
        filters = mid_ch if i < 2 else out_ch
        output = decoder_block(output, filters ,skip=skip_connects[i],
                               block_size = blockSizes[i], keep_prob=keep_prob)
    return Add()([output, residual])

def RSU2times(input_feature, mid_ch = 16, out_ch = 64, keep_prob = 1):
    residual = convBn(input_feature, out_ch, block_size=3, keep_prob=keep_prob)

    skip1 = convBn(residual, mid_ch, block_size=3, keep_prob=keep_prob)
    output = downSampleBlock(skip1, block_size = 3, keep_prob = keep_prob)

    skip2 = convBn(output, mid_ch, block_size=2, keep_prob=keep_prob)
    output = downSampleBlock(skip2,block_size = 2, keep_prob = keep_prob)

    output = convBn(output, mid_ch, keep_prob= keep_prob)
    output = convBn(output, mid_ch, (1,1,1), keep_prob= keep_prob)
    
    skip_connects = [skip2, skip1]
    blockSizes = [2, 3]
    for i in range(2):
        filters = mid_ch if i < 1 else out_ch
        output = decoder_block(output, filters ,skip=skip_connects[i],
                               block_size = blockSizes[i], keep_prob=keep_prob)
    return Add()([output, residual])

def RSU1times(input_feature, mid_ch = 16, out_ch = 64, keep_prob = 1):
    residual = convBn(input_feature, out_ch, block_size=2, keep_prob=keep_prob)

    skip1 = convBn(residual, mid_ch, block_size=2, keep_prob=keep_prob)
    output = downSampleBlock(skip1, block_size = 2, keep_prob = keep_prob)

    output = convBn(output, mid_ch, keep_prob= keep_prob)
    output = convBn(output, mid_ch, (1,1,1), keep_prob= keep_prob)
    
    output = decoder_block(output, out_ch ,skip=skip1,
                        block_size = 2, keep_prob=keep_prob)
    return Add()([output, residual])

In [None]:
def seg_net(input_shape= (32,32,32,1), out_channels = 4, keep_prob = 1):
    inputT1 = Input(shape=input_shape, name="inputT1")
    inputT2 = Input(shape=input_shape, name="inputT2")

    outputT1 = convBn(inputT1, 16)
    outputT2 = convBn(inputT2, 16)
    output = Concatenate()([outputT1, outputT2])
    output = convBn(output, 32)

    skip1 = RSU4times(output, mid_ch = 16, out_ch = 64, keep_prob = keep_prob)
    output = downSampleBlock(skip1, block_size = 7, keep_prob = keep_prob)

    skip2 = RSU3times(output, mid_ch = 32, out_ch = 64, keep_prob = keep_prob)
    output = downSampleBlock(skip2, block_size = 5, keep_prob = keep_prob)

    skip3 = RSU2times(output, mid_ch = 32, out_ch = 128, keep_prob = keep_prob)
    output = downSampleBlock(skip3, block_size = 3, keep_prob = keep_prob)

    skip4 = RSU1times(output, mid_ch = 64, out_ch = 128, keep_prob = keep_prob)
    output = downSampleBlock(skip4, block_size = 2, keep_prob = keep_prob)

    output = convBn(output, 256, keep_prob= keep_prob)
    output = convBn(output, 256, (1,1,1), keep_prob= keep_prob)

    num_filters = [128, 64, 64, 32]
    skip_connects = [skip4, skip3, skip2, skip1]
    blockSizes = [2, 3, 5, 7]
    for i in range(4):
        output = decoder_block(output, num_filters[i] ,skip=skip_connects[i],
                               block_size = blockSizes[i], keep_prob=keep_prob)

    output = Conv3D(out_channels,(1, 1 ,1), kernel_initializer='he_normal')(output)
    if out_channels > 1 : 
        output1 = Activation('softmax', name = "active")(output)
        output2 = Activation('softmax', name = "levelset")(output)
    else :
        output1 = Activation('sigmoid', name = "active")(output)
        output2 = Activation('sigmoid', name = "levelset")(output)
    return Model([inputT1, inputT2], [output1, output2])
    # return Model(inp, output)

In [None]:
# S=seg_net(keep_prob=1)
# S.summary()