In [1]:
import os
import sys
import tensorflow as tf

sys.path.insert(0, os.path.abspath('..'))
DATA_PATH = os.path.join("..", "data", "airbus-ship-detection")

2024-02-29 22:35:00.231230: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


<h3>Explore model blocks

In [2]:
base_model = tf.keras.applications.MobileNetV2(  # explore MobileNetV2
    input_shape=(768, 768, 3),
    include_top=False,
)
base_model.summary()



Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 768, 768, 3)]        0         []                            
                                                                                                  
 Conv1 (Conv2D)              (None, 384, 384, 32)         864       ['input_1[0][0]']             
                                                                                                  
 bn_Conv1 (BatchNormalizati  (None, 384, 384, 32)         128       ['Conv1[0][0]']               
 on)                                                                                              
                                                                                                  
 Conv1_relu (ReLU)           (None, 384, 384, 32)         0         ['bn_Conv1[

In [3]:
layer_names = [  # names of layers to use as skip connections
    'block_1_expand_relu',   # 384x384
    'block_3_expand_relu',   # 192x192
    'block_6_expand_relu',   # 96x96
    'block_13_expand_relu',  # 48x48
    'block_16_project',      # 24x24
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)  # add skip connections to output
down_stack.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 768, 768, 3)]        0         []                            
                                                                                                  
 Conv1 (Conv2D)              (None, 384, 384, 32)         864       ['input_1[0][0]']             
                                                                                                  
 bn_Conv1 (BatchNormalizati  (None, 384, 384, 32)         128       ['Conv1[0][0]']               
 on)                                                                                              
                                                                                                  
 Conv1_relu (ReLU)           (None, 384, 384, 32)         0         ['bn_Conv1[0][0]']        

                                                                                                  
 block_4_depthwise (Depthwi  (None, 96, 96, 192)          1728      ['block_4_expand_relu[0][0]'] 
 seConv2D)                                                                                        
                                                                                                  
 block_4_depthwise_BN (Batc  (None, 96, 96, 192)          768       ['block_4_depthwise[0][0]']   
 hNormalization)                                                                                  
                                                                                                  
 block_4_depthwise_relu (Re  (None, 96, 96, 192)          0         ['block_4_depthwise_BN[0][0]']
 LU)                                                                                              
                                                                                                  
 block_4_p

<h3>Build Unet Model

In [4]:
from src.model import upscale_layer


up_stack = [  # upsampling layers
    upscale_layer(512),  # 24x24 -> 48x48
    upscale_layer(256),  # 48x48 -> 96x96
    upscale_layer(128),  # 96x96 -> 192x192
    upscale_layer(64),   # 192x192 -> 384x384
]

def unet_model():  # define u-net
    inputs = tf.keras.layers.Input(shape=[768, 768, 3])
    preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
    preprocessed_input = preprocess_input(inputs)
    skips = down_stack(preprocessed_input)
    x = skips[-1]
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    last = tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=3, strides=2,
        padding='same', activation='sigmoid'
    )  #384x384 -> 768x768

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


In [5]:
model = unet_model()  # create and explore model
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['accuracy']
)
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 768, 768, 3)]        0         []                            
                                                                                                  
 tf.math.truediv (TFOpLambd  (None, 768, 768, 3)          0         ['input_2[0][0]']             
 a)                                                                                               
                                                                                                  
 tf.math.subtract (TFOpLamb  (None, 768, 768, 3)          0         ['tf.math.truediv[0][0]']     
 da)                                                                                              
                                                                                            