From 0701737b4a3df65cb4080efb0f533e1b28d44d67 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 10 Dec 2023 13:24:35 +0100 Subject: [PATCH 1/6] Add BCEDiceLoss, update distance loss --- torch_em/loss/dice.py | 26 ++++++++++++++++++++++++++ torch_em/loss/distance_based.py | 4 +++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/torch_em/loss/dice.py b/torch_em/loss/dice.py index f7ccbe78..213815b4 100644 --- a/torch_em/loss/dice.py +++ b/torch_em/loss/dice.py @@ -84,6 +84,32 @@ def forward(self, input_, target): ) +class BCEDiceLoss(nn.Module): + + def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): + super().__init__() + self.alpha = alpha + self.beta = beta + self.channelwise = channelwise + self.eps = eps + + # all torch_em classes should store init kwargs to easily recreate the init call + self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} + + def forward(self, input_, target): + loss_dice = dice_score( + input_, + target, + invert=True, + channelwise=self.channelwise, + eps=self.eps + ) + loss_bce = nn.functional.binary_cross_entropy( + input_, target + ) + return self.alpha * loss_dice + self.beta * loss_bce + + # TODO think about how to handle combined losses like this for mixed precision training class BCEDiceLossWithLogits(nn.Module): diff --git a/torch_em/loss/distance_based.py b/torch_em/loss/distance_based.py index 537ec5e9..8374ad4d 100644 --- a/torch_em/loss/distance_based.py +++ b/torch_em/loss/distance_based.py @@ -16,7 +16,7 @@ class DistanceLoss(nn.Module): """ def __init__( self, - mask_distances_in_bg: bool, + mask_distances_in_bg: bool = True, foreground_loss: nn.Module = DiceLoss(), distance_loss: nn.Module = nn.MSELoss(reduction="mean") ) -> None: @@ -26,6 +26,8 @@ def __init__( self.distance_loss = distance_loss self.mask_distances_in_bg = mask_distances_in_bg + self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} + def forward(self, input_, target): assert input_.shape == target.shape, input_.shape assert input_.shape[1] == 3, input_.shape From caeaf2966ec23448a53f7cd86fdbc4d7709db6a1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 10 Dec 2023 13:34:29 +0100 Subject: [PATCH 2/6] Add doc string for distance label transformations --- torch_em/transform/label.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 63dfdf0b..45d2f5a5 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import skimage.measure import skimage.segmentation @@ -192,16 +194,25 @@ def __call__(self, labels): class DistanceTransform: - """Compute distances to foreground. + """Compute distances to foreground in the labels. + + Args: + distances: Whether to compute the absolute distances. + directed_distances: Whether to compute the directed distances (vector distances). + normalize: Whether to normalize the computed distances. + max_distance: Maximal distance at which to threshold the distances. + foreground_id: Label id to which the distance is compute. + invert Whether to invert the distances: + func: Normalization function for the distances. """ eps = 1e-7 def __init__( self, - distances=True, - directed_distances=False, - normalize=True, - max_distance=None, + distances: bool = True, + directed_distances: bool = False, + normalize: bool = True, + max_distance: Optional[float] = None, foreground_id=1, invert=False, func=None @@ -272,6 +283,16 @@ def __call__(self, labels): class PerObjectDistanceTransform: """Compute normalized distances per object in a segmentation. + + Args: + distances: Whether to compute the undirected distances. + boundary_distances: Whether to compute the distances to the object boundaries. + directed_distances: Whether to compute the directed distances (vector distances). + foreground: Whether to return a foreground channel. + apply_label: Whether to apply connected components to the labels before computing distances. + correct_centers: Whether to correct centers that are not in the objects. + min_size: Minimal size of objects for distance calculdation. + distance_fill_value: Fill value for the distances outside of objects. """ eps = 1e-7 From 5b9e34134f705d458bf2f0a1caa4170d2c9b045b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 10 Dec 2023 14:16:10 +0100 Subject: [PATCH 3/6] Update UNETR to take state dict in addition to checkpoint path; add example script --- scripts/load_mae_vit.py | 13 +++++++++ torch_em/model/unetr.py | 61 ++++++++++++++++++++++------------------- 2 files changed, 46 insertions(+), 28 deletions(-) create mode 100644 scripts/load_mae_vit.py diff --git a/scripts/load_mae_vit.py b/scripts/load_mae_vit.py new file mode 100644 index 00000000..b1d5c5e0 --- /dev/null +++ b/scripts/load_mae_vit.py @@ -0,0 +1,13 @@ +from collections import OrderedDict + +import torch +from torch_em.model import UNETR + +checkpoint = "imagenet.pth" +encoder_state = torch.load(checkpoint, map_location="cpu")["model"] +encoder_state = OrderedDict({ + k: v for k, v in encoder_state.items() + if (k != "mask_token" and not k.startswith("decoder")) +}) + +unetr_model = UNETR(backbone="mae", encoder="vit_l", encoder_checkpoint=encoder_state) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index b09de6b9..cf272404 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -1,9 +1,10 @@ +from collections import OrderedDict +from typing import Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple - from .unet import Decoder, ConvBlock2d, Upsampler2d from .vit import get_vision_transformer @@ -22,37 +23,41 @@ class UNETR(nn.Module): def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): - if backbone == "sam": - # If we have a SAM encoder, then we first try to load the full SAM Model - # (using micro_sam) and otherwise fall back on directly loading the encoder state - # from the checkpoint - try: - _, model = get_sam_model( - model_type=encoder, - checkpoint_path=checkpoint, - return_sam=True - ) - encoder_state = model.image_encoder.state_dict() - except Exception: - # If we have a MAE encoder, then we directly load the encoder state - # from the checkpoint. + if isinstance(checkpoint, str): + if backbone == "sam": + # If we have a SAM encoder, then we first try to load the full SAM Model + # (using micro_sam) and otherwise fall back on directly loading the encoder state + # from the checkpoint + try: + _, model = get_sam_model( + model_type=encoder, + checkpoint_path=checkpoint, + return_sam=True + ) + encoder_state = model.image_encoder.state_dict() + except Exception: + # If we have a MAE encoder, then we directly load the encoder state + # from the checkpoint. + encoder_state = torch.load(checkpoint) + + elif backbone == "mae": encoder_state = torch.load(checkpoint) - elif backbone == "mae": - encoder_state = torch.load(checkpoint) + else: + encoder_state = checkpoint self.encoder.load_state_dict(encoder_state) def __init__( self, - backbone="sam", - encoder="vit_b", - decoder=None, - out_channels=1, - use_sam_stats=False, - use_mae_stats=False, - encoder_checkpoint_path=None, - final_activation=None, + backbone: str = "sam", + encoder: str = "vit_b", + decoder: Optional[nn.Module] = None, + out_channels: int = 1, + use_sam_stats: bool = False, + use_mae_stats: bool = False, + encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, + final_activation: Optional[Union[str, nn.Module]] = None, ) -> None: super().__init__() @@ -61,8 +66,8 @@ def __init__( print(f"Using {encoder} from {backbone.upper()}") self.encoder = get_vision_transformer(backbone=backbone, model=encoder) - if encoder_checkpoint_path is not None: - self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint_path) + if encoder_checkpoint is not None: + self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) # parameters for the decoder network depth = 3 From d3acf43a8dacdfea156bdd49e3af5a5cdbee1709 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 10 Dec 2023 16:28:40 +0100 Subject: [PATCH 4/6] Update mae checkpoint ini --- scripts/load_mae_vit.py | 18 ++++++++---------- torch_em/model/unetr.py | 16 ++++++++++++++-- torch_em/model/vit.py | 14 +++++++------- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/scripts/load_mae_vit.py b/scripts/load_mae_vit.py index b1d5c5e0..00b689dc 100644 --- a/scripts/load_mae_vit.py +++ b/scripts/load_mae_vit.py @@ -1,13 +1,11 @@ -from collections import OrderedDict - -import torch from torch_em.model import UNETR -checkpoint = "imagenet.pth" -encoder_state = torch.load(checkpoint, map_location="cpu")["model"] -encoder_state = OrderedDict({ - k: v for k, v in encoder_state.items() - if (k != "mask_token" and not k.startswith("decoder")) -}) -unetr_model = UNETR(backbone="mae", encoder="vit_l", encoder_checkpoint=encoder_state) +def main(): + checkpoint = "/home/nimanwai/mae_models/imagenet.pth" + unetr_model = UNETR(img_size=224, backbone="mae", encoder="vit_l", encoder_checkpoint=checkpoint) + print(unetr_model) + + +if __name__ == "__main__": + main() diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index cf272404..55680fd9 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -41,7 +41,18 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): encoder_state = torch.load(checkpoint) elif backbone == "mae": - encoder_state = torch.load(checkpoint) + # vit initialization hints from: + # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 + encoder_state = torch.load(checkpoint)["model"] + encoder_state = OrderedDict({ + k: v for k, v in encoder_state.items() + if (k != "mask_token" and not k.startswith("decoder")) + }) + + # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) + current_encoder_state = self.encoder.state_dict() + if "head.weight" and "head.bias" in current_encoder_state: + del self.encoder.head else: encoder_state = checkpoint @@ -50,6 +61,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): def __init__( self, + img_size: int = 1024, backbone: str = "sam", encoder: str = "vit_b", decoder: Optional[nn.Module] = None, @@ -65,7 +77,7 @@ def __init__( self.use_mae_stats = use_mae_stats print(f"Using {encoder} from {backbone.upper()}") - self.encoder = get_vision_transformer(backbone=backbone, model=encoder) + self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) if encoder_checkpoint is not None: self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) diff --git a/torch_em/model/vit.py b/torch_em/model/vit.py index 2f5a9bbc..2e67755e 100644 --- a/torch_em/model/vit.py +++ b/torch_em/model/vit.py @@ -123,7 +123,7 @@ def forward(self, x): return x, list_from_encoder -def get_vision_transformer(backbone: str, model: str): +def get_vision_transformer(backbone: str, model: str, img_size: int = 1024): if backbone == "sam": if model == "vit_b": encoder = ViT_Sam( @@ -155,18 +155,18 @@ def get_vision_transformer(backbone: str, model: str): elif backbone == "mae": if model == "vit_b": encoder = ViT_MAE( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) elif model == "vit_l": encoder = ViT_MAE( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) elif model == "vit_h": encoder = ViT_MAE( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) else: raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") From e0beb448a8e50f7ee04fe4ac2ec862df1ed13cd8 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 10 Dec 2023 19:00:50 +0100 Subject: [PATCH 5/6] Fix weights and bias key checks --- torch_em/model/unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 55680fd9..33c35070 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -51,7 +51,7 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) current_encoder_state = self.encoder.state_dict() - if "head.weight" and "head.bias" in current_encoder_state: + if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): del self.encoder.head else: From db299c8d5768d1e77a413378a55a24afdb33b870 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 10 Dec 2023 19:44:20 +0100 Subject: [PATCH 6/6] Update sample script for loading the mae vit in unetr --- scripts/load_mae_vit.py | 11 ----------- .../load_mae_vit_in_unetr.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 11 deletions(-) delete mode 100644 scripts/load_mae_vit.py create mode 100644 scripts/vision_transformer/load_mae_vit_in_unetr.py diff --git a/scripts/load_mae_vit.py b/scripts/load_mae_vit.py deleted file mode 100644 index 00b689dc..00000000 --- a/scripts/load_mae_vit.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch_em.model import UNETR - - -def main(): - checkpoint = "/home/nimanwai/mae_models/imagenet.pth" - unetr_model = UNETR(img_size=224, backbone="mae", encoder="vit_l", encoder_checkpoint=checkpoint) - print(unetr_model) - - -if __name__ == "__main__": - main() diff --git a/scripts/vision_transformer/load_mae_vit_in_unetr.py b/scripts/vision_transformer/load_mae_vit_in_unetr.py new file mode 100644 index 00000000..248d05e7 --- /dev/null +++ b/scripts/vision_transformer/load_mae_vit_in_unetr.py @@ -0,0 +1,19 @@ +import argparse +from torch_em.model import UNETR + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint") + parser.add_argument("--encoder", default="vit_l") + parser.add_argument("--img_size", type=int, default=224) + args = parser.parse_args() + + UNETR( + img_size=args.img_size, backbone="mae", encoder=args.encoder, encoder_checkpoint=args.checkpoint + ) + print("UNETR Model successfully created and encoder initialized from", args.checkpoint) + + +if __name__ == "__main__": + main()