diff --git a/experiments/vision-transformer/unetr/.gitignore b/experiments/vision-transformer/unetr/.gitignore index 3ac63299..830609da 100644 --- a/experiments/vision-transformer/unetr/.gitignore +++ b/experiments/vision-transformer/unetr/.gitignore @@ -1,2 +1,3 @@ *.sh -*.out \ No newline at end of file +*.out +*.csv \ No newline at end of file diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py index cc3d7dc1..ba71c6f9 100644 --- a/experiments/vision-transformer/unetr/livecell/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -174,7 +174,7 @@ def get_unetr_model( model = torch_em_models.UNETR( backbone=backbone, encoder=model_name, out_channels=output_channels, use_sam_stats=sam_initialization, final_activation="Sigmoid", - encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None, + encoder_checkpoint=MODELS[model_name] if sam_initialization else None, ) elif source_choice == "monai": @@ -241,7 +241,12 @@ def predict_for_unetr( elif with_distances: # inference using foreground and hv distance maps outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16)) fg, cdist, bdist = outputs.squeeze() - dm_seg = segmentation.watershed_from_center_and_boundary_distances(cdist, bdist, fg, min_size=50) + dm_seg = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) else: # inference using foreground-boundary inputs - for the unetr training outputs = predict_with_halo( diff --git a/experiments/vision-transformer/unetr/livecell/train_by_parts.py b/experiments/vision-transformer/unetr/livecell/train_by_parts.py new file mode 100644 index 00000000..abbd09f4 --- /dev/null +++ b/experiments/vision-transformer/unetr/livecell/train_by_parts.py @@ -0,0 +1,134 @@ +import os +from collections import OrderedDict + +import torch +from torch_em import model as torch_em_models + +import common + + +def prune_prefix(checkpoint_path): + state = torch.load(checkpoint_path, map_location="cpu") + model_state = state["model_state"] + + # let's prune the `.sam` prefix for the finetuned models + sam_prefix = "sam.image_encoder." + updated_model_state = [] + for k, v in model_state.items(): + if k.startswith(sam_prefix): + updated_model_state.append((k[len(sam_prefix):], v)) + updated_model_state = OrderedDict(updated_model_state) + + return updated_model_state + + +def get_custom_unetr_model( + device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training +): + if checkpoint_path is not None: + if checkpoint_path.endswith("pt"): # for finetuned models + model_state = prune_prefix(checkpoint_path) + else: # for vanilla sam models + model_state = checkpoint_path + else: # while checkpoint path is None, hence we train from scratch + model_state = checkpoint_path + + model = torch_em_models.UNETR( + backbone="sam", + encoder=model_name, + out_channels=output_channels, + use_sam_stats=sam_initialization, + final_activation="Sigmoid", + encoder_checkpoint=model_state, + use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default + ) + + model.to(device) + + # if expected, let's freeze the image encoder + if freeze_encoder: + for name, param in model.named_parameters(): + if name.startswith("encoder"): + param.requires_grad = False + + return model + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups + patch_shape = (512, 512) # patch size used for training on livecell + + # directory folder to save different parts of the scheme + dir_structure = os.path.join( + args.model_name, f"freeze_encoder_{args.freeze_encoder}", "distances", "dicebaseddistloss", + f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch" + ) + + # get the desired loss function for training + loss = common.get_loss_function(with_distances=True, combine_dist_with_dice=True) + + # get the custom model for the training and inference on livecell dataset + model = get_custom_unetr_model( + device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3, + checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training + ) + + # determining where to save the checkpoints and tensorboard logs + save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root + + # determines the directory where the predictions will be saved + root_save_dir = os.path.join(args.save_dir, dir_structure) + + if args.train: + print("2d (custom) UNETR training (with distances) on LiveCELL...") + + # get the desired livecell loaders for training + train_loader, val_loader = common.get_my_livecell_loaders( + args.input, patch_shape, args.cell_type, with_distances=True, + input_norm=not args.do_sam_ini + ) + + common.do_unetr_training( + train_loader=train_loader, val_loader=val_loader, model=model, loss=loss, + device=device, save_root=save_root, iterations=args.iterations + ) + + if args.predict: + print("2d (custom) UNETR inference (with distances) on LiveCELL...") + common.do_unetr_inference( + input_path=args.input, device=device, model=model, save_root=save_root, + root_save_dir=root_save_dir, with_distances=True, + # the logic written for `input_norm` is complicated, but the idea is simple: + # - should standardize the inputs when we "DONOT" use SAM initialization + # - should not standardize the inputs when we use SAM initialization + input_norm=not args.do_sam_ini + ) + print("Predictions are saved in", root_save_dir) + + if args.evaluate: + print("2d (custom) UNETR evaluation (with distances) on LiveCELL...") + csv_save_dir = os.path.join("results", dir_structure) + os.makedirs(csv_save_dir, exist_ok=True) + + common.do_unetr_evaluation( + input_path=args.input, root_save_dir=root_save_dir, csv_save_dir=csv_save_dir, with_distances=True + ) + + +# we train three setups: +# - training from scratch, seeing the performance using instance segmentation +# - training from vanilla SAM, seeing the performance using instance segmentation +# - training from finetuned SAM, seeing the performance using instance segmentation +if __name__ == "__main__": + parser = common.get_parser() + parser.add_argument( + "--checkpoint", type=str, default=None, help="The checkpoint to the specific pretrained models." + ) + parser.add_argument( + "--freeze_encoder", action="store_true", help="Experiments to freeze the encoder." + ) + parser.add_argument( + "--joint_training", action="store_true", help="Uses VNETR for training" + ) + args = parser.parse_args() + main(args) diff --git a/scripts/vision_transformer/load_sam_encoder_in_unetr.py b/scripts/vision_transformer/load_sam_encoder_in_unetr.py new file mode 100644 index 00000000..6a6158ce --- /dev/null +++ b/scripts/vision_transformer/load_sam_encoder_in_unetr.py @@ -0,0 +1,35 @@ +import torch + +from torch_em.model import UNETR + +from micro_sam.util import get_sam_model + + +def main(): + checkpoint = "/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" + model_type = "vit_b" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + predictor = get_sam_model( + model_type=model_type, + checkpoint_path=checkpoint + ) + + model = UNETR( + backbone="sam", + encoder=predictor.model.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False + ) + model.to(device) + + x = torch.ones((1, 1, 512, 512)).to(device) + y = model(x) + + print("UNETR Model successfully created and encoder initialized from", checkpoint) + + +if __name__ == "__main__": + main() diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 33c35070..010439a2 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from .unet import Decoder, ConvBlock2d, Upsampler2d -from .vit import get_vision_transformer +from .vit import get_vision_transformer, ViT_MAE, ViT_Sam try: from micro_sam.util import get_sam_model @@ -24,7 +24,7 @@ class UNETR(nn.Module): def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): if isinstance(checkpoint, str): - if backbone == "sam": + if backbone == "sam" and isinstance(encoder, str): # If we have a SAM encoder, then we first try to load the full SAM Model # (using micro_sam) and otherwise fall back on directly loading the encoder state # from the checkpoint @@ -63,23 +63,47 @@ def __init__( self, img_size: int = 1024, backbone: str = "sam", - encoder: str = "vit_b", + encoder: Optional[Union[nn.Module, str]] = "vit_b", decoder: Optional[nn.Module] = None, out_channels: int = 1, use_sam_stats: bool = False, use_mae_stats: bool = False, encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, final_activation: Optional[Union[str, nn.Module]] = None, + use_skip_connection: bool = True, + embed_dim: Optional[int] = None ) -> None: super().__init__() self.use_sam_stats = use_sam_stats self.use_mae_stats = use_mae_stats + self.use_skip_connection = use_skip_connection - print(f"Using {encoder} from {backbone.upper()}") - self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) - if encoder_checkpoint is not None: - self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) + if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h" + print(f"Using {encoder} from {backbone.upper()}") + self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) + if encoder_checkpoint is not None: + self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) + + in_chans = self.encoder.in_chans + if embed_dim is None: + embed_dim = self.encoder.embed_dim + + else: # `nn.Module` ViT backbone + self.encoder = encoder + + have_neck = False + for name, _ in self.encoder.named_parameters(): + if name.startswith("neck"): + have_neck = True + + if embed_dim is None: + if have_neck: + embed_dim = self.encoder.neck[2].out_channels # the value is 256 + else: + embed_dim = self.encoder.patch_embed.proj.out_channels + + in_chans = self.encoder.patch_embed.proj.in_channels # parameters for the decoder network depth = 3 @@ -99,18 +123,21 @@ def __init__( else: self.decoder = decoder - self.z_inputs = ConvBlock2d(self.encoder.in_chans, features_decoder[-1]) + self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1]) + + self.base = ConvBlock2d(embed_dim, features_decoder[0]) - self.base = ConvBlock2d(self.encoder.embed_dim, features_decoder[0]) self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) - self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0]) + self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) + self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) + + self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) - self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) + self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) - self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1]) self.final_activation = self._get_activation(final_activation) def _get_activation(self, activation): @@ -167,26 +194,42 @@ def forward(self, x): # backbone used for reshaping inputs to the desired "encoder" shape x = torch.stack([self.preprocess(e) for e in x], dim=0) - z0 = self.z_inputs(x) + use_skip_connection = getattr(self, "use_skip_connection", True) - z12, from_encoder = self.encoder(x) - x = self.base(z12) + encoder_outputs = self.encoder(x) - from_encoder = from_encoder[::-1] - z9 = self.deconv1(from_encoder[0]) + if isinstance(self.encoder, ViT_Sam) or isinstance(self.encoder, ViT_MAE): + z12, from_encoder = encoder_outputs + else: + z12 = encoder_outputs - z6 = self.deconv1(from_encoder[1]) - z6 = self.deconv2(z6) + if use_skip_connection: + # TODO: we share the weights in the deconv(s), and should preferably avoid doing that + from_encoder = from_encoder[::-1] + z9 = self.deconv1(from_encoder[0]) - z3 = self.deconv1(from_encoder[2]) - z3 = self.deconv2(z3) - z3 = self.deconv3(z3) + z6 = self.deconv1(from_encoder[1]) + z6 = self.deconv2(z6) + + z3 = self.deconv1(from_encoder[2]) + z3 = self.deconv2(z3) + z3 = self.deconv3(z3) + + z0 = self.z_inputs(x) + + else: + z9 = self.deconv1(z12) + z6 = self.deconv2(z9) + z3 = self.deconv3(z6) + z0 = self.deconv4(z3) updated_from_encoder = [z9, z6, z3] + + x = self.base(z12) x = self.decoder(x, encoder_inputs=updated_from_encoder) - x = self.deconv4(x) - x = torch.cat([x, z0], dim=1) + x = self.deconv_out(x) + x = torch.cat([x, z0], dim=1) x = self.decoder_head(x) x = self.out_conv(x) diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 45d2f5a5..d13233e2 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -302,6 +302,7 @@ def __init__( boundary_distances=True, directed_distances=False, foreground=True, + instances=False, apply_label=True, correct_centers=True, min_size=0, @@ -313,6 +314,7 @@ def __init__( self.boundary_distances = boundary_distances self.directed_distances = directed_distances self.foreground = foreground + self.instances = instances self.apply_label = apply_label self.correct_centers = correct_centers @@ -441,4 +443,7 @@ def __call__(self, labels): binary_labels = (labels > 0).astype("float32") distances = np.concatenate([binary_labels[None], distances], axis=0) + if self.instances: + distances = np.concatenate([labels[None], distances], axis=0) + return distances