# Fully Convolutional Network (FCN)

In [17]:
import _02c_read_datasets
import _02_evaluate_model

import os
import numpy as np
import itertools
import matplotlib.pyplot as plt
import tensorflow as tf
from datetime import datetime
from tensorflow.keras.layers import *
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
import pandas as pd
from tensorflow import keras
from tensorflow.keras import *

In [18]:
#load data
epochs = 20
batches = 16
input_width = 256
input_shape = (256,256,4)
shuffled = True
augment = False #{True, False}
if augment:
    augmentation_settings = {
    "flip_left_right": 0,
    "flip_up_down": 0,
    "gaussian_blur": 0.2,
    "random_noise": 0.0,
    "random_brightness": 0.5,
    "random_contrast": 0.5}
else:
    augmentation_settings = None

train_dataset, val_dataset, test_dataset = _02c_read_datasets.load_datasets(augmented = augment)

In [19]:
def fcn_model(input_shape):
    inputs = Input(input_shape)
    
    # Encoding
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottleneck
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    
    # Decoding
    
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Conv2D(256, (2, 2), activation='relu', padding='same')(up6)
    up6 = BatchNormalization()(up6)
    merge6 = Concatenate(axis=3)([up6, conv4])
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    
    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Conv2D(128, (2, 2), activation='relu', padding='same')(up7)
    up7 = BatchNormalization()(up7)
    merge7 = Concatenate(axis=3)([up7, conv3])
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    
    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Conv2D(64, (2, 2), activation='relu', padding='same')(up8)
    up8 = BatchNormalization()(up8)
    merge8 = Concatenate(axis=3)([up8, conv2])
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    
    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Conv2D(32, (2, 2), activation='relu', padding='same')(up9)
    up9 = BatchNormalization()(up9)
    merge9 = Concatenate(axis=3)([up9, conv1])
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(merge9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

model = fcn_model(input_shape)
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 256, 256, 4  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_34 (Conv2D)             (None, 256, 256, 32  1184        ['input_3[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_33 (BatchN  (None, 256, 256, 32  128        ['conv2d_34[0][0]']              
 ormalization)                  )                                                           

 batch_normalization_44 (BatchN  (None, 32, 32, 256)  1024       ['conv2d_45[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 conv2d_46 (Conv2D)             (None, 32, 32, 256)  590080      ['batch_normalization_44[0][0]'] 
                                                                                                  
 batch_normalization_45 (BatchN  (None, 32, 32, 256)  1024       ['conv2d_46[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 up_sampling2d_6 (UpSampling2D)  (None, 64, 64, 256)  0          ['batch_normalization_45[0][0]'] 
                                                                                                  
 conv2d_47

In [None]:
# ----------- create directories
out_dir = '../results/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_FCN/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    os.makedirs(out_dir + '/plots')
    os.makedirs(out_dir + '/weights')
    os.makedirs(out_dir + '/predictions')
    os.makedirs(out_dir + '/bestweights')

    
# Define the path where you want to save the weights
checkpoint_path = out_dir + 'bestweights/' 

# Define the ModelCheckpoint callback
checkpoint = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_accuracy', 
    mode='max', 
    verbose=1
)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"]
)
history = model.fit(train_dataset, validation_data=val_dataset,epochs=epochs,callbacks=[checkpoint])

In [None]:
# ---------------------- save results


# Load the saved, optimal  weights
model.load_weights(checkpoint_path)

# Compile the model with the same optimizer and loss function used during training
model.compile(optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"])
model.save_weights(out_dir+'model.hdf5')


# ----------- plot the training and validation loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'loss.png')

# ----------- plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'accuracy.png')

# ----------- save weights
model.save(out_dir + '/weights/' + 'model.h5')

# ----------- save predictions
def visualize_predictions(index, test_dataset, out_dir):
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % batches
    image = image_batch[wrapped_index].numpy()
    
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)
    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    ground_truth = label_batch[wrapped_index].numpy()

    fig, ax = plt.subplots(2, 2, figsize=(10, 10));
    ax[0,0].imshow(image_rgb);
    ax[0,0].set_title("Input Image");
    ax[0,1].imshow(np.squeeze(ground_truth), cmap='gray');
    ax[0,1].set_title("Ground Truth");
    ax[1,0].imshow(np.squeeze(prediction), cmap='gray')
    ax[1,0].set_title("Prediction")
    ax[1,1].imshow(np.squeeze(prediction) > 0.5, cmap='gray')
    ax[1,1].set_title("Prediction (binary)")

    for i in range(2):
        for j in range(2):
            ax[i,j].axis('off')
            
    
    plt.savefig(out_dir + '/predictions/' + 'comparison_' + str(index) + '.png');
for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)

# ----------- save metrics

model_info = _02_evaluate_model.evaluate_model(
    "FCN", 
    test_dataset, 
    model, 
    input_shape, 
    shuffled, 
    batches, 
    epochs, 
    augmentation_settings=augmentation_settings, 
    threshold=0.5)
df = pd.DataFrame(model_info)
df.to_csv(os.path.join(out_dir, 'metrics.csv'), index=False)

In [None]:
def visualize_predictions(index, test_dataset, out_dir, batches = 16):
    
    dir = "image_" + str(index)
    if not os.path.exists(out_dir + '/predictions/' + dir + '/'):
        os.makedirs(out_dir + '/predictions/' + dir + '/')
        os.makedirs(out_dir + '/predictions/' + dir + '/input_image')
        os.makedirs(out_dir + '/predictions/' + dir + '/ground_truth')
        os.makedirs(out_dir + '/predictions/' + dir + '/prediction')
        os.makedirs(out_dir + '/predictions/' + dir + '/prediction_binary')
    
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % 16
    image = image_batch[wrapped_index].numpy()
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)

    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    plt.imsave(out_dir + '/predictions/' + dir + '/input_image/' + str(index) + '.png', image_rgb)

    ground_truth = label_batch[wrapped_index].numpy()
    plt.imsave(out_dir + '/predictions/' + dir + '/ground_truth/' + str(index) + '.png', np.squeeze(ground_truth), cmap='gray')

    plt.imsave(out_dir + '/predictions/' + dir + '/prediction/' + str(index) + '.png', np.squeeze(prediction), cmap='gray')

    prediction_binary = np.where(prediction > 0.5, 1, 0)
    plt.imsave(out_dir + '/predictions/' + dir + '/prediction_binary/' + str(index) + '.png', np.squeeze(prediction_binary), cmap='gray')

for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)
