# Road Extraction

## Imports

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import os
import torch
import random
import cv2
import tqdm
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import albumentations as album
from PIL import Image

%matplotlib inline

## Global Variables

In [None]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASS_NAMES = ['road', 'background']
CLASS_RGB_VALUES = [[255,255,255], [0,0,0]]
IMG_SIZE = 1024
IMG_PATH = '/home/ah2719/FYP/Spatial_Finance_Transport/data/road_extraction_example.jpg'
PRED_MASK_IMG_PATH = '/home/ah2719/FYP/Spatial_Finance_Transport/data/pred_mask.jpeg'

# Get RGB values of required classes
SELECT_CLASS_INDICES = [CLASS_NAMES.index(cls.lower()) for cls in CLASS_NAMES]
SELECT_CLASS_RGB_VALUES =  np.array(CLASS_RGB_VALUES)[SELECT_CLASS_INDICES]

## Helper Functions

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

In [None]:
# Perform one hot encoding on label (TRAINING ONLY)
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map

In [None]:
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

In [None]:
# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(pred_mask, img):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """


    result_img = np.empty(img.shape)

    print("pred_mask shape: {}".format(pred_mask.shape))
    
    for i in range(IMG_SIZE):
        for j in range(IMG_SIZE):
            if np.any(pred_mask[i][j]):
                result_img[i][j] = img[i][j]
            else:
                result_img[i][j] = np.array([0,0,0])
    

    where = np.where(pred_mask == 1)

    print("final result_img shape: {}".format(result_img.shape))
    return result_img

In [None]:
def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [None]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

## Road Extraction Model
https://ieeexplore.ieee.org/document/8127098

https://www.kaggle.com/code/balraj98/road-extraction-from-satellite-images-deeplabv3/notebook

In [None]:
# May need to refer to https://www.kaggle.com/code/balraj98/road-extraction-from-satellite-images-deeplabv3/comments if errors
#model = torch.load("/home/ah2719/FYP/Spatial_Finance_Transport/models/best_model.pth")

In [None]:
chkpt = torch.load("/home/ah2719/FYP/Spatial_Finance_Transport/models/state_dict.pth")
model = smp.DeepLabV3Plus(
    encoder_name='resnet50', 
    encoder_weights='imagenet', 
    classes=2, 
    activation='sigmoid',
)
model.load_state_dict(chkpt)
model.eval()

## Test Data Predictions

In [None]:
# Random inference on a dataset not used for training process
random_idx = random.randint(1,10)
img = cv2.imread(IMG_PATH)
img= cv2.resize(img,(IMG_SIZE,IMG_SIZE))
img= cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

#show image
plt.imshow(img)

In [None]:
# preprocessing
img_preprocessed= preprocessing_fn(img)
x_tensor  = to_tensor(img_preprocessed)
x_tensor = torch.from_numpy(x_tensor).unsqueeze(0)
print(x_tensor.shape)

In [None]:
pred_mask = model(x_tensor)
pred_mask = pred_mask.squeeze().detach().numpy()
pred_mask = np.transpose(pred_mask,(1,2,0))

pred_mask_reversed = reverse_one_hot(pred_mask)
pred_mask_processed = colour_code_segmentation(pred_mask_reversed, img)

pred_mask_processed = (pred_mask_processed * 255).astype(np.uint8)

# show prediction
plt.imshow(pred_mask_processed)

In [None]:
pred_mask_im = Image.fromarray(pred_mask_processed)
pred_mask_im.save("/home/ah2719/FYP/Spatial_Finance_Transport/data/pred_mask.jpeg")

## Save Model State

In [None]:
# To save model state dict
#torch.save(model.state_dict(), "/home/ah2719/FYP/Spatial_Finance_Transport/models/state_dict.pth")

## Line Segment Detector

In [None]:
#Read gray image
img = cv2.imread(PRED_MASK_IMG_PATH,0)

#Create default parametrization LSD
lsd = cv2.createLineSegmentDetector(0)

#Detect lines in the image
lines = lsd.detect(img)[0] #Position 0 of the returned tuple are the detected lines

#Draw detected lines in the image
drawn_img = lsd.drawSegments(img,lines)

#Show image
plt.imshow(drawn_img)