Skip to content

Commit

Permalink
Fix UNETR Model for LIVECell (#174)
Browse files Browse the repository at this point in the history
* Update unetr segmentation
  • Loading branch information
anwai98 committed Dec 5, 2023
1 parent 4925ec5 commit 87441eb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 8 additions & 4 deletions experiments/vision-transformer/unetr/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,18 @@ def get_unetr_model(
source_choice: str,
patch_shape: Tuple[int, int],
sam_initialization: bool,
output_channels: int
output_channels: int,
backbone: str = "sam"
):
"""Returns the expected UNETR model
"""
if source_choice == "torch-em":
# this returns the unetr model whihc uses the vision transformer from segment anything
from torch_em import model as torch_em_models
model = torch_em_models.UNETR(
encoder=model_name, out_channels=output_channels,
encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None
backbone=backbone, encoder=model_name, out_channels=output_channels,
encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None,
use_sam_stats=sam_initialization # FIXME: add mae weight initialization
)

elif source_choice == "monai":
Expand All @@ -117,7 +119,9 @@ def get_unetr_model(
model.out_channels = 2 # type: ignore

else:
raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}")
tmp_msg = "The available UNETR models are either from `torch-em` or `monai`. "
tmp_msg += f"Please choose from them instead of {source_choice}"
raise ValueError(tmp_msg)

return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def main(args):
# get the model for the training and inference on livecell dataset
model = common.get_unetr_model(
model_name=args.model_name, source_choice=args.source_choice, patch_shape=patch_shape,
sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities)
sam_initialization=args.do_sam_ini, output_channels=common._get_output_channels(args.with_affinities),
backbone=args.pretrained_choice
)
model.to(device)

Expand All @@ -137,7 +138,7 @@ def main(args):
print("2d UNETR training on LIVECell dataset")
# get the desired livecell loaders for training
train_loader, val_loader = common.get_my_livecell_loaders(
args.input, patch_shape, args.cell_type,
args.input, patch_shape, args.cell_type, with_boundary=not args.with_affinities,
with_affinities=args.with_affinities # this takes care of getting the loaders with affinities
)
do_unetr_training(
Expand Down

0 comments on commit 87441eb

Please sign in to comment.