In [1]:
!fuser -k /dev/nvidia0

In [2]:
!nvidia-smi

In [2]:
import argparse
import sys, os
import tensorflow as tf
sys.path.append("../")
import tools.model
import json
from functools import partial

In [3]:
def expend_as(tensor, rep,name):
    my_repeat = tf.keras.layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3), 
                                       arguments={'repnum': rep},  name='psi_up'+name)(tensor)
    return my_repeat

def AttnGatingBlock(x, g, inter_shape, name):
    ''' take g which is the spatially smaller signal, do a conv to get the same
    number of feature channels as x (bigger spatially)
    do a conv on x to also get same geature channels (theta_x)
    then, upsample g to be same size as x 
    add x and g (concat_xg)
    relu, 1x1 conv, then sigmoid then upsample the final - this gives us attn coefficients'''
    
    shape_x = K.int_shape(x)  # 32
    shape_g = K.int_shape(g)  # 16

    theta_x = tf.keras.layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same', name='xl'+name)(x)  # 16
    shape_theta_x = K.int_shape(theta_x)

    phi_g = tf.keras.layers.Conv2D(inter_shape, (1, 1), padding='same')(g)
    upsample_g = tf.keras.layers.Conv2DTranspose(inter_shape, (3, 3),strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),padding='same', name='g_up'+name)(phi_g)  # 16

    concat_xg = tf.keras.layers.add([upsample_g, theta_x])
    act_xg = tf.keras.layers.Activation('relu')(concat_xg)
    psi = tf.keras.layers.Conv2D(1, (1, 1), padding='same', name='psi'+name)(act_xg)
    sigmoid_xg = tf.keras.layers.Activation('sigmoid')(psi)
    shape_sigmoid = K.int_shape(sigmoid_xg)
    upsample_psi = tf.keras.layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg)  # 32

    upsample_psi = expend_as(upsample_psi, shape_x[3],  name)
    y = tf.keras.layers.multiply([upsample_psi, x], name='q_attn'+name)

    result = tf.keras.layers.Conv2D(shape_x[3], (1, 1), padding='same',name='q_attn_conv'+name)(y)
    result_bn = tf.keras.layers.BatchNormalization(name='q_attn_bn'+name)(result)
    return result_bn

def UnetGatingSignal(input, is_batchnorm, name):
    ''' this is simply 1x1 convolution, bn, activation '''
    shape = K.int_shape(input)
    x = Conv2D(shape[3] * 1, (1, 1), strides=(1, 1), padding="same",  kernel_initializer='glorot_normal', 
               name=name + 'gatingconv')(input)
    if is_batchnorm:
        x = BatchNormalization(name=name + 'gatingbn')(x)
    x = Activation('relu', name = name + 'gatingact')(x)
    return x

In [4]:
class _GridAttentionBlockND(tf.keras.layers.Layer):
    def __init__(self, gating_channels, inter_channels=None,
                 dimension=2, mode='concatenation', sub_sample_factor=(2, 2)):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2, 3]
        assert mode in ['concatenation', 'concatenation_debug',
                        'concatenation_residual']

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple):
            self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list):
            self.sub_sample_factor = tuple(sub_sample_factor)
        else:
            self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.gating_channels = gating_channels

        assert inter_channels is not None
        self.inter_channels = inter_channels

        if dimension == 3:
            conv_nd = partial(tf.keras.layers.Conv3D,
                              kernel_initializer=tf.keras.initializers.HeNormal())
            self.upsample = tf.keras.layers.UpSampling3D
        elif dimension == 2:
            conv_nd = partial(tf.keras.layers.Conv2D,
                              kernel_initializer=tf.keras.initializers.HeNormal())
            self.upsample = tf.keras.layers.UpSampling2D
        else:
            raise NotImplemented
        self.conv_nd = conv_nd

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size,
                             strides=self.sub_sample_factor,
                             padding='same', use_bias=False)
        self.phi = conv_nd(self.inter_channels, kernel_size=1, strides=1,
                           padding='same', use_bias=True)
        self.psi = conv_nd(1, kernel_size=1, strides=1, padding='same',
                           use_bias=True)
        self.W = None
        self.phi_upsample = None
        self.sigm_psi_upsample = None

        # Define the operation
        if mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concatenation_debug':
            self.operation_function = self._concatenation_debug
        elif mode == 'concatenation_residual':
            self.operation_function = self._concatenation_residual
        else:
            raise NotImplementedError('Unknown operation function.')

    def build(self, inputs):
        self.W = tf.keras.models.Sequential([
            self.conv_nd(filters=inputs[0][-1], kernel_size=1, strides=1,
                         padding='same'),
            tf.keras.layers.BatchNormalization()])
        upsample_times = inputs[0][1] // inputs[1][1]
        self.phi_upsample = self.upsample(upsample_times // 2,
                                          name='upsample_phi')
        self.sigm_psi_upsample = self.upsample(2, name='upsample_sigm_psi')

    def call(self, inputs):
        x, g = inputs
        return self.operation_function(x, g)

    def _concatenation(self, x, g):
        input_size = x.shape

        # theta => (b, c, h, w) -> (b, i_c, h, w) -> (b, i_c, hw)
        # phi   => (b, g_d) -> (b, i_c)
        theta_x = self.theta(x)

        # g (b, c, h', w') -> phi_g (b, i_c, h', w')
        #  Relu(theta_x + phi_g + bias) -> f = (b, i_c, hw) -> (b, i_c, h/s1, w/s2)
        phi_g = self.phi(g)
        upsample_times = theta_x.shape[1] // phi_g.shape[1]
        phi_g = self.phi_upsample(phi_g)
        f = tf.keras.activations.relu(theta_x + phi_g)

        #  psi^T * f -> (b, psi_i_c, h/s1, w/s2)
        sigm_psi_f = tf.keras.activations.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        sigm_psi_f = self.sigm_psi_upsample(sigm_psi_f)
        y = sigm_psi_f * x
        W_y = self.W(y)
        return W_y, sigm_psi_f
    
class GridAttentionBlock2D(_GridAttentionBlockND):
    def __init__(self, gating_channels, inter_channels=None, mode='concatenation',
                 sub_sample_factor=(2, 2)):
        super(GridAttentionBlock2D, self).__init__(
            gating_channels=gating_channels, inter_channels=inter_channels,
            dimension=2, mode=mode, sub_sample_factor=sub_sample_factor)
        
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, name, kernel_size=3):
        super(ConvBlock, self).__init__(name=name)
        self.conv_2d = tf.keras.layers.Conv2D(
            num_filters, kernel_size=kernel_size, padding='same',
            activation='linear')
        self.batch_norm = tf.keras.layers.BatchNormalization()

    def call(self, inputs):
        x = inputs
        x = self.conv_2d(x)
        x = self.batch_norm(x)
        x = tf.keras.activations.relu(x)
        return x

In [5]:
def attention_gate(g, s, num_filters):
    Wg = tf.keras.layers.Conv2D(num_filters, 1, padding="same")(g)
    Wg = tf.keras.layers.BatchNormalization()(Wg)
 
    Ws = tf.keras.layers.Conv2D(num_filters, 1, padding="same")(s)
    Ws = tf.keras.layers.BatchNormalization()(Ws)
 
    out = tf.keras.layers.Activation("relu")(Wg + Ws)
    out = tf.keras.layers.Conv2D(num_filters, 1, padding="same")(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.Activation("sigmoid")(out)
 
    return out * s

In [6]:
def encoder_mini_block(inputs, n_filters=32, activation="relu", dropout_prob=0.3, max_pooling=True, name=""):
    """
    Encoder mini block for U-Net architecture. It consists of two convolutional layers with the same activation function
    and number of filters. Optionally, a dropout layer can be added after the second convolutional layer. If max_pooling
    is set to True, a max pooling layer is added at the end of the block. The skip connection is the output of the second
    convolutional layer.

    :param inputs: Input tensor to the block
    :param n_filters: Number of filters for the convolutional layers
    :param activation: Activation function for the convolutional layers
    :param dropout_prob: Dropout probability for the dropout layer (0 means no dropout)
    :param max_pooling: Boolean to add a max pooling layer at the end of the block
    :param name: Name of the block (Optional)
    :return: The output tensor of the block and the skip connection tensor
    """
    
    conv = tf.keras.layers.Conv2D(n_filters,
                                  3,  # filter size
                                  activation="linear",
                                  padding='same',
                                  kernel_initializer='HeNormal',
                                  name="eblock" + name + "conv1")(inputs)
    
    conv = tf.keras.layers.BatchNormalization(name="eblock" + name + "norm1")(conv)
    conv = tf.keras.layers.Activation(activation=activation, name="eblock" + name + activation+"1")(conv)
    
    conv = tf.keras.layers.Conv2D(n_filters,
                                  3,  # filter size
                                  activation="linear",
                                  padding='same',
                                  kernel_initializer='HeNormal',
                                  name="eblock" + name + "conv2")(conv)
    conv = tf.keras.layers.BatchNormalization(name="eblock" + name + "norm2")(conv)
    conv = tf.keras.layers.Activation(activation=activation, name="eblock" + name + activation+"2")(conv)
    
    if dropout_prob > 0:
        conv = tf.keras.layers.Dropout(dropout_prob, name="eblock" + name + "drop")(conv)
    if max_pooling:
        next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name="eblock" + name + "pool")(conv)
    else:
        next_layer = conv
    skip_connection = conv
    return next_layer, skip_connection


def decoder_mini_block(prev_layer_input, skip_layer_input, n_filters=32, activation="relu", dropout_prob=0.3,
                       max_pooling=True, attention=True, name=""):
    """
    Decoder mini block for U-Net architecture that consists of a transposed convolutional layer followed by two
    convolutional layers. The skip connection is the concatenation of the transposed convolutional layer and the
    corresponding encoder skip connection.

    :param prev_layer_input: Input tensor to the block from the previous layer
    :param skip_layer_input: Input tensor to the block from the corresponding encoder skip connection
    :param n_filters: Number of filters for the convolutional layers
    :param activation: Activation function for the convolutional layers
    :param name: Name of the block (Optional)
    :return: The output tensor of the block
    """
    
    if max_pooling:
        prev_layer_input =  tf.keras.layers.UpSampling2D(interpolation="bilinear")(prev_layer_input)
    if attention and max_pooling:
        skip_layer_input=attention_gate(prev_layer_input, skip_layer_input, n_filters)
    merge = tf.keras.layers.concatenate([prev_layer_input, skip_layer_input], name="dblock" + name + "concat")
    conv = tf.keras.layers.Conv2D(n_filters,
                                  3,  # filter size
                                  activation="linear",
                                  padding='same',
                                  kernel_initializer='HeNormal',
                                  name="dblock" + name + "conv1")(merge)
    conv = tf.keras.layers.BatchNormalization(name="dblock" + name + "norm1")(conv)
    conv = tf.keras.layers.Activation(activation=activation, name="dblock" + name + activation+"1")(conv)
    
    conv = tf.keras.layers.Conv2D(n_filters,
                                  3,  # filter size
                                  activation="linear",
                                  padding='same',
                                  kernel_initializer='HeNormal',
                                  name="dblock" + name + "conv2")(conv)
    conv = tf.keras.layers.BatchNormalization(name="dblock" + name + "norm2")(conv)
    conv = tf.keras.layers.Activation(activation=activation, name="dblock" + name + activation+"2")(conv)
    if dropout_prob > 0:
        conv = tf.keras.layers.Dropout(dropout_prob, name="dblock" + name + "drop")(conv)

    return conv


def unet_model(input_size, arhitecture, attention=True):
    """
    U-Net model for semantic segmentation. The model consists of an encoder and a decoder. The encoder downsamples the
    input image and extracts features. The decoder upsamples the features and generates the segmentation mask. Skip
    connections are used to concatenate the encoder features with the decoder features. The model is created from the
    architecture dictionary that contains the number of filters, activation functions, dropout probabilities, and max
    pooling for each mini block.

    :param input_size: Size of the input image
    :param arhitecture: Dictionary containing the architecture of the U-Net model
    :return: U-Net model
    """

    inputs = tf.keras.layers.Input(input_size, name="input")
    inputs = tf.keras.layers.BatchNormalization(name="inputnormalisation")(inputs)
    skip_connections = []
    layer = inputs
    if not "attention" in arhitecture.keys():
        arhitecture["attention"] = [False for i in arhitecture["upFilters"]]
    # Encoder
    for i in range(len(arhitecture["downFilters"])):
        layer, skip = encoder_mini_block(layer,
                                         n_filters=arhitecture["downFilters"][i],
                                         activation=arhitecture["downActivation"][i],
                                         dropout_prob=arhitecture["downDropout"][i],
                                         max_pooling=arhitecture["downMaxPool"][i],
                                         name=str(i))
        skip_connections.append(skip)       
    # Decoder
    for i in range(len(arhitecture["upFilters"])):
        """if arhitecture["attention"][i]:
            gating = ConvBlock(arhitecture["upFilters"][i], name=str(len(arhitecture["upFilters"])-1-i)+'gating')(layer)
            skip_con, _ = GridAttentionBlock2D(inter_channels=arhitecture["upFilters"][i], gating_channels=128)(
                [skip_connections[len(arhitecture["upFilters"])-1-i], gating])
        else:"""
        skip_con = skip_connections[len(arhitecture["upFilters"])-1-i]
        layer = decoder_mini_block(layer,
                                   skip_con,
                                   n_filters=arhitecture["upFilters"][i],
                                   activation=arhitecture["upActivation"][i],
                                   attention=arhitecture["attention"][i],
                                   dropout_prob=arhitecture["upDropout"][i],
                                   max_pooling=arhitecture["downMaxPool"][len(arhitecture["upFilters"])-1-i],
                                   name=str(len(arhitecture["upFilters"])-1-i))

    outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name="output")(layer)

    model = tf.keras.Model(""inputs=[inputs], outputs=[outputs], name="AsteroidNET")
    return model

In [7]:
arhitecture = {"downFilters": [16, 32, 64, 128, 256, 512], 
               "downActivation": ["relu", "sigmoid", "relu", "sigmoid", "relu", "sigmoid"], 
               "downDropout": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1], 
               "downMaxPool": [True, True, True, True, True, True], 
               "upFilters": [512, 256, 128, 64, 32, 16], 
               "upActivation": ["sigmoid", "relu", "sigmoid", "relu", "sigmoid", "relu"], 
               "upDropout": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
               "attention": [True, True, True, True, True, False],}

In [8]:
batch_size = 128
mirrored_strategy = tf.distribute.MirroredStrategy()
dataset_train = tf.data.TFRecordDataset(["../DATA/train1.tfrecord"])
tfrecord_shape = tools.model.get_shape_of_quadratic_image_tfrecord(dataset_train)
dataset_train = dataset_train.map(tools.model.parse_function(img_shape=tfrecord_shape, test=False))
dataset_train = dataset_train.shuffle(5*batch_size).batch(batch_size).prefetch(2)
dataset_val = tf.data.TFRecordDataset(["../DATA/test.tfrecord"])
dataset_val = dataset_val.map(tools.model.parse_function(img_shape=tfrecord_shape, test=False))
dataset_val = dataset_val.batch(batch_size).prefetch(2)

terminateonnan_kb = tf.keras.callbacks.TerminateOnNaN()
reducelronplateau_kb = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.75,
                                                            patience=2, cooldown=2, verbose=1)
kb = [terminateonnan_kb, reducelronplateau_kb]

In [None]:
FE = tf.keras.losses.BinaryFocalCrossentropy(apply_class_balancing=True, alpha=0.95)
FT = tools.metrics.FocalTversky(alpha=0.7, gamma=1/5)
with mirrored_strategy.scope():
    model = unet_model((128, 128, 1), arhitecture)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 
                  loss=FT,
                  metrics=["Precision", "Recall", tools.metrics.F1_Score()])
results = model.fit(dataset_train, epochs=128, validation_data=dataset_val,
                    callbacks=kb, verbose=1)

In [None]:
model.save("../DATA/Model_test_5.keras")

In [None]:
from numba import cuda
device = cuda.get_current_device()
device.reset()