# model

In [13]:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, ReLU
from tensorflow.keras.layers import BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import plot_model

from tensorflow.keras.applications import VGG16

class SMR_Model():
    ''' creating vgg16 model that last block is trainable
    '''

    def __init__(self, input_shape, model_path_to_file=''):
        self.input_shape = input_shape
        if model_path_to_file:
            self.model_to_load = model_path_to_file

    def convolution_block(self, inputs, num_filters):
        ''' simple UNET convolution block with BatchNormalisation '''

        # convolution layer 1 of the block
        x = Conv2D(num_filters, (3,3), padding='same')(inputs)  # padding='same' to avoid cut-down with conv
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        # convolution layer 2 of the block
        x = Conv2D(num_filters, (3,3), padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        # max pooling not used here as just the bridge

        return x

    def decoder_block(self, inputs, skip_tensor, num_filters):
        ''' decoder block for UNET '''
        # adds in the skips with concatenate
        x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs) # stride important here to up-sample
        x = Concatenate()([x, skip_tensor])     # bringing in skip layer
        x = self.convolution_block(x, num_filters)

        return x

    def build_vgg16_unet(self):
        ''' build vgg-16 '''

        inputs = Input(self.input_shape)

        # see actual VGG-16 here: https://github.com/keras-team/keras/blob/v2.9.0/keras/applications/vgg16.py#L43-L227
        vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=inputs)
        # vgg16.summary()
        # Unfreeze the last block of VGG16
        for layer in vgg16.layers[:15]:
            layer.trainable = False

        # check that layers have been set accordingly:
        for i, layer in enumerate(vgg16.layers):
            print(i, layer.name, layer.trainable)
        

        ''' Encoder - skip layers '''
        skip1 = vgg16.get_layer('block1_conv2').output #  256 x 256, 64 filters in vgg16
        skip2 = vgg16.get_layer('block2_conv2').output #  128 x 128, 128 filters in vgg16
        skip3 = vgg16.get_layer('block3_conv3').output #   64 x 64, 256 filters in vgg16
        skip4 = vgg16.get_layer('block4_conv3').output #   32 x 32, 512 filters in vgg16
        # display('skip4: ' + str(skip4.shape))

        # only need to specify the skip layers, as VGG16 is an Encoder
        # Therefore, VGG16 comes built with MaxPool2d, so we don't specify

        ''' Bridge '''
        bridge = vgg16.get_layer('block5_conv3').output # 16 x 16, with 512 filters in vgg16
        # display('bridge: ' + str(bridge.shape))


        ''' Decoder '''
        d1 = self.decoder_block(bridge, skip4, 512) #  512 filters, as per the bridge
        d2 = self.decoder_block(d1, skip3, 256) #  256 filters
        d3 = self.decoder_block(d2, skip2, 128) #  128 filters
        d4 = self.decoder_block(d3, skip1, 64)  #   64 filters

        ''' Output '''
        outputs = Conv2D(1, (1,1), padding='same', activation='sigmoid')(d4)

        model = Model(inputs, outputs, name='first_VGG16_UNET')

        return model

    
    def loss_combo_dice_bce(self, y_true, y_pred):
        # JACK
        def dice_loss(y_true, y_pred):
            y_pred = tf.math.sigmoid(y_pred)
            numerator = 2 * tf.reduce_sum(y_true * y_pred)
            denominator = tf.reduce_sum(y_true + y_pred)

            return 1 - numerator/denominator

        y_true = tf.cast(y_true, tf.float32)
        o = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred) + dice_loss(y_true, y_pred)

        return tf.reduce_mean(o)
        
    def compile_model(self, m):
        ''' with accuracy, binaryIoU, AuC '''
        # metrics
        threshold = 0.5
        binaryIoU = tf.keras.metrics.BinaryIoU(target_class_ids=[1], threshold=threshold)
        AuC = tf.keras.metrics.AUC()

        # loss
        #self.dice_loss = ...
        
        # Compile Model
        m.compile(
                    loss=self.loss_combo_dice_bce,
                    optimizer='adam',
                    metrics=['accuracy', binaryIoU, AuC]
                    )
        return m

In [14]:

get_vgg16 = SMR_Model((224,224,3))
model_vgg16 = get_vgg16.build_vgg16_unet()
model_vgg16_compiled = get_vgg16.compile_model(model_vgg16)


0 input_4 False
1 block1_conv1 False
2 block1_conv2 False
3 block1_pool False
4 block2_conv1 False
5 block2_conv2 False
6 block2_pool False
7 block3_conv1 False
8 block3_conv2 False
9 block3_conv3 False
10 block3_pool False
11 block4_conv1 False
12 block4_conv2 False
13 block4_conv3 False
14 block4_pool False
15 block5_conv1 True
16 block5_conv2 True
17 block5_conv3 True
18 block5_pool True


In [8]:
history = model_vgg16_compiled.fit(ds_)

Model: "first_VGG16_UNET"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 224, 224, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 block1_conv2 (Conv2D)          (None, 224, 224, 64  36928       ['block1_conv1[0][0]']           
                                )                                                  

 activation_3 (Activation)      (None, 56, 56, 256)  0           ['batch_normalization_3[0][0]']  
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 112, 112, 12  131200     ['activation_3[0][0]']           
 spose)                         8)                                                                
                                                                                                  
 concatenate_2 (Concatenate)    (None, 112, 112, 25  0           ['conv2d_transpose_2[0][0]',     
                                6)                                'block2_conv2[0][0]']           
                                                                                                  
 conv2d_4 (Conv2D)              (None, 112, 112, 12  295040      ['concatenate_2[0][0]']          
                                8)                                                                
          

In [None]:
history_vgg16 = model_vgg16_compiled.fit(
  ds_train,
  validation_data = ds_val,
  epochs=20,
  callbacks=[mc, es]
  )

In [1]:
#dfv