# Watershed optimization

## Imports and fns

In [None]:
import glob
import os
import time

import geopandas as gpd
import laspy as lp
import matplotlib.pyplot as plt
import napari
import numpy as np
import rasterio
import scipy
import scipy.ndimage as ndimage
import scipy.ndimage.filters as filters
import tifffile as tiff
from scipy import interpolate
from scipy.spatial import cKDTree as kdtree
from skimage.color import label2rgb
from skimage.measure import regionprops
from skimage.segmentation import watershed
from skimage.transform import resize
import 


def las2chm(las_file):
    las = lp.read(las_file)
    points = las.xyz.copy()
    return_num = las.return_number.copy()
    num_of_returns = las.number_of_returns.copy()
    classification = las.classification.copy()
    select = classification != 5
    select += (return_num == 1) * (num_of_returns == 1)
    select += (return_num == 2) * (num_of_returns == 2)
    select += (return_num == 3) * (num_of_returns == 3)
    select += (return_num == 4) * (num_of_returns == 4)
    select += (return_num == 5) * (num_of_returns == 5)
    points = points[~select]
    tr = kdtree(points)
    distances, indices = tr.query(points, k=25, workers=-1)
    distances = distances[:, -1]
    thr = 2.0
    select = distances > thr
    points = points[~select]
    orginal_points = las.xyz.copy()
    tr = kdtree(orginal_points)
    distances, indices = tr.query(points, k=10, workers=-1)
    distances = distances[:, -1]
    indices = np.unique(indices[distances < 0.5])
    points = np.vstack((points, orginal_points[indices]))
    slice_position = np.mean(points[:, 1])
    width = 5
    slice_org = np.sqrt((orginal_points[:, 1] - slice_position) ** 2) <= width
    slice = np.sqrt((points[:, 1] - slice_position) ** 2) <= width
    gridsize = 1.0  # [m]
    ground_points = las.xyz[las.classification == 2]
    grid_x = ((ground_points[:, 0] - ground_points[:, 0].min()) / gridsize).astype(
        "int"
    )
    grid_y = ((ground_points[:, 1] - ground_points[:, 1].min()) / gridsize).astype(
        "int"
    )
    grid_index = grid_x + grid_y * grid_x.max()
    df = gpd.GeoDataFrame(
        {"gi": grid_index, "gx": grid_x, "gy": grid_y, "height": ground_points[:, 2]}
    )
    df2 = df.sort_values(["gx", "gy", "height"], ascending=[True, True, True])
    df3 = df2.groupby("gi")[["gx", "gy", "height"]].last()
    grid_x = np.array(df3["gx"])
    grid_y = np.array(df3["gy"])
    max_height = np.array(df3["height"])
    DTM = np.ones((grid_x.max() + 1, grid_y.max() + 1)) * np.nan
    DTM[grid_x, grid_y] = max_height
    mask = np.isnan(DTM)
    xx, yy = np.meshgrid(np.arange(DTM.shape[0]), np.arange(DTM.shape[1]))
    valid_x = xx[~mask]
    valid_y = yy[~mask]
    newarr = DTM[~mask]
    DTM_interp = interpolate.griddata(
        (valid_x, valid_y), newarr.ravel(), (xx, yy), method="linear"
    )
    gridsize = 1.0  # [m]
    filt_points = points
    grid_x = ((filt_points[:, 0] - filt_points[:, 0].min()) / gridsize).astype("int")
    grid_y = ((filt_points[:, 1] - filt_points[:, 1].min()) / gridsize).astype("int")
    grid_index = grid_x + grid_y * grid_x.max()
    df = gpd.GeoDataFrame(
        {"gi": grid_index, "gx": grid_x, "gy": grid_y, "height": filt_points[:, 2]}
    )
    df2 = df.sort_values(["gx", "gy", "height"], ascending=[True, True, True])
    df3 = df2.groupby("gi")[["gx", "gy", "height"]].last()
    grid_x = np.array(df3["gx"])
    grid_y = np.array(df3["gy"])
    max_height = np.array(df3["height"])
    DSM = np.ones((grid_x.max() + 1, grid_y.max() + 1)) * np.nan
    DSM[grid_x, grid_y] = max_height
    CHM = DSM - DTM_interp
    CHM[np.isnan(CHM)] = 0
    return CHM


def ws_labels(CHM, ns=4, thr=3):
    CHM_smooth = scipy.ndimage.gaussian_filter(CHM, thr)
    CHM_max = ndimage.maximum_filter(CHM_smooth, ns)
    local_maxima = CHM_smooth == CHM_max
    local_maxima[CHM == 0] = 0
    labeled, num_objects = ndimage.label(local_maxima)
    xy = np.array(
        ndimage.center_of_mass(
            input=CHM, labels=labeled, index=range(1, num_objects + 1)
        )
    )



# def ws_labels(CHM, ns=4, thr=3):
#     CHM_smooth = scipy.ndimage.gaussian_filter(CHM, thr)
#     CHM_max = ndimage.maximum_filter(CHM_smooth, ns)
#     local_maxima = CHM_smooth == CHM_max
#     local_maxima[CHM == 0] = 0
#     labeled, num_objects = ndimage.label(local_maxima)
#     xy = np.array(
#         ndimage.center_of_mass(input=CHM, labels=labeled, index=range(1, num_objects + 1))
#     )

#     min_height = 2
#     max_height = 40

#     binary_mask = np.where(((CHM >= min_height) & (CHM <= max_height)), 1, 0)
#     binary_mask = ndimage.binary_fill_holes(binary_mask).astype(int)

#     labels = watershed(-CHM, labeled, mask=binary_mask)
#     return labels


    for region in regions:
        if (
            region.area >= area
            and (region.axis_minor_length / region.axis_major_length >= ar)
            and (region.eccentricity <= ecc)
            and (region.area / region.area_bbox >= abr)
            and (region.intensity_mean >= intensity)
        ):
            filtered_labels[region.coords[:, 0], region.coords[:, 1]] = region.label

    return filtered_labels

# Las -> CHM -> Loose & Strict Labels

In [None]:
data_dir = "../../../data/"
las_dir = f"{data_dir}las/"
las_files = os.listdir(las_dir)
out_dir = f"{data_dir}watershed/"

# CHM params
ns = 4
thr = 3

# Loose label params [area, ecc, ar, abr, intensity]
loose_params = [40, 0.95, 0.1, 0.3, 80]
# Strict label params
strict_params = [55, 0.8, 0.5, 0.5, 115]
keyword = "strict"

for i, f in enumerate(las_files):
    print("Reading files...")
    name = f.split(".")[0]
    img_path = f"{data_dir}watershed/RGBI_{name}.tif"
    img = tiff.imread(img_path)

    print("Creating CHM...")
    chm = las2chm(las_dir + f)

    print("Generating labels...")
    labels = ws_labels(chm)
    labels = np.rot90(labels, 1)

    print("Filtering regions...")
    labels = resize(labels, (img.shape[0], img.shape[1]), order=0)
    filtered_labels = filter_labels(labels, img, 3, *strict_params)

    print("Writing tif...")
    raster = rasterio.open(img_path)
    new_dataset = rasterio.open(
        f"{out_dir}{name}_{keyword}_labels.tif",
        "w",
        driver="GTiff",
        height=raster.height,
        width=raster.width,
        count=1,
        dtype=np.dtype(int),
        crs=raster.crs,
        transform=raster.transform,
    )

    new_dataset.write(filtered_labels, 1)
    new_dataset.close()

## Explore the labels

In [None]:
viewer = napari.Viewer()
img_dir = "../../../data/watershed/"
label_dir = "../../../data/watershed_imp/"

img = tiff.imread(img_dir + "RGBI_Friedrichshain.tif")
labels = tiff.imread(label_dir + "Friedrichshain_labels.tif")

viewer.add_image(img)
viewer.add_labels(labels)

## Grid search for optimal params

In [None]:
# nsx = [3, 4, 5, 6]
# thrx = [1.5, 2, 2.5, 3, 3.5]
ns = 4
thr = 3

area = 40
eccx = [0.8, 0.95]
arx = [0.1, 0.5]
abrx = [0.3, 0.5]
intx = [80, 115]

# Loose params
ecc = 0.95
ar = 0.1
abr = 0.3
intensity = 80

# Strict params
ecc = 0.8
ar = 0.5
abr = 0.5
intensity = 115

las_dir = "../../../data/las/"
las_files = os.listdir(las_dir)

chm = las2chm(las_dir + las_files[0])
viewer = napari.Viewer()
img = tiff.imread("../../../data/watershed/RGBI_Friedrichshain.tif")
viewer.add_image(img)
for ecc in eccx:
    for ar in arx:
        for abr in abrx:
            for intensity in intx:
                labels = ws_labels(chm, ns, thr)
                labels = resize(labels, (img.shape[0], img.shape[1]), order=0)
                filtered_labels = filter_labels(labels, img, 3, *strict_params)
                viewer.add_labels(
                    filtered_labels,
                    name=f"ar={str(ar)}_ecc={str(ecc)}_abr={str(abr)}_int={str(intensity)}",
                    blending="opaque",
                )

## Address bad automatic labels by "masking" bad labels with average intensity value in image.

In [None]:
bad_labels = False
if bad_labels:
    data_dir = "../../../data/watershed/"
    rgb_fn = glob.glob(f"{data_dir}rgbi/*.tif")
    lab_l_fn = glob.glob(f"{data_dir}labels/loose/*.tif")
    lab_s_fn = glob.glob(f"{data_dir}labels/strict/*.tif")
    rgb_fn.sort()
    lab_l_fn.sort()
    lab_s_fn.sort()

    for i in range(len(rgb_fn)):
        # Load rgb and loose and strict labels
        rgb = tiff.imread(rgb_fn[i])
        lab_l = tiff.imread(lab_l_fn[i])
        lab_s = tiff.imread(lab_s_fn[i])
        # Get the mean values of each channel
        rgb_mns = [rgb[..., j].mean() for j in range(rgb.shape[2])]
        # Find where loose labels exist but not strict
        idx = np.where((lab_l > 0) & (lab_s == 0))
        # Replace values of "bad" labels with the channel mean
        for k, m in enumerate(rgb_mns):
            rgb[idx[0], idx[1], k] = m
        # Save rgb
        tiff.imwrite(rgb_fn[i].split("tif")[0] + "masked.tif", rgb)

## Create "ignore" labels

In [None]:
import os
import glob
import tifffile as tiff
import numpy as np

ignore = False
if ignore:
    data_dir = "../../../data/watershed/"
    masked_dir = os.path.join(data_dir, "labels/masked")
    lab_l_fn = glob.glob(f"{data_dir}labels/loose/*.tif")
    lab_s_fn = glob.glob(f"{data_dir}labels/strict/*.tif")
    lab_l_fn.sort()
    lab_s_fn.sort()

    if not os.path.exists(masked_dir):
        os.makedirs(masked_dir)

    for loose_fn, strict_fn in zip(lab_l_fn, lab_s_fn):
        loose = tiff.imread(loose_fn)
        strict = tiff.imread(strict_fn)
        masked = strict.copy()
        masked[np.where((loose > 0) & (strict == 0))] = -1
        tiff.imwrite(
            os.path.join(
                masked_dir, f"{loose_fn.split('/')[-1].split('_')[0]}_masked_labels.tif"
            ),
            masked,
        )

In [None]:
data_dir = "../../../data/watershed/"
labels = glob.glob(f"{data_dir}labels/masked/*.tif")
label = tiff.imread(labels[1])

In [None]:
viewer = napari.Viewer()
viewer.add_labels(label)

## Patchify Watershed Data (loose and strict)

In [None]:
do_it = False

if do_it:
    import os
    import glob
    import tifffile as tiff
    from patchify import patchify

    patch_size = 512
    label_type = "masked"

    data_dir = "../../../data/watershed/"
    # Unpatchified directories
    # unpatched_watershed_rgbi_dir = os.path.join(data_dir, f"rgbi/{label_type}/")
    unpatched_watershed_label_dir = os.path.join(data_dir, f"labels/{label_type}")

    # Patchified directories
    # patched_watershed_rgbi_dest = os.path.join(data_dir, f"rgbi/{label_type}/{patch_size}/")
    patched_watershed_label_dest = os.path.join(data_dir, f"labels/{label_type}/{patch_size}/")

    # if not os.path.exists(patched_watershed_rgbi_dest):
    #     os.makedirs(patched_watershed_rgbi_dest)

    if not os.path.exists(patched_watershed_label_dest):
        os.makedirs(patched_watershed_label_dest)

    def patchify_watershed(rgbi_loc, lab_loc, rgbi_dest, lab_dest, patch_size):
        # watershed_rgbi = glob.glob(os.path.join(unpatched_watershed_rgbi_dir, "*.tif"))
        watershed_labels = glob.glob(
            os.path.join(unpatched_watershed_label_dir, "*.tif")
        )
        # watershed_rgbi.sort()
        watershed_labels.sort()

        for k in range(len(watershed_rgbi)):
            # rgbi_name = watershed_rgbi[k].split("/")[-1].split(".tif")[0]
            label_name = watershed_labels[k].split("/")[-1].split(".tif")[0]

            # patches_train = patchify(
            #     tiff.imread(watershed_rgbi[k]),
            #     (patch_size, patch_size, 4),
            #     step=patch_size,
            # )

            patches_label = patchify(
                tiff.imread(watershed_labels[k]),
                (patch_size, patch_size),
                step=patch_size,
            )

            for i in range(patches_train.shape[0]):
                for j in range(patches_train.shape[1]):
                    tiff.imwrite(
                        f"{patched_watershed_rgbi_dest}{rgbi_name}_{i}_{j}.tif",
                        patches_train[i, j, 0, :, :, :],
                    )
                    tiff.imwrite(
                        f"{patched_watershed_label_dest}{label_name}_{i}_{j}.tif",
                        patches_label[i, j, :, :],
                    )

    patchify_watershed(
        unpatched_watershed_rgbi_dir,
        unpatched_watershed_label_dir,
        patched_watershed_rgbi_dest,
        patched_watershed_label_dest,
        patch_size,
    )

Explore labels and region props in plotly

In [None]:
# img = tiff.imread("../../../data/watershed/RGBI_Friedrichshain.tif")
# labels = ws_labels(chm, 4, 3)
# labels = resize(labels, (img.shape[0], img.shape[1]), order=0)
# labels = np.rot90(labels, 1)
# regions = regionprops(labels, img)

# import plotly
# import plotly.express as px
# import plotly.graph_objects as go
# from skimage import data, filters, measure, morphology

# dim = 1000
# img = img[0:dim, 0:dim, 1]
# labels = labels[0:dim, 0:dim]

# fig = px.imshow(img, binary_string=True, width=dim, height=dim)
# fig.update_traces(hoverinfo="skip")  # hover is only for label info


# props = measure.regionprops(labels, img)
# properties = ["area", "eccentricity", "perimeter", "intensity_mean"]

# # For each label, add a filled scatter trace for its contour,
# # and display the properties of the label in the hover of this trace.
# for index in range(0, len(np.unique(labels)) - 1):
#     label_i = props[index].label
#     contour = measure.find_contours(labels == label_i, 0.5)[0]
#     y, x = contour.T
#     hoverinfo = ""
#     for prop_name in properties:
#         hoverinfo += f"<b>{prop_name}: {np.mean(getattr(props[index], prop_name)):.2f}</b><br>"
#     fig.add_trace(
#         go.Scatter(
#             x=x,
#             y=y,
#             name=label_i,
#             mode="lines",
#             fill="toself",
#             showlegend=False,
#             hovertemplate=hoverinfo,
#             hoveron="points+fills",
#         )
#     )

# plotly.io.show(fig)