# Evaluating green cover and open spaces in informal settlements of Mumbai using deep learning

# Importing Libraries

In [None]:
import pickle
import visualkeras
from tqdm import tqdm
import numpy as np
import pandas as pd
import skimage
from skimage import io
import tifffile as tifi
from skimage.io import imread_collection
import albumentations as A
from IPython.display import SVG
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from PIL import Image, ImageFont
from collections import defaultdict
import os, re, sys, random, shutil, cv2

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, precision_score, recall_score, f1_score

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import applications, optimizers
from tensorflow.keras.applications import VGG16, VGG19, DenseNet121, InceptionResNetV2, ResNet50, MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.utils import model_to_dot, plot_model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger, LearningRateScheduler, TensorBoard

# Working with the Dataset

In [None]:
data_dir = "../Kaggle_Data/split_data_v8/"
train_images = "../Kaggle_Data/split_data_v8/train_images/"
train_masks = "../Kaggle_Data/split_data_v8/train_masks/"
val_images = "../Kaggle_Data/split_data_v8/val_images/"
val_masks = "../Kaggle_Data/split_data_v8/val_masks/"
test_images = "../Kaggle_Data/split_data_v8/test_images/"
test_masks = "../Kaggle_Data/split_data_v8/test_masks/"

In [None]:
print('Number of images in training set: ', len(os.listdir(train_images+'train')))
print('Number of masks in training set: ', len(os.listdir(train_masks+'train')))
print('Number of images in validation set: ', len(os.listdir(val_images+'val')))
print('Number of masks in validation set: ', len(os.listdir(val_masks+'val')))
print('Number of images in testing set: ', len(os.listdir(test_images+'test')))
print('Number of masks in testing set: ', len(os.listdir(test_masks+'test')))

# Satellite Images and Semantic Segmentation Masks (Ground Truths)

In [None]:
files = ['tile_1.0_1', 'tile_1.1_1', 'tile_3.7_4', 'tile_3.7_49', 'tile_4.34_74', 'tile_4.35_21', 'tile_5.18_60', 'tile_5.26_80', 'tile_6.10_45']

def show_data(files, images_dir, masks_dir):
        fig, axs = plt.subplots(3, 6, figsize=(22, 12), constrained_layout=True)
        sns.set_style("ticks")
        fig.suptitle('Training Set Images & Masks\n', fontsize=18, fontweight='medium')
        idx = 0
        for i in range(3):
            for j in range(6):
                if j%2 == 0:
                    axs[i][j].imshow(cv2.resize(cv2.cvtColor(cv2.imread(f'{images_dir}train/{files[idx]}.tif'), cv2.COLOR_BGR2RGB),(128, 128)))
                    axs[i][j].set_title(f'Satellite Image: {files[idx]}.tif', fontdict = {'fontsize':14, 'fontweight':'medium'})
                    axs[i][j].grid(False)
                    axs[i][j].axis(True)
                elif j%2 != 0:
                    if (files[idx] in ['image_1', 'image_2', 'image_6']):
                        axs[i][j].imshow(cv2.resize(cv2.cvtColor(cv2.imread(f'{masks_dir}train/{files[idx]}.tif'), cv2.COLOR_BGR2RGB),(128, 128)))
                        axs[i][j].set_title(f'Ground Truth: {files[idx]}.tif', fontdict = {'fontsize':14, 'fontweight':'medium'})
                        axs[i][j].grid(False)
                        axs[i][j].axis(True)
                        idx += 1
                    else:
                        axs[i][j].imshow(cv2.resize(cv2.cvtColor(cv2.imread(f'{masks_dir}train/{files[idx]}.png'), cv2.COLOR_BGR2RGB),(128, 128)))
                        axs[i][j].set_title(f'Ground Truth: {files[idx]}.png', fontdict = {'fontsize':14, 'fontweight':'medium'})
                        axs[i][j].grid(False)
                        axs[i][j].axis(True)
                        idx += 1
                        
        plt.savefig('./Output/sample_data', facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 300)
        plt.show()

show_data(files, train_images, train_masks)

# Reading RGB Color Codes for Labels

In [None]:
class_dict_df = pd.read_csv(f'{data_dir}class_dict.csv', index_col=False, skipinitialspace=True)
class_dict_df

In [None]:
label_names= list(class_dict_df.name)
label_codes = []
r= np.asarray(class_dict_df.r)
g= np.asarray(class_dict_df.g)
b= np.asarray(class_dict_df.b)

for i in range(len(class_dict_df)):
    label_codes.append(tuple([r[i], g[i], b[i]]))
    
label_codes, label_names

# Create Useful Label & Code Conversion Dictionaries

These will be used for:

* One hot encoding the mask labels for model training

* Decoding the predicted labels for interpretation and visualization

In [None]:
code2id = {v:k for k,v in enumerate(label_codes)}
id2code = {k:v for k,v in enumerate(label_codes)}

name2id = {v:k for k,v in enumerate(label_names)}
id2name = {k:v for k,v in enumerate(label_names)}

In [None]:
id2code

In [None]:
id2name

# Define Functions for One Hot Encoding RGB Labels & Decoding Encoded Predictions

In [None]:
def rgb_to_onehot(rgb_image, colormap = id2code):
    '''Function to one hot encode RGB mask labels
        Inputs: 
            rgb_image - image matrix (eg. 256 x 256 x 3 dimension numpy ndarray)
            colormap - dictionary of color to label id
        Output: One hot encoded image of dimensions (height x width x num_classes) where num_classes = len(colormap)
    '''
    num_classes = len(colormap)
    shape = rgb_image.shape[:2]+(num_classes,)
    encoded_image = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(colormap):
        encoded_image[:,:,i] = np.all(rgb_image.reshape( (-1,3) ) == colormap[i], axis=1).reshape(shape[:2])
    return encoded_image

def onehot_to_rgb(onehot, colormap = id2code):
    '''Function to decode encoded mask labels
        Inputs: 
            onehot - one hot encoded image matrix (height x width x num_classes)
            colormap - dictionary of color to label id
        Output: Decoded RGB image (height x width x 3) 
    '''
    single_layer = np.argmax(onehot, axis=-1)
    output = np.zeros( onehot.shape[:2]+(3,) )
    for k in colormap.keys():
        output[single_layer==k] = colormap[k]
    return np.uint8(output)

# Creating Custom Image Data Generators

## Defining Data Generators

In [None]:
# Normalizing only frame images, since masks contain label info
data_gen_args = dict(rescale=1./255)
mask_gen_args = dict()

train_frames_datagen = ImageDataGenerator(**data_gen_args)
train_masks_datagen = ImageDataGenerator(**mask_gen_args)
val_frames_datagen = ImageDataGenerator(**data_gen_args)
val_masks_datagen = ImageDataGenerator(**mask_gen_args)
test_frames_datagen = ImageDataGenerator(**data_gen_args)
test_masks_datagen = ImageDataGenerator(**mask_gen_args)

# Seed defined for aligning images and their masks
seed = 1

# Custom Image Data Generators for Creating Batches of Frames and Masks

In [None]:
def TrainAugmentGenerator(train_images_dir, train_masks_dir, seed = 1, batch_size = 8, target_size = (512, 512)):
    '''Train Image data generator
        Inputs: 
            seed - seed provided to the flow_from_directory function to ensure aligned data flow
            batch_size - number of images to import at a time
            train_images_dir - train images directory
            train_masks_dir - train masks directory
            target_size - tuple of integers (height, width)
            
        Output: Decoded RGB image (height x width x 3) 
    '''
    train_image_generator = train_frames_datagen.flow_from_directory(
    train_images_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)

    train_mask_generator = train_masks_datagen.flow_from_directory(
    train_masks_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)

    while True:
        X1i = train_image_generator.next()
        X2i = train_mask_generator.next()
        
        #One hot encoding RGB images
        mask_encoded = [rgb_to_onehot(X2i[0][x,:,:,:], id2code) for x in range(X2i[0].shape[0])]
        
        yield X1i[0], np.asarray(mask_encoded)

def ValAugmentGenerator(val_images_dir, val_masks_dir, seed = 1, batch_size = 8, target_size = (512, 512)):
    '''Validation Image data generator
        Inputs: 
            seed - seed provided to the flow_from_directory function to ensure aligned data flow
            batch_size - number of images to import at a time
            val_images_dir - validation images directory
            val_masks_dir - validation masks directory
            target_size - tuple of integers (height, width)
            
        Output: Decoded RGB image (height x width x 3) 
    '''
    val_image_generator = val_frames_datagen.flow_from_directory(
    val_images_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)


    val_mask_generator = val_masks_datagen.flow_from_directory(
    val_masks_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)


    while True:
        X1i = val_image_generator.next()
        X2i = val_mask_generator.next()
        
        #One hot encoding RGB images
        mask_encoded = [rgb_to_onehot(X2i[0][x,:,:,:], id2code) for x in range(X2i[0].shape[0])]
        
        yield X1i[0], np.asarray(mask_encoded)
        
def TestAugmentGenerator(test_images_dir, test_masks_dir, seed = 1, batch_size = 8, target_size = (512, 512)):
    '''Validation Image data generator
        Inputs: 
            seed - seed provided to the flow_from_directory function to ensure aligned data flow
            batch_size - number of images to import at a time
            test_images_dir - testing images directory
            test_masks_dir - testing masks directory
            target_size - tuple of integers (height, width)
            
        Output: Decoded RGB image (height x width x 3) 
    '''
    test_image_generator = test_frames_datagen.flow_from_directory(
    test_images_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)


    test_mask_generator = test_masks_datagen.flow_from_directory(
    test_masks_dir,
    batch_size = batch_size, 
    seed = seed, 
    target_size = target_size)


    while True:
        X1i = test_image_generator.next()
        X2i = test_mask_generator.next()
        
        #One hot encoding RGB images
        mask_encoded = [rgb_to_onehot(X2i[0][x,:,:,:], id2code) for x in range(X2i[0].shape[0])]
        
        yield X1i[0], np.asarray(mask_encoded)

# 1. DeepLabV3+ Model

In [None]:
batch_size = 32
num_train_samples = len(np.sort(os.listdir(train_images+'train')))
num_val_samples = len(np.sort(os.listdir(val_images+'val')))
steps_per_epoch = np.ceil(float(num_train_samples) / float(batch_size))
print('steps_per_epoch: ', steps_per_epoch)
validation_steps = np.ceil(float(num_val_samples) / (float(batch_size)))
print('validation_steps: ', validation_steps)

In [None]:
def dice_coef(y_true, y_pred):
    return (2. * K.sum(y_true * y_pred) + 1.) / (K.sum(y_true) + K.sum(y_pred) + 1.)

In [None]:
def class_accuracy(confusion: np.ndarray) -> np.ndarray:
    """
    Return the per class accuracy from confusion matrix.
    Args:
        confusion: the confusion matrix between ground truth and predictions
    Returns:
        a vector representing the per class accuracy
    """
    # extract the number of correct guesses from the diagonal
    preds_correct = np.sum(confusion * np.eye(len(confusion)), axis=-1)
    # extract the number of total values per class from ground truth
    trues = np.sum(confusion, axis=-1)
    # get per class accuracy by dividing correct by total
    return preds_correct / trues

def iou(confusion: np.ndarray) -> np.ndarray:
    """
    Return the per class Intersection over Union (I/U) from confusion matrix.
    Args:
        confusion: the confusion matrix between ground truth and predictions
    Returns:
        a vector representing the per class I/U
    Reference:
        https://en.wikipedia.org/wiki/Jaccard_index
    """
    # get |intersection| (AND) from the diagonal of the confusion matrix
    intersection = (confusion * np.eye(len(confusion))).sum(axis=-1)
    # calculate the total ground truths and predictions per class
    preds = confusion.sum(axis=0)
    trues = confusion.sum(axis=-1)
    # get |union| (OR) from the predictions, ground truths, and intersection
    union = trues + preds - intersection
    # return the intersection over the union
    return intersection / union

In [None]:
""" Atrous Spatial Pyramid Pooling """
def ASPP(inputs):
    shape = inputs.shape

    y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]), name='average_pooling')(inputs)
    y_pool = Conv2D(filters=256, kernel_size=1, padding='same', use_bias=False)(y_pool)
    y_pool = BatchNormalization(name=f'bn_1')(y_pool)
    y_pool = Activation('relu', name=f'relu_1')(y_pool)
    y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool)

    y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(inputs)
    y_1 = BatchNormalization()(y_1)
    y_1 = Activation('relu')(y_1)
    
    y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same', use_bias=False)(inputs)
    y_6 = BatchNormalization()(y_6)
    y_6 = Activation('relu')(y_6)
    
    y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same', use_bias=False)(inputs)
    y_12 = BatchNormalization()(y_12)
    y_12 = Activation('relu')(y_12)
    
    y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same', use_bias=False)(inputs)
    y_18 = BatchNormalization()(y_18)
    y_18 = Activation('relu')(y_18)
    
    y = Concatenate()([y_pool, y_1, y_6, y_12, y_18])

    y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    
    return y

def DeepLabV3Plus(shape, num_classes):
    """ Inputs """
    inputs = Input(shape)

    """ Pre-trained ResNet50 """
    base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)

    """ Pre-trained ResNet50 Output """
    image_features = base_model.get_layer('conv4_block6_out').output
    x_a = ASPP(image_features)
    x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a)

    """ Get low-level features """
    x_b = base_model.get_layer('conv2_block2_out').output
    x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b)
    x_b = BatchNormalization()(x_b)
    x_b = Activation('relu')(x_b)
    
    x = Concatenate()([x_a, x_b])

    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D((4, 4), interpolation="bilinear")(x)

    """ Outputs """
    x = Dropout(0.2)(x)
    x = Conv2D(num_classes, (1, 1), name='output_layer', padding="same")(x)
    x = Activation('softmax')(x)

    """ Model """
    model = Model(inputs=inputs, outputs=x)
    return model

In [None]:
K.clear_session()

loss = tf.keras.losses.CategoricalCrossentropy()
deeplabv3plus = DeepLabV3Plus(shape = (128, 128, 3), num_classes = 7)
deeplabv3plus.compile(optimizer=Adam(learning_rate = 0.0001), loss=loss, metrics=[dice_coef, "accuracy"])
deeplabv3plus.summary()

In [None]:
color_map = defaultdict(dict)
color_map[Dropout]['fill'] = 'gray'

font = ImageFont.truetype("./fonts/OpenSans-Semibold.ttf", 32)
visualkeras.layered_view(deeplabv3plus, legend=True, font=font, to_file='./Output/deeplabv3plus_unet_architecture.png', color_map=color_map, draw_volume=False)

# DeepLabV3+ Model Training

In [None]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.0001, 40)

lr_scheduler = LearningRateScheduler(
    exponential_decay_fn,
    verbose=1
)

checkpoint = ModelCheckpoint(
    filepath = './Output/deeplabv3plus.h5',
    save_best_only = True, 
#     save_weights_only = False,
    monitor = 'val_loss', 
    mode = 'auto', 
    verbose = 1
)

# earlystop = EarlyStopping(
#     monitor = 'val_loss', 
#     min_delta = 0.001, 
#     patience = 10, 
#     mode = 'auto', 
#     verbose = 1,
#     restore_best_weights = True
# )

csvlogger = CSVLogger(
    filename= "./Output/deeplabv3plus_training.csv",
    separator = ",",
    append = False
)

callbacks = [checkpoint, csvlogger, lr_scheduler]

In [None]:
history_1 = deeplabv3plus.fit(
    TrainAugmentGenerator(train_images_dir = train_images, train_masks_dir = train_masks, target_size = (128, 128)), 
    steps_per_epoch = steps_per_epoch,
    validation_data = ValAugmentGenerator(val_images_dir = val_images, val_masks_dir = val_masks, target_size = (128, 128)), 
    validation_steps = validation_steps, 
    epochs = 40,
    callbacks = callbacks,
    use_multiprocessing = False,
    verbose = 1
)

In [None]:
with open('./Output/trainHistoryDict_deeplabv3plus', 'wb') as file_pi:
    pickle.dump(history_1.history, file_pi)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(20, 12))
ax = ax.ravel()
sns.set_style("ticks")
metrics = ['Dice Coefficient', 'Accuracy', 'Loss', 'Learning Rate']

for i, met in enumerate(['dice_coef', 'accuracy', 'loss', 'lr']): 
    if met != 'lr':
        ax[i].plot(history_1.history[met], '-')
        ax[i].plot(history_1.history['val_' + met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].legend(['Train', 'Validation'])
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
    else:
        ax[i].plot(history_1.history[met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        
plt.savefig('./Output/deeplabv3plus_metrics_plot.png', facecolor= 'w',transparent= False, bbox_inches= 'tight', dpi= 300)

# Evaluating DeepLabV3+ Model on Test Set

In [None]:
testing_gen = TestAugmentGenerator(test_images_dir = test_images, test_masks_dir = test_masks, batch_size = 32, target_size = (128, 128))

In [None]:
deeplabv3plus.load_weights("./Output/deeplabv3plus.h5")
deeplabv3plus_eval = deeplabv3plus.evaluate(testing_gen, steps=21, return_dict=True)

In [None]:
scores = {
    'deeplabv3plus': {},
    'vgg16_unet': {},
    'mobilenetv2_unet' : {}
}

In [None]:
scores['deeplabv3plus'] = deeplabv3plus_eval

In [None]:
Y_true_all_1, y_pred_all_1 = np.array([]), np.array([])
count = 0

for i in range(21):
    batch_img, batch_mask = next(testing_gen)
    pred_all= deeplabv3plus.predict(batch_img)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        true_msk = batch_mask[j]
        pred_msk = pred_all[j]
        Y_true = np.argmax(true_msk, axis=-1) # Convert one-hot to index
        y_pred = np.argmax(pred_msk, axis=-1) # Convert one-hot to index

        Y_true_flat = Y_true.flatten()
        y_pred_flat = y_pred.flatten()
            
        Y_true_all_1 = np.append(Y_true_all_1, Y_true_flat)
        y_pred_all_1 = np.append(y_pred_all_1, y_pred_flat)

In [None]:
print(Y_true_all_1.shape, y_pred_all_1.shape)

In [None]:
print(classification_report(Y_true_all_1, y_pred_all_1))

In [None]:
print('Accuracy:', accuracy_score(Y_true_all_1, y_pred_all_1))
print('Precision:', precision_score(Y_true_all_1, y_pred_all_1, average='weighted'))
print('Recall:', recall_score(Y_true_all_1, y_pred_all_1, average='weighted'))
print('F1 Score:', f1_score(Y_true_all_1, y_pred_all_1, average='weighted'))

In [None]:
cm_1 = confusion_matrix(Y_true_all_1, y_pred_all_1)
df_cm_1 = pd.DataFrame(cm_1, label_names, label_names)
fig, ax = plt.subplots(figsize=(14,12))
# sns.set(font_scale=1.4) # for label size
sns.heatmap(df_cm_1, annot=True, annot_kws={"size": 16}, cmap=plt.cm.YlGnBu)
plt.title('Confusion Matrix for DeepLabV3+ CNN\n', fontsize=16)
plt.savefig('./Output/confusion_matrix_1.png', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

In [None]:
per_class_iou_1 = iou(cm_1)
per_class_iou_1

In [None]:
per_class_acc_1 = class_accuracy(cm_1)
per_class_acc_1

In [None]:
deeplabv3plus_class_acc = pd.DataFrame(zip(label_names, per_class_acc_1[:6].T), columns=['Class', 'Accuracy'])
deeplabv3plus_class_acc['F1'] = f1_score(Y_true_all_1, y_pred_all_1, average=None)[:6]
deeplabv3plus_class_acc

In [None]:
plt.figure(figsize=(12, 7))
plt.title('Class-wise Accuracy of DeepLabV3+ CNN', fontsize=16)
sns.set_style("ticks")
sns.barplot(x="Class", y="Accuracy", data=deeplabv3plus_class_acc, palette='turbo', alpha = 0.8)
plt.grid(axis='y', color = 'lightgray', linestyle='-', linewidth=0.8)
plt.savefig('./Output/deeplabv3plus_class_acc', facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

# Predictions on Test Set Using DeepLabV3+ Model

In [None]:
!mkdir ./Output/deeplabv3plus_pred

In [None]:
count = 0
for i in range(5):
    batch_img,batch_mask = next(testing_gen)
    pred_all= deeplabv3plus.predict(batch_img)
    np.shape(pred_all)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        fig = plt.figure(figsize=(20,8))

        ax1 = fig.add_subplot(1,3,1)
        ax1.imshow(batch_img[j])
        ax1.set_title('Input Image', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax1.set_xticks(np.arange(0, 129, 16))
        ax1.set_yticks(np.arange(0, 129, 16))
        ax1.grid(False)

        ax2 = fig.add_subplot(1,3,2)
        ax2.set_title('Ground Truth Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax2.imshow(onehot_to_rgb(batch_mask[j],id2code))
        ax2.set_xticks(np.arange(0, 129, 16))
        ax2.set_yticks(np.arange(0, 129, 16))
        ax2.grid(False)

        ax3 = fig.add_subplot(1,3,3)
        ax3.set_title('Predicted Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax3.imshow(onehot_to_rgb(pred_all[j],id2code))
        ax3.set_xticks(np.arange(0, 129, 16))
        ax3.set_yticks(np.arange(0, 129, 16))
        ax3.grid(False)

        plt.savefig('./Output/deeplabv3plus_pred/deeplabv3plus_pred_{}.png'.format(count), facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 200)
        plt.show()

# 2. VGG16 Encoder Based UNet

In [None]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_vgg16_unet(input_shape, num_classes):
    """ Input """
    inputs = Input(input_shape)

    """ Pre-trained VGG16 Model """
    vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=inputs)

    """ Encoder """
    s1 = vgg16.get_layer("block1_conv2").output         
    s2 = vgg16.get_layer("block2_conv2").output         
    s3 = vgg16.get_layer("block3_conv3").output         
    s4 = vgg16.get_layer("block4_conv3").output         

    """ Bridge """
    b1 = vgg16.get_layer("block5_conv3").output         

    """ Decoder """
    d1 = decoder_block(b1, s4, 512)                     
    d2 = decoder_block(d1, s3, 256)                     
    d3 = decoder_block(d2, s2, 128)                     
    d4 = decoder_block(d3, s1, 64)   

    """ Dropout """
    x1 = Dropout(0.4)(d4)
    
    """ Output """
    outputs = Conv2D(num_classes, 1, padding="same", activation="softmax")(d4)
    
    model = Model(inputs, outputs, name="VGG16_UNet")
    return model

In [None]:
K.clear_session()

loss = tf.keras.losses.CategoricalCrossentropy()
vgg16_unet = build_vgg16_unet(input_shape = (128, 128, 3), num_classes = 7)
vgg16_unet.compile(optimizer=Adam(learning_rate = 0.0001), loss=loss, metrics=[dice_coef, "accuracy"])
vgg16_unet.summary()

In [None]:
color_map = defaultdict(dict)
color_map[Dropout]['fill'] = 'gray'

font = ImageFont.truetype("./fonts/OpenSans-Semibold.ttf", 32)
visualkeras.layered_view(vgg16_unet, legend=True, font=font, to_file='./Output/vgg16_unet_architecture.png', color_map=color_map, draw_volume=False)

# VGG16 UNet Model Training

In [None]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.0001, 40)

lr_scheduler = LearningRateScheduler(
    exponential_decay_fn,
    verbose=1
)

checkpoint = ModelCheckpoint(
    filepath = './Output/vgg16_unet.h5',
    save_best_only = True, 
#     save_weights_only = False,
    monitor = 'val_loss', 
    mode = 'auto', 
    verbose = 1
)

# earlystop = EarlyStopping(
#     monitor = 'val_loss', 
#     min_delta = 0.001, 
#     patience = 10, 
#     mode = 'auto', 
#     verbose = 1,
#     restore_best_weights = True
# )

csvlogger = CSVLogger(
    filename= "./Output/vgg16_unet_training.csv",
    separator = ",",
    append = False
)

callbacks = [checkpoint, csvlogger, lr_scheduler]

In [None]:
history_2 = vgg16_unet.fit(
    TrainAugmentGenerator(train_images_dir = train_images, train_masks_dir = train_masks, target_size = (128, 128)), 
    steps_per_epoch = steps_per_epoch,
    validation_data = ValAugmentGenerator(val_images_dir = val_images, val_masks_dir = val_masks, target_size = (128, 128)), 
    validation_steps = validation_steps, 
    epochs = 40,
    callbacks = callbacks,
    use_multiprocessing = False,
    verbose = 1
)

In [None]:
with open('./Output/trainHistoryDict_vgg16_unet', 'wb') as file_pi:
    pickle.dump(history_2.history, file_pi)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(20, 12))
ax = ax.ravel()
sns.set_style("ticks")
metrics = ['Dice Coefficient', 'Accuracy', 'Loss', 'Learning Rate']

for i, met in enumerate(['dice_coef', 'accuracy', 'loss', 'lr']): 
    if met != 'lr':
        ax[i].plot(history_2.history[met], '-')
        ax[i].plot(history_2.history['val_' + met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].legend(['Train', 'Validation'])
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
    else:
        ax[i].plot(history_2.history[met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        
plt.savefig('./Output/vgg16_unet_metrics_plot.png', facecolor= 'w',transparent= False, bbox_inches= 'tight', dpi= 300)

# Evaluating VGG16-UNet Model on Test Set

In [None]:
testing_gen = TestAugmentGenerator(test_images_dir = test_images, test_masks_dir = test_masks, batch_size = 32, target_size = (128, 128))

In [None]:
vgg16_unet.load_weights("./Output/vgg16_unet.h5")
vgg16_unet_eval = vgg16_unet.evaluate(testing_gen, steps=21, return_dict=True)

In [None]:
scores['vgg16_unet'] = vgg16_unet_eval

In [None]:
Y_true_all_2, y_pred_all_2 = np.array([]), np.array([])
count = 0

for i in range(21):
    batch_img, batch_mask = next(testing_gen)
    pred_all= vgg16_unet.predict(batch_img)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        true_msk = batch_mask[j]
        pred_msk = pred_all[j]
        Y_true = np.argmax(true_msk, axis=-1) # Convert one-hot to index
        y_pred = np.argmax(pred_msk, axis=-1) # Convert one-hot to index

        Y_true_flat = Y_true.flatten()
        y_pred_flat = y_pred.flatten()
            
        Y_true_all_2 = np.append(Y_true_all_2, Y_true_flat)
        y_pred_all_2 = np.append(y_pred_all_2, y_pred_flat)

In [None]:
print(Y_true_all_2.shape, y_pred_all_2.shape)

In [None]:
print(classification_report(Y_true_all_2, y_pred_all_2))

In [None]:
print('Accuracy:', accuracy_score(Y_true_all_2, y_pred_all_2))
print('Precision:', precision_score(Y_true_all_2, y_pred_all_2, average='weighted'))
print('Recall:', recall_score(Y_true_all_2, y_pred_all_2, average='weighted'))
print('F1 Score:', f1_score(Y_true_all_2, y_pred_all_2, average='weighted'))

In [None]:
print('Class-wise F1 Score\n', f1_score(Y_true_all_2, y_pred_all_2, average=None))

In [None]:
cm_2 = confusion_matrix(Y_true_all_2, y_pred_all_2)
df_cm_2 = pd.DataFrame(cm_2, label_names, label_names)
fig, ax = plt.subplots(figsize=(14,12))
sns.heatmap(df_cm_2, annot=True, annot_kws={"size": 16}, cmap=plt.cm.YlGnBu)
plt.title('Confusion Matrix for VGG16-UNet CNN\n', fontsize=16)
plt.savefig('./Output/confusion_matrix_2.png', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

In [None]:
per_class_acc_2 = class_accuracy(cm_2)
per_class_acc_2

In [None]:
vgg16_unet_class_acc = pd.DataFrame(zip(label_names, per_class_acc_2[:6].T), columns=['Class', 'Accuracy'])
vgg16_unet_class_acc['F1'] = f1_score(Y_true_all_2, y_pred_all_2, average=None)[:6]
vgg16_unet_class_acc

In [None]:
plt.figure(figsize=(12, 7))
plt.title('Class-wise Accuracy of VGG16-UNet CNN', fontsize=16)
sns.set_style("ticks")
sns.barplot(x="Class", y="Accuracy", data=vgg16_unet_class_acc, palette='turbo', alpha = 0.8)
plt.grid(axis='y', color = 'lightgray', linestyle='-', linewidth=0.8)
plt.savefig('./Output/vgg16_unet_class_acc', facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

# Predictions on Test Set Using VGG16 UNet Model

In [None]:
!mkdir ./Output/vgg16_unet_pred

In [None]:
count = 0
for i in range(5):
    batch_img,batch_mask = next(testing_gen)
    pred_all= vgg16_unet.predict(batch_img)
    np.shape(pred_all)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        fig = plt.figure(figsize=(20,8))

        ax1 = fig.add_subplot(1,3,1)
        ax1.imshow(batch_img[j])
        # cv2.imwrite(f'./Output/vgg16_unet_pred/{count}_img.tif', cv2.cvtColor(batch_img[j], cv2.COLOR_BGR2RGB))
        plt.imsave(f'./Output/vgg16_unet_pred/{count}_img.png', batch_img[j])
        # tifi.imsave(f'./Output/vgg16_unet_pred/{count}_img.tif', batch_img[j])
        # cv2.imwrite(f'./Output/vgg16_unet_pred/{count}_img.png', cv2.convertScaleAbs(batch_img[j], alpha=(255.0)))
        ax1.set_title('Input Image', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax1.set_xticks(np.arange(0, 129, 16))
        ax1.set_yticks(np.arange(0, 129, 16))
        ax1.grid(False)

        ax2 = fig.add_subplot(1,3,2)
        ax2.set_title('Ground Truth Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax2.imshow(onehot_to_rgb(batch_mask[j],id2code))
        cv2.imwrite(f'./Output/vgg16_unet_pred/{count}_gt.png', cv2.cvtColor(onehot_to_rgb(batch_mask[j],id2code), cv2.COLOR_BGR2RGB))
        ax2.set_xticks(np.arange(0, 129, 16))
        ax2.set_yticks(np.arange(0, 129, 16))
        ax2.grid(False)

        ax3 = fig.add_subplot(1,3,3)
        ax3.set_title('Predicted Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax3.imshow(onehot_to_rgb(pred_all[j],id2code))
        cv2.imwrite(f'./Output/vgg16_unet_pred/{count}_pred.png', cv2.cvtColor(onehot_to_rgb(pred_all[j],id2code), cv2.COLOR_BGR2RGB))
        ax3.set_xticks(np.arange(0, 129, 16))
        ax3.set_yticks(np.arange(0, 129, 16))
        ax3.grid(False)

        # plt.savefig('./Output/vgg16_unet_pred/vgg16_unet_pred_{}.png'.format(count), facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 200)
        plt.show()

# 3. MobileNetV2 Encoder Based UNet

In [None]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip])
    x = conv_block(x, num_filters)

    return x

def build_mobilenetv2_unet(input_shape, num_classes):   
    """ Input """
    inputs = Input(shape=input_shape)

    """ Pre-trained MobileNetV2 """
    encoder = MobileNetV2(include_top=False, weights="imagenet",
        input_tensor=inputs, alpha=1.4)

    """ Encoder """
    s1 = encoder.get_layer("input_1").output                
    s2 = encoder.get_layer("block_1_expand_relu").output    
    s3 = encoder.get_layer("block_3_expand_relu").output    
    s4 = encoder.get_layer("block_6_expand_relu").output    

    """ Bridge """
    b1 = encoder.get_layer("block_13_expand_relu").output   

    """ Decoder """
    d1 = decoder_block(b1, s4, 512)                         
    d2 = decoder_block(d1, s3, 256)                         
    d3 = decoder_block(d2, s2, 128)                         
    d4 = decoder_block(d3, s1, 64)                          
    
    """ Dropout """
    x1 = Dropout(0.4)(d4)
    
    """ Output """
    outputs = Conv2D(num_classes, 1, padding="same", activation="softmax")(x1)

    model = Model(inputs, outputs, name="MobileNetV2-UNet")
    return model

In [None]:
K.clear_session()

loss = tf.keras.losses.CategoricalCrossentropy()
mobilenetv2_unet = build_mobilenetv2_unet(input_shape = (128, 128, 3), num_classes = 7)
mobilenetv2_unet.compile(optimizer=Adam(learning_rate = 0.0001), loss=loss, metrics=[dice_coef, "accuracy"])
mobilenetv2_unet.summary()

In [None]:
color_map = defaultdict(dict)
color_map[Dropout]['fill'] = 'gray'

font = ImageFont.truetype("./fonts/OpenSans-Semibold.ttf", 32)
visualkeras.layered_view(mobilenetv2_unet, legend=True, font=font, to_file='./Output/mobilenetv2_unet_architecture.png', color_map=color_map, draw_volume=False)

# MobileNetV2-UNet Model Training

In [None]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.0001, 40)

lr_scheduler = LearningRateScheduler(
    exponential_decay_fn,
    verbose=1
)

checkpoint = ModelCheckpoint(
    filepath = './Output/mobilenetv2_unet.h5',
    save_best_only = True, 
#     save_weights_only = False,
    monitor = 'val_loss', 
    mode = 'auto', 
    verbose = 1
)

# earlystop = EarlyStopping(
#     monitor = 'val_loss', 
#     min_delta = 0.001, 
#     patience = 10, 
#     mode = 'auto', 
#     verbose = 1,
#     restore_best_weights = True
# )

csvlogger = CSVLogger(
    filename= "./Output/mobilenetv2_unet_training.csv",
    separator = ",",
    append = False
)

callbacks = [checkpoint, csvlogger, lr_scheduler]

In [None]:
history_3 = mobilenetv2_unet.fit(
    TrainAugmentGenerator(train_images_dir = train_images, train_masks_dir = train_masks, target_size = (128, 128)), 
    steps_per_epoch = steps_per_epoch,
    validation_data = ValAugmentGenerator(val_images_dir = val_images, val_masks_dir = val_masks, target_size = (128, 128)), 
    validation_steps = validation_steps, 
    epochs = 40,
    callbacks = callbacks,
    use_multiprocessing = False,
    verbose = 1
)

In [None]:
with open('./Output/trainHistoryDict_mobilenetv2_unet', 'wb') as file_pi:
    pickle.dump(history_3.history, file_pi)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(20, 12))
ax = ax.ravel()
sns.set_style("ticks")

metrics = ['Dice Coefficient', 'Accuracy', 'Loss', 'Learning Rate']

for i, met in enumerate(['dice_coef', 'accuracy', 'loss', 'lr']): 
    if met != 'lr':
        ax[i].plot(history_3.history[met], '-')
        ax[i].plot(history_3.history['val_' + met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].legend(['Train', 'Validation'])
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
    else:
        ax[i].plot(history_3.history[met], '-')
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs', fontsize=12)
        ax[i].set_ylabel(metrics[i], fontsize=12)
#         ax[i].set_xticks(np.arange(0,100,4))
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        
plt.savefig('./Output/mobilenetv2_unet_metrics_plot.png', facecolor= 'w',transparent= False, bbox_inches= 'tight', dpi= 300)

# Evaluating MobileNetV2-UNet Model on Test Set

In [None]:
testing_gen = TestAugmentGenerator(test_images_dir = test_images, test_masks_dir = test_masks, batch_size = 32, target_size = (128, 128))

In [None]:
mobilenetv2_unet.load_weights("./Output/mobilenetv2_unet.h5")
mobilenetv2_unet_eval = mobilenetv2_unet.evaluate(testing_gen, steps=21, return_dict=True)

In [None]:
scores['mobilenetv2_unet'] = mobilenetv2_unet_eval

In [None]:
Y_true_all_3, y_pred_all_3 = np.array([]), np.array([])
count = 0

for i in range(21):
    batch_img, batch_mask = next(testing_gen)
    pred_all= mobilenetv2_unet.predict(batch_img)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        true_msk = batch_mask[j]
        pred_msk = pred_all[j]
        Y_true = np.argmax(true_msk, axis=-1) # Convert one-hot to index
        y_pred = np.argmax(pred_msk, axis=-1) # Convert one-hot to index

        Y_true_flat = Y_true.flatten()
        y_pred_flat = y_pred.flatten()
            
        Y_true_all_3 = np.append(Y_true_all_3, Y_true_flat)
        y_pred_all_3 = np.append(y_pred_all_3, y_pred_flat)

In [None]:
print(Y_true_all_3.shape, y_pred_all_3.shape)

In [None]:
print('Accuracy:', accuracy_score(Y_true_all_3, y_pred_all_3))
print('Precision:', precision_score(Y_true_all_3, y_pred_all_3, average='weighted'))
print('Recall:', recall_score(Y_true_all_3, y_pred_all_3, average='weighted'))
print('F1 Score:', f1_score(Y_true_all_3, y_pred_all_3, average='weighted'))

In [None]:
print(classification_report(Y_true_all_3, y_pred_all_3))

In [None]:
cm_3 = confusion_matrix(Y_true_all_3, y_pred_all_3)
df_cm_3 = pd.DataFrame(cm_3, label_names, label_names)
fig, ax = plt.subplots(figsize=(14,12))
sns.heatmap(df_cm_3, annot=True, annot_kws={"size": 16}, cmap=plt.cm.YlGnBu)
plt.title('Confusion Matrix for MobileNetV2-UNet CNN\n', fontsize=16)
plt.savefig('./Output/confusion_matrix_3.png', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

In [None]:
per_class_acc_3 = class_accuracy(cm_3)
per_class_acc_3

In [None]:
mobilenetv2_unet_class_acc = pd.DataFrame(zip(label_names, per_class_acc_3[:6].T), columns=['Class', 'Accuracy'])
mobilenetv2_unet_class_acc['F1'] = f1_score(Y_true_all_3, y_pred_all_3, average=None)[:6]
mobilenetv2_unet_class_acc

In [None]:
plt.figure(figsize=(12, 7))
plt.title('Class-wise Accuracy of MobileNetV2-UNet CNN', fontsize=16)
sns.set_style("ticks")
sns.barplot(x="Class", y="Accuracy", data=mobilenetv2_unet_class_acc, palette='turbo', alpha = 0.8)
plt.grid(axis='y', color = 'lightgray', linestyle='-', linewidth=0.8)
plt.savefig('./Output/mobilenetv2_unet_class_acc', facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 300)
plt.show()

# Predictions on Test Set Using MobileNetV2-UNet Model

In [None]:
!mkdir mobilenetv2_unet_pred

In [None]:
count = 0
for i in range(5):
    batch_img,batch_mask = next(testing_gen)
    pred_all= mobilenetv2_unet.predict(batch_img)
    np.shape(pred_all)
    
    for j in range(0,np.shape(pred_all)[0]):
        count += 1
        fig = plt.figure(figsize=(20,8))

        ax1 = fig.add_subplot(1,3,1)
        ax1.imshow(batch_img[j])
        ax1.set_title('Input Image', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax1.set_xticks(np.arange(0, 129, 16))
        ax1.set_yticks(np.arange(0, 129, 16))
        ax1.grid(False)

        ax2 = fig.add_subplot(1,3,2)
        ax2.set_title('Ground Truth Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax2.imshow(onehot_to_rgb(batch_mask[j],id2code))
        ax2.set_xticks(np.arange(0, 129, 16))
        ax2.set_yticks(np.arange(0, 129, 16))
        ax2.grid(False)

        ax3 = fig.add_subplot(1,3,3)
        ax3.set_title('Predicted Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
        ax3.imshow(onehot_to_rgb(pred_all[j],id2code))
        ax3.set_xticks(np.arange(0, 129, 16))
        ax3.set_yticks(np.arange(0, 129, 16))
        ax3.grid(False)

        plt.savefig('./Output/mobilenetv2_unet_pred/mobilenetv2_unet_pred_{}.png'.format(count), facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 200)
        plt.show()

# Predict on Whole Satellite Imagery

In [None]:
tiles_dir = '../Mumbai_Data/jp22/jp22/'

In [None]:
def predict_on_tiles(tiles_dir, model, save_dir):
    files = os.listdir(tiles_dir)
    for tile in tqdm(files, desc="[Predicting…]", ascii=False, ncols=75):
        img = cv2.cvtColor(cv2.imread(f'{tiles_dir}{tile}'), cv2.COLOR_BGR2RGB)
        img = np.array(img).astype('float32')/255
        img = np.expand_dims(img, axis=0)
        pred_img = model.predict(img)
        pred_img = onehot_to_rgb(pred_img[0],id2code)
        pred_filename = tile.split('.tif')[0]
        cv2.imwrite(f'{save_dir}{pred_filename}.png', cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB))

In [None]:
def stitch_tiles(pred_tiles_dir, save_filename, save_dir):
#     !mkdir hconcat_tiles
    img_list = imread_collection(f'{pred_tiles_dir}*.png')
    j = 0
    for i in range(0, len(os.listdir(pred_tiles_dir)), 285):
        img = cv2.hconcat(img_list[i:i+285])
        cv2.imwrite(f'{save_dir}hconcat_tile_{j}.png', cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        j += 1
                   
    hconcat_imgs = imread_collection(f'{save_dir}*.png')
    final_pred_tile = cv2.vconcat(hconcat_imgs)
    cv2.imwrite(f'{save_filename}.png', cv2.cvtColor(final_pred_tile, cv2.COLOR_BGR2RGB))
    return final_pred_tile
        

## 1. DeepLabV3+ Semantic Map Prediction

In [None]:
!mkdir deeplabv3plus_pred_tiles

In [None]:
deeplabv3plus.load_weights("./Output/deeplabv3plus.h5")

In [None]:
predict_on_tiles(tiles_dir=tiles_dir, model=deeplabv3plus, save_dir='./deeplabv3plus_pred_tiles/')

In [None]:
!zip -r deeplabv3plus_pred_tiles.zip './deeplabv3plus_pred_tiles'
!mkdir deeplabv3plus_stitched_pred

In [None]:
deeplabv3plus_tile = stitch_tiles(pred_tiles_dir='./deeplabv3plus_pred_tiles/', save_filename='deeplabv3plus_jp22', save_dir='./deeplabv3plus_stitched_pred/')

## VGG16-UNet Semantic Maps Predictions

In [None]:
!mkdir vgg16_unet_pred_tiles

In [None]:
vgg16_unet.load_weights("./Output/vgg16_unet.h5")

In [None]:
predict_on_tiles(tiles_dir=tiles_dir, model=vgg16_unet, save_dir='./vgg16_unet_pred_tiles/')

In [None]:
!zip -r vgg16_unet_pred_tiles_v2.zip './vgg16_unet_pred_tiles'
!mkdir vgg16_stitched_pred

In [None]:
vgg16_unet_tile = stitch_tiles(pred_tiles_dir=vgg16_unet_pred_tiles, save_filename='vgg16_unet_jp22', save_dir='./vgg16_stitched_pred/')

## MobileNetV2-UNet Semantic Maps Predictions

In [None]:
!mkdir mobilenetv2_unet_pred_tiles

In [None]:
mobilenetv2_unet.load_weights("./Output/mobilenetv2_unet.h5")

In [None]:
predict_on_tiles(tiles_dir=tiles_dir, model=mobilenetv2_unet, save_dir='./mobilenetv2_unet_pred_tiles/')

In [None]:
!zip -r mobilenetv2_unet_pred_tiles.zip './mobilenetv2_unet_pred_tiles'
!mkdir mobilenetv2_unet_stitched_pred

In [None]:
mobilenetv2_unet_tile = stitch_tiles(pred_tiles_dir='./mobilenetv2_unet_pred_tiles/', save_filename='mobilenetv2_unet_jp22', save_dir='./mobilenetv2_unet_stitched_pred/')