# Classify satellite images into building footprints 
If everything worked so far, we can then move on to actually classify an image provided we have enough confidence in the trained model. In this section, we will now load an image that we truly want to classify for real-world applications.

This code runs only with images that is ~5.5GB of disk space.
Bigger images will crash the notebook.

*Version: 0.2*

In [11]:
from google.colab import drive
drive.mount("/content/drive/")

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
time: 1.76 ms (started: 2021-01-18 17:12:15 +00:00)


In [1]:
#Import libraries

import os
import gdal
import numpy as np
import random
import math

import matplotlib.pyplot as plt 
import matplotlib.patches as mpatches

!pip install rasterio
import rasterio
from rasterio.windows import Window

import tensorflow as tf
import keras
from tensorflow.python.keras import backend as K

import seaborn as sea

!pip install ipython-autotime
%load_ext autotime

!pip install tqdm
from tqdm import trange

time: 2.45 s (started: 2021-01-18 19:06:51 +00:00)


## Load pre-requisite functions and models for predictions

In [2]:
#Pre-requisite codes for loading the model

def accuracy(y_true, y_pred, threshold=0.5):
    """compute accuracy"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    y_pred = K.round(y_pred +0.5 - threshold)
    return K.equal(K.round(y_true), K.round(y_pred))

def dice_coef(y_true, y_pred, smooth=0.0000001):
    """compute dice coef"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    union = K.sum(y_true, axis=-1) + K.sum(y_pred, axis=-1)
    return K.mean((2. * intersection + smooth) / (union + smooth), axis=-1)

def dice_loss(y_true, y_pred):
    """compute dice loss"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    return 1 - dice_coef(y_true, y_pred)

# K.round() returns the Element-wise rounding to the closest integer!!!
# So the threshold to determine a true positive is set here!!!!!
def true_positives(y_true, y_pred, threshold=0.5):
    """compute true positive"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    y_pred = K.round(y_pred +0.5 - threshold)
    return K.round(y_true * y_pred)

def false_positives(y_true, y_pred, threshold=0.5):
    """compute false positive"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    y_pred = K.round(y_pred +0.5 - threshold)
    return K.round((1 - y_true) * y_pred)

def true_negatives(y_true, y_pred, threshold=0.5):
    """compute true negative"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    y_pred = K.round(y_pred +0.5 - threshold)
    return K.round((1 - y_true) * (1 - y_pred))

def false_negatives(y_true, y_pred, threshold=0.5):
    """compute false negative"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    y_pred = K.round(y_pred +0.5 - threshold)
    return K.round((y_true) * (1 - y_pred))

# K.sum() returns a single integer output unlike the K.round() which returns an element-wise matrix
def sensitivity(y_true, y_pred):
    """compute sensitivity (recall)"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    tp = true_positives(y_true, y_pred)
    fn = false_negatives(y_true, y_pred)
    return K.sum(tp) / (K.sum(tp) + K.sum(fn))

def specificity(y_true, y_pred):
    """compute specificity ()"""
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    tn = true_negatives(y_true, y_pred)
    fp = false_positives(y_true, y_pred)
    return K.sum(tn) / (K.sum(tn) + K.sum(fp))

def recall_m(y_true, y_pred):
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    tp = true_positives(y_true, y_pred)
    fn = false_negatives(y_true, y_pred)
    recall = K.sum(tp) / (K.sum(tp) + K.sum(fn)+ K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    #y_t = y_true[...,0]
    #y_t = y_t[...,np.newaxis]
    tp = true_positives(y_true, y_pred)
    fp = false_positives(y_true, y_pred)
    precision = K.sum(tp) / (K.sum(tp) + K.sum(fp)+ K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

time: 65.8 ms (started: 2021-01-18 18:15:56 +00:00)


In [3]:
# Tversky
def tversky(y_true, y_pred, alpha=0.3, beta=0.7):
    """
    Function to calculate the Tversky loss for imbalanced data
    :param prediction: the logits
    :param ground_truth: the segmentation ground_truth
    :param alpha: weight of false positives
    :param beta: weight of false negatives
    :param weight_map:
    :return: the loss
    """
    '''
    EPSILON = 0.00001 (default)
    '''
    y_true_pos = K.flatten(y_true)
    y_pred_pos = K.flatten(y_pred)
    # TP
    true_pos = K.sum(y_true_pos * y_pred_pos)
    # FN
    false_neg = K.sum(y_true_pos * (1-y_pred_pos))
    # FP
    false_pos = K.sum((1-y_true_pos) * y_pred_pos)
    return 1 - (true_pos + K.epsilon())/(true_pos + alpha * false_neg + beta * false_pos + K.epsilon())

time: 4.55 ms (started: 2021-01-18 18:15:56 +00:00)


In [4]:
# Load the model
from keras.models import load_model

model = load_model("/content/drive/MyDrive/Kushanav MSc Thesis shared folder/Local Dataset/All tiles/Saved models and weights/Saved models/resunet_12_12_1e-05.hdf5", 
                   custom_objects={"tversky": tversky, "f1_m": f1_m, "accuracy": accuracy, "precision_m": precision_m, "recall_m": recall_m}, compile=True) 

time: 13.8 s (started: 2021-01-18 18:15:56 +00:00)


## Load the image

In [2]:
#Import the image to classify
path = "/content/drive/MyDrive/Kushanav MSc Thesis shared folder/Local Dataset/All tiles/Test images/Entire Study Area/Images to classify"
img_dir = os.path.join(path, "A3.tif")

#Load the image to classify
image = gdal.Open(img_dir)
bands_test = [image.GetRasterBand(i+1).ReadAsArray() for i in trange(image.RasterCount)]
new_image = np.stack(bands_test, axis=2)       
del bands_test

# To store the meta data, open with Rasterio
src = rasterio.open(img_dir)

100%|██████████| 3/3 [01:08<00:00, 22.69s/it]


time: 1min 13s (started: 2021-01-18 19:06:54 +00:00)


In [None]:
#CIEW
print(f"Shape of the satellite image{new_image.shape}")

Shape of the satellite image(25088, 30208, 3)
time: 4.6 ms (started: 2021-01-17 17:35:45 +00:00)


## Patch generation and classification

In [6]:
# Patch the image into 512x512 to predict

def gridwise_sample(imgarray, patchsize):

    """Extract sample patches of size patchsize x patchsize from an image (imgarray) in a gridwise manner.
    """
    nrows, ncols, nbands = imgarray.shape
    patchsamples = np.zeros(shape=(0, patchsize, patchsize, nbands),
                            dtype=imgarray.dtype)
    for i in trange(int(nrows/patchsize)):
        for j in trange(int(ncols/patchsize)):
            tocat = imgarray[i*patchsize:(i+1)*patchsize,
                             j*patchsize:(j+1)*patchsize, :]
            tocat = np.expand_dims(tocat, axis=0)
            patchsamples = np.concatenate((patchsamples, tocat),
                                          axis=0)
    return patchsamples

time: 4.54 ms (started: 2021-01-18 18:17:37 +00:00)


In [7]:
# GENERATE PATCH TILES OF THE IMAGE
PATCHSIZE = 512

# Sample each tile systematically in a gridwise manner
patch = gridwise_sample(new_image, PATCHSIZE)

  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/59 [00:00<?, ?it/s][A
 39%|███▉      | 23/59 [00:00<00:00, 225.50it/s][A
 54%|█████▍    | 32/59 [00:00<00:00, 153.11it/s][A
 68%|██████▊   | 40/59 [00:00<00:00, 106.54it/s][A
 81%|████████▏ | 48/59 [00:00<00:00, 78.75it/s] [A
100%|██████████| 59/59 [00:00<00:00, 77.11it/s]
  2%|▏         | 1/50 [00:00<00:37,  1.30it/s]
  0%|          | 0/59 [00:00<?, ?it/s][A
  7%|▋         | 4/59 [00:00<00:01, 37.68it/s][A
 14%|█▎        | 8/59 [00:00<00:01, 36.92it/s][A
 20%|██        | 12/59 [00:00<00:01, 35.22it/s][A
 27%|██▋       | 16/59 [00:00<00:01, 34.21it/s][A
 32%|███▏      | 19/59 [00:00<00:01, 32.80it/s][A
 37%|███▋      | 22/59 [00:00<00:01, 31.61it/s][A
 42%|████▏     | 25/59 [00:00<00:01, 30.56it/s][A
 47%|████▋     | 28/59 [00:00<00:01, 29.64it/s][A
 53%|█████▎    | 31/59 [00:00<00:00, 28.82it/s][A
 58%|█████▊    | 34/59 [00:01<00:00, 27.91it/s][A
 63%|██████▎   | 37/59 [00:01<00:00, 26.91it/s][A
 68%|██████▊   

time: 26min 39s (started: 2021-01-18 18:17:37 +00:00)





In [8]:
#Predict on the satellite image
prediction = model.predict(patch)

print("The predictions were tested on %i number patches." % (patch.shape[0]))

The predictions were tested on 2950 number patches.
time: 56.4 s (started: 2021-01-18 18:44:16 +00:00)


In [3]:
# Saved the predicted patches as npy in order to avoid OOM issue
save_path = "/content/drive/MyDrive/Kushanav MSc Thesis shared folder/Local Dataset/All tiles/Test images/Entire Study Area/Output predictions/A3.npy"
#saved_array = np.save(save_path, prediction)

time: 1.23 ms (started: 2021-01-18 19:08:07 +00:00)


In [4]:
# Load the saved file "{}.npy" after saving if the notebook crashes 
# NO NEED TO RUN THIS CELL IF THE NOTEBOOK DOES NOT CRASH AFTER STITCHING THE PREDICTED PATCHES
datum = np.load(save_path)

time: 42.7 s (started: 2021-01-18 19:08:07 +00:00)


In [5]:
# Check the shape of loaded array
print(f'The shape of the loaded data is {datum.shape}')

The shape of the loaded data is (2950, 512, 512, 1)
time: 1.12 ms (started: 2021-01-18 19:08:50 +00:00)


## Stitch predicted patch images 
Stich the predicted image into one single image that is almost same as the satellite image

In [6]:
# Save only the number of rows and columns according to the test image
nrows, ncols = new_image[:,:,0].shape
PATCHSIZE = 512
# Iterate loop to generate a combined prediction image from the many predicted image patches 
combo = []
patch_col = math.floor(ncols/PATCHSIZE) 
patch_row = math.floor(nrows/PATCHSIZE)

for i in range(patch_row):
  patch = np.concatenate(datum[patch_col*i:patch_col*(i+1)], axis=1) 
  combo.append(patch)

join = np.concatenate(combo, axis=0)
stacked_image = join[:,:,0]

time: 3.99 s (started: 2021-01-18 19:10:55 +00:00)


In [7]:
#CIEW
print(f"Total number of rows and columns for the stitched predicted image: {stacked_image.shape}")

Total number of rows and columns for the stitched predicted image: (25088, 30720)
time: 1.11 ms (started: 2021-01-18 17:52:05 +00:00)


## Window Tranformation
To add geo-reference to the predicted image from the satellite image

In [7]:
# Size of pixels of the predicted stacked image
xsize, ysize = stacked_image[:,:].shape

# Generate a random window location / Comes from the OG image
xmin, xmax = 0, src.width - xsize
ymin, ymax = 0, src.height - ysize
xoff, yoff = 0, 0  #random.randint(xmin, xmax), random.randint(ymin, ymax)

# Create the window and calculate the transformation objects from the source data (OG image)
window = Window(xoff, yoff, xsize, ysize)
transform = src.window_transform(window)

# Update the profile of the new windowed image
profile = src.profile # Comes from the OG image
src.profile.update({
    "height": ysize,
    "width": xsize,
    "transform": transform
})

time: 27.3 ms (started: 2021-01-18 19:10:59 +00:00)


In [8]:
# Export the geo-referenced predicted image as a tiff file
new_transform = src.meta["transform"]
new_crs = src.meta["crs"]

new_tiff = rasterio.open("/content/drive/MyDrive/Kushanav MSc Thesis shared folder/Local Dataset/All tiles/Test images/Entire Study Area/Output predictions/A3.tif",
                         mode = "w",
                         height = stacked_image.shape[0],
                         width = stacked_image.shape[1],
                         driver = "GTiff",
                         count = 1,
                         dtype = str(stacked_image.dtype), # Here, the dtype comes from the stacked predicted image
                         crs = new_crs,
                         transform = new_transform)
new_tiff.write(stacked_image, 1)
new_tiff.close()
print("Geo-reference Transformation Successful !!!")

Geo-reference Transformation Successful !!!
time: 30.1 s (started: 2021-01-18 19:11:00 +00:00)
