From 87441ebf66605f9d4378e5f5650db9731484c14c Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 5 Dec 2023 16:15:29 +0100 Subject: [PATCH] Fix UNETR Model for LIVECell (#174) * Update unetr segmentation --- .../vision-transformer/unetr/livecell/common.py | 12 ++++++++---- .../unetr/livecell/livecell_all_unetr.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py index 33bbbc52..d4497cf7 100644 --- a/experiments/vision-transformer/unetr/livecell/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -93,7 +93,8 @@ def get_unetr_model( source_choice: str, patch_shape: Tuple[int, int], sam_initialization: bool, - output_channels: int + output_channels: int, + backbone: str = "sam" ): """Returns the expected UNETR model """ @@ -101,8 +102,9 @@ def get_unetr_model( # this returns the unetr model whihc uses the vision transformer from segment anything from torch_em import model as torch_em_models model = torch_em_models.UNETR( - encoder=model_name, out_channels=output_channels, - encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None + backbone=backbone, encoder=model_name, out_channels=output_channels, + encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None, + use_sam_stats=sam_initialization # FIXME: add mae weight initialization ) elif source_choice == "monai": @@ -117,7 +119,9 @@ def get_unetr_model( model.out_channels = 2 # type: ignore else: - raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") + tmp_msg = "The available UNETR models are either from `torch-em` or `monai`. " + tmp_msg += f"Please choose from them instead of {source_choice}" + raise ValueError(tmp_msg) return model diff --git a/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py index 2cad35a9..862848c1 100644 --- a/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py +++ b/experiments/vision-transformer/unetr/livecell/livecell_all_unetr.py @@ -126,7 +126,8 @@ def main(args): # get the model for the training and inference on livecell dataset model = common.get_unetr_model( model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape, - sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities) + sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities), + backbone=args.pretrained_choice ) model.to(device) @@ -137,7 +138,7 @@ def main(args): print("2d UNETR training on LIVECell dataset") # get the desired livecell loaders for training train_loader, val_loader = common.get_my_livecell_loaders( - args.input, patch_shape, args.cell_type, + args.input, patch_shape, args.cell_type, with_boundary=not args.with_affinities, with_affinities=args.with_affinities # this takes care of getting the loaders with affinities ) do_unetr_training(