Skip to content

Commit

Permalink
Merge pull request #160 from anwai98/aa-ws
Browse files Browse the repository at this point in the history
Add Watershed Segmentation Functionality
  • Loading branch information
constantinpape committed Oct 23, 2023
2 parents 03559f7 + ee309eb commit 38f2096
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 13 deletions.
61 changes: 51 additions & 10 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import Tuple, List

import imageio.v2 as imageio
from elf.evaluation import dice_score
from skimage.segmentation import find_boundaries
from elf.evaluation import dice_score, mean_segmentation_accuracy

import torch
import torch_em
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.data.datasets import get_livecell_loader
from torch_em.util.prediction import predict_with_halo
Expand Down Expand Up @@ -126,6 +127,16 @@ def do_unetr_inference(
model.to(device)
model.eval()

fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary")
ws1_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed1")
ws2_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed2")

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
os.makedirs(ws1_save_dir, exist_ok=True)
os.makedirs(ws2_save_dir, exist_ok=True)

with torch.no_grad():
for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"):
fname = os.path.split(img_path)[-1]
Expand All @@ -138,14 +149,13 @@ def do_unetr_inference(

fg, bd = outputs[0, :, :], outputs[1, :, :]

fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary")

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
ws1 = segmentation.watershed_from_components(bd, fg, min_size=10)
ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1)

imageio.imwrite(os.path.join(fg_save_dir, fname), fg)
imageio.imwrite(os.path.join(bd_save_dir, fname), bd)
imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1)
imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2)


def do_unetr_evaluation(
Expand All @@ -156,6 +166,7 @@ def do_unetr_evaluation(
source_choice: str
):
fg_list, bd_list = [], []
ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], []

for c1 in cell_types:
_save_dir = os.path.join(root_save_dir, f"src-{c1}")
Expand All @@ -164,18 +175,24 @@ def do_unetr_evaluation(
continue

fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}
ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}
for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"):
fg_dir = os.path.join(_save_dir, "foreground")
bd_dir = os.path.join(_save_dir, "boundary")
ws1_dir = os.path.join(_save_dir, "watershed1")
ws2_dir = os.path.join(_save_dir, "watershed2")

gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*")
cwise_fg, cwise_bd = [], []
cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], []
for gt_path in glob(gt_dir):
fname = os.path.split(gt_path)[-1]

gt = imageio.imread(gt_path)
fg = imageio.imread(os.path.join(fg_dir, fname))
bd = imageio.imread(os.path.join(bd_dir, fname))
ws1 = imageio.imread(os.path.join(ws1_dir, fname))
ws2 = imageio.imread(os.path.join(ws2_dir, fname))

true_bd = find_boundaries(gt)

Expand All @@ -191,21 +208,45 @@ def do_unetr_evaluation(
# For the ground-truth we have already a binary label, so we don't need to threshold it again.
cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None))

msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore
msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore

cwise_ws1_msa.append(msa1)
cwise_ws2_msa.append(msa2)
cwise_ws1_sa50.append(sa_acc1[0])
cwise_ws2_sa50.append(sa_acc2[0])

fg_set[c2] = np.mean(cwise_fg)
bd_set[c2] = np.mean(cwise_bd)
ws1_msa_set[c2] = np.mean(cwise_ws1_msa)
ws2_msa_set[c2] = np.mean(cwise_ws2_msa)
ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50)
ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50)

fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore
bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore
ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore
ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore
ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore
ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore

f_df_fg = pd.concat(fg_list, ignore_index=True)
f_df_bd = pd.concat(bd_list, ignore_index=True)
f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True)
f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True)
f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True)
f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True)

csv_save_dir = "./results/"
tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
csv_save_dir = f"./results/{tmp_csv_name}"
os.makedirs(csv_save_dir, exist_ok=True)

tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv"))
f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv"))
f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv"))
f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv"))
f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv"))
f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv"))


def main(args):
Expand Down
66 changes: 63 additions & 3 deletions torch_em/util/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import numpy as np
import elf.segmentation as elseg
import vigra

import vigra
import elf.segmentation as elseg
from elf.segmentation.utils import normalize_input

from skimage.measure import label
from skimage.filters import gaussian
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy.ndimage import distance_transform_edt


#
# segmentation functionality
#


# could also refactor this into elf
def size_filter(seg, min_size, hmap=None, with_background=False):
if hmap is None:
ids, sizes = np.unique(seg, return_counts=True)
bg_ids = ids[sizes < min_size]
seg[np.isin(seg, bg_ids)] = 0
vigra.analysis.relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=True)
seg, _, _ = vigra.analysis.relabelConsecutive(seg.astype(np.uint), start_label=1, keep_zeros=True)
else:
assert hmap.ndim in (seg.ndim, seg.ndim + 1)
hmap_ = np.max(hmap[:seg.ndim], axis=0) if hmap.ndim > seg.ndim else hmap
Expand All @@ -41,3 +46,58 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5):
mask = normalize_input(foreground > threshold)
seg = watershed(boundaries, markers=seeds, mask=mask)
return seg.astype("uint64")


def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5):
"""The default approach:
- Subtract the boundaries from the foreground to separate touching objects.
- Use the connected components of this as seeds.
- Use the thresholded foreground predictions as mask to grow back the pieces
lost by subtracting the boundary prediction.
Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foregrounds for objects
- min_size: [int] - The minimum pixels (below which) to filter objects
- threshold1: [float] - To separate touching objects (by subtracting bd and fg) above threshold
- threshold2: [float] - To threshold foreground predictions
Returns:
seg: [np.ndarray] - instance segmentation
"""
seeds = label((foreground - boundaries) > threshold1)
mask = foreground > threshold2
seg = watershed(boundaries, seeds, mask=mask)
seg = size_filter(seg, min_size)
return seg


def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5):
"""Find objects via seeded watershed starting from the maxima of the distance transform instead.
This has the advantage that objects can be better separated, but it may over-segment
if the objects have complex shapes.
The min_distance parameter controls the minimal distance between seeds, which
corresponds to the minimal distance between object centers.
Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foreground for objects
- min_size: [int] - min. pixels (below which) to filter objects
- min_distance: [int] - min. distance of peaks (see `from skimage.feature import peak_local_max`)
- sigma: [float] - standard deviation for gaussian kernel. (see `from skimage.filters import gaussian`)
- threshold1: [float] - To threshold foreground predictions
Returns
seg: [np.ndarray] - instance segmentation
"""
mask = foreground > threshold1
boundary_distances = distance_transform_edt(boundaries < 0.1)
boundary_distances[~mask] = 0 # type: ignore
boundary_distances = gaussian(boundary_distances, sigma) # type: ignore
seed_points = peak_local_max(boundary_distances, min_distance=min_distance, exclude_border=False)
seeds = np.zeros(mask.shape, dtype="uint32")
seeds[seed_points[:, 0], seed_points[:, 1]] = np.arange(1, len(seed_points) + 1)
seg = watershed(boundaries, markers=seeds, mask=foreground)
seg = size_filter(seg, min_size)
return seg

0 comments on commit 38f2096

Please sign in to comment.