# Google DeepLabV3+

In [43]:
import _02c_read_datasets
import _02_evaluate_model

import os
import numpy as np
import itertools
import matplotlib.pyplot as plt
import tensorflow as tf
from datetime import datetime
from tensorflow.keras.layers import Conv2D, MaxPooling2D, concatenate, UpSampling2D, BatchNormalization, Activation
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
import pandas as pd
from tensorflow import keras
from tensorflow.keras import *

In [45]:
#load data
epochs = 20
batches = 16
input_width = 256
shuffled = True
augment = True #{True, False}
if augment:
    augmentation_settings = {
    "flip_left_right": 0,
    "flip_up_down": 0,
    "gaussian_blur": 0.2,
    "random_noise": 0.0,
    "random_brightness": 0.5,
    "random_contrast": 0.5}
else:
    augmentation_settings = None

train_dataset, val_dataset, test_dataset = _02c_read_datasets.load_datasets(augmented = augment)

## ResNet50 backbone

ResNet traditionally is a powerful, deep CNN that has been shown to work well for image segmentation in the past. DeepLabV3 thus typically uses ResNet as a backbone. In this case, ResNet50 will be applied as is done in https://keras.io/examples/vision/deeplabv3_plus/#inference-on-validation-images, which will mainly serve as a benchmark. It is important to note that ResNet only allows for three input bands, which is why, for ResNet50, the the last band is eliminated to return to a rgb/3-band format using `eliminate_last_channel`. After that, `convolution_block` and `DilatedSpatialPyramidPooling` are specific building blocks for the DeepLabV3+ model `DeeplabV3Plus_ResNet50` using ResNet50 as a backbone.

In [28]:
# ResNet50 only allows for three input channels
def eliminate_last_channel(x, y):
    x = x[:, :, :, :3]
    return x, y

# Apply the transformation function to each element in the dataset
train_dataset_1 = train_dataset.map(eliminate_last_channel)
val_dataset_1 = val_dataset.map(eliminate_last_channel)
test_dataset_1 = test_dataset.map(eliminate_last_channel)

# Test the output
#for x, y in val_dataset_1:
#    print(x.shape, y.shape)

In [29]:
def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    padding="same",
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return tf.nn.relu(x)


def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output

In [30]:
def DeeplabV3Plus_ResNet50(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    resnet50 = keras.applications.ResNet50(
       weights="imagenet", include_top=False, input_tensor=model_input,
    )
    
    #Encoder
    x = resnet50.get_layer("conv4_block6_2_relu").output
    #print(x.shape)
    x = DilatedSpatialPyramidPooling(x)
    
    #Decoder
    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output 
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus_ResNet50(image_size=256, num_classes=1)
model.summary()

Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 262, 262, 3)  0           ['input_4[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 128, 128, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                            

                                                                                                  
 conv2_block3_1_relu (Activatio  (None, 64, 64, 64)  0           ['conv2_block3_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv2_block3_2_conv (Conv2D)   (None, 64, 64, 64)   36928       ['conv2_block3_1_relu[0][0]']    
                                                                                                  
 conv2_block3_2_bn (BatchNormal  (None, 64, 64, 64)  256         ['conv2_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv2_block3_2_relu (Activatio  (None, 64, 64, 64)  0           ['conv2_block3_2_bn[0][0]']      
 n)       

                                                                                                  
 conv3_block3_1_relu (Activatio  (None, 32, 32, 128)  0          ['conv3_block3_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv3_block3_2_conv (Conv2D)   (None, 32, 32, 128)  147584      ['conv3_block3_1_relu[0][0]']    
                                                                                                  
 conv3_block3_2_bn (BatchNormal  (None, 32, 32, 128)  512        ['conv3_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block3_2_relu (Activatio  (None, 32, 32, 128)  0          ['conv3_block3_2_bn[0][0]']      
 n)       

                                                                                                  
 conv4_block2_1_bn (BatchNormal  (None, 16, 16, 256)  1024       ['conv4_block2_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block2_1_relu (Activatio  (None, 16, 16, 256)  0          ['conv4_block2_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block2_2_conv (Conv2D)   (None, 16, 16, 256)  590080      ['conv4_block2_1_relu[0][0]']    
                                                                                                  
 conv4_block2_2_bn (BatchNormal  (None, 16, 16, 256)  1024       ['conv4_block2_2_conv[0][0]']    
 ization) 

 conv4_block5_1_conv (Conv2D)   (None, 16, 16, 256)  262400      ['conv4_block4_out[0][0]']       
                                                                                                  
 conv4_block5_1_bn (BatchNormal  (None, 16, 16, 256)  1024       ['conv4_block5_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block5_1_relu (Activatio  (None, 16, 16, 256)  0          ['conv4_block5_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block5_2_conv (Conv2D)   (None, 16, 16, 256)  590080      ['conv4_block5_1_relu[0][0]']    
                                                                                                  
 conv4_blo

                                                                                                  
 concatenate_6 (Concatenate)    (None, 16, 16, 1280  0           ['up_sampling2d_9[0][0]',        
                                )                                 'tf.nn.relu_28[0][0]',          
                                                                  'tf.nn.relu_29[0][0]',          
                                                                  'tf.nn.relu_30[0][0]',          
                                                                  'tf.nn.relu_31[0][0]']          
                                                                                                  
 conv2d_47 (Conv2D)             (None, 16, 16, 256)  327680      ['concatenate_6[0][0]']          
                                                                                                  
 batch_normalization_36 (BatchN  (None, 16, 16, 256)  1024       ['conv2d_47[0][0]']              
 ormalizat

In [None]:
# ----------- create directories
out_dir = '../results/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_DeepLabV3+_ResNet50/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    os.makedirs(out_dir + '/plots')
    os.makedirs(out_dir + '/weights')
    os.makedirs(out_dir + '/predictions')
    os.makedirs(out_dir + '/bestweights')

    
# Define the path where you want to save the weights
checkpoint_path = out_dir + 'bestweights/' 

# Define the ModelCheckpoint callback
checkpoint = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_accuracy', 
    mode='max', 
    verbose=1
)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"]
)
history = model.fit(train_dataset_1, validation_data=val_dataset_1,epochs=epochs,callbacks=[checkpoint])

In [None]:
# ---------------------- save results


# Load the saved, optimal  weights
model.load_weights(checkpoint_path)

# Compile the model with the same optimizer and loss function used during training
model.compile(optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"])
model.save_weights(out_dir+'model.hdf5')

# ----------- plot the training and validation loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'loss.png')

# ----------- plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'accuracy.png')

# ----------- save weights
model.save(out_dir + '/weights/' + 'model.h5')

# ----------- save predictions
def visualize_predictions(index, test_dataset, out_dir):
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % batches
    image = image_batch[wrapped_index].numpy()
    image = image[:,:,:3]
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)
    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    ground_truth = label_batch[wrapped_index].numpy()

    fig, ax = plt.subplots(2, 2, figsize=(10, 10));
    ax[0,0].imshow(image_rgb);
    ax[0,0].set_title("Input Image");
    ax[0,1].imshow(np.squeeze(ground_truth), cmap='gray');
    ax[0,1].set_title("Ground Truth");
    ax[1,0].imshow(np.squeeze(prediction), cmap='gray')
    ax[1,0].set_title("Prediction")
    ax[1,1].imshow(np.squeeze(prediction) > 0.5, cmap='gray')
    ax[1,1].set_title("Prediction (binary)")

    for i in range(2):
        for j in range(2):
            ax[i,j].axis('off')
            
    
    plt.savefig(out_dir + '/predictions/' + 'comparison_' + str(index) + '.png');
for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)

# ----------- save metrics

model_info = _02_evaluate_model.evaluate_model(
    "DeepLabV3+ with ResNet50 backbone", 
    test_dataset_1, 
    model, 
    (256, 256, 3), 
    shuffled, 
    batches, 
    epochs, 
    augmentation_settings=augmentation_settings, 
    threshold=0.5)
df = pd.DataFrame(model_info)
df.to_csv(os.path.join(out_dir, 'metrics.csv'), index=False)

## Xception

The original Xception code is derived from:

https://colab.research.google.com/github/mavenzer/Autism-Detection-Using_YOLO/blob/master/Tutorial_implementing_Xception_in_TensorFlow_2_0_using_the_Functional_API.ipynb#scrollTo=_cwgleGGqE9T

Here, we adjust this code as is done in the paper. Corrections are indicated by `#Correction suggested by paper`.

In [46]:
def entry_flow(inputs):
    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    previous_block_activation = x  # Set aside residual
  
    # Blocks 2, 3, 4 are identical apart from the feature depth.
    list1=[128, 256, 728]
    for size in list1:
        idx=list1.index(size)+2
        x = layers.Activation('relu')(x)
        x = layers.SeparableConv2D(size, 3, padding='same')(x)
        x = layers.BatchNormalization(name="block"+str(idx)+"_sepconv1_bn")(x)

        x = layers.Activation('relu')(x)
        x = layers.SeparableConv2D(size, 3, padding='same')(x)
        x = layers.BatchNormalization(name="block"+str(idx)+"_sepconv2_bn")(x)

        #x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
        #Correction suggested by paper
        x = layers.SeparableConv2D(size, 3, padding='same',strides=2)(x)
        x = layers.BatchNormalization(name="block"+str(idx)+"_sepconv3_bn")(x)
        x = layers.Activation('relu')(x)
    
        # Project residual
        residual = layers.Conv2D(
        size, 1, strides=2, padding='same')(previous_block_activation)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual
        
    return x


def middle_flow(x, num_blocks=16):
    previous_block_activation = x

    for i in range(num_blocks):
        prefix = "block" + str(i + 5)
        x = layers.Activation('relu')(x)
        x = layers.SeparableConv2D(728, 3, padding='same')(x)
        x = layers.BatchNormalization(name=prefix + "_sepconv1_bn")(x)

        x = layers.Activation('relu')(x)
        x = layers.SeparableConv2D(728, 3, padding='same')(x)
        x = layers.BatchNormalization(name=prefix + "_sepconv2_bn")(x)
    
        x = layers.Activation('relu')(x)
        x = layers.SeparableConv2D(728, 3, padding='same')(x)
        x = layers.BatchNormalization(name=prefix + "_sepconv3_bn")(x)

        x = layers.add([x, previous_block_activation])  # Add back residual
        previous_block_activation = x  # Set aside next residual
    
    return x

def exit_flow(x, num_classes=1):

    previous_block_activation = x
    x = layers.Activation('relu')(x)
    x = layers.SeparableConv2D(728, 3, padding='same')(x)
    x = layers.BatchNormalization(name="block" + str(21)+"_sepconv1_bn")(x)

    x = layers.Activation('relu')(x)
    x = layers.SeparableConv2D(1024, 3, padding='same')(x)
    x = layers.BatchNormalization(name="block" + str(21)+"_sepconv2_bn")(x)

    #x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    #Correction suggested by paper
    x = layers.SeparableConv2D(1024, 3, padding='same',strides=2)(x)
    x = layers.BatchNormalization(name="block" + str(21)+"_sepconv3_bn")(x)
    x = layers.Activation('relu')(x)

    # Project residual
    residual = layers.Conv2D(1024, 1, strides=2, padding='same')(previous_block_activation)
    x = layers.add([x, residual])  # Add back residual

    x = layers.SeparableConv2D(1536, 3, padding='same')(x)
    x = layers.BatchNormalization(name="block" + str(22)+"_sepconv1_bn")(x)
    x = layers.Activation('relu')(x)
    
    #Correction suggested by paper
    x = layers.SeparableConv2D(1536, 3, padding='same',strides=2)(x)
    x = layers.BatchNormalization(name="block" + str(22)+"_sepconv2_bn")(x)
    x = layers.Activation('relu')(x)

    x = layers.SeparableConv2D(2048, 3, padding='same')(x)
    x = layers.BatchNormalization(name="block" + str(22)+"_sepconv3_bn")(x)
    x = layers.Activation('relu')(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 1:
        activation = 'sigmoid'
    else:
        activation = 'softmax'
    return layers.Dense(num_classes, activation=activation)(x)

In [47]:
def DeeplabV3Plus_Xception(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 4))
    
    outputs_xception = exit_flow(middle_flow(entry_flow(model_input)))
    xception = keras.Model(model_input, outputs_xception)
    
    #Encoder
    x = xception.get_layer("block21_sepconv2_bn").output #output stride=16 (fast fully convolutional mode)
    # output stride specifies the requested ratio of input to output spatial resolution
    print(x.shape)
    x = DilatedSpatialPyramidPooling(x)
    
    #Decoder
    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = xception.get_layer("block3_sepconv2_bn").output #block3_sepconv2_bn
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
    
    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus_Xception(image_size=256, num_classes=1)
model.summary()

(None, 16, 16, 728)
Model: "model_11"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_7 (InputLayer)           [(None, 256, 256, 4  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_84 (Conv2D)             (None, 128, 128, 32  1184        ['input_7[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_62 (BatchN  (None, 128, 128, 32  128        ['conv2d_84[0][0]']              
 ormalization)                  )                                      

                                                                                                  
 activation_268 (Activation)    (None, 32, 32, 256)  0           ['add_81[0][0]']                 
                                                                                                  
 separable_conv2d_258 (Separabl  (None, 32, 32, 728)  189400     ['activation_268[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block4_sepconv1_bn (BatchNorma  (None, 32, 32, 728)  2912       ['separable_conv2d_258[0][0]']   
 lization)                                                                                        
                                                                                                  
 activation_269 (Activation)    (None, 32, 32, 728)  0           ['block4_sepconv1_bn[0][0]']     
          

                                                                                                  
 activation_277 (Activation)    (None, 16, 16, 728)  0           ['add_84[0][0]']                 
                                                                                                  
 separable_conv2d_267 (Separabl  (None, 16, 16, 728)  537264     ['activation_277[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block7_sepconv1_bn (BatchNorma  (None, 16, 16, 728)  2912       ['separable_conv2d_267[0][0]']   
 lization)                                                                                        
                                                                                                  
 activation_278 (Activation)    (None, 16, 16, 728)  0           ['block7_sepconv1_bn[0][0]']     
          

                                                                                                  
 separable_conv2d_276 (Separabl  (None, 16, 16, 728)  537264     ['activation_286[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block10_sepconv1_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_276[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_287 (Activation)    (None, 16, 16, 728)  0           ['block10_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_277 (Separabl  (None, 16, 16, 728)  537264     ['activation_287[0][0]']         
 eConv2D) 

 eConv2D)                                                                                         
                                                                                                  
 block13_sepconv1_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_285[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_296 (Activation)    (None, 16, 16, 728)  0           ['block13_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_286 (Separabl  (None, 16, 16, 728)  537264     ['activation_296[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block13_s

 ormalization)                                                                                    
                                                                                                  
 tf.nn.relu_61 (TFOpLambda)     (None, 64, 64, 256)  0           ['batch_normalization_71[0][0]'] 
                                                                                                  
 conv2d_98 (Conv2D)             (None, 64, 64, 256)  589824      ['tf.nn.relu_61[0][0]']          
                                                                                                  
 batch_normalization_72 (BatchN  (None, 64, 64, 256)  1024       ['conv2d_98[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 tf.nn.relu_62 (TFOpLambda)     (None, 64, 64, 256)  0           ['batch_normalization_72[0][0]'] 
          

In [None]:
# ----------- create directories
out_dir = '../results/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_DeepLabV3+_Xception/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    os.makedirs(out_dir + '/plots')
    os.makedirs(out_dir + '/weights')
    os.makedirs(out_dir + '/predictions')
    os.makedirs(out_dir + '/bestweights')

    
# Define the path where you want to save the weights
checkpoint_path = out_dir + 'bestweights/' 

# Define the ModelCheckpoint callback
checkpoint = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_accuracy', 
    mode='max', 
    verbose=1
)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"]
)
history = model.fit(train_dataset, validation_data=val_dataset,epochs=epochs,callbacks=[checkpoint])

In [None]:
# ---------------------- save results

# Load the saved, optimal  weights
model.load_weights(checkpoint_path)
# Compile the model with the same optimizer and loss function used during training
model.compile(optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"])
model.save_weights(out_dir+'model.hdf5')


# ----------- plot the training and validation loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'loss.png')

# ----------- plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'accuracy.png')

# ----------- save weights
model.save(out_dir + '/weights/' + 'model.h5')

# ----------- save predictions
def visualize_predictions(index, test_dataset, out_dir):
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % batches
    image = image_batch[wrapped_index].numpy()

    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)

    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    ground_truth = label_batch[wrapped_index].numpy()

    fig, ax = plt.subplots(2, 2, figsize=(10, 10))
    ax[0,0].imshow(image_rgb)
    ax[0,0].set_title("Input Image")
    ax[0,1].imshow(np.squeeze(ground_truth), cmap='gray')
    ax[0,1].set_title("Ground Truth")
    ax[1,0].imshow(np.squeeze(prediction), cmap='gray')
    ax[1,0].set_title("Prediction")
    ax[1,1].imshow(np.squeeze(prediction) > 0.5, cmap='gray')
    ax[1,1].set_title("Prediction (binary)")

    for i in range(2):
        for j in range(2):
            ax[i,j].axis('off')

    plt.savefig(out_dir + '/predictions/' + 'comparison_' + str(index) + '.png')
    #plt.show()
for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)

# ----------- save metrics

model_info = _02_evaluate_model.evaluate_model(
    "DeepLabV3+ with Xception backbone",
    test_dataset, 
    model, 
    (256,256,4), 
    shuffled, 
    batches, 
    epochs, 
    augmentation_settings=augmentation_settings, 
    threshold=0.5)
df = pd.DataFrame(model_info)
df.to_csv(os.path.join(out_dir, 'metrics.csv'), index=False)

## Xception with three bands

In [39]:
def DeeplabV3Plus_Xception_3(image_size, num_classes):
    model_input = keras.Input(shape=(image_size, image_size, 3))
    
    outputs = exit_flow(middle_flow(entry_flow(model_input)))
    xception = keras.Model(model_input, outputs)
    
    #Encoder
    x = xception.get_layer("block21_sepconv2_bn").output #output stride=16 (fast fully convolutional mode)
    x = DilatedSpatialPyramidPooling(x)
    
    #Decoder
    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = xception.get_layer("block3_sepconv2_bn").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)


model = DeeplabV3Plus_Xception_3(image_size=256, num_classes=1)
model.summary()

Model: "model_9"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_68 (Conv2D)             (None, 128, 128, 32  896         ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_51 (BatchN  (None, 128, 128, 32  128        ['conv2d_68[0][0]']              
 ormalization)                  )                                                           

                                                                                                  
 activation_203 (Activation)    (None, 32, 32, 256)  0           ['add_61[0][0]']                 
                                                                                                  
 separable_conv2d_195 (Separabl  (None, 32, 32, 728)  189400     ['activation_203[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block4_sepconv1_bn (BatchNorma  (None, 32, 32, 728)  2912       ['separable_conv2d_195[0][0]']   
 lization)                                                                                        
                                                                                                  
 activation_204 (Activation)    (None, 32, 32, 728)  0           ['block4_sepconv1_bn[0][0]']     
          

                                                                                                  
 activation_212 (Activation)    (None, 16, 16, 728)  0           ['add_64[0][0]']                 
                                                                                                  
 separable_conv2d_204 (Separabl  (None, 16, 16, 728)  537264     ['activation_212[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block7_sepconv1_bn (BatchNorma  (None, 16, 16, 728)  2912       ['separable_conv2d_204[0][0]']   
 lization)                                                                                        
                                                                                                  
 activation_213 (Activation)    (None, 16, 16, 728)  0           ['block7_sepconv1_bn[0][0]']     
          

                                                                                                  
 separable_conv2d_213 (Separabl  (None, 16, 16, 728)  537264     ['activation_221[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block10_sepconv1_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_213[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_222 (Activation)    (None, 16, 16, 728)  0           ['block10_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_214 (Separabl  (None, 16, 16, 728)  537264     ['activation_222[0][0]']         
 eConv2D) 

 eConv2D)                                                                                         
                                                                                                  
 block13_sepconv1_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_222[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_231 (Activation)    (None, 16, 16, 728)  0           ['block13_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_223 (Separabl  (None, 16, 16, 728)  537264     ['activation_231[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block13_s

 block16_sepconv1_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_231[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_240 (Activation)    (None, 16, 16, 728)  0           ['block16_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_232 (Separabl  (None, 16, 16, 728)  537264     ['activation_240[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block16_sepconv2_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_232[0][0]']   
 alization)                                                                                       
          

                                                                                                  
 activation_249 (Activation)    (None, 16, 16, 728)  0           ['block19_sepconv1_bn[0][0]']    
                                                                                                  
 separable_conv2d_241 (Separabl  (None, 16, 16, 728)  537264     ['activation_249[0][0]']         
 eConv2D)                                                                                         
                                                                                                  
 block19_sepconv2_bn (BatchNorm  (None, 16, 16, 728)  2912       ['separable_conv2d_241[0][0]']   
 alization)                                                                                       
                                                                                                  
 activation_250 (Activation)    (None, 16, 16, 728)  0           ['block19_sepconv2_bn[0][0]']    
          

                                                                                                  
 batch_normalization_55 (BatchN  (None, 16, 16, 256)  1024       ['conv2d_76[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_56 (BatchN  (None, 16, 16, 256)  1024       ['conv2d_77[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_57 (BatchN  (None, 16, 16, 256)  1024       ['conv2d_78[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 up_sampli

In [None]:
# ----------- create directories
out_dir = '../results/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_DeepLabV3+_Xception_3bands/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    os.makedirs(out_dir + '/plots')
    os.makedirs(out_dir + '/weights')
    os.makedirs(out_dir + '/predictions')
    os.makedirs(out_dir + '/bestweights')

    
# Define the path where you want to save the weights
checkpoint_path = out_dir + 'bestweights/' 

# Define the ModelCheckpoint callback
checkpoint = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_accuracy', 
    mode='max', 
    verbose=1
)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"]
)
history = model.fit(train_dataset_1, validation_data=val_dataset_1,epochs=epochs,callbacks=[checkpoint])

In [None]:
# ---------------------- save results

# Load the saved, optimal  weights
model.load_weights(checkpoint_path)

# Compile the model with the same optimizer and loss function used during training
model.compile(optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"])



# ----------- plot the training and validation loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'loss.png')

# ----------- plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'accuracy.png')

# ----------- save weights
model.save(out_dir + '/weights/' + 'model.h5')

# ----------- save predictions
def visualize_predictions(index, test_dataset, out_dir):
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % batches
    image = image_batch[wrapped_index].numpy()
    image = image[:,:,:3]
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)
    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    ground_truth = label_batch[wrapped_index].numpy()

    fig, ax = plt.subplots(2, 2, figsize=(10, 10))
    ax[0,0].imshow(image_rgb)
    ax[0,0].set_title("Input Image")
    ax[0,1].imshow(np.squeeze(ground_truth), cmap='gray')
    ax[0,1].set_title("Ground Truth")
    ax[1,0].imshow(np.squeeze(prediction), cmap='gray')
    ax[1,0].set_title("Prediction")
    ax[1,1].imshow(np.squeeze(prediction) > 0.5, cmap='gray')
    ax[1,1].set_title("Prediction (binary)")

    for i in range(2):
        for j in range(2):
            ax[i,j].axis('off')

    plt.savefig(out_dir + '/predictions/' + 'comparison_' + str(index) + '.png')
    #plt.show()
for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)

# ----------- save metrics

if augment:
    augmetation_settings = {
    "flip_left_right": 0,
    "flip_up_down": 0,
    "gaussian_blur": 0.2,
    "random_noise": 0.0,
    "random_brightness": 0.5,
    "random_contrast": 0.5}
else:
    augmetation_settings = None


model_info = _02_evaluate_model.evaluate_model("DeepLabV3+ with Xception backbone (3 bands)", test_dataset_1, model, (256, 256, 3), shuffled, batches, epochs, augmentation_settings, threshold=0.5)
df = pd.DataFrame(model_info)
df.to_csv(os.path.join(out_dir, 'metrics.csv'), index=False)