Skip to content

Commit

Permalink
Add LIVECell Distance Map Experiments (#175)
Browse files Browse the repository at this point in the history
Add LiveCELL experiments for distance based segmentation
  • Loading branch information
anwai98 committed Dec 5, 2023
1 parent 87441eb commit 22c6bb7
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 11 deletions.
9 changes: 9 additions & 0 deletions experiments/vision-transformer/unetr/livecell/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Using different `UNETR` settings on LIVECell

- Binary Segmentation - TODO
- Foreground-Boundary Segmentation - TODO
- Affinities - TODO
- Distance Maps (HoVerNet-style)
```python
python livecell_all_hovernet [--train / --predict / --evaluate] -i <LIVECELL_DATA> -s <SAVE_ROOT> --save_dir <PREDICTION_DIR>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import imageio.v2 as imageio
import napari

LIVECELL_FOLDER = "/home/pape/Work/data/incu_cyte/livecell"


def check_hv_segmentation(image, gt):
from torch_em.transform.label import PerObjectDistanceTransform
from common import opencv_hovernet_instance_segmentation

# This transform gives only directed boundary distances
# and foreground probabilities.
trafo = PerObjectDistanceTransform(
distances=False,
boundary_distances=False,
directed_distances=True,
foreground=True,
min_size=10,
)
target = trafo(gt)
seg = opencv_hovernet_instance_segmentation(target)

v = napari.Viewer()
v.add_image(image)
v.add_image(target)
v.add_labels(gt)
v.add_labels(seg)
napari.run()


def check_distance_segmentation(image, gt):
from torch_em.transform.label import PerObjectDistanceTransform
from torch_em.util.segmentation import watershed_from_center_and_boundary_distances

# This transform gives distance to the centroid,
# to the boundaries and the foreground probabilities
trafo = PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=10,
)
target = trafo(gt)

# run the segmentation
fg, cdist, bdist = target
seg = watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
)

# visualize it
v = napari.Viewer()
v.add_image(image)
v.add_image(target)
v.add_labels(gt)
v.add_labels(seg)
napari.run()


def main():
# load image and ground-truth from LiveCELL
fname = "A172_Phase_A7_1_01d00h00m_1.tif"
image_path = os.path.join(LIVECELL_FOLDER, "images/livecell_train_val_images", fname)
image = imageio.imread(image_path)

label_path = os.path.join(LIVECELL_FOLDER, "annotations/livecell_train_val_images/A172", fname)
gt = imageio.imread(label_path)

# Check the hovernet instance segmentation on GT.
check_hv_segmentation(image, gt)

# Check the new distance based segmentation on GT.
check_distance_segmentation(image, gt)


if __name__ == "__main__":
main()
106 changes: 95 additions & 11 deletions experiments/vision-transformer/unetr/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from skimage.segmentation import find_boundaries
from elf.evaluation import dice_score, mean_segmentation_accuracy

import torch
import torch_em
import torch.nn as nn
from torch_em.util import segmentation
from torch_em.transform.raw import standardize
from torch_em.data.datasets import get_livecell_loader
Expand Down Expand Up @@ -44,10 +47,25 @@ def get_my_livecell_loaders(
input_path: str,
patch_shape: Tuple[int, int],
cell_types: Optional[str] = None,
with_affinities: bool = False
with_binary: bool = False,
with_boundary: bool = False,
with_affinities: bool = False,
with_distance_maps: bool = False,
):
"""Returns the LIVECell training and validation dataloaders
"""

if with_distance_maps:
label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25,
)
else:
label_trafo = None

train_loader = get_livecell_loader(
path=input_path,
split="train",
Expand All @@ -58,8 +76,10 @@ def get_my_livecell_loaders(
cell_types=None if cell_types is None else [cell_types],
# this returns dataloaders with affinity channels and foreground-background channels
offsets=OFFSETS if with_affinities else None,
# this returns dataloaders with foreground and boundary channels
boundaries=False if with_affinities else True
boundaries=with_boundary, # this returns dataloaders with foreground and boundary channels
binary=with_binary,
label_transform=label_trafo,
label_dtype=torch.float32
)
val_loader = get_livecell_loader(
path=input_path,
Expand All @@ -71,8 +91,10 @@ def get_my_livecell_loaders(
cell_types=None if cell_types is None else [cell_types],
# this returns dataloaders with affinity channels and foreground-background channels
offsets=OFFSETS if with_affinities else None,
# this returns dataloaders with foreground and boundary channels
boundaries=False if with_affinities else True
boundaries=with_boundary, # this returns dataloaders with foreground and boundary channels
binary=with_binary,
label_transform=label_trafo,
label_dtype=torch.float32
)

return train_loader, val_loader
Expand Down Expand Up @@ -127,10 +149,12 @@ def get_unetr_model(


#
# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities
# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities / foreground dist. maps
#

def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, ctype=None):
def predict_for_unetr(
img_path, model, root_save_dir, device, with_affinities=False, with_distance_maps=False, ctype=None
):
input_ = imageio.imread(img_path)
input_ = standardize(input_)

Expand All @@ -139,6 +163,11 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c
fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:])
mws = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS)

elif with_distance_maps: # inference using foreground and hv distance maps
outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16))
fg, cdist, bdist = outputs.squeeze()
dm_seg = segmentation.watershed_from_center_and_boundary_distances(cdist, bdist, fg, min_size=50)

else: # inference using foreground-boundary inputs - for the unetr training
outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True)
fg, bd = outputs[0, :, :], outputs[1, :, :]
Expand All @@ -148,14 +177,21 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c
fname = Path(img_path).stem
save_path = os.path.join(root_save_dir, "src-all" if ctype is None else f"src-{ctype}", f"{fname}.h5")
with h5py.File(save_path, "a") as f:
ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype)
ds[:] = fg
if with_affinities:
ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype)
ds[:] = fg
ds = f.require_dataset("affinities", shape=affs.shape, compression="gzip", dtype=affs.dtype)
ds[:] = affs
ds = f.require_dataset("segmentation", shape=mws.shape, compression="gzip", dtype=mws.dtype)
ds[:] = mws

elif with_distance_maps:
ds = f.require_dataset("segmentation", shape=dm_seg.shape, compression="gzip", dtype=dm_seg.dtype)
ds[:] = dm_seg

else:
ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype)
ds[:] = fg
ds = f.require_dataset("boundary", shape=bd.shape, compression="gzip", dtype=bd.dtype)
ds[:] = bd
ds = f.require_dataset("watershed1", shape=ws1.shape, compression="gzip", dtype=ws1.dtype)
Expand All @@ -168,14 +204,18 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c
# LIVECELL UNETR EVALUATION - foreground boundary / foreground affinities
#

def evaluate_for_unetr(gt_path, _save_dir, with_affinities):
def evaluate_for_unetr(gt_path, _save_dir, with_affinities=False, with_distance_maps=False):
fname = Path(gt_path).stem
gt = imageio.imread(gt_path)

output_file = os.path.join(_save_dir, f"{fname}.h5")
with h5py.File(output_file, "r") as f:
if with_affinities:
mws = f["segmentation"][:]

elif with_distance_maps:
instances = f["segmentation"][:]

else:
fg = f["foreground"][:]
bd = f["boundary"][:]
Expand All @@ -186,6 +226,10 @@ def evaluate_for_unetr(gt_path, _save_dir, with_affinities):
mws_msa, mws_sa_acc = mean_segmentation_accuracy(mws, gt, return_accuracies=True)
return mws_msa, mws_sa_acc[0]

elif with_distance_maps:
instances_msa, instances_sa_acc = mean_segmentation_accuracy(instances, gt, return_accuracies=True)
return instances_msa, instances_sa_acc[0]

else:
true_bd = find_boundaries(gt)

Expand Down Expand Up @@ -258,6 +302,11 @@ def get_parser():
help="Path to save predictions from UNETR model"
)

# this argument takes care of which ViT encoder to use for the UNETR (as ViTs from SAM and MAE are different)
parser.add_argument(
"--pretrained_choice", type=str, default="sam",
)

parser.add_argument(
"--with_affinities", action="store_true",
help="Trains the UNETR model with affinities"
Expand All @@ -267,13 +316,48 @@ def get_parser():
return parser


def get_loss_function(with_affinities=True):
def get_loss_function(with_affinities=False, with_distance_maps=False):
if with_affinities:
loss = LossWrapper(
loss=DiceLoss(),
transform=ApplyAndRemoveMask(masking_method="multiply")
)
elif with_distance_maps:
# Updated the loss function for the simplfied distance loss.
# TODO we can try both with and without masking
loss = DistanceLoss(mask_distances_in_bg=True)
else:
loss = DiceLoss()

return loss


class DistanceLoss(nn.Module):
def __init__(self, mask_distances_in_bg):
super().__init__()

self.dice_loss = DiceLoss()
self.mse_loss = nn.MSELoss()
self.mask_distances_in_bg = mask_distances_in_bg

def forward(self, input_, target):
assert input_.shape == target.shape, input_.shape
assert input_.shape[1] == 3, input_.shape

fg_input, fg_target = input_[:, 0, ...], target[:, 0, ...]
fg_loss = self.dice_loss(fg_target, fg_input)

cdist_input, cdist_target = input_[:, 1, ...], target[:, 1, ...]
if self.mask_distances_in_bg:
cdist_loss = self.mse_loss(cdist_target * fg_target, cdist_input * fg_target)
else:
cdist_loss = self.mse_loss(cdist_target, cdist_input)

bdist_input, bdist_target = input_[:, 2, ...], target[:, 2, ...]
if self.mask_distances_in_bg:
bdist_loss = self.mse_loss(bdist_target * fg_target, bdist_input * fg_target)
else:
bdist_loss = self.mse_loss(bdist_target, bdist_input)

overall_loss = fg_loss + cdist_loss + bdist_loss
return overall_loss
Loading

0 comments on commit 22c6bb7

Please sign in to comment.