diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index fd883c81..0ff9cde2 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -7,11 +7,12 @@ from typing import Tuple, List import imageio.v2 as imageio -from elf.evaluation import dice_score from skimage.segmentation import find_boundaries +from elf.evaluation import dice_score, mean_segmentation_accuracy import torch import torch_em +from torch_em.util import segmentation from torch_em.transform.raw import standardize from torch_em.data.datasets import get_livecell_loader from torch_em.util.prediction import predict_with_halo @@ -126,6 +127,16 @@ def do_unetr_inference( model.to(device) model.eval() + fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground") + bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary") + ws1_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed1") + ws2_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed2") + + os.makedirs(fg_save_dir, exist_ok=True) + os.makedirs(bd_save_dir, exist_ok=True) + os.makedirs(ws1_save_dir, exist_ok=True) + os.makedirs(ws2_save_dir, exist_ok=True) + with torch.no_grad(): for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): fname = os.path.split(img_path)[-1] @@ -138,14 +149,13 @@ def do_unetr_inference( fg, bd = outputs[0, :, :], outputs[1, :, :] - fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground") - bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary") - - os.makedirs(fg_save_dir, exist_ok=True) - os.makedirs(bd_save_dir, exist_ok=True) + ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) imageio.imwrite(os.path.join(fg_save_dir, fname), fg) imageio.imwrite(os.path.join(bd_save_dir, fname), bd) + imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1) + imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2) def do_unetr_evaluation( @@ -156,6 +166,7 @@ def do_unetr_evaluation( source_choice: str ): fg_list, bd_list = [], [] + ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], [] for c1 in cell_types: _save_dir = os.path.join(root_save_dir, f"src-{c1}") @@ -164,18 +175,24 @@ def do_unetr_evaluation( continue fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} + ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): fg_dir = os.path.join(_save_dir, "foreground") bd_dir = os.path.join(_save_dir, "boundary") + ws1_dir = os.path.join(_save_dir, "watershed1") + ws2_dir = os.path.join(_save_dir, "watershed2") gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*") cwise_fg, cwise_bd = [], [] + cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], [] for gt_path in glob(gt_dir): fname = os.path.split(gt_path)[-1] gt = imageio.imread(gt_path) fg = imageio.imread(os.path.join(fg_dir, fname)) bd = imageio.imread(os.path.join(bd_dir, fname)) + ws1 = imageio.imread(os.path.join(ws1_dir, fname)) + ws2 = imageio.imread(os.path.join(ws2_dir, fname)) true_bd = find_boundaries(gt) @@ -191,21 +208,45 @@ def do_unetr_evaluation( # For the ground-truth we have already a binary label, so we don't need to threshold it again. cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None)) + msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore + msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore + + cwise_ws1_msa.append(msa1) + cwise_ws2_msa.append(msa2) + cwise_ws1_sa50.append(sa_acc1[0]) + cwise_ws2_sa50.append(sa_acc2[0]) + fg_set[c2] = np.mean(cwise_fg) bd_set[c2] = np.mean(cwise_bd) + ws1_msa_set[c2] = np.mean(cwise_ws1_msa) + ws2_msa_set[c2] = np.mean(cwise_ws2_msa) + ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50) + ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50) fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore + ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore + ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore + ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore + ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore f_df_fg = pd.concat(fg_list, ignore_index=True) f_df_bd = pd.concat(bd_list, ignore_index=True) + f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True) + f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True) + f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True) + f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True) - csv_save_dir = "./results/" + tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" + csv_save_dir = f"./results/{tmp_csv_name}" os.makedirs(csv_save_dir, exist_ok=True) - tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) - f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) + f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) + f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) + f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) + f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) + f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) + f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) def main(args): diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index 176e47ca..d1c27441 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -1,23 +1,28 @@ import numpy as np -import elf.segmentation as elseg -import vigra +import vigra +import elf.segmentation as elseg from elf.segmentation.utils import normalize_input + from skimage.measure import label +from skimage.filters import gaussian from skimage.segmentation import watershed +from skimage.feature import peak_local_max +from scipy.ndimage import distance_transform_edt # # segmentation functionality # + # could also refactor this into elf def size_filter(seg, min_size, hmap=None, with_background=False): if hmap is None: ids, sizes = np.unique(seg, return_counts=True) bg_ids = ids[sizes < min_size] seg[np.isin(seg, bg_ids)] = 0 - vigra.analysis.relabelConsecutive(seg, out=seg, start_label=1, keep_zeros=True) + seg, _, _ = vigra.analysis.relabelConsecutive(seg.astype(np.uint), start_label=1, keep_zeros=True) else: assert hmap.ndim in (seg.ndim, seg.ndim + 1) hmap_ = np.max(hmap[:seg.ndim], axis=0) if hmap.ndim > seg.ndim else hmap @@ -41,3 +46,58 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): mask = normalize_input(foreground > threshold) seg = watershed(boundaries, markers=seeds, mask=mask) return seg.astype("uint64") + + +def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5): + """The default approach: + - Subtract the boundaries from the foreground to separate touching objects. + - Use the connected components of this as seeds. + - Use the thresholded foreground predictions as mask to grow back the pieces + lost by subtracting the boundary prediction. + + Arguments: + - boundaries: [np.ndarray] - The boundaries for objects + - foreground: [np.ndarray] - The foregrounds for objects + - min_size: [int] - The minimum pixels (below which) to filter objects + - threshold1: [float] - To separate touching objects (by subtracting bd and fg) above threshold + - threshold2: [float] - To threshold foreground predictions + + Returns: + seg: [np.ndarray] - instance segmentation + """ + seeds = label((foreground - boundaries) > threshold1) + mask = foreground > threshold2 + seg = watershed(boundaries, seeds, mask=mask) + seg = size_filter(seg, min_size) + return seg + + +def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5): + """Find objects via seeded watershed starting from the maxima of the distance transform instead. + This has the advantage that objects can be better separated, but it may over-segment + if the objects have complex shapes. + + The min_distance parameter controls the minimal distance between seeds, which + corresponds to the minimal distance between object centers. + + Arguments: + - boundaries: [np.ndarray] - The boundaries for objects + - foreground: [np.ndarray] - The foreground for objects + - min_size: [int] - min. pixels (below which) to filter objects + - min_distance: [int] - min. distance of peaks (see `from skimage.feature import peak_local_max`) + - sigma: [float] - standard deviation for gaussian kernel. (see `from skimage.filters import gaussian`) + - threshold1: [float] - To threshold foreground predictions + + Returns + seg: [np.ndarray] - instance segmentation + """ + mask = foreground > threshold1 + boundary_distances = distance_transform_edt(boundaries < 0.1) + boundary_distances[~mask] = 0 # type: ignore + boundary_distances = gaussian(boundary_distances, sigma) # type: ignore + seed_points = peak_local_max(boundary_distances, min_distance=min_distance, exclude_border=False) + seeds = np.zeros(mask.shape, dtype="uint32") + seeds[seed_points[:, 0], seed_points[:, 1]] = np.arange(1, len(seed_points) + 1) + seg = watershed(boundaries, markers=seeds, mask=foreground) + seg = size_filter(seg, min_size) + return seg