Skip to content

Commit

Permalink
Add docstrings to watershed functions
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 21, 2023
1 parent e0cfad9 commit 35bd684
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
18 changes: 9 additions & 9 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import torch
import torch_em
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.data.datasets import get_livecell_loader
from torch_em.util.prediction import predict_with_halo
from torch_em.util import segmentation


def get_unetr_model(
Expand Down Expand Up @@ -134,6 +134,8 @@ def do_unetr_inference(

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
os.makedirs(ws1_save_dir, exist_ok=True)
os.makedirs(ws2_save_dir, exist_ok=True)

with torch.no_grad():
for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"):
Expand All @@ -147,15 +149,13 @@ def do_unetr_inference(

fg, bd = outputs[0, :, :], outputs[1, :, :]

ws1 = segmentation.watershed_from_components(bd, fg, min_size=100)
ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=100, min_distance=1)

breakpoint()
ws1 = segmentation.watershed_from_components(bd, fg, min_size=10)
ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1)

# imageio.imwrite(os.path.join(fg_save_dir, fname), fg)
# imageio.imwrite(os.path.join(bd_save_dir, fname), bd)
# imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1)
# imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2)
imageio.imwrite(os.path.join(fg_save_dir, fname), fg)
imageio.imwrite(os.path.join(bd_save_dir, fname), bd)
imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1)
imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2)


def do_unetr_evaluation(
Expand Down
43 changes: 26 additions & 17 deletions torch_em/util/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def size_filter(seg, min_size, hmap=None, with_background=False):
ids, sizes = np.unique(seg, return_counts=True)
bg_ids = ids[sizes < min_size]
seg[np.isin(seg, bg_ids)] = 0
vigra.analysis.relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=True)
seg, _, _ = vigra.analysis.relabelConsecutive(seg.astype(np.uint), start_label=1, keep_zeros=True)
else:
assert hmap.ndim in (seg.ndim, seg.ndim + 1)
hmap_ = np.max(hmap[:seg.ndim], axis=0) if hmap.ndim > seg.ndim else hmap
Expand All @@ -48,18 +48,22 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5):
return seg.astype("uint64")


def watershed_from_components(
boundaries,
foreground,
min_size,
threshold1=0.5,
threshold2=0.5,
):
def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5):
"""The default approach:
- Subtract the boundaries from the foreground to separate touching objects.
- Use the connected components of this as seeds.
- Use the thresholded foreground predictions as mask to grow back the pieces
lost by subtracting the boundary prediction.
lost by subtracting the boundary prediction.
Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foregrounds for objects
- min_size: [int] - The minimum pixels (below which) to filter objects
- threshold1: [float] - To separate touching objects (by subtracting bd and fg) above threshold
- threshold2: [float] - To threshold foreground predictions
Returns:
seg: [np.ndarray] - instance segmentation
"""
seeds = label((foreground - boundaries) > threshold1)
mask = foreground > threshold2
Expand All @@ -68,21 +72,26 @@ def watershed_from_components(
return seg


def watershed_from_maxima(
boundaries,
foreground,
min_size,
min_distance,
sigma=1.0,
):
def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5):
"""Find objects via seeded watershed starting from the maxima of the distance transform instead.
This has the advantage that objects can be better separated, but it may over-segment
if the objects have complex shapes.
The min_distance parameter controls the minimal distance between seeds, which
corresponds to the minimal distance between object centers.
Arguments:
- boundaries: [np.ndarray] - The boundaries for objects
- foreground: [np.ndarray] - The foreground for objects
- min_size: [int] - min. pixels (below which) to filter objects
- min_distance: [int] - min. distance of peaks (see `from skimage.feature import peak_local_max`)
- sigma: [float] - standard deviation for gaussian kernel. (see `from skimage.filters import gaussian`)
- threshold1: [float] - To threshold foreground predictions
Returns
seg: [np.ndarray] - instance segmentation
"""
mask = foreground > 0.5
mask = foreground > threshold1
boundary_distances = distance_transform_edt(boundaries < 0.1)
boundary_distances[~mask] = 0 # type: ignore
boundary_distances = gaussian(boundary_distances, sigma) # type: ignore
Expand Down

0 comments on commit 35bd684

Please sign in to comment.