Author: Daniel Lusk, University of Potsdam

Inspired by: Ankit Kariryaa ([github repo](https://github.com/ankitkariryaa/An-unexpectedly-large-count-of-trees-in-the-western-Sahara-and-Sahel))

### Overview

TODO: Write overview

### Getting started

TODO: Write getting started

In [2]:
import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import tifffile as tiff
import matplotlib as mpl

from config import Preprocessing
from skimage.measure import label, regionprops
from skimage.morphology import erosion
from skimage.segmentation import find_boundaries
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm_notebook as tqdm
from matplotlib.colors import ListedColormap
from core.utils import mask_bg

import warnings  # ignore annoying warnings
warnings.filterwarnings("ignore")

# Magic commands
%matplotlib inline
%reload_ext autoreload
%autoreload 2

plt.style.use("lusk") # Use custom plot styles

# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "all"

2023-04-29 14:45:25.143791: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-29 14:45:25.694045: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/miniconda/envs/berlin-trees/lib/
2023-04-29 14:45:25.694101: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/miniconda/envs/berlin-trees/lib/


Load configuration and get image directories

In [3]:
config = Preprocessing.Config()
im_dirs = glob.glob(os.path.join(config.data_dir_path, "*"))

Calculate and write NDVI

In [None]:
def calculate_and_write_ndvi(d):
    """Takes an image directory, locates the RGBI image, calculates the corresponding
    NDVI layer, and writes that layer to the NDVI subdirectory.

    Args:
        d (str): Path of image directory
    """
    im_id = os.path.basename(d)
    rgbi = tiff.imread(glob.glob(os.path.join(d, f"{config.rgbi_dn}/*{config.image_file_ext}"))[0])
    red = rgbi[..., 0] / 255.  # Normalize and convert to float to avoid div by zero issues
    nir = rgbi[..., -1] / 255.

    # Calculate NDVI
    ndvi = (nir.astype(float) - red.astype(float)) / (nir + red)

    # Write NDVI to disk
    ndvi_dir_path = os.path.join(d, config.ndvi_dn)
    if not os.path.exists(ndvi_dir_path):
        os.makedirs(ndvi_dir_path)
    tiff.imwrite(
        os.path.join(
            ndvi_dir_path, f"{im_id}_{config.ndvi_suffix}{config.image_file_ext}"
        ),
        ndvi
    )
    

for d in tqdm(im_dirs):
    calculate_and_write_ndvi(d)

Load labels, erode, and convert to binary mask (this may take a while)

In [None]:
def erode_labels(d, bool_mask=False):
    """Takes an image directory, locates the labels image, erodes each tree by 1px,
    and returns a 2D image of the resulting eroded labels with their ids or as a
    boolean mask.

    Args:
        d (str): Path of image directory

    Returns:
        np.array: Eroded labels with ids or as boolean mask. A 2D array of shape
        (labels.height, labels.width)
    """
    labels = tiff.imread(
        glob.glob(os.path.join(d, config.label_dn, f"*{config.image_file_ext}"))[0]
    )
    labels = label(labels) # Ensure label count == region count
    regions = regionprops(labels)  # Get regions with props
    eroded_labels = np.zeros_like(labels)
    
    for i in range(1, labels.max()):
        label_i = regions[i].label
        eroded = erosion(labels == label_i)
        eroded_labels[eroded] = label_i

    if bool_mask:
        eroded_labels = eroded_labels > 0
        
    return eroded_labels

# Erode the labels to ensure boundaries between each one
eroded_labels = np.zeros((len(im_dirs), 512, 512))

for i, d in tqdm(enumerate(im_dirs), total=len(im_dirs)):
    eroded_labels[i] = erode_labels(d)

Write the eroded labels to disk in case you want to use them for training.

In [None]:
def write_eroded_labels(eroded_label, d, eroded_dn):
    im_id = os.path.basename(d)
    
    eroded_dir_path = os.path.join(d, eroded_dn)
    if not os.path.exists(eroded_dir_path):
        os.makedirs(eroded_dir_path)
    tiff.imwrite(
        os.path.join(
            eroded_dir_path, f"{im_id}_eroded_labels{config.image_file_ext}"
        ),
        eroded_label
    )

for label, d in tqdm(zip(eroded_labels, im_dirs), total=len(im_dirs)):
    write_eroded_labels(label, d, "labels_eroded")

In [None]:
###########################################################################
# Acknowledgements:
# The code was taken and adapted from Rok Mihevc (rok/unet_weight_map.py).
# https://gist.github.com/rok/5f4314ed3c294521456c6afda36a3a50
###########################################################################

def calculate_ronneberger_weights(labels, wc=None, w0 = 10, sigma = 5):
    """
    Generate weight maps as specified in the U-Net paper
    for boolean mask.
    
    "U-Net: Convolutional Networks for Biomedical Image Segmentation"
    https://arxiv.org/pdf/1505.04597.pdf
    
    Parameters
    ----------
    y: Numpy array
        2D array of shape (image_height, image_width) representing boolean (or binary)
        mask of objects.
    wc: dict
        Dictionary of weight classes.
    w0: int
        Border weight parameter.
    sigma: int
        Border width parameter.
    Returns
    -------
    Numpy array
        Training weights. A 2D array of shape (image_height, image_width).
    """
    
    # Check if mask is boolean or binary mask
    if len(np.unique(labels)) == 2:
        labels = label(labels)
        
    no_labels = labels == 0
    label_ids = sorted(np.unique(labels))[1:]

    if len(label_ids) > 1:
        distances = np.zeros((labels.shape[0], labels.shape[1], len(label_ids)))

        for i, label_id in enumerate(label_ids):
            distances[:,:,i] = distance_transform_edt(labels != label_id)

        distances = np.sort(distances, axis=2)
        d1 = distances[:,:,0]
        d2 = distances[:,:,1]
        w = w0 * np.exp(-1/2*((d1 + d2) / sigma)**2) * no_labels
        
        if wc:
            class_weights = np.zeros_like(labels)
            for k, v in wc.items():
                class_weights[labels == k] = v
            w = w + class_weights
    else:
        w = np.zeros_like(labels)
    
    return w


def calculate_border_weights(labels):
    borders = find_boundaries(labels) # Returns a boolean mask of boundaries
    borders = np.where(borders, 1, 0)
    return borders

Create and write the weights images.

In [None]:
def calc_and_write_weights(d, labels, wt_type="ronneberger"):
    im_id = os.path.basename(d)
    
    if wt_type == "ronneberger":
        # Set the weights
        wc = {
            0: 1,  # background
            1: 1  # objects
        }  
        w = calculate_ronneberger_weights(labels, wc)
    
    elif wt_type == "border":
        w = calculate_border_weights(labels)
    
    # Write weights to disk
    weights_dir_path = os.path.join(d, config.boundary_weights_dn)
    if not os.path.exists(weights_dir_path):
        os.makedirs(weights_dir_path)
    
    tiff.imwrite(
            os.path.join(
                weights_dir_path,
                f"{im_id}_{config.boundary_suffix}{config.image_file_ext}",
            ),
            w,
        )


for d in tqdm(im_dirs, total=len(im_dirs)):
    labels = tiff.imread(
        glob.glob(os.path.join(d, config.label_dn, f"*{config.image_file_ext}"))[0]
    )
    calc_and_write_weights(d, labels, wt_type="border")

Inspect some random weights files for sanity check

In [None]:
# fig, ax = plt.subplots(6, 2, figsize=(10, 30), dpi=250)
# ax = ax.ravel()

# for a in ax:
#     k = np.random.randint(0, len(weight_maps))
#     im = a.imshow(weight_maps[k], vmin=0, vmax=10)
#     plt.colorbar(im, ax=a, shrink=0.8)
#     a.axis("off")
# plt.show();

## Data inspection

Get total tree counts in training and validation sets (watershed and hand-labeled sets).

In [None]:
train_label_fns = glob.glob("../../data/dap05/*loose*.tif")

train_tree_cts = np.zeros((len(train_label_fns)))

for i, fn in enumerate(train_label_fns):
    labels = tiff.imread(fn)
    ct = np.count_nonzero(np.unique(labels))
    train_tree_cts[i] = ct

print("Total number of trees in training set:", train_tree_cts.astype(int).sum())

### Semi-automated labels vs hand labels

In [4]:
class Patch:
    def __init__(self, dirname, labels_only=False):
        dirs = sorted(glob.glob(dirname))
        for d in dirs:
            dname = os.path.basename(d)
            fname = glob.glob(os.path.join(d, "*.tif"))[0]
            if not labels_only:
                im = tiff.imread(fname)
                setattr(self, dname, im)
            else:
                if dname == "labels":
                    im = tiff.imread(fname)
                    setattr(self, dname, im)

# Colormap
rainbow = mpl.colormaps["rainbow"]

In [None]:
samp_hand_dir = "../../data/dap05/combined/512/393_5823_2020_01_01/*"
samp_ws_dir = "../../data/dap05/combined/512/Friedrichshain_1_2/*"

hand = Patch(samp_hand_dir)
auto = Patch(samp_ws_dir)

In [None]:
fig, axs = plt.subplots(1, 2)
cmap = ListedColormap(rainbow(np.random.random(len(np.unique(hand.labels)))))

patches = [auto, hand]
titles = [ "Train", "Val/Test"]

for patch, title, ax in zip(patches, titles, axs):
    ax.imshow(patch.rgbi[..., :3])
    ax.imshow(mask_bg(patch.labels), cmap=cmap, alpha=0.4)
    ax.set_title(title)
    ax.axis("off")

plt.savefig(os.path.join(config.figures_dir, "hand-vs-auto-labels.png"), bbox_inches="tight")

Eroded vs non-eroded

In [None]:
fig, axs = plt.subplots(1, 2)
cmap = ListedColormap(rainbow(np.random.random(len(np.unique(auto.labels)))))

auto_rgbi = auto.rgbi[..., :3]
patches = [auto.labels, auto.labels_eroded]
titles = ["ORIG", "ERODED"]

for patch, title, ax in zip(patches, titles, axs):
    ax.imshow(auto_rgbi)
    ax.imshow(mask_bg(patch), cmap=cmap, alpha=0.4)
    ax.set_title(title)
    ax.axis("off")
    
plt.savefig(os.path.join(config.figures_dir, "orig-vs-eroded-labels.png"), bbox_inches="tight")

Weight maps

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, axs = plt.subplots(1, 4)
cmap = ListedColormap(rainbow(np.random.random(len(np.unique(auto.labels)))))

rgbi = hand.rgbi[..., :3]
# patches = [hand.weights, hand.border_weights]
titles = ["Labels", "RONN", "BOUNDS10", "BORD10"]

ronn = hand.weights.copy()
bound = hand.weights.copy()
bord = hand.border_weights.copy()
bound_thr = 2

ronn[np.where(ronn == 0)] = 0.5
bound[np.where(bound == 0)] = 0.5
bound[np.where((bound < bound_thr) & (bound > 0.5))] = 1
bound[np.where(bound >= bound_thr)] = 10
bord = np.where(bord >= 1, 10., 1.)
bg = np.where(hand.labels > 0, np.nan, 1)
bord[np.where(bg == 1)] = 1.5
# bord[np.where(bord > 2)] = 10.
patches = [hand.labels > 0, ronn, bound, bord]

for patch, title, ax in zip(patches, titles, axs):
    # ax.imshow(auto_rgbi)
    im = ax.imshow(mask_bg(patch), cmap="viridis")
    if title == "RONN":
        divider = make_axes_locatable(ax)
        cax = divider.new_vertical(size="5%", pad=-0.05, pack_start=True)
        fig.add_axes(cax)
        plt.colorbar(im, cax=cax, orientation="horizontal")
    # if title == "BORD10":
    #     ax.imshow(bg, cmap="viridis")
    ax.set_title(title, fontsize=10)
    ax.axis("off")
    
plt.savefig(os.path.join(config.figures_dir, "weight_maps.png"), bbox_inches="tight")

Semantic vs instance segmentation

In [None]:
fig, axs = plt.subplots(1, 2)
cmap = ListedColormap(rainbow(np.random.random(len(np.unique(hand.labels)))))

patches = [hand.labels]
titles = ["Semantic Segmentation", "Instance Segmentation"]

for title, ax in zip(titles, axs):
    # ax.imshow(auto_rgbi)
    if title == "Semantic Segmentation":
        im = ax.imshow(mask_bg(hand.labels > 0))
    else:
        im = ax.imshow(mask_bg(hand.labels), cmap=cmap, alpha=1)

    ax.set_title(title, fontsize=12)
    ax.axis("off")
    
plt.savefig(os.path.join(config.figures_dir, "semantic-vs-instance.png"), bbox_inches="tight")

Get label/bg percentages

In [None]:
all_dirs = glob.glob("../../data/dap05/combined/512/*")



In [None]:
hand = Patch(samp_hand_dir)
auto = Patch(samp_ws_dir)