# Road Segmentation Project


In [None]:
# Imports
import math
import os
import re
import cv2
import torch
import numpy as np
import parameters as params
from utils import utils
import trainer
from processing import augment
import matplotlib.pyplot as plt
from glob import glob
from random import sample
from PIL import Image
from torch import nn
from train import train
from sklearn.model_selection import train_test_split
from utils.datasets import ImageDataset
from utils.losses import DiceBCELoss

In [None]:
AUGMENT_FACTOR = 0

# Loading data
images_org = utils.load_images(os.path.join(params.ROOT_PATH, 'training', 'images'), False)
masks_org = utils.load_images(os.path.join(params.ROOT_PATH, 'training', 'groundtruth'), True)

# Apply augmentation
images_aug, masks_aug = augment.augment_data(images_org, masks_org, AUGMENT_FACTOR)

# Make the range be between 0 and 1 and convert to array
images = np.stack([img/255.0 for img in images_aug]).astype(np.float32)
masks = np.stack([mask/255.0 for mask in masks_aug]).astype(np.float32)
print("Finished data augmentation.")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# train_patches, train_labels = utils.image_to_patches(train_images, train_masks)
# val_patches, val_labels = utils.image_to_patches(val_images, val_masks)

In [None]:
# INSPECT AUGMENTED IMAGES
indices = np.linspace(0, 144 * (AUGMENT_FACTOR+1) - 1, 5, dtype=int)
for i in indices:
    print("INDEX:" + str(i), "MIN: " + str(np.min(images[i])), "MAX: " + str(np.max(images[i])))
    utils.show_image(images[i], masks[i])

In [None]:
masks[5].shape

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=0.2, random_state=42, shuffle=False
)

# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset(train_images, train_masks, device, use_patches=False, resize_to=(params.RESIZE, params.RESIZE))
val_dataset = ImageDataset(val_images, val_masks, device, use_patches=False, resize_to=(params.RESIZE, params.RESIZE))
full_dataset = ImageDataset(images, masks, device, use_patches=False, resize_to=(params.RESIZE, params.RESIZE))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=params.BATCH_SIZE, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=params.BATCH_SIZE, shuffle=True)
full_dataloader = torch.utils.data.DataLoader(full_dataset, batch_size=params.BATCH_SIZE, shuffle=True)

## Baseline 2: ResU-Net --> Road Extraction by Deep Residual U-Net
This is the provided baseline U-Net with F1 score of 89%.

In [None]:
import segmentation_models_pytorch as smp

validate_model = True

if validate_model:
    model = smp.Unet(
        encoder_name="efficientnet-b3",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 #efficientnet-b3
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=1,                      # model output channels (number of classes in your dataset)
    )

    # TRAINING WITH VALIDATIONs
    # for image segmentation dice loss could be the best first choice
    model = model.to(device)
    loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    pos_weight = torch.tensor([2.0], dtype=torch.float32, device=device)
    #loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    loss_fn = DiceBCELoss(dice_weight=0.5, bce_weight=0.5, pos_weight=pos_weight)
    metric_fns = {'acc': utils.accuracy_fn}
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3, weight_decay=1e-6)
    trainer.train_smp(train_dataloader, val_dataloader, model, loss_fn, metric_fns, optimizer, 100, 10)

In [None]:
# TRAINING WITHOUT VALIDATION ON FULL DATASET

model_full = smp.Unet(
    encoder_name="efficientnet-b3",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 #efficientnet-b3
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

# TRAINING WITH VALIDATIONs
# for image segmentation dice loss could be the best first choice
model_full = model_full.to(device)
pos_weight = torch.tensor([2.0], dtype=torch.float32, device=device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': utils.accuracy_fn}
optimizer = torch.optim.Adam(model_full.parameters(), lr= 1e-3)
trainer.train_smp(full_dataloader, None, model_full, loss_fn, metric_fns, optimizer, 100, 10)

In [None]:
utils.create_submission("test", "images",'efficientunet_submission.csv', model_full, device)

In [None]:
from models import resunet

In [None]:
model = resunet.ResUnet(3).to(device)
loss_fn = nn.BCELoss()
metric_fns = {'acc': utils.accuracy_fn, 'patch_acc': utils.patch_accuracy_fn}
optimizer = torch.optim.Adam(model.parameters())
trainer.train(train_dataloader, val_dataloader, model, loss_fn, metric_fns, optimizer, 100, 10)

In [None]:
utils.create_submission("test", "images",'resunet_submission.csv', model, device)

# Upgrade 1 - Using Transfer Learning for the Encoder
In the architecture of the U-Net, the encoder is replaced with pretrained VGG16 model.

### Training


## Upgrade 2 - CGAN --> https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8628717
In the paper, they use a simple Unet architecture. I tried transfer learning in this part (did not give a better score).

## Upgrade 2.3 -- CGAN with DCED Framework --> Road Segmentation of Remotely-Sensed Images Using Deep Convolutional Neural Networks with Landscape Metrics and Conditional Random Fields
In this framework, the writers use 4 additional ideas.
1. Using ELU activation function instead of RELU
2. Using Gaussian Smoothing and Connected Component Labeling
3. False Road Object Removal with LMs
4. Road Object Sharpening with CRFs

In [None]:
import segmentation_models_pytorch as smp

In [None]:
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

In [None]:
model = model.to(device)

In [None]:
import wandb

In [None]:
wandb.init(
    name="smp-Unet",
    project="CIL-2024",
    config={
    },
    group="Unet"
)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
metric_fns = {'acc': utils.accuracy_fn, 'patch_acc': utils.patch_accuracy_fn, "f1": utils.f1_fn}
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer, train_dataloader, val_dataloader, loss_fn, epoch_n=60, start_epoch=0, metric_fns=metric_fns, wandb=wandb)

In [None]:
import copy
import pydensecrf.densecrf as dcrf
import numpy as np
from pydensecrf.utils import create_pairwise_gaussian, create_pairwise_bilateral, unary_from_softmax


# Defining the functions for the framework
def gaussian_smoothing(kernel_size, sigma=1):
  kernel_size = int(kernel_size) // 2
  x, y = np.mgrid[-kernel_size:kernel_size+1, -kernel_size:kernel_size+1]
  normal = 1 / (2.0 * np.pi * sigma**2)
  g =  np.exp(-((x**2 + y**2) / (2.0*sigma**2))) * normal
  return g

def connected_component_labeling(prediction, gaussian_filter, threshold=128):
    mask = np.uint8(prediction*255)
    mask = cv2.filter2D(mask,-1,gaussian_filter)
    _, binary_image = cv2.threshold(np.uint8(mask), threshold, 255, cv2.THRESH_BINARY)

    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_image, connectivity=4)
    # Stats is --> https://stackoverflow.com/questions/35854197/how-to-use-opencvs-connectedcomponentswithstats-in-python
    
    #print(f"Number of labels: {num_labels}")
    #print("Stats: ")
    #print(stats)
    #print("Centroids: ")
    #print(centroids)
    return num_labels, labels, stats, centroids

def calculate_shape_index(stats):
  perimeter = 2 * (stats[2] + stats[3])
  return perimeter / (4 * math.sqrt(stats[-1]))

def remove_noise(image, num_labels, labels, stats, threshold=1.25, isprint=False):
    output = copy.deepcopy(image)
    # Map component labels to hue value
    for label in range(1, num_labels):
        mask = labels == label
        index = calculate_shape_index(stats[label].tolist())
        if isprint:
            print(label, 'and', index)
        if index < threshold:
          output[mask] = 0 # removing the object
    return output

# Applying gaussian blur + ccl + lm
#test_pred = test_pred.reshape(test_pred.shape[0], test_pred.shape[1], test_pred.shape[2], 1)
def apply_lm(test_pred, sigma=3, ccl_threshold=128, shape_index_threshold=1.3):
    gaussian_filter = gaussian_smoothing(filter_size, sigma)
    filter_size = 6*sigma+1 # Rule of thumb: size is 6 times standard deviation
    
    output = []
    
    for i in range(test_pred.shape[0]):
        num_labels, labels, stats, centroids = connected_component_labeling(test_pred[i], gaussian_filter, threshold)
        lm_output = remove_noise(test_pred[i], num_labels, labels, stats, shape_index_threshold)
        output.append(lm_output)
    output = np.array(output)
    return output

## How to call 
#output = []
#for i, pred in enumerate(test_pred):
    #crf_result = apply_dense_crf_2(test_pred[i] --> model prediction --> logits, test_images2[i] --> original image, num_classes=2)
    #output.append(crf_result)

def apply_dense_crf(model_pred, image, num_classes=2, iterations=10, sxy_gaussian=(3, 3), compat_gaussian=3):
    """
    Apply DenseCRF to the probabilities of an image.

    :param probabilities: The probability map of shape (num_classes, height, width)
    :param image: The original image of shape (height, width, channels)
    :param num_classes: Number of classes (default: 2 for binary classification)
    :param iterations: Number of iterations for CRF inference
    :param sxy_gaussian: Spatial kernel size for Gaussian kernel
    :param compat_gaussian: Compatibility for Gaussian kernel
    :return: Refined predictions
    """
    height, width = image.shape[:2]
    
    probabilities = 1 / (1 + np.exp(-model_pred))
    probabilities_2d = np.zeros((2, height, width), dtype=np.float32)
    probabilities_2d[0, :, :] = 1 - probabilities
    probabilities_2d[1, :, :] = probabilities

    d = dcrf.DenseCRF2D(width, height, num_classes)

    # The unary potential is negative log probability
    unary = -np.log(probabilities_2d)
    unary = unary.reshape((num_classes, -1))
    d.setUnaryEnergy(unary)

    # Add pairwise Gaussian
    d.addPairwiseGaussian(sxy=sxy_gaussian, compat=compat_gaussian, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)

    image_uint8 = (image * 255).astype(np.uint8) if image.dtype == np.float32 else image.astype(np.uint8)

    # Add pairwise Bilateral
    d.addPairwiseBilateral(sxy=(80, 80), srgb=(13, 13, 13), rgbim=image_uint8, compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)

    # Perform inference
    Q = d.inference(iterations)
    result = np.argmax(Q, axis=0).reshape((height, width))

    return result

In [None]:
img1 = np.zeros((len(val_dataset), 384, 384), np.uint8)
img2 = np.zeros((len(val_dataset), 384, 384), np.uint8)

c = 0
for img in val_dataloader:
    y_hat = model(img[0])
    for i in range(y_hat.shape[0]):
        y_hat_i = y_hat[i].detach().cpu().numpy()
        img1[c] = (y_hat_i >= params.CUTOFF).astype(np.uint8) * 255
        img2[c] = apply_dense_crf(y_hat_i, val_images[c])
        c += 1
        print(c)

In [None]:
n = 5
fig, axs = plt.subplots(4, n, figsize=(18.5, 12))

for i in range(n):
    axs[0, i].imshow(np.moveaxis(val_dataset[i][0].detach().cpu().numpy(), 0, -1))
    axs[1, i].imshow(np.moveaxis(val_dataset[i][1].detach().cpu().numpy(), 0, -1))
    axs[2, i].imshow(img1[i])
    axs[3, i].imshow(img2[i])
    axs[0, i].set_title(f'input {i}')
    axs[1, i].set_title(f'true mask {i}')
    axs[2, i].set_title(f'output {i}')
    axs[3, i].set_title(f'crf {i}')
    axs[0, i].set_axis_off()
    axs[1, i].set_axis_off()
    axs[2, i].set_axis_off()
    axs[3, i].set_axis_off()