## Baseline - Otsu thresholding

In [1]:
import os
from pathlib import Path
from torch.utils.data import DataLoader

import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
import torch
from typing import List, Tuple

In [None]:
cwd = Path(os.getcwd())
os.chdir(str(cwd.parent))
print(os.getcwd())

In [3]:
from floods.datasets.flood import FloodDataset
from floods.prepare import eval_transforms, inverse_transform
from tqdm import tqdm

In [4]:
dataset_cls = FloodDataset
mean = dataset_cls.mean()[:2]
std = dataset_cls.std()[:2]

test_transform = eval_transforms(mean=mean,
                                    std=std,
                                    clip_max=30,
                                    clip_min=-30)
# create the test dataset
test_dataset = dataset_cls(path=Path('/mnt/data1/projects/shelter/flood/tiled'),
                            subset="test",
                            include_dem=False,
                            normalization=test_transform)

loader = DataLoader(test_dataset, batch_size=1, num_workers=4, pin_memory=False, shuffle=False)
invert = inverse_transform(mean=FloodDataset.mean(), std=FloodDataset.std())

In [5]:
def plot_images(output: List[np.array], titles: List[str], figsize: Tuple[int, int] = (10, 10)) -> None:
    figure(figsize=figsize, dpi=80)
    """Plot a list of images with titles"""
    for j in range(len(output)):
        plt.subplot(1, len(output), j + 1), plt.imshow(output[j], 'gray')
        plt.title(titles[j])
        plt.xticks([]), plt.yticks([])
    plt.show()


In [19]:
def minmaxscaler_image(x: np.array, xmin: float, xmax: float, scale: float = 1.0, dtype=np.float32) -> np.array:
    x = (x - xmin) / (xmax - xmin)
    if scale is not None:
        x *=scale
    return x.astype(dtype)

In [30]:
def preprocess_images(images: torch.Tensor, scale: float = 1.0) -> np.array:
    images = images.numpy()

    vv_img = images[0, :, :]
    vh_img = images[1, :, :]

    vv_img = minmaxscaler_image(vv_img, vv_img.min(), vv_img.max(), scale=scale)
    vh_img = minmaxscaler_image(vh_img, vh_img.min(), vh_img.max(), scale=scale)

    return vv_img, vh_img

In [None]:
from skimage.filters import threshold_otsu
from skimage.restoration.non_local_means import denoise_nl_means
from skimage.morphology import opening

images, label = test_dataset[124]
label = label.squeeze(0)
mask = (label.flatten() != 255)
images = images.squeeze(0)
vv, vh = preprocess_images(images, scale=255.0)

image = vh
image = denoise_nl_means(image, preserve_range=True)
# automated threshold definition is not ideal
thresh = threshold_otsu(image)
binary = image < thresh
binary = opening(binary)

fig, axes = plt.subplots(ncols=3, figsize=(24, 7.5))
ax = axes.ravel()
ax[0] = plt.subplot(1, 3, 1)
ax[1] = plt.subplot(1, 3, 2)
ax[2] = plt.subplot(1, 3, 3, sharex=ax[0], sharey=ax[0])

ax[0].imshow(image, cmap=plt.cm.gray)
ax[0].set_title('Original')
ax[0].axis('off')

ax[1].hist(image.ravel(), bins=256)
ax[1].set_title(f'Histogram: {thresh}')
ax[1].axvline(thresh, color='r')

ax[2].imshow(binary, cmap=plt.cm.gray)
ax[2].set_title('Thresholded')
ax[2].axis('off')

plt.show()

In [None]:
titles = ['mask', 'vv', 'vh', 'pred_mean_vv_vh', 'pred_gaussian_vv_vh', 'pred_otsu_vv_vh', 'pred_otsu_gaussian_vv_vh']

for i, (images, label) in tqdm(enumerate(loader)):
    
    label = label.squeeze(0)
    mask = (label.flatten() != 255)
    images = images.squeeze(0)

    vv_img, vh_img = preprocess_images(images)
    pred_mean, pred_gauss, pred_otsu, pred_otsu_gauss = compute_multiple_thresholds(vv_img, vh_img)

    output = [label, vv_img, vh_img, pred_mean, pred_gauss, pred_otsu, pred_otsu_gauss]

    plot_images(output, titles, (16,20))
    
    if(i == 4):
        break


In [None]:
# manually fixed after evaluating on a batch of images
thresh = 4

prec_scores = np.zeros(test_dataset.__len__())
recall_scores = np.zeros(test_dataset.__len__())
f1_scores = np.zeros(test_dataset.__len__())
bg_f1_scores = np.zeros(test_dataset.__len__())
iou_scores = np.zeros(test_dataset.__len__())
bg_iou_scores = np.zeros(test_dataset.__len__())

for i, (images, label) in tqdm(enumerate(loader)):

    label = label.squeeze(0)
    mask = (label.flatten() != 255)
    images = images.squeeze(0)

    vv, vh = preprocess_images(images, scale=255.0)

    image_vh = denoise_nl_means(vh, preserve_range=True)
    image_vv = denoise_nl_means(vv, preserve_range=True)

    binary_vh = image_vh < thresh
    binary_vv = image_vv < thresh
    binary = (binary_vh * binary_vv).astype(np.uint8)
    binary = opening(binary)

    to_save = (binary * 255).astype(np.uint8)
    plt.imsave(f"outputs/otsu/{i}.png",
               to_save,
               cmap='gray')

    label = label.flatten()[mask]
    binary = binary.flatten()[mask]

    prec_scores[i] = precision_score(y_true=label, y_pred=binary)
    recall_scores[i] = recall_score(y_true=label, y_pred=binary)
    f1_scores[i] = f1_score(y_true=label, y_pred=binary)
    bg_f1_scores[i] = f1_score(y_true=label, y_pred=binary, pos_label=0)
    iou_scores[i] = jaccard_score(y_true=label, y_pred=binary)
    bg_iou_scores[i] = jaccard_score(y_true=label, y_pred=binary, pos_label=0)


print(f"precision: {np.mean(prec_scores)}")
print(f"recall: {np.mean(recall_scores)}")
print(f"f1: {np.mean(f1_scores)}")
print(f"bg_f1: {np.mean(bg_f1_scores)}")
print(f"iou: {np.mean(iou_scores)}")
print(f"bg_iou: {np.mean(bg_iou_scores)}")