From 648fb14c5ca8398916f444f0363d19f091cf49a7 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 2 Nov 2023 18:10:25 +0100 Subject: [PATCH 1/9] Add training script for livecell with affinities --- .../vision-transformer/unetr/common.py | 166 +++++++++++++++ .../unetr/livecell_unetr.py | 194 +++++------------- 2 files changed, 214 insertions(+), 146 deletions(-) create mode 100644 experiments/vision-transformer/unetr/common.py diff --git a/experiments/vision-transformer/unetr/common.py b/experiments/vision-transformer/unetr/common.py new file mode 100644 index 00000000..fe3dbaf5 --- /dev/null +++ b/experiments/vision-transformer/unetr/common.py @@ -0,0 +1,166 @@ +import argparse +from typing import Tuple + +from torch_em.data.datasets import get_livecell_loader +from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask + + +OFFSETS = [ + [-1, 0], [0, -1], + [-3, 0], [0, -3], + [-9, 0], [0, -9], + [-27, 0], [0, -27] +] + + +CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + + +# +# LIVECELL DATALOADERS +# + +def get_my_livecell_loaders( + input_path: str, + patch_shape: Tuple[int, int], + cell_types: str, + with_affinities: bool = False +): + """Returns the LIVECell training and validation dataloaders + """ + if with_affinities: + # this returns dataloaders with affinity channels and foreground-background channels + n_out = len(OFFSETS) + 1 + train_loader = get_livecell_loader( + path=input_path, split="train", patch_shape=patch_shape, batch_size=2, + cell_types=[cell_types], download=True, offsets=OFFSETS, num_workers=16 + ) + val_loader = get_livecell_loader( + path=input_path, split="val", patch_shape=patch_shape, batch_size=1, + cell_types=[cell_types], download=True, offsets=OFFSETS, num_workers=16 + ) + + else: + # this returns dataloaders with foreground and boundary channels + n_out = 2 + train_loader = get_livecell_loader( + path=input_path, split="train", patch_shape=patch_shape, batch_size=2, + cell_types=[cell_types], download=True, boundaries=True, num_workers=16 + ) + val_loader = get_livecell_loader( + path=input_path, split="val", patch_shape=patch_shape, batch_size=1, + cell_types=[cell_types], download=True, boundaries=True, num_workers=16 + ) + + return train_loader, val_loader, n_out + + +# +# UNETR MODEL(S) FROM MONAI AND torch_em +# + + +def get_unetr_model( + model_name: str, + source_choice: str, + patch_shape: Tuple[int, int], + sam_initialization: bool, + output_channels: int +): + """Returns the expected UNETR model + """ + if source_choice == "torch-em": + # this returns the unetr model whihc uses the vision transformer from segment anything + from torch_em import model as torch_em_models + model = torch_em_models.UNETR( + encoder=model_name, out_channels=output_channels, + encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None + ) + + elif source_choice == "monai": + # this returns the unetr model from monai + from monai.networks import nets as monai_models + model = monai_models.unetr.UNETR( + in_channels=1, + out_channels=output_channels, + img_size=patch_shape, + spatial_dims=2 + ) + model.out_channels = 2 # type: ignore + + else: + raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") + + return model + + +# +# miscellanous utilities +# + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train", action='store_true', help="Enables UNETR training on LiveCell dataset" + ) + + parser.add_argument( + "--predict", action='store_true', help="Enables UNETR prediction on LiveCell dataset" + ) + + parser.add_argument( + "--evaluate", action='store_true', help="Enables UNETR evaluation on LiveCell dataset" + ) + + parser.add_argument( + "--source_choice", type=str, default="torch-em", + help="The source where the model comes from, i.e. either torch-em / monai" + ) + + parser.add_argument( + "-m", "--model_name", type=str, default="vit_b", help="Name of the ViT to use as the encoder in UNETR" + ) + + parser.add_argument( + "--do_sam_ini", action='store_true', help="Enables initializing UNETR with SAM's ViT weights" + ) + + parser.add_argument( + "-c", "--cell_type", type=str, default=None, help="Choice of cell-type for doing the training" + ) + + parser.add_argument( + "-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell", + help="Path where the dataset already exists/will be downloaded by the dataloader" + ) + + parser.add_argument( + "-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/", + help="Path where checkpoints and logs will be saved" + ) + + parser.add_argument( + "--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", + help="Path to save predictions from UNETR model" + ) + + parser.add_argument( + "--with_affinities", action="store_true", + help="Trains the UNETR model with affinities" + ) + + parser.add_argument("--iterations", type=int, default=100000) + return parser + + +def get_loss_function(with_affinities=True): + if with_affinities: + loss = LossWrapper( + loss=DiceLoss(), + transform=ApplyAndRemoveMask(masking_method="multiply") + ) + else: + loss = DiceLoss() + + return loss diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 0ff9cde2..a7f6ba79 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -1,10 +1,9 @@ import os -import argparse import numpy as np import pandas as pd from glob import glob from tqdm import tqdm -from typing import Tuple, List +from typing import List import imageio.v2 as imageio from skimage.segmentation import find_boundaries @@ -14,81 +13,22 @@ import torch_em from torch_em.util import segmentation from torch_em.transform.raw import standardize -from torch_em.data.datasets import get_livecell_loader from torch_em.util.prediction import predict_with_halo - -def get_unetr_model( - model_name: str, - source_choice: str, - patch_shape: Tuple[int, int], - sam_initialization: bool, - output_channels: int -): - """Returns the expected UNETR model - """ - if source_choice == "torch-em": - from torch_em import model as torch_em_models - model = torch_em_models.UNETR( - encoder=model_name, out_channels=output_channels, - encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None - ) - - elif source_choice == "monai": - from monai.networks import nets as monai_models - model = monai_models.unetr.UNETR( - in_channels=1, - out_channels=output_channels, - img_size=patch_shape, - spatial_dims=2 - ) - model.out_channels = 2 # type: ignore - - else: - raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") - - return model +import common def do_unetr_training( - input_path: str, + train_loader, + val_loader, model, cell_types: List[str], - patch_shape: Tuple[int, int], device: torch.device, - save_root: str, iterations: int, - sam_initialization: bool, - source_choice: str + loss, + save_root: str ): print("Run training for cell types:", cell_types) - train_loader = get_livecell_loader( - path=input_path, - split="train", - patch_shape=patch_shape, - batch_size=2, - cell_types=[cell_types], - download=True, - boundaries=True, - num_workers=8 - ) - - val_loader = get_livecell_loader( - path=input_path, - split="val", - patch_shape=patch_shape, - batch_size=1, - cell_types=[cell_types], - download=True, - boundaries=True, - num_workers=8 - ) - - _save_root = os.path.join( - save_root, - f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - ) if save_root is not None else save_root - trainer = torch_em.default_segmentation_trainer( name=f"livecell-{cell_types}", model=model, @@ -99,9 +39,9 @@ def do_unetr_training( mixed_precision=True, log_image_interval=50, compile_model=False, - save_root=_save_root + save_root=save_root, + loss=loss ) - trainer.fit(iterations) @@ -111,16 +51,12 @@ def do_unetr_inference( model, cell_types: List[str], root_save_dir: str, - sam_initialization: bool, save_root: str, - source_choice: str ): for ctype in cell_types: test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") - model_ckpt = os.path.join(save_root, - f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch", - "checkpoints", f"livecell-{ctype}", "best.pt") + model_ckpt = os.path.join(save_root, "checkpoints", f"livecell-{ctype}", "best.pt") assert os.path.exists(model_ckpt) model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) @@ -175,7 +111,8 @@ def do_unetr_evaluation( continue fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} - ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} + (ws1_msa_set, ws2_msa_set, + ws1_sa50_set, ws2_sa50_set) = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): fg_dir = os.path.join(_save_dir, "foreground") bd_dir = os.path.join(_save_dir, "boundary") @@ -250,98 +187,63 @@ def do_unetr_evaluation( def main(args): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - patch_shape = (512, 512) - output_channels = 2 - - all_cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] - - model = get_unetr_model( - model_name=args.model_name, - source_choice=args.source_choice, - patch_shape=patch_shape, - sam_initialization=args.do_sam_ini, - output_channels=output_channels + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # get the desired loss function for training + loss = common.get_loss_function( + with_affinities=args.with_affinities # takes care of calling the loss for training with affinities + ) + + # get the desired livecell loaders for training + train_loader, val_loader, output_channels = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, + with_affinities=args.with_affinities # this takes care of getting the loaders with affinities + ) + + # get the model for the training and inference on livecell dataset + model = common.get_unetr_model( + model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, output_channels=output_channels ) model.to(device) + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join( + args.save_root, + "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) if args.save_root is not None else args.save_root + if args.train: - print("2d UNETR training on LiveCell dataset") + print("2d UNETR training on LIVECell dataset") do_unetr_training( - input_path=args.input, - model=model, - cell_types=args.cell_type, - patch_shape=patch_shape, - device=device, - save_root=args.save_root, - iterations=args.iterations, - sam_initialization=args.do_sam_ini, - source_choice=args.source_choice + train_loader=train_loader, val_loader=val_loader, model=model, cell_types=args.cell_type, + device=device, save_root=save_root, iterations=args.iterations, loss=loss ) + # determines the directory where the predictions will be saved root_save_dir = os.path.join( - args.save_dir, - f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" + args.save_dir, f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" ) print("Predictions are saved in", root_save_dir) if args.predict: - print("2d UNETR inference on LiveCell dataset") + print("2d UNETR inference on LIVECell dataset") do_unetr_inference( - input_path=args.input, - device=device, - model=model, - cell_types=all_cell_types, - root_save_dir=root_save_dir, - sam_initialization=args.do_sam_ini, - save_root=args.save_root, - source_choice=args.source_choice + input_path=args.input, device=device, model=model, cell_types=common.CELL_TYPES, + root_save_dir=root_save_dir, save_root=save_root ) if args.evaluate: - print("2d UNETR evaluation on LiveCell dataset") + print("2d UNETR evaluation on LIVECell dataset") do_unetr_evaluation( - input_path=args.input, - cell_types=all_cell_types, - root_save_dir=root_save_dir, - sam_initialization=args.do_sam_ini, - source_choice=args.source_choice + input_path=args.input, cell_types=common.CELL_TYPES, root_save_dir=root_save_dir, + sam_initialization=args.do_sam_ini, source_choice=args.source_choice ) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--train", action='store_true', - help="Enables UNETR training on LiveCell dataset") - - parser.add_argument("--predict", action='store_true', - help="Enables UNETR prediction on LiveCell dataset") - - parser.add_argument("--evaluate", action='store_true', - help="Enables UNETR evaluation on LiveCell dataset") - - parser.add_argument("--source_choice", type=str, default="torch-em", - help="The source where the model comes from, i.e. either torch-em / monai") - - parser.add_argument("-m", "--model_name", type=str, default="vit_b", - help="Name of the ViT to use as the encoder in UNETR") - - parser.add_argument("--do_sam_ini", action='store_true', - help="Enables initializing UNETR with SAM's ViT weights") - - parser.add_argument("-c", "--cell_type", type=str, default=None, - help="Choice of cell-type for doing the training") - - parser.add_argument("-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell", - help="Path where the dataset already exists/will be downloaded by the dataloader") - - parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/", - help="Path where checkpoints and logs will be saved") - - parser.add_argument("--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", - help="Path to save predictions from UNETR model") - - parser.add_argument("--iterations", type=int, default=100000) - + parser = common.get_parser() args = parser.parse_args() main(args) From 1cda92648142a0cea847ca99cb673f95beb9dd5f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 2 Nov 2023 21:32:26 +0100 Subject: [PATCH 2/9] Update metric in training --- experiments/vision-transformer/unetr/livecell_unetr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index a7f6ba79..3757a0c9 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -40,7 +40,8 @@ def do_unetr_training( log_image_interval=50, compile_model=False, save_root=save_root, - loss=loss + loss=loss, + metric=loss ) trainer.fit(iterations) From 625443f62f5472111543494689d89daae9fda088 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 3 Nov 2023 16:54:19 +0100 Subject: [PATCH 3/9] Refactor inference for livecell unetr --- .../unetr/livecell_unetr.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 3757a0c9..b89703ab 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -53,6 +53,7 @@ def do_unetr_inference( cell_types: List[str], root_save_dir: str, save_root: str, + with_affinities: bool ): for ctype in cell_types: test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") @@ -64,15 +65,11 @@ def do_unetr_inference( model.to(device) model.eval() - fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground") - bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary") - ws1_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed1") - ws2_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed2") - - os.makedirs(fg_save_dir, exist_ok=True) - os.makedirs(bd_save_dir, exist_ok=True) - os.makedirs(ws1_save_dir, exist_ok=True) - os.makedirs(ws2_save_dir, exist_ok=True) + # creating the respective directories for saving the outputs + _settings = ["foreground", "boundary", "watershed1", "watershed2"] + for _setting in _settings: + tmp_save_dir = os.path.join(root_save_dir, f"src-{ctype}", _setting) + os.makedirs(tmp_save_dir, exist_ok=True) with torch.no_grad(): for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): @@ -80,19 +77,21 @@ def do_unetr_inference( input_img = imageio.imread(img_path) input_img = standardize(input_img) - outputs = predict_with_halo( - input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True - ) - fg, bd = outputs[0, :, :], outputs[1, :, :] + if with_affinities: + raise NotImplementedError("This still needs to be implemented for affinity-based training") - ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) - ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) + else: # inference using foreground-boundary inputs - for the unetr training + outputs = predict_with_halo( + input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True + ) + fg, bd = outputs[0, :, :], outputs[1, :, :] + ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) - imageio.imwrite(os.path.join(fg_save_dir, fname), fg) - imageio.imwrite(os.path.join(bd_save_dir, fname), bd) - imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1) - imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2) + _save_outputs = [fg, bd, ws1, ws2] + for _setting, _output in zip(_settings, _save_outputs): + imageio.imwrite(os.path.join(root_save_dir, f"src-{ctype}", _setting, fname), _output) def do_unetr_evaluation( @@ -233,7 +232,7 @@ def main(args): print("2d UNETR inference on LIVECell dataset") do_unetr_inference( input_path=args.input, device=device, model=model, cell_types=common.CELL_TYPES, - root_save_dir=root_save_dir, save_root=save_root + root_save_dir=root_save_dir, save_root=save_root, with_affinities=args.with_affinities ) if args.evaluate: From e4d77b3084548b62049ba9cb1ded8f4038f7ad03 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 4 Nov 2023 12:31:15 +0100 Subject: [PATCH 4/9] Update mutex watershed segmentation functionality --- torch_em/util/segmentation.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index d1c27441..b18b8635 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -34,10 +34,21 @@ def size_filter(seg, min_size, hmap=None, with_background=False): return seg -def mutex_watershed(affinities, offsets, mask=None, strides=None): - return elseg.mutex_watershed( - affinities, offsets, mask=mask, strides=strides, randomize_strides=True - ).astype("uint64") +def mutex_watershed_segmentation(foreground, affinities, offsets, min_size, threshold=0.5): + """Computes the mutex watershed segmentation using the affinity maps for respective pixel offsets + + Arguments: + - foreground: [np.ndarray] - The foreground background channel for the objects + - affinities [np.ndarray] - The input affinity maps + - offsets: [list[list[int]]] - The pixel offsets corresponding to the affinity channels + - min_size: [int] - The minimum pixels (below which) to filter objects + - threshold: [float] - To threshold foreground predictions + """ + mask = (foreground >= threshold) + strides = [2] * foreground.ndim + segmentation = elseg.mutex_watershed(affinities, offets=offsets, mask=mask, strides=strides, randomize_strides=True) + segmentation = size_filter(segmentation.astype("uint32"), min_size=min_size, hmap=affinities, with_background=True) + return segmentation def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): From 5b4a3ae1d887746ebf32425fc97dd15fa2e971eb Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 4 Nov 2023 15:25:23 +0100 Subject: [PATCH 5/9] Update inference with affinities for unetr --- .../vision-transformer/unetr/common.py | 60 ++++++++++++++++++- .../unetr/livecell_unetr.py | 54 +++++------------ torch_em/util/segmentation.py | 7 ++- 3 files changed, 75 insertions(+), 46 deletions(-) diff --git a/experiments/vision-transformer/unetr/common.py b/experiments/vision-transformer/unetr/common.py index fe3dbaf5..df24e0bb 100644 --- a/experiments/vision-transformer/unetr/common.py +++ b/experiments/vision-transformer/unetr/common.py @@ -1,8 +1,18 @@ +import os +import h5py import argparse +import numpy as np from typing import Tuple +import imageio.v3 as imageio + +from torch_em.util import segmentation +from torch_em.transform.raw import standardize from torch_em.data.datasets import get_livecell_loader from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask +from torch_em.util.prediction import predict_with_halo, predict_with_padding + +import common OFFSETS = [ @@ -20,6 +30,15 @@ # LIVECELL DATALOADERS # + +def _get_output_channels(with_affinities): + if with_affinities: + n_out = len(OFFSETS) + 1 + else: + n_out = 2 + return n_out + + def get_my_livecell_loaders( input_path: str, patch_shape: Tuple[int, int], @@ -30,7 +49,6 @@ def get_my_livecell_loaders( """ if with_affinities: # this returns dataloaders with affinity channels and foreground-background channels - n_out = len(OFFSETS) + 1 train_loader = get_livecell_loader( path=input_path, split="train", patch_shape=patch_shape, batch_size=2, cell_types=[cell_types], download=True, offsets=OFFSETS, num_workers=16 @@ -42,7 +60,6 @@ def get_my_livecell_loaders( else: # this returns dataloaders with foreground and boundary channels - n_out = 2 train_loader = get_livecell_loader( path=input_path, split="train", patch_shape=patch_shape, batch_size=2, cell_types=[cell_types], download=True, boundaries=True, num_workers=16 @@ -52,7 +69,7 @@ def get_my_livecell_loaders( cell_types=[cell_types], download=True, boundaries=True, num_workers=16 ) - return train_loader, val_loader, n_out + return train_loader, val_loader # @@ -94,6 +111,43 @@ def get_unetr_model( return model +# +# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities +# + +def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affinities): + input_ = imageio.imread(img_path) + input_ = standardize(input_) + + if with_affinities: # inference using affinities + outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) + fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) + mws = segmentation.mutex_watershed_segmentation(fg, affs, common.OFFSETS, 100) + + else: # inference using foreground-boundary inputs - for the unetr training + outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) + fg, bd = outputs[0, :, :], outputs[1, :, :] + ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) + + fname = os.path.split(img_path)[-1] + with h5py.File(os.path.join(root_save_dir, f"src-{ctype}", f"{fname}.h5"), "a") as f: + ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) + ds[:] = fg + if with_affinities: + ds = f.require_dataset("affinities", shape=affs.shape, compression="gzip", dtype=affs.dtype) + ds[:] = affs + ds = f.require_dataset("segmentation", shape=mws.shape, compression="gzip", dtype=mws.dtype) + ds[:] = mws + else: + ds = f.require_dataset("boundary", shape=bd.shape, compression="gzip", dtype=bd.dtype) + ds[:] = bd + ds = f.require_dataset("watershed1", shape=ws1.shape, compression="gzip", dtype=ws1.dtype) + ds[:] = ws1 + ds = f.require_dataset("watershed2", shape=ws2.shape, compression="gzip", dtype=ws2.dtype) + ds[:] = ws2 + + # # miscellanous utilities # diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index b89703ab..4040d4c3 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -5,15 +5,12 @@ from tqdm import tqdm from typing import List -import imageio.v2 as imageio +import imageio.v3 as imageio from skimage.segmentation import find_boundaries from elf.evaluation import dice_score, mean_segmentation_accuracy import torch import torch_em -from torch_em.util import segmentation -from torch_em.transform.raw import standardize -from torch_em.util.prediction import predict_with_halo import common @@ -55,43 +52,21 @@ def do_unetr_inference( save_root: str, with_affinities: bool ): + test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") for ctype in cell_types: - test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") - model_ckpt = os.path.join(save_root, "checkpoints", f"livecell-{ctype}", "best.pt") - assert os.path.exists(model_ckpt) + assert os.path.exists(model_ckpt), model_ckpt model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) model.to(device) model.eval() # creating the respective directories for saving the outputs - _settings = ["foreground", "boundary", "watershed1", "watershed2"] - for _setting in _settings: - tmp_save_dir = os.path.join(root_save_dir, f"src-{ctype}", _setting) - os.makedirs(tmp_save_dir, exist_ok=True) + os.makedirs(os.path.join(root_save_dir, f"src-{ctype}"), exist_ok=True) with torch.no_grad(): for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): - fname = os.path.split(img_path)[-1] - - input_img = imageio.imread(img_path) - input_img = standardize(input_img) - - if with_affinities: - raise NotImplementedError("This still needs to be implemented for affinity-based training") - - else: # inference using foreground-boundary inputs - for the unetr training - outputs = predict_with_halo( - input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True - ) - fg, bd = outputs[0, :, :], outputs[1, :, :] - ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) - ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) - - _save_outputs = [fg, bd, ws1, ws2] - for _setting, _output in zip(_settings, _save_outputs): - imageio.imwrite(os.path.join(root_save_dir, f"src-{ctype}", _setting, fname), _output) + common.predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affinities) def do_unetr_evaluation( @@ -195,28 +170,26 @@ def main(args): with_affinities=args.with_affinities # takes care of calling the loss for training with affinities ) - # get the desired livecell loaders for training - train_loader, val_loader, output_channels = common.get_my_livecell_loaders( - args.input, patch_shape, args.cell_type, - with_affinities=args.with_affinities # this takes care of getting the loaders with affinities - ) - # get the model for the training and inference on livecell dataset model = common.get_unetr_model( model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, - sam_initialization=args.do_sam_ini, output_channels=output_channels + sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities) ) model.to(device) # determining where to save the checkpoints and tensorboard logs save_root = os.path.join( - args.save_root, - "affinities" if args.with_affinities else "boundaries", + args.save_root, "affinities" if args.with_affinities else "boundaries", f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" ) if args.save_root is not None else args.save_root if args.train: print("2d UNETR training on LIVECell dataset") + # get the desired livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, + with_affinities=args.with_affinities # this takes care of getting the loaders with affinities + ) do_unetr_training( train_loader=train_loader, val_loader=val_loader, model=model, cell_types=args.cell_type, device=device, save_root=save_root, iterations=args.iterations, loss=loss @@ -224,7 +197,8 @@ def main(args): # determines the directory where the predictions will be saved root_save_dir = os.path.join( - args.save_dir, f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" + args.save_dir, "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" ) print("Predictions are saved in", root_save_dir) diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index b18b8635..8767589f 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -3,6 +3,7 @@ import vigra import elf.segmentation as elseg from elf.segmentation.utils import normalize_input +from elf.segmentation.mutex_watershed import mutex_watershed from skimage.measure import label from skimage.filters import gaussian @@ -46,9 +47,9 @@ def mutex_watershed_segmentation(foreground, affinities, offsets, min_size, thre """ mask = (foreground >= threshold) strides = [2] * foreground.ndim - segmentation = elseg.mutex_watershed(affinities, offets=offsets, mask=mask, strides=strides, randomize_strides=True) - segmentation = size_filter(segmentation.astype("uint32"), min_size=min_size, hmap=affinities, with_background=True) - return segmentation + seg = mutex_watershed(affinities, offsets=offsets, mask=mask, strides=strides, randomize_strides=True) + seg = size_filter(seg.astype("uint32"), min_size=min_size, hmap=affinities, with_background=True) + return seg def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): From 249b55e4c62f9c44ad3910056b5be85a589a4586 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 Nov 2023 15:54:32 +0100 Subject: [PATCH 6/9] Update MWS Evaluation --- .../vision-transformer/unetr/common.py | 49 +++++- .../unetr/livecell_unetr.py | 144 +++++++++--------- .../unetr/submit_training.py | 12 +- 3 files changed, 125 insertions(+), 80 deletions(-) diff --git a/experiments/vision-transformer/unetr/common.py b/experiments/vision-transformer/unetr/common.py index df24e0bb..50640390 100644 --- a/experiments/vision-transformer/unetr/common.py +++ b/experiments/vision-transformer/unetr/common.py @@ -3,8 +3,11 @@ import argparse import numpy as np from typing import Tuple +from pathlib import Path import imageio.v3 as imageio +from skimage.segmentation import find_boundaries +from elf.evaluation import dice_score, mean_segmentation_accuracy from torch_em.util import segmentation from torch_em.transform.raw import standardize @@ -130,7 +133,7 @@ def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affini ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) - fname = os.path.split(img_path)[-1] + fname = Path(img_path).stem with h5py.File(os.path.join(root_save_dir, f"src-{ctype}", f"{fname}.h5"), "a") as f: ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) ds[:] = fg @@ -148,6 +151,50 @@ def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affini ds[:] = ws2 +# +# LIVECELL UNETR EVALUATION - foreground boundary / foreground affinities +# + +def evaluate_for_unetr(gt_path, _save_dir, with_affinities): + # FIXME: fname = Path(img_path).stem + fname = os.path.split(gt_path)[-1] + gt = imageio.imread(gt_path) + + output_file = os.path.join(_save_dir, f"{fname}.h5") + with h5py.File(output_file, "r") as f: + if with_affinities: + mws = f["segmentation"][:] + else: + fg = f["foreground"][:] + bd = f["boundary"][:] + ws1 = f["watershed1"][:] + ws2 = f["watershed2"][:] + + if with_affinities: + mws_msa, mws_sa_acc = mean_segmentation_accuracy(mws, gt, return_accuracies=True) + return mws_msa, mws_sa_acc[0] + + else: + true_bd = find_boundaries(gt) + + # Compare the foreground prediction to the ground-truth. + # Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to + # foreground in the dice function, since we have a comparision > 0 in there, and everything in the + # binary prediction evaluates to true. + # For the GT we can set the threshold to 0, because this will map to the correct binary mask. + fg_dice = dice_score(fg, gt, threshold_gt=0, threshold_seg=None) + + # Compare the background prediction to the ground-truth. + # Here, we don't need any thresholds: for the prediction the same holds as before. + # For the ground-truth we have already a binary label, so we don't need to threshold it again. + bd_dice = dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None) + + msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore + msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore + + return fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 + + # # miscellanous utilities # diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 4040d4c3..32fabdb6 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -5,10 +5,6 @@ from tqdm import tqdm from typing import List -import imageio.v3 as imageio -from skimage.segmentation import find_boundaries -from elf.evaluation import dice_score, mean_segmentation_accuracy - import torch import torch_em @@ -73,92 +69,88 @@ def do_unetr_evaluation( input_path: str, cell_types: List[str], root_save_dir: str, - sam_initialization: bool, - source_choice: str + csv_save_dir: str, + with_affinities: bool ): + # list for foreground-boundary evaluations fg_list, bd_list = [], [] ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], [] + # lists for affinities evaluation + mws_msa_list, mws_sa50_list = [], [] + for c1 in cell_types: + # we check whether we have predictions from a particular cell-type _save_dir = os.path.join(root_save_dir, f"src-{c1}") if not os.path.exists(_save_dir): print("Skipping", _save_dir) continue + # dict for foreground-boundary evaluations fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} (ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set) = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} - for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): - fg_dir = os.path.join(_save_dir, "foreground") - bd_dir = os.path.join(_save_dir, "boundary") - ws1_dir = os.path.join(_save_dir, "watershed1") - ws2_dir = os.path.join(_save_dir, "watershed2") + # dict for affinities evaluation + mws_msa_set, mws_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} + + for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*") + + # cell-wise evaluation list for foreground-boundary evaluations cwise_fg, cwise_bd = [], [] cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], [] + + # cell-wise evaluation list for affinities evaluation + cwise_mws_msa, cwise_mws_sa50 = [], [] + for gt_path in glob(gt_dir): - fname = os.path.split(gt_path)[-1] - - gt = imageio.imread(gt_path) - fg = imageio.imread(os.path.join(fg_dir, fname)) - bd = imageio.imread(os.path.join(bd_dir, fname)) - ws1 = imageio.imread(os.path.join(ws1_dir, fname)) - ws2 = imageio.imread(os.path.join(ws2_dir, fname)) - - true_bd = find_boundaries(gt) - - # Compare the foreground prediction to the ground-truth. - # Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to - # foreground in the dice function, since we have a comparision > 0 in there, and everything in the - # binary prediction evaluates to true. - # For the GT we can set the threshold to 0, because this will map to the correct binary mask. - cwise_fg.append(dice_score(fg, gt, threshold_gt=0, threshold_seg=None)) - - # Compare the background prediction to the ground-truth. - # Here, we don't need any thresholds: for the prediction the same holds as before. - # For the ground-truth we have already a binary label, so we don't need to threshold it again. - cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None)) - - msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore - msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore - - cwise_ws1_msa.append(msa1) - cwise_ws2_msa.append(msa2) - cwise_ws1_sa50.append(sa_acc1[0]) - cwise_ws2_sa50.append(sa_acc2[0]) - - fg_set[c2] = np.mean(cwise_fg) - bd_set[c2] = np.mean(cwise_bd) - ws1_msa_set[c2] = np.mean(cwise_ws1_msa) - ws2_msa_set[c2] = np.mean(cwise_ws2_msa) - ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50) - ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50) - - fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore - bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore - ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore - ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore - ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore - ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore - - f_df_fg = pd.concat(fg_list, ignore_index=True) - f_df_bd = pd.concat(bd_list, ignore_index=True) - f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True) - f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True) - f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True) - f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True) - - tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - csv_save_dir = f"./results/{tmp_csv_name}" - os.makedirs(csv_save_dir, exist_ok=True) - - f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) - f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) - f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) - f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) - f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) - f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) + all_metrics = common.evaluate_for_unetr(gt_path, _save_dir, with_affinities) + if with_affinities: + mws_msa, mws_sa50 = all_metrics + cwise_mws_msa.append(mws_msa) + cwise_mws_sa50.append(mws_sa50) + else: + fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 = all_metrics + cwise_fg.append(fg_dice) + cwise_bd.append(bd_dice) + cwise_ws1_msa.append(msa1) + cwise_ws2_msa.append(msa2) + cwise_ws1_sa50.append(sa_acc1[0]) + cwise_ws2_sa50.append(sa_acc2[0]) + + if with_affinities: + mws_msa_set[c2], mws_sa50_set[c2] = np.mean(cwise_mws_msa), np.mean(cwise_mws_sa50) + else: + fg_set[c2], bd_set[c2] = np.mean(cwise_fg), np.mean(cwise_bd) + ws1_msa_set[c2], ws2_msa_set[c2] = np.mean(cwise_ws1_msa), np.mean(cwise_ws2_msa) + ws1_sa50_set[c2], ws2_sa50_set[c2] = np.mean(cwise_ws1_sa50), np.mean(cwise_ws2_sa50) + + if with_affinities: + mws_msa_list.append(pd.DataFrame.from_dict([mws_msa_set])) + mws_sa50_list.append(pd.DataFrame.from_dict([mws_sa50_set])) + else: + fg_list.append(pd.DataFrame.from_dict([fg_set])) + bd_list.append(pd.DataFrame.from_dict([bd_set])) + ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) + ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) + ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) + ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) + + if with_affinities: + df_mws_msa, df_mws_sa50 = pd.concat(mws_msa_list, ignore_index=True), pd.concat(mws_sa50_list, ignore_index=True) + df_mws_msa.to_csv(os.path.join(csv_save_dir, "mws-affs-msa.csv")) + df_mws_sa50.to_csv(os.path.join(csv_save_dir, "mws-affs-sa50.csv")) + else: + df_fg, df_bd = pd.concat(fg_list, ignore_index=True), pd.concat(bd_list, ignore_index=True) + df_ws1_msa, df_ws2_msa = pd.concat(ws1_msa_list, ignore_index=True), pd.concat(ws2_msa_list, ignore_index=True) + df_ws1_sa50, df_ws2_sa50 = pd.concat(ws1_sa50_list, ignore_index=True), pd.concat(ws2_sa50_list, ignore_index=True) + df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) + df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) + df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) + df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) + df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) + df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) def main(args): @@ -200,7 +192,6 @@ def main(args): args.save_dir, "affinities" if args.with_affinities else "boundaries", f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" ) - print("Predictions are saved in", root_save_dir) if args.predict: print("2d UNETR inference on LIVECell dataset") @@ -208,12 +199,17 @@ def main(args): input_path=args.input, device=device, model=model, cell_types=common.CELL_TYPES, root_save_dir=root_save_dir, save_root=save_root, with_affinities=args.with_affinities ) + print("Predictions are saved in", root_save_dir) if args.evaluate: print("2d UNETR evaluation on LIVECell dataset") + tmp_csv_name = f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + csv_save_dir = os.path.join("results", "affinities" if args.with_affinities else "boundaries", tmp_csv_name) + os.makedirs(csv_save_dir, exist_ok=True) + do_unetr_evaluation( input_path=args.input, cell_types=common.CELL_TYPES, root_save_dir=root_save_dir, - sam_initialization=args.do_sam_ini, source_choice=args.source_choice + csv_save_dir=csv_save_dir, with_affinities=args.with_affinities ) diff --git a/experiments/vision-transformer/unetr/submit_training.py b/experiments/vision-transformer/unetr/submit_training.py index b7042f8a..84ea6b56 100644 --- a/experiments/vision-transformer/unetr/submit_training.py +++ b/experiments/vision-transformer/unetr/submit_training.py @@ -7,7 +7,7 @@ from datetime import datetime -def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): +def write_batch_script(out_path, ini_sam=True, source_choice="torch-em", with_affinity=True): """ inputs: source_choice:str - [torch_em / monai] source of the unetr model coming from @@ -22,7 +22,7 @@ def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): #SBATCH --ntasks=1 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -c 8 +#SBATCH -c 16 #SBATCH -A gzz0001 """ if ini_sam: @@ -42,9 +42,11 @@ def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): add_source_choice = f"--source_choice {source_choice} " batch_script += add_ctype + add_source_choice - add_sam_ini = "--do_sam_ini " if ini_sam: - batch_script += add_sam_ini + batch_script += "--do_sam_ini " + + if with_affinity: + batch_script += "--with_affinities " _op = out_path[:-3] + f"_{i}.sh" @@ -58,7 +60,7 @@ def submit_slurm(): tmp_folder = os.path.expanduser("./gpu_jobs") os.makedirs(tmp_folder, exist_ok=True) - script_name = "unetr-monai" + script_name = "unetr_" dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") tmp_name = script_name + dt From da1edb5469ace7bb28009298c539088418a5392f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 6 Nov 2023 23:07:25 +0100 Subject: [PATCH 7/9] Update min_size for watershed --- experiments/vision-transformer/unetr/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experiments/vision-transformer/unetr/common.py b/experiments/vision-transformer/unetr/common.py index 50640390..31af2a4f 100644 --- a/experiments/vision-transformer/unetr/common.py +++ b/experiments/vision-transformer/unetr/common.py @@ -125,13 +125,13 @@ def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affini if with_affinities: # inference using affinities outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) - mws = segmentation.mutex_watershed_segmentation(fg, affs, common.OFFSETS, 100) + mws = segmentation.mutex_watershed_segmentation(fg, affs, common.OFFSETS, 250) else: # inference using foreground-boundary inputs - for the unetr training outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) fg, bd = outputs[0, :, :], outputs[1, :, :] - ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) - ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) + ws1 = segmentation.watershed_from_components(bd, fg, min_size=250) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=250, min_distance=1) fname = Path(img_path).stem with h5py.File(os.path.join(root_save_dir, f"src-{ctype}", f"{fname}.h5"), "a") as f: From a7a7a4c5b16edcd880aad2931d75ca33fda8496e Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 8 Nov 2023 00:44:08 +0100 Subject: [PATCH 8/9] Refactor UNETR Training (#4) Refactoring the unetr training and inference scripts to adapt to training for all cell-types --- .../unetr/{ => cremi}/cremi_unetr.py | 0 .../unetr/{ => livecell}/common.py | 66 ++++--- .../unetr/livecell/livecell_all_unetr.py | 173 ++++++++++++++++++ .../livecell_cell_types_unetr.py} | 0 .../unetr/{ => livecell}/submit_training.py | 0 5 files changed, 210 insertions(+), 29 deletions(-) rename experiments/vision-transformer/unetr/{ => cremi}/cremi_unetr.py (100%) rename experiments/vision-transformer/unetr/{ => livecell}/common.py (84%) create mode 100644 experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py rename experiments/vision-transformer/unetr/{livecell_unetr.py => livecell/livecell_cell_types_unetr.py} (100%) rename experiments/vision-transformer/unetr/{ => livecell}/submit_training.py (100%) diff --git a/experiments/vision-transformer/unetr/cremi_unetr.py b/experiments/vision-transformer/unetr/cremi/cremi_unetr.py similarity index 100% rename from experiments/vision-transformer/unetr/cremi_unetr.py rename to experiments/vision-transformer/unetr/cremi/cremi_unetr.py diff --git a/experiments/vision-transformer/unetr/common.py b/experiments/vision-transformer/unetr/livecell/common.py similarity index 84% rename from experiments/vision-transformer/unetr/common.py rename to experiments/vision-transformer/unetr/livecell/common.py index 31af2a4f..b3b3828d 100644 --- a/experiments/vision-transformer/unetr/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -2,8 +2,8 @@ import h5py import argparse import numpy as np -from typing import Tuple from pathlib import Path +from typing import Tuple, Optional import imageio.v3 as imageio from skimage.segmentation import find_boundaries @@ -15,8 +15,6 @@ from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask from torch_em.util.prediction import predict_with_halo, predict_with_padding -import common - OFFSETS = [ [-1, 0], [0, -1], @@ -45,32 +43,37 @@ def _get_output_channels(with_affinities): def get_my_livecell_loaders( input_path: str, patch_shape: Tuple[int, int], - cell_types: str, + cell_types: Optional[str] = None, with_affinities: bool = False ): """Returns the LIVECell training and validation dataloaders """ - if with_affinities: + train_loader = get_livecell_loader( + path=input_path, + split="train", + patch_shape=patch_shape, + batch_size=2, + download=True, + num_workers=16, + cell_types=None if cell_types is None else [cell_types], # this returns dataloaders with affinity channels and foreground-background channels - train_loader = get_livecell_loader( - path=input_path, split="train", patch_shape=patch_shape, batch_size=2, - cell_types=[cell_types], download=True, offsets=OFFSETS, num_workers=16 - ) - val_loader = get_livecell_loader( - path=input_path, split="val", patch_shape=patch_shape, batch_size=1, - cell_types=[cell_types], download=True, offsets=OFFSETS, num_workers=16 - ) - - else: + offsets=OFFSETS if with_affinities else None, # this returns dataloaders with foreground and boundary channels - train_loader = get_livecell_loader( - path=input_path, split="train", patch_shape=patch_shape, batch_size=2, - cell_types=[cell_types], download=True, boundaries=True, num_workers=16 - ) - val_loader = get_livecell_loader( - path=input_path, split="val", patch_shape=patch_shape, batch_size=1, - cell_types=[cell_types], download=True, boundaries=True, num_workers=16 - ) + boundaries=False if with_affinities else True + ) + val_loader = get_livecell_loader( + path=input_path, + split="val", + patch_shape=patch_shape, + batch_size=1, + download=True, + num_workers=16, + cell_types=None if cell_types is None else [cell_types], + # this returns dataloaders with affinity channels and foreground-background channels + offsets=OFFSETS if with_affinities else None, + # this returns dataloaders with foreground and boundary channels + boundaries=False if with_affinities else True + ) return train_loader, val_loader @@ -79,6 +82,11 @@ def get_my_livecell_loaders( # UNETR MODEL(S) FROM MONAI AND torch_em # +MODELS = { + "vit_b": "/scratch/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth", + "vit_h": "/scratch/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth" +} + def get_unetr_model( model_name: str, @@ -94,7 +102,7 @@ def get_unetr_model( from torch_em import model as torch_em_models model = torch_em_models.UNETR( encoder=model_name, out_channels=output_channels, - encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None + encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None ) elif source_choice == "monai": @@ -118,14 +126,14 @@ def get_unetr_model( # LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities # -def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affinities): +def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, ctype=None): input_ = imageio.imread(img_path) input_ = standardize(input_) if with_affinities: # inference using affinities outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) - mws = segmentation.mutex_watershed_segmentation(fg, affs, common.OFFSETS, 250) + mws = segmentation.mutex_watershed_segmentation(fg, affs, OFFSETS, 250) else: # inference using foreground-boundary inputs - for the unetr training outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) @@ -134,7 +142,8 @@ def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affini ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=250, min_distance=1) fname = Path(img_path).stem - with h5py.File(os.path.join(root_save_dir, f"src-{ctype}", f"{fname}.h5"), "a") as f: + save_path = os.path.join(root_save_dir, "src-all" if ctype is None else f"src-{ctype}", f"{fname}.h5") + with h5py.File(save_path, "a") as f: ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) ds[:] = fg if with_affinities: @@ -156,8 +165,7 @@ def predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affini # def evaluate_for_unetr(gt_path, _save_dir, with_affinities): - # FIXME: fname = Path(img_path).stem - fname = os.path.split(gt_path)[-1] + fname = Path(gt_path).stem gt = imageio.imread(gt_path) output_file = os.path.join(_save_dir, f"{fname}.h5") diff --git a/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py new file mode 100644 index 00000000..2cad35a9 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm + +import torch +import torch_em + +import common + + +def do_unetr_training( + train_loader, + val_loader, + model, + device: torch.device, + iterations: int, + loss, + save_root: str +): + print("Run training for all cell types") + trainer = torch_em.default_segmentation_trainer( + name="livecell-all", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-5, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss + ) + trainer.fit(iterations) + + +def do_unetr_inference( + input_path: str, + device: torch.device, + model, + root_save_dir: str, + save_root: str, + with_affinities: bool +): + test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") + model_ckpt = os.path.join(save_root, "checkpoints", "livecell-all", "best.pt") + assert os.path.exists(model_ckpt), model_ckpt + + model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + # creating the respective directories for saving the outputs + os.makedirs(os.path.join(root_save_dir, "src-all"), exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): + common.predict_for_unetr(img_path, model, root_save_dir, device, with_affinities) + + +def do_unetr_evaluation( + input_path: str, + root_save_dir: str, + csv_save_dir: str, + with_affinities: bool +): + _save_dir = os.path.join(root_save_dir, "src-all") + assert os.path.exists(_save_dir), _save_dir + + gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", "*", "*") + + mws_msa_list, mws_sa50_list = [], [] + fg_list, bd_list, msa1_list, sa501_list, msa2_list, sa502_list = [], [], [], [], [], [] + for gt_path in tqdm(glob(gt_dir)): + all_metrics = common.evaluate_for_unetr(gt_path, _save_dir, with_affinities) + if with_affinities: + msa, sa50 = all_metrics + mws_msa_list.append(msa) + mws_sa50_list.append(sa50) + else: + fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 = all_metrics + fg_list.append(fg_dice) + bd_list.append(bd_dice) + msa1_list.append(msa1) + sa501_list.append(sa_acc1[0]) + msa2_list.append(msa2) + sa502_list.append(sa_acc2[0]) + + if with_affinities: + res_dict = { + "LIVECell": "Metrics", + "mSA": np.mean(mws_msa_list), + "SA50": np.mean(mws_sa50_list) + } + else: + res_dict = { + "LIVECell": "Metrics", + "ws1_mSA": np.mean(msa1_list), + "ws1_SA50": np.mean(sa501_list), + "ws2_mSA": np.mean(msa2_list), + "ws2_SA50": np.mean(sa502_list) + } + + df = pd.DataFrame.from_dict([res_dict]) + df.to_csv(os.path.join(csv_save_dir, "livecell.csv")) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # directory folder to save different parts of the scheme + dir_structure = os.path.join( + args.model_name, "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) + + # get the desired loss function for training + loss = common.get_loss_function( + with_affinities=args.with_affinities # takes care of calling the loss for training with affinities + ) + + # get the model for the training and inference on livecell dataset + model = common.get_unetr_model( + model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities) + ) + model.to(device) + + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root + + if args.train: + print("2d UNETR training on LIVECell dataset") + # get the desired livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, + with_affinities=args.with_affinities # this takes care of getting the loaders with affinities + ) + do_unetr_training( + train_loader=train_loader, val_loader=val_loader, model=model, + device=device, save_root=save_root, iterations=args.iterations, loss=loss + ) + + # determines the directory where the predictions will be saved + root_save_dir = os.path.join(args.save_dir, dir_structure) + + if args.predict: + print("2d UNETR inference on LIVECell dataset") + do_unetr_inference( + input_path=args.input, device=device, model=model, save_root=save_root, + root_save_dir=root_save_dir, with_affinities=args.with_affinities + ) + print("Predictions are saved in", root_save_dir) + + if args.evaluate: + print("2d UNETR evaluation on LIVECell dataset") + csv_save_dir = os.path.join("results", dir_structure) + os.makedirs(csv_save_dir, exist_ok=True) + + do_unetr_evaluation( + input_path=args.input, root_save_dir=root_save_dir, + csv_save_dir=csv_save_dir, with_affinities=args.with_affinities + ) + + +if __name__ == "__main__": + parser = common.get_parser() + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell/livecell_cell_types_unetr.py similarity index 100% rename from experiments/vision-transformer/unetr/livecell_unetr.py rename to experiments/vision-transformer/unetr/livecell/livecell_cell_types_unetr.py diff --git a/experiments/vision-transformer/unetr/submit_training.py b/experiments/vision-transformer/unetr/livecell/submit_training.py similarity index 100% rename from experiments/vision-transformer/unetr/submit_training.py rename to experiments/vision-transformer/unetr/livecell/submit_training.py From 185e4a3d30a7e77aab8e59293515455870e168c4 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 8 Nov 2023 00:57:51 +0100 Subject: [PATCH 9/9] Update watershed min_size defaults for size filtering --- experiments/vision-transformer/unetr/livecell/common.py | 6 +++--- torch_em/util/segmentation.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py index b3b3828d..33bbbc52 100644 --- a/experiments/vision-transformer/unetr/livecell/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -133,13 +133,13 @@ def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, c if with_affinities: # inference using affinities outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) - mws = segmentation.mutex_watershed_segmentation(fg, affs, OFFSETS, 250) + mws = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) else: # inference using foreground-boundary inputs - for the unetr training outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) fg, bd = outputs[0, :, :], outputs[1, :, :] - ws1 = segmentation.watershed_from_components(bd, fg, min_size=250) - ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=250, min_distance=1) + ws1 = segmentation.watershed_from_components(bd, fg) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_distance=1) fname = Path(img_path).stem save_path = os.path.join(root_save_dir, "src-all" if ctype is None else f"src-{ctype}", f"{fname}.h5") diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index 8767589f..2327a60d 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -35,7 +35,7 @@ def size_filter(seg, min_size, hmap=None, with_background=False): return seg -def mutex_watershed_segmentation(foreground, affinities, offsets, min_size, threshold=0.5): +def mutex_watershed_segmentation(foreground, affinities, offsets, min_size=250, threshold=0.5): """Computes the mutex watershed segmentation using the affinity maps for respective pixel offsets Arguments: @@ -60,7 +60,7 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): return seg.astype("uint64") -def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5): +def watershed_from_components(boundaries, foreground, min_size=250, threshold1=0.5, threshold2=0.5): """The default approach: - Subtract the boundaries from the foreground to separate touching objects. - Use the connected components of this as seeds. @@ -84,7 +84,7 @@ def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, return seg -def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5): +def watershed_from_maxima(boundaries, foreground, min_distance, min_size=250, sigma=1.0, threshold1=0.5): """Find objects via seeded watershed starting from the maxima of the distance transform instead. This has the advantage that objects can be better separated, but it may over-segment if the objects have complex shapes.