From 4fd38ecdf2a8e806b57346af60a76a1c4ca21db2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 13 Sep 2023 15:58:20 +0200 Subject: [PATCH 1/5] Fix UNETR Training + Update Initialization for SAM --- .../unetr/initialize_with_sam.py | 11 ++- .../unetr/livecell_unetr.py | 30 ++++--- torch_em/model/__init__.py | 2 +- torch_em/model/unetr.py | 80 ++++++++++++++----- 4 files changed, 85 insertions(+), 38 deletions(-) diff --git a/experiments/vision-transformer/unetr/initialize_with_sam.py b/experiments/vision-transformer/unetr/initialize_with_sam.py index 75d80d3a..e325f674 100644 --- a/experiments/vision-transformer/unetr/initialize_with_sam.py +++ b/experiments/vision-transformer/unetr/initialize_with_sam.py @@ -1,9 +1,14 @@ import torch -from torch_em.model.unetr import build_unetr_with_sam_intialization +from torch_em.model.unetr import build_unetr_with_sam_initialization # FIXME this doesn't work yet -model = build_unetr_with_sam_intialization() -x = torch.randn(1, 3, 1024, 1024) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = build_unetr_with_sam_initialization( + checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth") +model.to(device) + +x = torch.randn(1, 3, 1024, 1024).to(device=device) y = model(x) print(y.shape) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index bdfc4f82..0d895824 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -3,7 +3,7 @@ import torch import torch_em -from torch_em.model import UNETR +from torch_em.model.unetr import build_unetr_with_sam_initialization from torch_em.data.datasets import get_livecell_loader @@ -16,7 +16,7 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration batch_size=2, cell_types=cell_type, download=True, - binary=True + boundaries=True ) val_loader = get_livecell_loader( @@ -26,11 +26,15 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration batch_size=1, cell_types=cell_type, download=True, - binary=True + boundaries=True ) - model = UNETR(out_channels=1, - initialize_from_sam=True) + n_channels = 2 + + model = build_unetr_with_sam_initialization( + out_channels=n_channels, + checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" + ) model.to(device) trainer = torch_em.default_segmentation_trainer( @@ -39,7 +43,7 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration train_loader=train_loader, val_loader=val_loader, device=device, - learning_rate=1.0e-4, + learning_rate=1e-5, log_image_interval=10, save_root=save_root, compile_model=False @@ -53,12 +57,14 @@ def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.train: - print("Training a 2D UNETR on LiveCELL dataset") - do_unetr_training(data_path=args.inputs, - save_root=args.save_root, - cell_type=args.cell_type, - iterations=args.iterations, - device=device) + print("Training a 2D UNETR on LiveCell dataset") + do_unetr_training( + data_path=args.inputs, + save_root=args.save_root, + cell_type=args.cell_type, + iterations=args.iterations, + device=device + ) if __name__ == "__main__": diff --git a/torch_em/model/__init__.py b/torch_em/model/__init__.py index 7ffe347a..ac147f17 100644 --- a/torch_em/model/__init__.py +++ b/torch_em/model/__init__.py @@ -1,3 +1,3 @@ from .unet import AnisotropicUNet, UNet2d, UNet3d from .probabilistic_unet import ProbabilisticUNet -from .unetr import UNETR, build_unetr_with_sam_intialization +from .unetr import UNETR, build_unetr_with_sam_initialization diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 244b93a3..580ab46c 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -21,7 +21,7 @@ get_sam_model = None -class ViTb_Sam(ImageEncoderViT): +class ViT_Sam(ImageEncoderViT): """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py """ @@ -112,10 +112,12 @@ def window_unpartition( class UNETR(nn.Module): def __init__( self, - encoder=None, + encoder="vit_b", decoder=None, out_channels=1, - use_sam_preprocessing=False, + use_sam_preprocessing=True, + initialize_from_sam=False, + checkpoint_path=None ) -> None: depth = 3 initial_features = 64 @@ -127,8 +129,8 @@ def __init__( super().__init__() - if encoder is None: - self.encoder = ViTb_Sam( + if encoder == "vit_b": + self.encoder = ViT_Sam( depth=12, embed_dim=768, img_size=1024, @@ -142,8 +144,51 @@ def __init__( window_size=14, out_chans=256, ) + + elif encoder == "vit_l": + self.encoder = ViT_Sam( + depth=24, + embed_dim=1024, + img_size=1024, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), # type: ignore + num_heads=16, + patch_size=16, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=[5, 11, 17, 23], # type: ignore + window_size=14, + out_chans=256 + ) + + elif encoder == "vit_h": + self.encoder = ViT_Sam( + depth=32, + embed_dim=1280, + img_size=1024, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), # type: ignore + num_heads=16, + patch_size=16, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=[7, 15, 23, 31], # type: ignore + window_size=14, + out_chans=256 + ) + else: - self.encoder = encoder + raise NotImplementedError + + if initialize_from_sam: + assert checkpoint_path is not None + _, model = get_sam_model( + model_type=encoder, + checkpoint_path=checkpoint_path, + return_sam=True + ) # type: ignore + for param1, param2 in zip(model.parameters(), self.encoder.parameters()): + param2.data = param1 if decoder is None: self.decoder = Decoder( @@ -206,7 +251,7 @@ def forward(self, x): z0 = self.z_inputs(x) - z12, from_encoder = self.encoder(x) + z12, from_encoder = self.encoder(x) # type: ignore x = self.base(z12) from_encoder = from_encoder[::-1] @@ -279,21 +324,12 @@ def forward(self, x): return self.block(x) -def build_unetr_with_sam_intialization(out_channels=1, model_type="vit_b", checkpoint_path=None): - if get_sam_model is None: - raise RuntimeError( - "micro_sam is required to initialize the UNETR image encoder from segment anything weights." - "Please install it from https://github.com/computational-cell-analytics/micro-sam" - "and then rerun your code." - ) - predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) - _image_encoder = predictor.model.image_encoder +def build_unetr_with_sam_initialization(out_channels=1, model_type="vit_b", checkpoint_path=None): + unetr = UNETR(encoder=model_type, out_channels=out_channels, + initialize_from_sam=True, checkpoint_path=checkpoint_path) + return unetr - image_encoder = ViTb_Sam() - # FIXME this doesn't work yet because the parameters don't match - with torch.no_grad(): - for param1, param2 in zip(_image_encoder.parameters(), image_encoder.parameters()): - param2.data = param1.data - unetr = UNETR(encoder=image_encoder, out_channels=out_channels) +def build_unetr_without_sam_initialization(out_channels=1, model_type="vit_b"): + unetr = UNETR(encoder=model_type, out_channels=out_channels) return unetr From efe36aaee4f2d11eef92d94554a6b27443758cf6 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 13 Sep 2023 16:29:43 +0200 Subject: [PATCH 2/5] Add UNETR Training for Cremi Dataset --- .../vision-transformer/unetr/cremi_unetr.py | 80 +++++++++++++++++++ .../unetr/livecell_unetr.py | 11 ++- 2 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 experiments/vision-transformer/unetr/cremi_unetr.py diff --git a/experiments/vision-transformer/unetr/cremi_unetr.py b/experiments/vision-transformer/unetr/cremi_unetr.py new file mode 100644 index 00000000..69bcbd07 --- /dev/null +++ b/experiments/vision-transformer/unetr/cremi_unetr.py @@ -0,0 +1,80 @@ +import os +import argparse +import numpy as np + +import torch +import torch_em +from torch_em.model.unetr import build_unetr_with_sam_initialization +from torch_em.data.datasets import get_cremi_loader + + +def do_unetr_training(data_path: str, save_root: str, iterations: int, device, patch_shape=(1, 512, 512)): + os.makedirs(data_path, exist_ok=True) + + cremi_train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} + cremi_val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} + + train_loader = get_cremi_loader( + path=data_path, + patch_shape=patch_shape, download=True, + rois=cremi_train_rois, + ndim=2, + defect_augmentation_kwargs=None, + boundaries=True, + batch_size=2 + ) + + val_loader = get_cremi_loader( + path=data_path, + patch_shape=patch_shape, download=True, + rois=cremi_val_rois, + ndim=2, + defect_augmentation_kwargs=None, + boundaries=True, + batch_size=1 + ) + + model = build_unetr_with_sam_initialization( + checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" + ) + model.to(device) + + trainer = torch_em.default_segmentation_trainer( + name="unetr-cremi", + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-5, + log_image_interval=10, + save_root=save_root, + compile_model=False + ) + + trainer.fit(iterations) + + +def main(args): + print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.train: + print("Training a 2D UNETR on Cremi dataset") + do_unetr_training( + data_path=args.inputs, + save_root=args.save_root, + iterations=args.iterations, + device=device + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train", action='store_true', help="Enables UNETR training on Cremi dataset") + parser.add_argument("-i", "--inputs", type=str, default="./cremi/", + help="Path where the dataset already exists/will be downloaded by the dataloader") + parser.add_argument("-s", "--save_root", type=str, default=None, + help="Path where checkpoints and logs will be saved") + parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for") + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 0d895824..e574fcf7 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -38,7 +38,7 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration model.to(device) trainer = torch_em.default_segmentation_trainer( - name=f"unet-source-livecell-{cell_type[0]}", + name=f"unetr-source-livecell-{cell_type[0]}", model=model, train_loader=train_loader, val_loader=val_loader, @@ -70,9 +70,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train", action='store_true', help="Enables UNETR training on LiveCELL dataset") - parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"], help="Choice of cell-type for doing the training") - parser.add_argument("-i", "--inputs", type=str, default="./livecell/", help="Path where the dataset already exists/will be downloaded by the dataloader") - parser.add_argument("-s", "--save_root", type=str, default=None, help="Path where checkpoints and logs will be saved") + parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"], + help="Choice of cell-type for doing the training") + parser.add_argument("-i", "--inputs", type=str, default="./livecell/", + help="Path where the dataset already exists/will be downloaded by the dataloader") + parser.add_argument("-s", "--save_root", type=str, default=None, + help="Path where checkpoints and logs will be saved") parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for") args = parser.parse_args() main(args) From b5d96af7c23b5af8f5afc6567bcdcffe9adcc96d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 13 Sep 2023 18:12:26 +0200 Subject: [PATCH 3/5] Fix and Test different ViT models --- experiments/vision-transformer/unetr/initialize_with_sam.py | 3 ++- torch_em/model/unetr.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/experiments/vision-transformer/unetr/initialize_with_sam.py b/experiments/vision-transformer/unetr/initialize_with_sam.py index e325f674..c97d1a96 100644 --- a/experiments/vision-transformer/unetr/initialize_with_sam.py +++ b/experiments/vision-transformer/unetr/initialize_with_sam.py @@ -5,7 +5,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_unetr_with_sam_initialization( - checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth") + model_type="vit_h", + checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_h_4b8939.pth") model.to(device) x = torch.randn(1, 3, 1024, 1024).to(device=device) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 580ab46c..d53c4c0a 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -39,7 +39,7 @@ def __init__( "and then rerun your code." ) - super().__init__(**kwargs) + super().__init__(embed_dim=embed_dim, **kwargs) self.chunks_for_projection = global_attn_indexes self.in_chans = in_chans self.embed_dim = embed_dim From c12b7440588623e2c4f4f7ca7eb589390834bee1 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 13 Sep 2023 21:33:50 +0200 Subject: [PATCH 4/5] Update Minor Fixes --- .../vision-transformer/unetr/README.md | 13 +++++++++++ .../vision-transformer/unetr/cremi_unetr.py | 8 +++---- .../unetr/initialize_with_sam.py | 8 +++---- .../unetr/livecell_unetr.py | 9 ++++---- torch_em/model/__init__.py | 2 +- torch_em/model/unetr.py | 23 ++++--------------- 6 files changed, 30 insertions(+), 33 deletions(-) create mode 100644 experiments/vision-transformer/unetr/README.md diff --git a/experiments/vision-transformer/unetr/README.md b/experiments/vision-transformer/unetr/README.md new file mode 100644 index 00000000..2da3a87d --- /dev/null +++ b/experiments/vision-transformer/unetr/README.md @@ -0,0 +1,13 @@ +## Usage of SAM's ViT initialization in UNETR + +### Initialize ViT models for UNETR +``` +from torch_em.model import UNETR +unetr = UNETR(encoder=model_type, out_channels=out_channels, encoder_checkpoint_path=checkpoint_path) +``` + +### Vanilla ViT models for UNETR +``` +from torch_em.model import UNETR +unetr = UNETR(encoder=model_type, out_channels=out_channels) +``` \ No newline at end of file diff --git a/experiments/vision-transformer/unetr/cremi_unetr.py b/experiments/vision-transformer/unetr/cremi_unetr.py index 69bcbd07..c4696ee1 100644 --- a/experiments/vision-transformer/unetr/cremi_unetr.py +++ b/experiments/vision-transformer/unetr/cremi_unetr.py @@ -4,7 +4,7 @@ import torch import torch_em -from torch_em.model.unetr import build_unetr_with_sam_initialization +from torch_em.model import UNETR from torch_em.data.datasets import get_cremi_loader @@ -34,9 +34,9 @@ def do_unetr_training(data_path: str, save_root: str, iterations: int, device, p batch_size=1 ) - model = build_unetr_with_sam_initialization( - checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" - ) + model = UNETR( + encoder="vit_b", out_channels=1, + encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth") model.to(device) trainer = torch_em.default_segmentation_trainer( diff --git a/experiments/vision-transformer/unetr/initialize_with_sam.py b/experiments/vision-transformer/unetr/initialize_with_sam.py index c97d1a96..7ba7f9fa 100644 --- a/experiments/vision-transformer/unetr/initialize_with_sam.py +++ b/experiments/vision-transformer/unetr/initialize_with_sam.py @@ -1,12 +1,10 @@ import torch -from torch_em.model.unetr import build_unetr_with_sam_initialization +from torch_em.model import UNETR -# FIXME this doesn't work yet device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model = build_unetr_with_sam_initialization( - model_type="vit_h", - checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_h_4b8939.pth") +model = UNETR(encoder="vit_h", out_channels=1, + encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_h_4b8939.pth") model.to(device) x = torch.randn(1, 3, 1024, 1024).to(device=device) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index e574fcf7..aab00c94 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -3,7 +3,7 @@ import torch import torch_em -from torch_em.model.unetr import build_unetr_with_sam_initialization +from torch_em.model import UNETR from torch_em.data.datasets import get_livecell_loader @@ -31,10 +31,9 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration n_channels = 2 - model = build_unetr_with_sam_initialization( - out_channels=n_channels, - checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" - ) + model = UNETR( + encoder="vit_b", out_channels=n_channels, + encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth") model.to(device) trainer = torch_em.default_segmentation_trainer( diff --git a/torch_em/model/__init__.py b/torch_em/model/__init__.py index ac147f17..54b1264d 100644 --- a/torch_em/model/__init__.py +++ b/torch_em/model/__init__.py @@ -1,3 +1,3 @@ from .unet import AnisotropicUNet, UNet2d, UNet3d from .probabilistic_unet import ProbabilisticUNet -from .unetr import UNETR, build_unetr_with_sam_initialization +from .unetr import UNETR diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index d53c4c0a..5d1d80a1 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -21,7 +21,7 @@ get_sam_model = None -class ViT_Sam(ImageEncoderViT): +class ViT_Sam(ImageEncoderViT): # type: ignore """Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643): https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py """ @@ -116,8 +116,7 @@ def __init__( decoder=None, out_channels=1, use_sam_preprocessing=True, - initialize_from_sam=False, - checkpoint_path=None + encoder_checkpoint_path=None ) -> None: depth = 3 initial_features = 64 @@ -178,13 +177,12 @@ def __init__( ) else: - raise NotImplementedError + raise ValueError(f"{encoder} is not supported. Currently only vit_b, vit_l, vit_h are supported.") - if initialize_from_sam: - assert checkpoint_path is not None + if encoder_checkpoint_path is not None: _, model = get_sam_model( model_type=encoder, - checkpoint_path=checkpoint_path, + checkpoint_path=encoder_checkpoint_path, return_sam=True ) # type: ignore for param1, param2 in zip(model.parameters(), self.encoder.parameters()): @@ -322,14 +320,3 @@ def __init__(self, in_planes, out_planes, kernel_size=3): def forward(self, x): return self.block(x) - - -def build_unetr_with_sam_initialization(out_channels=1, model_type="vit_b", checkpoint_path=None): - unetr = UNETR(encoder=model_type, out_channels=out_channels, - initialize_from_sam=True, checkpoint_path=checkpoint_path) - return unetr - - -def build_unetr_without_sam_initialization(out_channels=1, model_type="vit_b"): - unetr = UNETR(encoder=model_type, out_channels=out_channels) - return unetr From 57df2bbbd75730b84fcf1dc6dbc52cb92a54891f Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 13 Sep 2023 21:44:14 +0200 Subject: [PATCH 5/5] Update README.md - Add usage details --- experiments/vision-transformer/unetr/README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/experiments/vision-transformer/unetr/README.md b/experiments/vision-transformer/unetr/README.md index 2da3a87d..43a6d7ca 100644 --- a/experiments/vision-transformer/unetr/README.md +++ b/experiments/vision-transformer/unetr/README.md @@ -1,6 +1,12 @@ -## Usage of SAM's ViT initialization in UNETR +## SAM's ViT Initialization in UNETR -### Initialize ViT models for UNETR +Note: +- `model_type` - [`vit_b`/`vit_l`/`vit_h`] +- `out_channels` - Number of output channels +- `encoder_checkpoint_path` - Pass the checkpoints from the pretrained [Segment Anything](https://github.com/facebookresearch/segment-anything) models to initialize the SAM weights to the (ViT) encoder backbone (Click on the model names to download them - [ViT-b](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) / [ViT-l](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) / [ViT-h](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)) + + +### How to initialize ViT models for UNETR? ``` from torch_em.model import UNETR unetr = UNETR(encoder=model_type, out_channels=out_channels, encoder_checkpoint_path=checkpoint_path) @@ -10,4 +16,4 @@ unetr = UNETR(encoder=model_type, out_channels=out_channels, encoder_checkpoint_ ``` from torch_em.model import UNETR unetr = UNETR(encoder=model_type, out_channels=out_channels) -``` \ No newline at end of file +```