# 3D UNET model with deep supervision loss and attention gate


In [2]:
import tensorflow as tf
from tensorflow.keras import backend as K

In [None]:
tf.keras.saving.get_custom_objects().clear()

In [None]:
@tf.keras.saving.register_keras_serializable(name="Tversky_Index")
class Tversky_Index(tf.keras.metrics.Metric):
    def __init__(
        self, alpha=0.5, beta=0.5, smooth=0.000001, name="tversky_index", **kwargs
    ):
        super(TverskyIndex, self).__init__(name=name, **kwargs)
        self.ti = self.add_weight(name="ti", initializer="zeros")
        self.count = self.add_weight(name="ti_c", initializer="zeros")
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    @tf.function
    def identify_axis(self, shape):
        # Three dimensional
        if len(shape) == 5:
            return [1, 2, 3]
        # Two dimensional
        elif len(shape) == 4:
            return [1, 2]
        # Exception - Unknown
        else:
            raise ValueError("Metric: Shape of tensor is neither 2D or 3D.")

    @tf.function
    def update_state(self, y_true, y_pred, sample_weight=None):
        axis = self.identify_axis(y_true.get_shape())
        y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)  # MY
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        ti_class = (tp + self.smooth) / (
            tp + self.alpha * fn + self.beta * fp + self.smooth
        )
        # Average class scores
        ti = K.mean(ti_class)
        self.ti.assign_add(ti)
        self.count.assign_add(1.0)

    def result(self):
        if self.count == 0.0:
            return self.ti

        mean_ti = self.ti / self.count
        return mean_ti

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.count.assign(0.0)
        self.ti.assign(0.0)

    def get_config(self):
        config = {"delta": self.delta, "smooth": self.smooth}
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
@tf.keras.saving.register_keras_serializable(name="Focal_Tversky_Loss")
class Focal_Tversky_Loss(tf.keras.losses.Loss):
    def __init__(self, delta=0.7, gamma=0.75, smooth=0.000001, name="ftl", **kwargs):
        super(FocalTverskyLoss, self).__init__(name=name, **kwargs)
        self.delta = delta
        self.gamma = gamma
        self.smooth = smooth

    @tf.function
    def identify_axis(self, shape):
        # Three dimensional
        if len(shape) == 5:
            return [1, 2, 3]
        # Two dimensional
        elif len(shape) == 4:
            return [1, 2]
        # Exception - Unknown
        else:
            raise ValueError("Metric: Shape of tensor is neither 2D or 3D.")

    def call(self, y_true, y_pred):
        axis = self.identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)

        tversky_class = (tp + self.smooth) / (
            tp + self.delta * fn + (1 - self.delta) * fp + self.smooth
        )
        # Average class scores
        ftl = K.mean(K.pow((1 - tversky_class), self.gamma))
        return ftl

    def get_config(self):
        config = {"delta": self.delta, "gamma": self.gamma, "smooth": self.smooth}
        base_config = super().get_config()
        return {**base_config, **config}

In [14]:
@tf.keras.saving.register_keras_serializable(name="Conv_Block")
class Conv_Block(tf.keras.layers.Layer):
    def __init__(
        self,
        num_filters=1,
        dropout=0,
        *args,
        *kwargs
    ):
        super(Conv_Block, self).__init__(*args, **kwargs)
        self.num_filters = num_filters
        self.dropout = dropout
        #layers
        self.conv1 = tf.keras.layers.Conv3D(num_filters, (3, 3, 3), padding="same", kernel_initializer="he_normal", name=f"{self.name}__conv1")
        self.bn1 = tf.keras.layers.BatchNormalization(axis=-1, name=f"{self.name}__bn1")
        self.relu1 = tf.keras.layers.ReLU(name=f"{self.name}__relu1")

        self.conv2 = tf.keras.layers.Conv3D(num_filters, (3, 3, 3), padding="same", kernel_initializer="he_normal", name=f"{self.name}__conv2")
        self.bn2 = tf.keras.layers.BatchNormalization(axis=-1, name=f"{self.name}__bn2")
        self.relu2 = tf.keras.layers.ReLU(name=f"{self.name}__relu2")

        if self.dropout > 0:
           self.do = tf.keras.layers.Dropout(self.dropout, name=f"{self.name}__do")
        

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        if self.dropout > 0:
            x = self.do(x)
        return x

    def get_config(self):
        config = {
            "num_filters": self.num_filters,
            "dropout": self.dropout,
        }
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
@tf.keras.saving.register_keras_serializable(name="Res_Conv_Block")
class Res_Conv_Block(tf.keras.layers.Layer):
    def __init__(self, num_filters=1, dropout=0, *args, **kwargs):
        super(Res_Conv_Block, self).__init__(*args, **kwargs)
        self.num_filters = num_filters
        self.kernel_initializer = kernel_initializer
        self.dropout = dropout
        # layers
        self.conv1 = tf.keras.layers.Conv3D(
            num_filters,
            (3, 3, 3),
            padding="same",
            kernel_initializer="he_normal",
            name=f"{self.name}_conv1",
        )
        self.bn1 = tf.keras.layers.BatchNormalization(axis=-1, name=f"{self.name}__bn1")
        self.relu1 = tf.keras.layers.ReLU(name=f"{self.name}__relu1")

        self.conv2 = tf.keras.layers.Conv3D(
            num_filters,
            (3, 3, 3),
            padding="same",
            kernel_initializer="he_normal",
            name=f"{self.name}__conv2",
        )
        self.bn2 = tf.keras.layers.BatchNormalization(axis=-1, name=f"{self.name}__bn2")

        if self.dropout > 0:
            self.do = tf.keras.layers.Dropout(self.dropout, name=f"{self.name}__do")

        self.shortcut_conv = tf.keras.layers.Conv3D(
            num_filters,
            (1, 1, 1),
            padding="same",
            kernel_initializer="he_normal",
            name=f"{self.name}__shortcut_conv",
        )
        self.shortcut_bn = tf.keras.layers.BatchNormalization(
            axis=-1, name=f"{self.name}_shortcut_bn"
        )

        self.add = tf.keras.layers.Add(name=f"{self.name}__add")
        self.add_relu = tf.keras.layers.ReLU(name=f"{self.name}__add_relu")

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.dropout > 0:
            x = self.do(x)
        shortcut = self.shortcut_conv(inputs)
        shortcut = self.shortcut_bn(shortcut)
        add = self.add([x, shortcut])
        add = self.add_relu(add)
        return add

    def get_config(self):
        config = {
            "num_filters": self.num_filters,
            "dropout": self.dropout,
        }
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
@tf.keras.saving.register_keras_serializable(name="Encoder")
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_filters=1, *args, **kwargs):
        super(Encoder, self).__init__(*args, **kwargs)
        self.num_filters = num_filters
        self.id = id
        self.name = f"encoder__{self.id}"
        self.conv_block = Res_Conv_Block(
            num_filters=num_filters, dropout=0, name=f"{self.name}__conv_block"
        )
        self.max_pool = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))

    def call(self, inputs):
        x = self.conv_block(inputs)
        p = self.max_pool(x)
        return x, p

    def get_config(self):
        config = {
            "num_filters": self.num_filters,
        }
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
@tf.keras.saving.register_keras_serializable(name="Attention_Gate")
class Attention_Gate(tf.keras.layers.Layer):
    def __init__(self, num_filters, apply_out_batch_norm=True, *args, **kwargs):
        super(Attention_Gate, self).__init__(*args, **kwargs)
        self.num_filters = num_filters
        self.activation = activation
        self.apply_out_batch_norm = apply_out_batch_norm
        self.Wg_conv = tf.keras.layers.Conv3D(
            self.num_filters, (1, 1, 1), padding="same", name=f"{self.name}__wg_conv"
        )
        self.Wg_bn = tf.keras.layers.BatchNormalization(name=f"{self.name}__wg_bn")
        self.Ws_conv = tf.keras.layers.Conv3D(
            self.num_filters, (1, 1, 1), padding="same", name=f"{self.name}__ws_conv"
        )
        self.Ws_bn = tf.keras.layers.BatchNormalization(name=f"{self.name}__ws_bn")
        self.relu = tf.keras.layers.ReLU(name=f"{self.name}__relu")
        self.out_conv = tf.keras.layers.Conv3D(
            self.num_filters, (1, 1, 1), padding="same", name=f"{self.name}__out_conv"
        )
        self.out_batch_norm = tf.keras.layers.BatchNormalization(
            name=f"{self.name}__out_bn"
        )
        self.out_activation = tf.keras.activations.sigmoid

    def call(self, inputs):
        g, s = inputs
        Wg = self.Wg_conv(g)
        Wg = self.Wg_bn(Wg)

        Ws = self.Ws_conv(s)
        Ws = self.Ws_bn(Ws)

        out = self.relu(Wg + Ws)
        out = self.out_conv(out)
        out = self.out_activation(out)
        out = out * s
        if apply_out_batch_norm:
            out = self.out_batch_norm(out)

        return out

    def get_config(self):
        config = {
            "num_filters": self.num_filters,
            "apply_out_batch_norm": self.apply_out_batch_norm,
        }
        base_config = super().get_config()
        return {**base_config, **config}

In [None]:
@tf.keras.saving.register_keras_serializable(name="Decoder")
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_filters=1, *args, **kwargs):
        super(Decoder, self).__init__(*args, **kwargs)
        self.num_filters = num_filters
        self.conv_block = Res_Conv_Block(
            num_layers=num_layers, dropout=0, name=f"{self.name}__res_conv_block"
        )
        self.transpose = tf.keras.layers.Conv3DTranspose(
            filters=num_filters,
            kernel_size=(2, 2, 2),
            strides=(2, 2, 2),
            padding="same",
            name=f"{self.name}__transpose",
        )
        self.concat = tf.keras.layers.Concatenate(name=f"{self.name}__concat")
        self.attention_gate = Attention_Gate(
            filters=num_filters, name=f"{self.name}__att_gate"
        )

    def call(self, inputs):
        x, concat_layer = inputs
        x = self.transpose(inputs)
        concat_layer = self.attention_gate([x, concat_layer])
        x = self.concat(x, concat_layer)
        x = self.conv_block(x)
        return x

    def get_config(self):
        config = {
            "num_layers": self.num_layers,
            "num_filters": self.num_filters,
            "kernel_size": self.kernel_size,
            "padding": self.padding,
            "activation": self.activation,
            "kernel_initializer": self.kernel_initializer,
        }
        base_config = super().get_config()
        return {**base_config, **config}

In [3]:
@tf.keras.saving.register_keras_serializable(name="Model_Unet_3D_DSL")
class Model_Unet_3D_DSL(tf.keras.Model):
    def __init__(self, input_shape, *args, **kwargs):
        super(Model_Unet_3D, self).__init__(*args, **kwargs)
        self.input_shape = input_shape
        self.loss_weights = [1.0, 0.5]
        self.dice = Tversky_Index(alpha=0.5, beta=0.5, name="dice")
        self.jacard = Tversky_Index(alpha=1, beta=1, name="jacard")
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.build(input_shape)  # TO SETUP INPUT SHAPE

    def call(self, inputs, training=False):
        encoded_1, pooling_1 = encoder_block(inputs, 1, name="enc1")  # 32
        encoded_2, pooling_2 = encoder_block(pooling_1, 1, name="enc2")  # 64
        encoded_3, pooling_3 = encoder_block(pooling_2, 1, name="enc3")  # 128
        encoded_4, pooling_4 = encoder_block(pooling_3, 1, name="enc4")  # 200
        mid = conv_block(pooling_4, 1, name="mid")  # Midsection 256
        decoded_1 = decoder_block(
            mid, 1, encoded_4, name="dec1"
        )  # Conjugate of encoder 4 - 200
        decoded_2 = decoder_block(
            decoded_1, 1, encoded_3, name="dec2"
        )  # Conjugate of encoder 3 - 128
        decoded_3 = decoder_block(
            decoded_2, 1, encoded_2, name="dec3"
        )  # Conjugate of encoder 2 - 64
        pre_output = tf.keras.layers.Conv3D(
            1, (1, 1, 1), activation="sigmoid", name="preOutput"
        )(
            decoded_3
        )  # Output from 2nd last decoder (32,128,128,1)
        decoded_4 = decoder_block(
            decoded_3, 1, encoded_1, name="dec4"
        )  # Conjugate of encoder 1 - 32
        output = tf.keras.layers.Conv3D(
            1, (1, 1, 1), activation="sigmoid", name="finalOutput"
        )(decoded_4)
        return [output, pre_output]

    def train_step(self, data):
        x, y, z = data

        with tf.GradientTape() as tape:
            y_preds = self(x, training=True)
            if type(y_preds) == list:
                y_pred, z_pred = y_preds
                loss = self.loss_weights[0] * self.compute_loss(y=y, y_pred=y_pred)
                loss += self.loss_weights[1] * self.compute_loss(
                    y=z, y_pred=z_pred
                )  # Deep Supervision Loss
                for metric in self.metrics:
                    if metric.name == "loss":
                        metric.update_state(loss)
                    else:
                        metric.update_state(y, y_preds[0])
            else:
                loss = self.compute_loss(y=y, y_pred=y_preds)
                for metric in self.metrics:
                    if metric.name == "loss":
                        metric.update_state(loss)
                    else:
                        metric.update_state(y, y_preds)
        trainable_vars = self.trainable_variables  # Network trainable parameters
        gradients = tape.gradient(loss, trainable_vars)  # Calculating gradients
        self.optimizer.apply_gradients(
            zip(gradients, trainable_vars)
        )  # Applying gradients to optimizer

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred, _ = self(x, training=False)
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # loss = self.loss_fn(y, y_pred)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def get_config(self):
        config = {
            "dice": self.dice,
            "jacard": self.jacard,
            "loss_tracker": self.loss_tracker,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    @classmethod
    def from_config(cls, config):
        dice = config.pop("dice")
        dice = tf.keras.saving.deserialize_keras_object(dice_config)
        jacard = config.pop("jacard")
        jacard = tf.keras.saving.deserialize_keras_object(jacard_config)
        loss_tracker = config.pop("loss_tracker")
        loss_tracker = tf.keras.saving.deserialize_keras_object(loss_tracker_config)
        return cls(dice, jacard, loss_tracker, **config)

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.loss_tracker, self.dice, self.jacard]  #