Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Watershed Segmentation Functionality #160

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import Tuple, List

import imageio.v2 as imageio
from elf.evaluation import dice_score
from skimage.segmentation import find_boundaries
from elf.evaluation import dice_score, mean_segmentation_accuracy

import torch
import torch_em
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.data.datasets import get_livecell_loader
from torch_em.util.prediction import predict_with_halo
Expand Down Expand Up @@ -126,6 +127,16 @@ def do_unetr_inference(
model.to(device)
model.eval()

fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary")
ws1_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed1")
ws2_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed2")

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
os.makedirs(ws1_save_dir, exist_ok=True)
os.makedirs(ws2_save_dir, exist_ok=True)

with torch.no_grad():
for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"):
fname = os.path.split(img_path)[-1]
Expand All @@ -138,14 +149,13 @@ def do_unetr_inference(

fg, bd = outputs[0, :, :], outputs[1, :, :]

fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary")

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
ws1 = segmentation.watershed_from_components(bd, fg, min_size=10)
ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1)

imageio.imwrite(os.path.join(fg_save_dir, fname), fg)
imageio.imwrite(os.path.join(bd_save_dir, fname), bd)
imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1)
imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2)


def do_unetr_evaluation(
Expand All @@ -156,6 +166,7 @@ def do_unetr_evaluation(
source_choice: str
):
fg_list, bd_list = [], []
ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], []

for c1 in cell_types:
_save_dir = os.path.join(root_save_dir, f"src-{c1}")
Expand All @@ -164,18 +175,24 @@ def do_unetr_evaluation(
continue

fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}
ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}
for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"):
fg_dir = os.path.join(_save_dir, "foreground")
bd_dir = os.path.join(_save_dir, "boundary")
ws1_dir = os.path.join(_save_dir, "watershed1")
ws2_dir = os.path.join(_save_dir, "watershed2")

gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*")
cwise_fg, cwise_bd = [], []
cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], []
for gt_path in glob(gt_dir):
fname = os.path.split(gt_path)[-1]

gt = imageio.imread(gt_path)
fg = imageio.imread(os.path.join(fg_dir, fname))
bd = imageio.imread(os.path.join(bd_dir, fname))
ws1 = imageio.imread(os.path.join(ws1_dir, fname))
ws2 = imageio.imread(os.path.join(ws2_dir, fname))

true_bd = find_boundaries(gt)

Expand All @@ -191,21 +208,45 @@ def do_unetr_evaluation(
# For the ground-truth we have already a binary label, so we don't need to threshold it again.
cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None))

msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore
msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore

cwise_ws1_msa.append(msa1)
cwise_ws2_msa.append(msa2)
cwise_ws1_sa50.append(sa_acc1[0])
cwise_ws2_sa50.append(sa_acc2[0])

fg_set[c2] = np.mean(cwise_fg)
bd_set[c2] = np.mean(cwise_bd)
ws1_msa_set[c2] = np.mean(cwise_ws1_msa)
ws2_msa_set[c2] = np.mean(cwise_ws2_msa)
ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50)
ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50)

fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore
bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore
ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore
ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore
ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore
ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore

f_df_fg = pd.concat(fg_list, ignore_index=True)
f_df_bd = pd.concat(bd_list, ignore_index=True)
f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True)
f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True)
f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True)
f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True)

csv_save_dir = "./results/"
tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
csv_save_dir = f"./results/{tmp_csv_name}"
os.makedirs(csv_save_dir, exist_ok=True)

tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv"))
f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv"))
f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv"))
f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv"))
f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv"))
f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv"))


def main(args):
Expand Down
66 changes: 63 additions & 3 deletions torch_em/util/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import numpy as np
import elf.segmentation as elseg
import vigra

import vigra
import elf.segmentation as elseg
from elf.segmentation.utils import normalize_input

from skimage.measure import label
from skimage.filters import gaussian
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy.ndimage import distance_transform_edt


#
# segmentation functionality
#


# could also refactor this into elf
def size_filter(seg, min_size, hmap=None, with_background=False):
if hmap is None:
ids, sizes = np.unique(seg, return_counts=True)
bg_ids = ids[sizes < min_size]
seg[np.isin(seg, bg_ids)] = 0
vigra.analysis.relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=True)
seg, _, _ = vigra.analysis.relabelConsecutive(seg.astype(np.uint), start_label=1, keep_zeros=True)
else:
assert hmap.ndim in (seg.ndim, seg.ndim + 1)
hmap_ = np.max(hmap[:seg.ndim], axis=0) if hmap.ndim > seg.ndim else hmap
Expand All @@ -41,3 +46,58 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5):
mask = normalize_input(foreground > threshold)
seg = watershed(boundaries, markers=seeds, mask=mask)
return seg.astype("uint64")


def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5):
"""The default approach:
- Subtract the boundaries from the foreground to separate touching objects.
- Use the connected components of this as seeds.
- Use the thresholded foreground predictions as mask to grow back the pieces
lost by subtracting the boundary prediction.

Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foregrounds for objects
- min_size: [int] - The minimum pixels (below which) to filter objects
- threshold1: [float] - To separate touching objects (by subtracting bd and fg) above threshold
- threshold2: [float] - To threshold foreground predictions

Returns:
seg: [np.ndarray] - instance segmentation
"""
seeds = label((foreground - boundaries) > threshold1)
mask = foreground > threshold2
seg = watershed(boundaries, seeds, mask=mask)
seg = size_filter(seg, min_size)
return seg


def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5):
"""Find objects via seeded watershed starting from the maxima of the distance transform instead.
This has the advantage that objects can be better separated, but it may over-segment
if the objects have complex shapes.

The min_distance parameter controls the minimal distance between seeds, which
corresponds to the minimal distance between object centers.

Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foreground for objects
- min_size: [int] - min. pixels (below which) to filter objects
- min_distance: [int] - min. distance of peaks (see `from skimage.feature import peak_local_max`)
- sigma: [float] - standard deviation for gaussian kernel. (see `from skimage.filters import gaussian`)
- threshold1: [float] - To threshold foreground predictions

Returns
seg: [np.ndarray] - instance segmentation
"""
mask = foreground > threshold1
boundary_distances = distance_transform_edt(boundaries < 0.1)
boundary_distances[~mask] = 0 # type: ignore
boundary_distances = gaussian(boundary_distances, sigma) # type: ignore
seed_points = peak_local_max(boundary_distances, min_distance=min_distance, exclude_border=False)
seeds = np.zeros(mask.shape, dtype="uint32")
seeds[seed_points[:, 0], seed_points[:, 1]] = np.arange(1, len(seed_points) + 1)
seg = watershed(boundaries, markers=seeds, mask=foreground)
seg = size_filter(seg, min_size)
return seg