diff --git a/experiments/vision-transformer/unetr/cremi_unetr.py b/experiments/vision-transformer/unetr/cremi/cremi_unetr.py similarity index 100% rename from experiments/vision-transformer/unetr/cremi_unetr.py rename to experiments/vision-transformer/unetr/cremi/cremi_unetr.py diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py new file mode 100644 index 00000000..33bbbc52 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -0,0 +1,275 @@ +import os +import h5py +import argparse +import numpy as np +from pathlib import Path +from typing import Tuple, Optional + +import imageio.v3 as imageio +from skimage.segmentation import find_boundaries +from elf.evaluation import dice_score, mean_segmentation_accuracy + +from torch_em.util import segmentation +from torch_em.transform.raw import standardize +from torch_em.data.datasets import get_livecell_loader +from torch_em.loss import DiceLoss, LossWrapper, ApplyAndRemoveMask +from torch_em.util.prediction import predict_with_halo, predict_with_padding + + +OFFSETS = [ + [-1, 0], [0, -1], + [-3, 0], [0, -3], + [-9, 0], [0, -9], + [-27, 0], [0, -27] +] + + +CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + + +# +# LIVECELL DATALOADERS +# + + +def _get_output_channels(with_affinities): + if with_affinities: + n_out = len(OFFSETS) + 1 + else: + n_out = 2 + return n_out + + +def get_my_livecell_loaders( + input_path: str, + patch_shape: Tuple[int, int], + cell_types: Optional[str] = None, + with_affinities: bool = False +): + """Returns the LIVECell training and validation dataloaders + """ + train_loader = get_livecell_loader( + path=input_path, + split="train", + patch_shape=patch_shape, + batch_size=2, + download=True, + num_workers=16, + cell_types=None if cell_types is None else [cell_types], + # this returns dataloaders with affinity channels and foreground-background channels + offsets=OFFSETS if with_affinities else None, + # this returns dataloaders with foreground and boundary channels + boundaries=False if with_affinities else True + ) + val_loader = get_livecell_loader( + path=input_path, + split="val", + patch_shape=patch_shape, + batch_size=1, + download=True, + num_workers=16, + cell_types=None if cell_types is None else [cell_types], + # this returns dataloaders with affinity channels and foreground-background channels + offsets=OFFSETS if with_affinities else None, + # this returns dataloaders with foreground and boundary channels + boundaries=False if with_affinities else True + ) + + return train_loader, val_loader + + +# +# UNETR MODEL(S) FROM MONAI AND torch_em +# + +MODELS = { + "vit_b": "/scratch/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth", + "vit_h": "/scratch/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth" +} + + +def get_unetr_model( + model_name: str, + source_choice: str, + patch_shape: Tuple[int, int], + sam_initialization: bool, + output_channels: int +): + """Returns the expected UNETR model + """ + if source_choice == "torch-em": + # this returns the unetr model whihc uses the vision transformer from segment anything + from torch_em import model as torch_em_models + model = torch_em_models.UNETR( + encoder=model_name, out_channels=output_channels, + encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None + ) + + elif source_choice == "monai": + # this returns the unetr model from monai + from monai.networks import nets as monai_models + model = monai_models.unetr.UNETR( + in_channels=1, + out_channels=output_channels, + img_size=patch_shape, + spatial_dims=2 + ) + model.out_channels = 2 # type: ignore + + else: + raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") + + return model + + +# +# LIVECELL UNETR INFERENCE - foreground boundary / foreground affinities +# + +def predict_for_unetr(img_path, model, root_save_dir, device, with_affinities, ctype=None): + input_ = imageio.imread(img_path) + input_ = standardize(input_) + + if with_affinities: # inference using affinities + outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) + fg, affs = np.array(outputs[0, 0]), np.array(outputs[0, 1:]) + mws = segmentation.mutex_watershed_segmentation(fg, affs, offsets=OFFSETS) + + else: # inference using foreground-boundary inputs - for the unetr training + outputs = predict_with_halo(input_, model, [device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True) + fg, bd = outputs[0, :, :], outputs[1, :, :] + ws1 = segmentation.watershed_from_components(bd, fg) + ws2 = segmentation.watershed_from_maxima(bd, fg, min_distance=1) + + fname = Path(img_path).stem + save_path = os.path.join(root_save_dir, "src-all" if ctype is None else f"src-{ctype}", f"{fname}.h5") + with h5py.File(save_path, "a") as f: + ds = f.require_dataset("foreground", shape=fg.shape, compression="gzip", dtype=fg.dtype) + ds[:] = fg + if with_affinities: + ds = f.require_dataset("affinities", shape=affs.shape, compression="gzip", dtype=affs.dtype) + ds[:] = affs + ds = f.require_dataset("segmentation", shape=mws.shape, compression="gzip", dtype=mws.dtype) + ds[:] = mws + else: + ds = f.require_dataset("boundary", shape=bd.shape, compression="gzip", dtype=bd.dtype) + ds[:] = bd + ds = f.require_dataset("watershed1", shape=ws1.shape, compression="gzip", dtype=ws1.dtype) + ds[:] = ws1 + ds = f.require_dataset("watershed2", shape=ws2.shape, compression="gzip", dtype=ws2.dtype) + ds[:] = ws2 + + +# +# LIVECELL UNETR EVALUATION - foreground boundary / foreground affinities +# + +def evaluate_for_unetr(gt_path, _save_dir, with_affinities): + fname = Path(gt_path).stem + gt = imageio.imread(gt_path) + + output_file = os.path.join(_save_dir, f"{fname}.h5") + with h5py.File(output_file, "r") as f: + if with_affinities: + mws = f["segmentation"][:] + else: + fg = f["foreground"][:] + bd = f["boundary"][:] + ws1 = f["watershed1"][:] + ws2 = f["watershed2"][:] + + if with_affinities: + mws_msa, mws_sa_acc = mean_segmentation_accuracy(mws, gt, return_accuracies=True) + return mws_msa, mws_sa_acc[0] + + else: + true_bd = find_boundaries(gt) + + # Compare the foreground prediction to the ground-truth. + # Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to + # foreground in the dice function, since we have a comparision > 0 in there, and everything in the + # binary prediction evaluates to true. + # For the GT we can set the threshold to 0, because this will map to the correct binary mask. + fg_dice = dice_score(fg, gt, threshold_gt=0, threshold_seg=None) + + # Compare the background prediction to the ground-truth. + # Here, we don't need any thresholds: for the prediction the same holds as before. + # For the ground-truth we have already a binary label, so we don't need to threshold it again. + bd_dice = dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None) + + msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore + msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore + + return fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 + + +# +# miscellanous utilities +# + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train", action='store_true', help="Enables UNETR training on LiveCell dataset" + ) + + parser.add_argument( + "--predict", action='store_true', help="Enables UNETR prediction on LiveCell dataset" + ) + + parser.add_argument( + "--evaluate", action='store_true', help="Enables UNETR evaluation on LiveCell dataset" + ) + + parser.add_argument( + "--source_choice", type=str, default="torch-em", + help="The source where the model comes from, i.e. either torch-em / monai" + ) + + parser.add_argument( + "-m", "--model_name", type=str, default="vit_b", help="Name of the ViT to use as the encoder in UNETR" + ) + + parser.add_argument( + "--do_sam_ini", action='store_true', help="Enables initializing UNETR with SAM's ViT weights" + ) + + parser.add_argument( + "-c", "--cell_type", type=str, default=None, help="Choice of cell-type for doing the training" + ) + + parser.add_argument( + "-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell", + help="Path where the dataset already exists/will be downloaded by the dataloader" + ) + + parser.add_argument( + "-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/", + help="Path where checkpoints and logs will be saved" + ) + + parser.add_argument( + "--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", + help="Path to save predictions from UNETR model" + ) + + parser.add_argument( + "--with_affinities", action="store_true", + help="Trains the UNETR model with affinities" + ) + + parser.add_argument("--iterations", type=int, default=100000) + return parser + + +def get_loss_function(with_affinities=True): + if with_affinities: + loss = LossWrapper( + loss=DiceLoss(), + transform=ApplyAndRemoveMask(masking_method="multiply") + ) + else: + loss = DiceLoss() + + return loss diff --git a/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py new file mode 100644 index 00000000..2cad35a9 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm + +import torch +import torch_em + +import common + + +def do_unetr_training( + train_loader, + val_loader, + model, + device: torch.device, + iterations: int, + loss, + save_root: str +): + print("Run training for all cell types") + trainer = torch_em.default_segmentation_trainer( + name="livecell-all", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-5, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss + ) + trainer.fit(iterations) + + +def do_unetr_inference( + input_path: str, + device: torch.device, + model, + root_save_dir: str, + save_root: str, + with_affinities: bool +): + test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") + model_ckpt = os.path.join(save_root, "checkpoints", "livecell-all", "best.pt") + assert os.path.exists(model_ckpt), model_ckpt + + model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + # creating the respective directories for saving the outputs + os.makedirs(os.path.join(root_save_dir, "src-all"), exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): + common.predict_for_unetr(img_path, model, root_save_dir, device, with_affinities) + + +def do_unetr_evaluation( + input_path: str, + root_save_dir: str, + csv_save_dir: str, + with_affinities: bool +): + _save_dir = os.path.join(root_save_dir, "src-all") + assert os.path.exists(_save_dir), _save_dir + + gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", "*", "*") + + mws_msa_list, mws_sa50_list = [], [] + fg_list, bd_list, msa1_list, sa501_list, msa2_list, sa502_list = [], [], [], [], [], [] + for gt_path in tqdm(glob(gt_dir)): + all_metrics = common.evaluate_for_unetr(gt_path, _save_dir, with_affinities) + if with_affinities: + msa, sa50 = all_metrics + mws_msa_list.append(msa) + mws_sa50_list.append(sa50) + else: + fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 = all_metrics + fg_list.append(fg_dice) + bd_list.append(bd_dice) + msa1_list.append(msa1) + sa501_list.append(sa_acc1[0]) + msa2_list.append(msa2) + sa502_list.append(sa_acc2[0]) + + if with_affinities: + res_dict = { + "LIVECell": "Metrics", + "mSA": np.mean(mws_msa_list), + "SA50": np.mean(mws_sa50_list) + } + else: + res_dict = { + "LIVECell": "Metrics", + "ws1_mSA": np.mean(msa1_list), + "ws1_SA50": np.mean(sa501_list), + "ws2_mSA": np.mean(msa2_list), + "ws2_SA50": np.mean(sa502_list) + } + + df = pd.DataFrame.from_dict([res_dict]) + df.to_csv(os.path.join(csv_save_dir, "livecell.csv")) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # directory folder to save different parts of the scheme + dir_structure = os.path.join( + args.model_name, "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) + + # get the desired loss function for training + loss = common.get_loss_function( + with_affinities=args.with_affinities # takes care of calling the loss for training with affinities + ) + + # get the model for the training and inference on livecell dataset + model = common.get_unetr_model( + model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities) + ) + model.to(device) + + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root + + if args.train: + print("2d UNETR training on LIVECell dataset") + # get the desired livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, + with_affinities=args.with_affinities # this takes care of getting the loaders with affinities + ) + do_unetr_training( + train_loader=train_loader, val_loader=val_loader, model=model, + device=device, save_root=save_root, iterations=args.iterations, loss=loss + ) + + # determines the directory where the predictions will be saved + root_save_dir = os.path.join(args.save_dir, dir_structure) + + if args.predict: + print("2d UNETR inference on LIVECell dataset") + do_unetr_inference( + input_path=args.input, device=device, model=model, save_root=save_root, + root_save_dir=root_save_dir, with_affinities=args.with_affinities + ) + print("Predictions are saved in", root_save_dir) + + if args.evaluate: + print("2d UNETR evaluation on LIVECell dataset") + csv_save_dir = os.path.join("results", dir_structure) + os.makedirs(csv_save_dir, exist_ok=True) + + do_unetr_evaluation( + input_path=args.input, root_save_dir=root_save_dir, + csv_save_dir=csv_save_dir, with_affinities=args.with_affinities + ) + + +if __name__ == "__main__": + parser = common.get_parser() + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/livecell/livecell_cell_types_unetr.py b/experiments/vision-transformer/unetr/livecell/livecell_cell_types_unetr.py new file mode 100644 index 00000000..32fabdb6 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/livecell_cell_types_unetr.py @@ -0,0 +1,219 @@ +import os +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm +from typing import List + +import torch +import torch_em + +import common + + +def do_unetr_training( + train_loader, + val_loader, + model, + cell_types: List[str], + device: torch.device, + iterations: int, + loss, + save_root: str +): + print("Run training for cell types:", cell_types) + trainer = torch_em.default_segmentation_trainer( + name=f"livecell-{cell_types}", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-5, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss + ) + trainer.fit(iterations) + + +def do_unetr_inference( + input_path: str, + device: torch.device, + model, + cell_types: List[str], + root_save_dir: str, + save_root: str, + with_affinities: bool +): + test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") + for ctype in cell_types: + model_ckpt = os.path.join(save_root, "checkpoints", f"livecell-{ctype}", "best.pt") + assert os.path.exists(model_ckpt), model_ckpt + + model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + # creating the respective directories for saving the outputs + os.makedirs(os.path.join(root_save_dir, f"src-{ctype}"), exist_ok=True) + + with torch.no_grad(): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): + common.predict_for_unetr(img_path, model, root_save_dir, ctype, device, with_affinities) + + +def do_unetr_evaluation( + input_path: str, + cell_types: List[str], + root_save_dir: str, + csv_save_dir: str, + with_affinities: bool +): + # list for foreground-boundary evaluations + fg_list, bd_list = [], [] + ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], [] + + # lists for affinities evaluation + mws_msa_list, mws_sa50_list = [], [] + + for c1 in cell_types: + # we check whether we have predictions from a particular cell-type + _save_dir = os.path.join(root_save_dir, f"src-{c1}") + if not os.path.exists(_save_dir): + print("Skipping", _save_dir) + continue + + # dict for foreground-boundary evaluations + fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} + (ws1_msa_set, ws2_msa_set, + ws1_sa50_set, ws2_sa50_set) = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} + + # dict for affinities evaluation + mws_msa_set, mws_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} + + for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): + gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*") + + # cell-wise evaluation list for foreground-boundary evaluations + cwise_fg, cwise_bd = [], [] + cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], [] + + # cell-wise evaluation list for affinities evaluation + cwise_mws_msa, cwise_mws_sa50 = [], [] + + for gt_path in glob(gt_dir): + all_metrics = common.evaluate_for_unetr(gt_path, _save_dir, with_affinities) + if with_affinities: + mws_msa, mws_sa50 = all_metrics + cwise_mws_msa.append(mws_msa) + cwise_mws_sa50.append(mws_sa50) + else: + fg_dice, bd_dice, msa1, sa_acc1, msa2, sa_acc2 = all_metrics + cwise_fg.append(fg_dice) + cwise_bd.append(bd_dice) + cwise_ws1_msa.append(msa1) + cwise_ws2_msa.append(msa2) + cwise_ws1_sa50.append(sa_acc1[0]) + cwise_ws2_sa50.append(sa_acc2[0]) + + if with_affinities: + mws_msa_set[c2], mws_sa50_set[c2] = np.mean(cwise_mws_msa), np.mean(cwise_mws_sa50) + else: + fg_set[c2], bd_set[c2] = np.mean(cwise_fg), np.mean(cwise_bd) + ws1_msa_set[c2], ws2_msa_set[c2] = np.mean(cwise_ws1_msa), np.mean(cwise_ws2_msa) + ws1_sa50_set[c2], ws2_sa50_set[c2] = np.mean(cwise_ws1_sa50), np.mean(cwise_ws2_sa50) + + if with_affinities: + mws_msa_list.append(pd.DataFrame.from_dict([mws_msa_set])) + mws_sa50_list.append(pd.DataFrame.from_dict([mws_sa50_set])) + else: + fg_list.append(pd.DataFrame.from_dict([fg_set])) + bd_list.append(pd.DataFrame.from_dict([bd_set])) + ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) + ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) + ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) + ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) + + if with_affinities: + df_mws_msa, df_mws_sa50 = pd.concat(mws_msa_list, ignore_index=True), pd.concat(mws_sa50_list, ignore_index=True) + df_mws_msa.to_csv(os.path.join(csv_save_dir, "mws-affs-msa.csv")) + df_mws_sa50.to_csv(os.path.join(csv_save_dir, "mws-affs-sa50.csv")) + else: + df_fg, df_bd = pd.concat(fg_list, ignore_index=True), pd.concat(bd_list, ignore_index=True) + df_ws1_msa, df_ws2_msa = pd.concat(ws1_msa_list, ignore_index=True), pd.concat(ws2_msa_list, ignore_index=True) + df_ws1_sa50, df_ws2_sa50 = pd.concat(ws1_sa50_list, ignore_index=True), pd.concat(ws2_sa50_list, ignore_index=True) + df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) + df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) + df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) + df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) + df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) + df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # get the desired loss function for training + loss = common.get_loss_function( + with_affinities=args.with_affinities # takes care of calling the loss for training with affinities + ) + + # get the model for the training and inference on livecell dataset + model = common.get_unetr_model( + model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities) + ) + model.to(device) + + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join( + args.save_root, "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) if args.save_root is not None else args.save_root + + if args.train: + print("2d UNETR training on LIVECell dataset") + # get the desired livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, + with_affinities=args.with_affinities # this takes care of getting the loaders with affinities + ) + do_unetr_training( + train_loader=train_loader, val_loader=val_loader, model=model, cell_types=args.cell_type, + device=device, save_root=save_root, iterations=args.iterations, loss=loss + ) + + # determines the directory where the predictions will be saved + root_save_dir = os.path.join( + args.save_dir, "affinities" if args.with_affinities else "boundaries", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) + + if args.predict: + print("2d UNETR inference on LIVECell dataset") + do_unetr_inference( + input_path=args.input, device=device, model=model, cell_types=common.CELL_TYPES, + root_save_dir=root_save_dir, save_root=save_root, with_affinities=args.with_affinities + ) + print("Predictions are saved in", root_save_dir) + + if args.evaluate: + print("2d UNETR evaluation on LIVECell dataset") + tmp_csv_name = f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + csv_save_dir = os.path.join("results", "affinities" if args.with_affinities else "boundaries", tmp_csv_name) + os.makedirs(csv_save_dir, exist_ok=True) + + do_unetr_evaluation( + input_path=args.input, cell_types=common.CELL_TYPES, root_save_dir=root_save_dir, + csv_save_dir=csv_save_dir, with_affinities=args.with_affinities + ) + + +if __name__ == "__main__": + parser = common.get_parser() + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/submit_training.py b/experiments/vision-transformer/unetr/livecell/submit_training.py similarity index 91% rename from experiments/vision-transformer/unetr/submit_training.py rename to experiments/vision-transformer/unetr/livecell/submit_training.py index b7042f8a..84ea6b56 100644 --- a/experiments/vision-transformer/unetr/submit_training.py +++ b/experiments/vision-transformer/unetr/livecell/submit_training.py @@ -7,7 +7,7 @@ from datetime import datetime -def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): +def write_batch_script(out_path, ini_sam=True, source_choice="torch-em", with_affinity=True): """ inputs: source_choice:str - [torch_em / monai] source of the unetr model coming from @@ -22,7 +22,7 @@ def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): #SBATCH --ntasks=1 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -c 8 +#SBATCH -c 16 #SBATCH -A gzz0001 """ if ini_sam: @@ -42,9 +42,11 @@ def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): add_source_choice = f"--source_choice {source_choice} " batch_script += add_ctype + add_source_choice - add_sam_ini = "--do_sam_ini " if ini_sam: - batch_script += add_sam_ini + batch_script += "--do_sam_ini " + + if with_affinity: + batch_script += "--with_affinities " _op = out_path[:-3] + f"_{i}.sh" @@ -58,7 +60,7 @@ def submit_slurm(): tmp_folder = os.path.expanduser("./gpu_jobs") os.makedirs(tmp_folder, exist_ok=True) - script_name = "unetr-monai" + script_name = "unetr_" dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") tmp_name = script_name + dt diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py deleted file mode 100644 index 0ff9cde2..00000000 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ /dev/null @@ -1,347 +0,0 @@ -import os -import argparse -import numpy as np -import pandas as pd -from glob import glob -from tqdm import tqdm -from typing import Tuple, List - -import imageio.v2 as imageio -from skimage.segmentation import find_boundaries -from elf.evaluation import dice_score, mean_segmentation_accuracy - -import torch -import torch_em -from torch_em.util import segmentation -from torch_em.transform.raw import standardize -from torch_em.data.datasets import get_livecell_loader -from torch_em.util.prediction import predict_with_halo - - -def get_unetr_model( - model_name: str, - source_choice: str, - patch_shape: Tuple[int, int], - sam_initialization: bool, - output_channels: int -): - """Returns the expected UNETR model - """ - if source_choice == "torch-em": - from torch_em import model as torch_em_models - model = torch_em_models.UNETR( - encoder=model_name, out_channels=output_channels, - encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None - ) - - elif source_choice == "monai": - from monai.networks import nets as monai_models - model = monai_models.unetr.UNETR( - in_channels=1, - out_channels=output_channels, - img_size=patch_shape, - spatial_dims=2 - ) - model.out_channels = 2 # type: ignore - - else: - raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") - - return model - - -def do_unetr_training( - input_path: str, - model, - cell_types: List[str], - patch_shape: Tuple[int, int], - device: torch.device, - save_root: str, - iterations: int, - sam_initialization: bool, - source_choice: str -): - print("Run training for cell types:", cell_types) - train_loader = get_livecell_loader( - path=input_path, - split="train", - patch_shape=patch_shape, - batch_size=2, - cell_types=[cell_types], - download=True, - boundaries=True, - num_workers=8 - ) - - val_loader = get_livecell_loader( - path=input_path, - split="val", - patch_shape=patch_shape, - batch_size=1, - cell_types=[cell_types], - download=True, - boundaries=True, - num_workers=8 - ) - - _save_root = os.path.join( - save_root, - f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - ) if save_root is not None else save_root - - trainer = torch_em.default_segmentation_trainer( - name=f"livecell-{cell_types}", - model=model, - train_loader=train_loader, - val_loader=val_loader, - learning_rate=1e-5, - device=device, - mixed_precision=True, - log_image_interval=50, - compile_model=False, - save_root=_save_root - ) - - trainer.fit(iterations) - - -def do_unetr_inference( - input_path: str, - device: torch.device, - model, - cell_types: List[str], - root_save_dir: str, - sam_initialization: bool, - save_root: str, - source_choice: str -): - for ctype in cell_types: - test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") - - model_ckpt = os.path.join(save_root, - f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch", - "checkpoints", f"livecell-{ctype}", "best.pt") - assert os.path.exists(model_ckpt) - - model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) - model.to(device) - model.eval() - - fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground") - bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary") - ws1_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed1") - ws2_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "watershed2") - - os.makedirs(fg_save_dir, exist_ok=True) - os.makedirs(bd_save_dir, exist_ok=True) - os.makedirs(ws1_save_dir, exist_ok=True) - os.makedirs(ws2_save_dir, exist_ok=True) - - with torch.no_grad(): - for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): - fname = os.path.split(img_path)[-1] - - input_img = imageio.imread(img_path) - input_img = standardize(input_img) - outputs = predict_with_halo( - input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True - ) - - fg, bd = outputs[0, :, :], outputs[1, :, :] - - ws1 = segmentation.watershed_from_components(bd, fg, min_size=10) - ws2 = segmentation.watershed_from_maxima(bd, fg, min_size=10, min_distance=1) - - imageio.imwrite(os.path.join(fg_save_dir, fname), fg) - imageio.imwrite(os.path.join(bd_save_dir, fname), bd) - imageio.imwrite(os.path.join(ws1_save_dir, fname), ws1) - imageio.imwrite(os.path.join(ws2_save_dir, fname), ws2) - - -def do_unetr_evaluation( - input_path: str, - cell_types: List[str], - root_save_dir: str, - sam_initialization: bool, - source_choice: str -): - fg_list, bd_list = [], [] - ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], [] - - for c1 in cell_types: - _save_dir = os.path.join(root_save_dir, f"src-{c1}") - if not os.path.exists(_save_dir): - print("Skipping", _save_dir) - continue - - fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} - ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} - for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): - fg_dir = os.path.join(_save_dir, "foreground") - bd_dir = os.path.join(_save_dir, "boundary") - ws1_dir = os.path.join(_save_dir, "watershed1") - ws2_dir = os.path.join(_save_dir, "watershed2") - - gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*") - cwise_fg, cwise_bd = [], [] - cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], [] - for gt_path in glob(gt_dir): - fname = os.path.split(gt_path)[-1] - - gt = imageio.imread(gt_path) - fg = imageio.imread(os.path.join(fg_dir, fname)) - bd = imageio.imread(os.path.join(bd_dir, fname)) - ws1 = imageio.imread(os.path.join(ws1_dir, fname)) - ws2 = imageio.imread(os.path.join(ws2_dir, fname)) - - true_bd = find_boundaries(gt) - - # Compare the foreground prediction to the ground-truth. - # Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to - # foreground in the dice function, since we have a comparision > 0 in there, and everything in the - # binary prediction evaluates to true. - # For the GT we can set the threshold to 0, because this will map to the correct binary mask. - cwise_fg.append(dice_score(fg, gt, threshold_gt=0, threshold_seg=None)) - - # Compare the background prediction to the ground-truth. - # Here, we don't need any thresholds: for the prediction the same holds as before. - # For the ground-truth we have already a binary label, so we don't need to threshold it again. - cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None)) - - msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore - msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore - - cwise_ws1_msa.append(msa1) - cwise_ws2_msa.append(msa2) - cwise_ws1_sa50.append(sa_acc1[0]) - cwise_ws2_sa50.append(sa_acc2[0]) - - fg_set[c2] = np.mean(cwise_fg) - bd_set[c2] = np.mean(cwise_bd) - ws1_msa_set[c2] = np.mean(cwise_ws1_msa) - ws2_msa_set[c2] = np.mean(cwise_ws2_msa) - ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50) - ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50) - - fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore - bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore - ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore - ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore - ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore - ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore - - f_df_fg = pd.concat(fg_list, ignore_index=True) - f_df_bd = pd.concat(bd_list, ignore_index=True) - f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True) - f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True) - f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True) - f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True) - - tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - csv_save_dir = f"./results/{tmp_csv_name}" - os.makedirs(csv_save_dir, exist_ok=True) - - f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) - f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) - f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) - f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) - f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) - f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) - - -def main(args): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - patch_shape = (512, 512) - output_channels = 2 - - all_cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] - - model = get_unetr_model( - model_name=args.model_name, - source_choice=args.source_choice, - patch_shape=patch_shape, - sam_initialization=args.do_sam_ini, - output_channels=output_channels - ) - model.to(device) - - if args.train: - print("2d UNETR training on LiveCell dataset") - do_unetr_training( - input_path=args.input, - model=model, - cell_types=args.cell_type, - patch_shape=patch_shape, - device=device, - save_root=args.save_root, - iterations=args.iterations, - sam_initialization=args.do_sam_ini, - source_choice=args.source_choice - ) - - root_save_dir = os.path.join( - args.save_dir, - f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" - ) - print("Predictions are saved in", root_save_dir) - - if args.predict: - print("2d UNETR inference on LiveCell dataset") - do_unetr_inference( - input_path=args.input, - device=device, - model=model, - cell_types=all_cell_types, - root_save_dir=root_save_dir, - sam_initialization=args.do_sam_ini, - save_root=args.save_root, - source_choice=args.source_choice - ) - - if args.evaluate: - print("2d UNETR evaluation on LiveCell dataset") - do_unetr_evaluation( - input_path=args.input, - cell_types=all_cell_types, - root_save_dir=root_save_dir, - sam_initialization=args.do_sam_ini, - source_choice=args.source_choice - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--train", action='store_true', - help="Enables UNETR training on LiveCell dataset") - - parser.add_argument("--predict", action='store_true', - help="Enables UNETR prediction on LiveCell dataset") - - parser.add_argument("--evaluate", action='store_true', - help="Enables UNETR evaluation on LiveCell dataset") - - parser.add_argument("--source_choice", type=str, default="torch-em", - help="The source where the model comes from, i.e. either torch-em / monai") - - parser.add_argument("-m", "--model_name", type=str, default="vit_b", - help="Name of the ViT to use as the encoder in UNETR") - - parser.add_argument("--do_sam_ini", action='store_true', - help="Enables initializing UNETR with SAM's ViT weights") - - parser.add_argument("-c", "--cell_type", type=str, default=None, - help="Choice of cell-type for doing the training") - - parser.add_argument("-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell", - help="Path where the dataset already exists/will be downloaded by the dataloader") - - parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/", - help="Path where checkpoints and logs will be saved") - - parser.add_argument("--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", - help="Path to save predictions from UNETR model") - - parser.add_argument("--iterations", type=int, default=100000) - - args = parser.parse_args() - main(args) diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index d1c27441..2327a60d 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -3,6 +3,7 @@ import vigra import elf.segmentation as elseg from elf.segmentation.utils import normalize_input +from elf.segmentation.mutex_watershed import mutex_watershed from skimage.measure import label from skimage.filters import gaussian @@ -34,10 +35,21 @@ def size_filter(seg, min_size, hmap=None, with_background=False): return seg -def mutex_watershed(affinities, offsets, mask=None, strides=None): - return elseg.mutex_watershed( - affinities, offsets, mask=mask, strides=strides, randomize_strides=True - ).astype("uint64") +def mutex_watershed_segmentation(foreground, affinities, offsets, min_size=250, threshold=0.5): + """Computes the mutex watershed segmentation using the affinity maps for respective pixel offsets + + Arguments: + - foreground: [np.ndarray] - The foreground background channel for the objects + - affinities [np.ndarray] - The input affinity maps + - offsets: [list[list[int]]] - The pixel offsets corresponding to the affinity channels + - min_size: [int] - The minimum pixels (below which) to filter objects + - threshold: [float] - To threshold foreground predictions + """ + mask = (foreground >= threshold) + strides = [2] * foreground.ndim + seg = mutex_watershed(affinities, offsets=offsets, mask=mask, strides=strides, randomize_strides=True) + seg = size_filter(seg.astype("uint32"), min_size=min_size, hmap=affinities, with_background=True) + return seg def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): @@ -48,7 +60,7 @@ def connected_components_with_boundaries(foreground, boundaries, threshold=0.5): return seg.astype("uint64") -def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, threshold2=0.5): +def watershed_from_components(boundaries, foreground, min_size=250, threshold1=0.5, threshold2=0.5): """The default approach: - Subtract the boundaries from the foreground to separate touching objects. - Use the connected components of this as seeds. @@ -72,7 +84,7 @@ def watershed_from_components(boundaries, foreground, min_size, threshold1=0.5, return seg -def watershed_from_maxima(boundaries, foreground, min_size, min_distance, sigma=1.0, threshold1=0.5): +def watershed_from_maxima(boundaries, foreground, min_distance, min_size=250, sigma=1.0, threshold1=0.5): """Find objects via seeded watershed starting from the maxima of the distance transform instead. This has the advantage that objects can be better separated, but it may over-segment if the objects have complex shapes.