# Watershed "instance segmentation" from semantic segmentation

## Imports and fns

In [None]:
import glob
import matplotlib.pyplot as pl
import numpy as np
import tifffile as tiff
from sklearn.model_selection import train_test_split
from keras.models import Model, load_model
from patchify import patchify
import napari
import utils
from unet import build_unet
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

## Data preparation

In [None]:
test = sorted(utils.get_by_ext("../../data/hand/rgb"))


In [None]:
DATA_DIR = "../../data/"
_, _, X_test, y_test, _, inst_test = utils.prep_data(DATA_DIR, "hand", "loose", "loose", channels=3)

In [None]:
WEIGHTS = "logs/simple_imp/20221016-172119_KF5of10_simp_imp/KF5of10_simp_imp_best_wg.h5"

input_shape = X_test.shape[1:]
model = build_unet(input_shape)
model.load_weights(WEIGHTS)
y_pred = model.predict(X_test)

In [None]:


# Look through some random predictions and compare with the test images
# for quality
threshold = 0.8

# k = np.random.randint(0, len(X_test) - 1)
k = 183
test_img = X_test[k]
ground_truth = y_test[k]
prediction = model.predict(np.expand_dims(test_img, 0))[0, :, :, 0]

neighborhood_size = 10
threshold = 6

area = 40
ratio = 0.4

labels, _ = utils.watershed_labels(prediction, neighborhood_size, threshold, 0.7)
strict_params = [55, 0.8, 0.5, 0.5, 0.5]
params = [40, 0.95, 0.1, 0.3, 0.7]
filtered_regions, bbox = utils.filter_labels(labels, prediction, False, *params)

fig, ax = plt.subplots(1, 3, figsize=(20, 7), dpi=300)
ax = ax.ravel()
ax[0].set_title("Testing Image")
ax[0].imshow(test_img[:, :, :3])

ax[1].set_title("Testing Label")
ax[1].imshow(ground_truth[:, :, 0], cmap="gray")

ax[2].set_title("Prediction on test image")
im = ax[2].imshow(prediction, vmin=0, vmax=1)
# add color bar below chart
# divider = make_axes_locatable(ax[2])
# cax = divider.new_vertical(size="5%", pad=0.3, pack_start=True)
# fig.add_axes(cax)
fig.colorbar(im, ax=[ax[2]], location="bottom", pad=0.05, label="Prediction confidence")

# ax[3].set_title("Watershed instances")
# ax[3].imshow(test_img[..., :3])

# filtered_regions = filtered_regions.astype(np.float32)
# filtered_regions[np.where(filtered_regions == 0)] = np.nan
# ax[3].imshow(filtered_regions, alpha=0.5, cmap=pl.cm.tab20b)

# for bb in bbox:
#     ax[3].plot(bb[0], bb[1], c="limegreen", ls="--", lw=1)

for a in ax:
    a.axis("off")

In [None]:
# Look through some random predictions and compare with the test images
# for quality
threshold = 0.8
k = np.random.randint(0, len(X_test) - 1)

test_img = X_test[k]
ground_truth = inst_test[k].astype(float)
ground_truth[ground_truth == 0] = np.nan
inst_count = len(np.unique(ground_truth[~np.isnan(ground_truth)]))
prediction = model.predict(np.expand_dims(test_img, 0))[0, :, :, 0]




neighborhood_size = 5
threshold = 2

area = 40
ratio = 0.4

labels, _ = utils.watershed_labels(prediction, neighborhood_size, threshold, 0.7)
strict_params = [55, 0.8, 0.5, 0.5, 0.5]
params = [40, 0.95, 0.1, 0.3, 0.7]
filtered_regions, bbox = utils.filter_labels(labels, prediction, False, *params)

fig, ax = plt.subplots(1, 4, figsize=(20, 7.25), dpi=300)
ax = ax.ravel()
ax[0].set_title("Testing Image")
ax[0].imshow(test_img[:, :, :3])

ax[1].set_title("Testing Labels")
ax[1].imshow(ground_truth[:, :, 0], cmap="tab20_r")
obs_tree_txt = "tree" if inst_count == 1 else "trees"
# ax[1].text(0.5,-0.1, f"{inst_count} {obs_tree_txt}", size=12, ha="center", transform=ax[1].transAxes)

ax[2].set_title("Prediction on test image")
im = ax[2].imshow(prediction, vmin=0, vmax=1)
fig.colorbar(im, ax=[ax[2]], location="bottom", pad=0.05, label="Prediction confidence")

ax[3].set_title("Watershed instances")
ax[3].imshow(test_img[..., :3])

filtered_regions = filtered_regions.astype(np.float32)
filtered_regions[np.where(filtered_regions == 0)] = np.nan
ax[3].imshow(filtered_regions, alpha=0.5, cmap=pl.cm.tab20b)

for bb in bbox:
    ax[3].plot(bb[0], bb[1], c="limegreen", ls="--", lw=1)

inst_tree_txt = "tree" if len(bbox) == 1 else "trees"
# ax[3].text(0.5,-0.1, f"{len(bbox)} {inst_tree_txt}", size=12, ha="center", transform=ax[3].transAxes)
for a in ax:
    a.axis("off")

## Get average % alignment

In [None]:
data = np.zeros((len(inst_test), 2))
for i, (img, p) in enumerate(zip(X_test,inst_test)):
    p = p.astype(float)
    p[p == 0] = np.nan
    num_inst = len(np.unique(p[~np.isnan(p)]))
    prediction = model.predict(np.expand_dims(img, 0))[0, :, :, 0]
    neighborhood_size = 4
    threshold = 2
    # area = 40
    # ratio = 0.4
    labels, _ = utils.watershed_labels(prediction, neighborhood_size, threshold, 0.7)
    # strict_params = [55, 0.8, 0.5, 0.5, 0.5]
    params = [40, 0.95, 0.1, 0.3, 0.7]
    _, bbox = utils.filter_labels(labels, prediction, False, *params)
    pred_inst = len(bbox)
    # if pred_inst == num_inst:
    #     TP = pred_inst
    #     TN = 1
    #     FP = FN = 0
    # elif pred_inst < num_inst:
    #     TP = pred_inst
    #     TN = 1
    #     FP = 0
    #     FN = num_inst - pred_inst
    # elif pred_inst > num_inst:
    #     TP = num_inst
    #     TN = 1
    #     FP = pred_inst - num_inst
    #     FN = 0
    # acc = (TP + TN) / (TP + TN + FP + FN)
    # prec = TP / (TP + FP)
    # rec = TP / (TP + FN)
    data[i][0], data[i][1] = num_inst, pred_inst

In [None]:
err = (np.abs(data[:, 1].mean() - data[:, 0].mean())) / data[:, 0].mean()

In [None]:
1 - err

In [None]:
import pandas as pd
df = pd.DataFrame(data, columns=["Actual", "Observed"])
group = df.groupby(["Observed"])
group.describe()

In [None]:
print(data[:, 0].mean())
print(data[:, 1].mean())

In [None]:
plt.hist(data[:, 0], label="Actual")
plt.hist(data[:, 1], label="Predicted")
plt.legend()

In [None]:
MSE = np.sum((data[:, 0] - data[:, 1])**2) / len(data)
RMSE = np.sqrt(MSE)

In [None]:
R2 = 1 - (np.sum((data[:, 1] - data[:, 0])**2)) / (np.sum((data[:, 0] - data[:, 0].mean())**2))

In [None]:
R2

In [None]:
nsx = [2, 4, 5, 10]
thx = [4, 5, 8]

neighborhood_size = 5
threshold = 4

from matplotlib.pyplot import tight_layout
import scipy.ndimage as ndimage
from skimage.segmentation import watershed
from skimage.measure import regionprops
from skimage.color import label2rgb

viewer = napari.Viewer()

areas = [10, 20, 30, 40, 50]
ratios = [0.0001, 0.001, 0.01, 0.1, 0.5]

area = 40
ratio = 0.1

_, ax = pl.subplots(5, 2, figsize=(10, 20), dpi=200, tight_layout=True)
ax = ax.ravel()

import random


for i, k in enumerate(random.sample(range(0, len(X_test)), 10)):
    test_img = X_test[k]
    ground_truth = y_test[k]
    prediction = model.predict(np.expand_dims(test_img, 0))[0, :, :, 0]

    p_smooth = ndimage.gaussian_filter(prediction, threshold)
    p_max = ndimage.maximum_filter(p_smooth, neighborhood_size)
    local_maxima = p_smooth == p_max
    local_maxima[prediction == 0] = 0
    labeled, num_objects = ndimage.label(local_maxima)
    xy = np.array(
        ndimage.center_of_mass(input=prediction, labels=labeled, index=range(1, num_objects + 1))
    )
    min_height = 0.3
    binary_mask = np.where(prediction >= min_height, 1, 0)
    binary_mask = ndimage.binary_fill_holes(binary_mask).astype(int)

    labels = watershed(-prediction, labeled, mask=binary_mask)

    regions = regionprops(labels)
    filtered_regions = np.zeros((prediction.shape[0], prediction.shape[1]), dtype=int)

    for region in regions:
        if region.area >= area and (
            region.axis_minor_length / region.axis_major_length >= ratio
        ):
            filtered_regions[region.coords[:, 0], region.coords[:, 1]] = region.label
            minr, minc, maxr, maxc = region.bbox
            bx = (minc, maxc, maxc, minc, minc)
            by = (minr, minr, maxr, maxr, minr)
            ax[i].plot(bx, by, c="limegreen", lw=1)

    ax[i].imshow(test_img[..., :3])
    viewer.add_image(ground_truth, blending="opaque", name=f"img_{i}")
    # viewer.add_labels(filtered_regions, name=f"img_{i}")