diff --git a/experiments/vision-transformer/unetr/livecell/README.md b/experiments/vision-transformer/unetr/livecell/README.md new file mode 100644 index 00000000..cdc3df01 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/README.md @@ -0,0 +1,9 @@ +# Using different `UNETR` settings on LIVECell + +- Binary Segmentation - TODO +- Foreground-Boundary Segmentation - TODO +- Affinities - TODO +- Distance Maps (HoVerNet-style) +```python +python livecell_all_hovernet [--train / --predict / --evaluate] -i -s --save_dir +``` diff --git a/experiments/vision-transformer/unetr/livecell/check_hv_segmentation.py b/experiments/vision-transformer/unetr/livecell/check_hv_segmentation.py new file mode 100644 index 00000000..8097bc2a --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/check_hv_segmentation.py @@ -0,0 +1,80 @@ +import os + +import imageio.v2 as imageio +import napari + +LIVECELL_FOLDER = "/home/pape/Work/data/incu_cyte/livecell" + + +def check_hv_segmentation(image, gt): + from torch_em.transform.label import PerObjectDistanceTransform + from common import opencv_hovernet_instance_segmentation + + # This transform gives only directed boundary distances + # and foreground probabilities. + trafo = PerObjectDistanceTransform( + distances=False, + boundary_distances=False, + directed_distances=True, + foreground=True, + min_size=10, + ) + target = trafo(gt) + seg = opencv_hovernet_instance_segmentation(target) + + v = napari.Viewer() + v.add_image(image) + v.add_image(target) + v.add_labels(gt) + v.add_labels(seg) + napari.run() + + +def check_distance_segmentation(image, gt): + from torch_em.transform.label import PerObjectDistanceTransform + from torch_em.util.segmentation import watershed_from_center_and_boundary_distances + + # This transform gives distance to the centroid, + # to the boundaries and the foreground probabilities + trafo = PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=10, + ) + target = trafo(gt) + + # run the segmentation + fg, cdist, bdist = target + seg = watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + ) + + # visualize it + v = napari.Viewer() + v.add_image(image) + v.add_image(target) + v.add_labels(gt) + v.add_labels(seg) + napari.run() + + +def main(): + # load image and ground-truth from LiveCELL + fname = "A172_Phase_A7_1_01d00h00m_1.tif" + image_path = os.path.join(LIVECELL_FOLDER, "images/livecell_train_val_images", fname) + image = imageio.imread(image_path) + + label_path = os.path.join(LIVECELL_FOLDER, "annotations/livecell_train_val_images/A172", fname) + gt = imageio.imread(label_path) + + # Check the hovernet instance segmentation on GT. + check_hv_segmentation(image, gt) + + # Check the new distance based segmentation on GT. + check_distance_segmentation(image, gt) + + +if __name__ == "__main__": + main() diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py index d4497cf7..990ff7e8 100644 --- a/experiments/vision-transformer/unetr/livecell/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -9,6 +9,9 @@ from skimage.segmentation import find_boundaries from elf.evaluation import dice_score, mean_segmentation_accuracy +import torch +import torch_em +import torch.nn as nn from torch_em.util import segmentation from torch_em.transform.raw import standardize from torch_em.data.datasets import get_livecell_loader @@ -44,10 +47,25 @@ def get_my_livecell_loaders( input_path: str, patch_shape: Tuple[int, int], cell_types: Optional[str] = None, - with_affinities: bool = False + with_binary: bool = False, + with_boundary: bool = False, + with_affinities: bool = False, + with_distance_maps: bool = False, ): """Returns the LIVECell training and validation dataloaders """ + + if with_distance_maps: + label_trafo = torch_em.transform.label.PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=25, + ) + else: + label_trafo = None + train_loader = get_livecell_loader( path=input_path, split="train", @@ -58,8 +76,10 @@ def get_my_livecell_loaders( cell_types=None if cell_types is None else [cell_types], # this returns dataloaders with affinity channels and foreground-background channels offsets=OFFSETS if with_affinities else None, - # this returns dataloaders with foreground and boundary channels - boundaries=False if with_affinities else True + boundaries=with_boundary, # this returns dataloaders with foreground and boundary channels + binary=with_binary, + label_transform=label_trafo, + label_dtype=torch.float32 ) val_loader = get_livecell_loader( path=input_path, @@ -71,8 +91,10 @@ def get_my_livecell_loaders( cell_types=None if cell_types is None else [cell_types], # this returns dataloaders with affinity channels and foreground-background channels offsets=OFFSETS if with_affinities else None, - # this returns dataloaders with foreground and boundary channels - boundaries=False if with_affinities else True + boundaries=with_boundary, # this returns dataloaders with foreground and boundary channels + binary=with_binary, + label_transform=label_trafo, + label_dtype=torch.float32 ) return train_loader, val_loader @@ -127,10 +149,12 @@ def get_unetr_model( # -# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities +# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities / foreground dist. maps # -def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, ctype=None): +def predict_for_unetr( + img_path, model, root_save_dir, device, with_affinities=False, with_distance_maps=False, ctype=None +): input_ = imageio.imread(img_path) input_ = standardize(input_) @@ -139,6 +163,11 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) mws = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + elif with_distance_maps: # inference using foreground and hv distance maps + outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) + fg, cdist, bdist = outputs.squeeze() + dm_seg = segmentation.watershed_from_center_and_boundary_distances(cdist, bdist, fg, min_size=50) + else: # inference using foreground-boundary inputs - for the unetr training outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) fg, bd = outputs[0, :, :], outputs[1, :, :] @@ -148,14 +177,21 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c fname = Path(img_path).stem save_path = os.path.join(root_save_dir, "src-all" if ctype is None else f"src-{ctype}", f"{fname}.h5") with h5py.File(save_path, "a") as f: - ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) - ds[:] = fg if with_affinities: + ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) + ds[:] = fg ds = f.require_dataset("affinities", shape=affs.shape, compression="gzip", dtype=affs.dtype) ds[:] = affs ds = f.require_dataset("segmentation", shape=mws.shape, compression="gzip", dtype=mws.dtype) ds[:] = mws + + elif with_distance_maps: + ds = f.require_dataset("segmentation", shape=dm_seg.shape, compression="gzip", dtype=dm_seg.dtype) + ds[:] = dm_seg + else: + ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) + ds[:] = fg ds = f.require_dataset("boundary", shape=bd.shape, compression="gzip", dtype=bd.dtype) ds[:] = bd ds = f.require_dataset("watershed1", shape=ws1.shape, compression="gzip", dtype=ws1.dtype) @@ -168,7 +204,7 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c # LIVECELL UNETR EVALUATION - foreground boundary / foreground affinities # -def evaluate_for_unetr(gt_path, _save_dir, with_affinities): +def evaluate_for_unetr(gt_path, _save_dir, with_affinities=False, with_distance_maps=False): fname = Path(gt_path).stem gt = imageio.imread(gt_path) @@ -176,6 +212,10 @@ def evaluate_for_unetr(gt_path, _save_dir, with_affinities): with h5py.File(output_file, "r") as f: if with_affinities: mws = f["segmentation"][:] + + elif with_distance_maps: + instances = f["segmentation"][:] + else: fg = f["foreground"][:] bd = f["boundary"][:] @@ -186,6 +226,10 @@ def evaluate_for_unetr(gt_path, _save_dir, with_affinities): mws_msa, mws_sa_acc = mean_segmentation_accuracy(mws, gt, return_accuracies=True) return mws_msa, mws_sa_acc[0] + elif with_distance_maps: + instances_msa, instances_sa_acc = mean_segmentation_accuracy(instances, gt, return_accuracies=True) + return instances_msa, instances_sa_acc[0] + else: true_bd = find_boundaries(gt) @@ -258,6 +302,11 @@ def get_parser(): help="Path to save predictions from UNETR model" ) + # this argument takes care of which ViT encoder to use for the UNETR (as ViTs from SAM and MAE are different) + parser.add_argument( + "--pretrained_choice", type=str, default="sam", + ) + parser.add_argument( "--with_affinities", action="store_true", help="Trains the UNETR model with affinities" @@ -267,13 +316,48 @@ def get_parser(): return parser -def get_loss_function(with_affinities=True): +def get_loss_function(with_affinities=False, with_distance_maps=False): if with_affinities: loss = LossWrapper( loss=DiceLoss(), transform=ApplyAndRemoveMask(masking_method="multiply") ) + elif with_distance_maps: + # Updated the loss function for the simplfied distance loss. + # TODO we can try both with and without masking + loss = DistanceLoss(mask_distances_in_bg=True) else: loss = DiceLoss() return loss + + +class DistanceLoss(nn.Module): + def __init__(self, mask_distances_in_bg): + super().__init__() + + self.dice_loss = DiceLoss() + self.mse_loss = nn.MSELoss() + self.mask_distances_in_bg = mask_distances_in_bg + + def forward(self, input_, target): + assert input_.shape == target.shape, input_.shape + assert input_.shape[1] == 3, input_.shape + + fg_input, fg_target = input_[:, 0, ...], target[:, 0, ...] + fg_loss = self.dice_loss(fg_target, fg_input) + + cdist_input, cdist_target = input_[:, 1, ...], target[:, 1, ...] + if self.mask_distances_in_bg: + cdist_loss = self.mse_loss(cdist_target * fg_target, cdist_input * fg_target) + else: + cdist_loss = self.mse_loss(cdist_target, cdist_input) + + bdist_input, bdist_target = input_[:, 2, ...], target[:, 2, ...] + if self.mask_distances_in_bg: + bdist_loss = self.mse_loss(bdist_target * fg_target, bdist_input * fg_target) + else: + bdist_loss = self.mse_loss(bdist_target, bdist_input) + + overall_loss = fg_loss + cdist_loss + bdist_loss + return overall_loss diff --git a/experiments/vision-transformer/unetr/livecell/livecell_all_hovernet.py b/experiments/vision-transformer/unetr/livecell/livecell_all_hovernet.py new file mode 100644 index 00000000..92887f63 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/livecell_all_hovernet.py @@ -0,0 +1,144 @@ +import os +from tqdm import tqdm +from glob import glob + +import numpy as np +import pandas as pd + +import torch +import torch_em + +import common + + +def do_unetr_hovernet_training( + train_loader, val_loader, model, device, iterations, loss, save_root +): + print("Run training with hovernet ideas for all cell types") + trainer = torch_em.default_segmentation_trainer( + name="livecell-all", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-5, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss + ) + trainer.fit(iterations) + + +def do_unetr_hovernet_inference( + input_path: str, + device: torch.device, + model, + root_save_dir: str, + save_root: str, + with_distance_maps: bool +): + test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") + model_ckpt = os.path.join(save_root, "checkpoints", "livecell-all", "best.pt") + assert os.path.exists(model_ckpt), model_ckpt + + model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + # creating the respective directories for saving the outputs + os.makedirs(os.path.join(root_save_dir, "src-all"), exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): + common.predict_for_unetr(img_path, model, root_save_dir, device, with_distance_maps=with_distance_maps) + + +def do_unetr_hovernet_evaluation( + input_path: str, + root_save_dir: str, + csv_save_dir: str, + with_distance_maps: bool +): + _save_dir = os.path.join(root_save_dir, "src-all") + assert os.path.exists(_save_dir), _save_dir + + gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", "*", "*") + + msa_list, sa50_list = [], [] + for gt_path in tqdm(glob(gt_dir)): + all_metrics = common.evaluate_for_unetr(gt_path, _save_dir, with_distance_maps=with_distance_maps) + msa, sa50 = all_metrics + msa_list.append(msa) + sa50_list.append(sa50) + + res_dict = { + "LIVECell": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list) + } + + df = pd.DataFrame.from_dict([res_dict]) + df.to_csv(os.path.join(csv_save_dir, "livecell.csv")) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # directory folder to save different parts of the scheme + dir_structure = os.path.join(args.model_name, "hovernet", "torch-em-sam") + + # get the desired loss function for training + loss = common.get_loss_function(with_distance_maps=True) + + # get the model for the training and inference on livecell dataset + model = common.get_unetr_model( + model_name=args.model_name, source_choice="torch-em", patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, output_channels=3, # foreground-background, x-map, y-map + backbone=args.pretrained_choice + ) + model.to(device) + + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root + + if args.train: + print("2d UNETR hovernet-style training on LIVECell dataset") + # get the desried livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, with_distance_maps=True + ) + do_unetr_hovernet_training( + train_loader=train_loader, val_loader=val_loader, model=model, + device=device, save_root=save_root, iterations=args.iterations, loss=loss + ) + + # determines the directory where the predictions will be saved + root_save_dir = os.path.join(args.save_dir, dir_structure) + + if args.predict: + print("2d UNETR hovernet-style inference on LIVECell dataset") + do_unetr_hovernet_inference( + input_path=args.input, device=device, model=model, save_root=save_root, + root_save_dir=root_save_dir, with_distance_maps=True + ) + print("Predictions are saved in", root_save_dir) + + if args.evaluate: + print("2d UNETR hovernet-style evaluation on LIVECell dataset") + csv_save_dir = os.path.join("results", dir_structure) + os.makedirs(csv_save_dir, exist_ok=True) + + do_unetr_hovernet_evaluation( + input_path=args.input, root_save_dir=root_save_dir, + csv_save_dir=csv_save_dir, with_distance_maps=True + ) + + +if __name__ == "__main__": + parser = common.get_parser() + args = parser.parse_args() + main(args)