diff --git a/scripts/vision_transformer/load_mae_vit_in_unetr.py b/scripts/vision_transformer/load_mae_vit_in_unetr.py new file mode 100644 index 00000000..248d05e7 --- /dev/null +++ b/scripts/vision_transformer/load_mae_vit_in_unetr.py @@ -0,0 +1,19 @@ +import argparse +from torch_em.model import UNETR + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint") + parser.add_argument("--encoder", default="vit_l") + parser.add_argument("--img_size", type=int, default=224) + args = parser.parse_args() + + UNETR( + img_size=args.img_size, backbone="mae", encoder=args.encoder, encoder_checkpoint=args.checkpoint + ) + print("UNETR Model successfully created and encoder initialized from", args.checkpoint) + + +if __name__ == "__main__": + main() diff --git a/torch_em/loss/dice.py b/torch_em/loss/dice.py index f7ccbe78..213815b4 100644 --- a/torch_em/loss/dice.py +++ b/torch_em/loss/dice.py @@ -84,6 +84,32 @@ def forward(self, input_, target): ) +class BCEDiceLoss(nn.Module): + + def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): + super().__init__() + self.alpha = alpha + self.beta = beta + self.channelwise = channelwise + self.eps = eps + + # all torch_em classes should store init kwargs to easily recreate the init call + self.init_kwargs = {"alpha": alpha, "beta": beta, "channelwise": channelwise, "eps": self.eps} + + def forward(self, input_, target): + loss_dice = dice_score( + input_, + target, + invert=True, + channelwise=self.channelwise, + eps=self.eps + ) + loss_bce = nn.functional.binary_cross_entropy( + input_, target + ) + return self.alpha * loss_dice + self.beta * loss_bce + + # TODO think about how to handle combined losses like this for mixed precision training class BCEDiceLossWithLogits(nn.Module): diff --git a/torch_em/loss/distance_based.py b/torch_em/loss/distance_based.py index 537ec5e9..8374ad4d 100644 --- a/torch_em/loss/distance_based.py +++ b/torch_em/loss/distance_based.py @@ -16,7 +16,7 @@ class DistanceLoss(nn.Module): """ def __init__( self, - mask_distances_in_bg: bool, + mask_distances_in_bg: bool = True, foreground_loss: nn.Module = DiceLoss(), distance_loss: nn.Module = nn.MSELoss(reduction="mean") ) -> None: @@ -26,6 +26,8 @@ def __init__( self.distance_loss = distance_loss self.mask_distances_in_bg = mask_distances_in_bg + self.init_kwargs = {"mask_distances_in_bg": mask_distances_in_bg} + def forward(self, input_, target): assert input_.shape == target.shape, input_.shape assert input_.shape[1] == 3, input_.shape diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index b09de6b9..33c35070 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -1,9 +1,10 @@ +from collections import OrderedDict +from typing import Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple - from .unet import Decoder, ConvBlock2d, Upsampler2d from .vit import get_vision_transformer @@ -22,37 +23,53 @@ class UNETR(nn.Module): def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): - if backbone == "sam": - # If we have a SAM encoder, then we first try to load the full SAM Model - # (using micro_sam) and otherwise fall back on directly loading the encoder state - # from the checkpoint - try: - _, model = get_sam_model( - model_type=encoder, - checkpoint_path=checkpoint, - return_sam=True - ) - encoder_state = model.image_encoder.state_dict() - except Exception: - # If we have a MAE encoder, then we directly load the encoder state - # from the checkpoint. - encoder_state = torch.load(checkpoint) - - elif backbone == "mae": - encoder_state = torch.load(checkpoint) + if isinstance(checkpoint, str): + if backbone == "sam": + # If we have a SAM encoder, then we first try to load the full SAM Model + # (using micro_sam) and otherwise fall back on directly loading the encoder state + # from the checkpoint + try: + _, model = get_sam_model( + model_type=encoder, + checkpoint_path=checkpoint, + return_sam=True + ) + encoder_state = model.image_encoder.state_dict() + except Exception: + # If we have a MAE encoder, then we directly load the encoder state + # from the checkpoint. + encoder_state = torch.load(checkpoint) + + elif backbone == "mae": + # vit initialization hints from: + # - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242 + encoder_state = torch.load(checkpoint)["model"] + encoder_state = OrderedDict({ + k: v for k, v in encoder_state.items() + if (k != "mask_token" and not k.startswith("decoder")) + }) + + # let's remove the `head` from our current encoder (as the MAE pretrained don't expect it) + current_encoder_state = self.encoder.state_dict() + if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state): + del self.encoder.head + + else: + encoder_state = checkpoint self.encoder.load_state_dict(encoder_state) def __init__( self, - backbone="sam", - encoder="vit_b", - decoder=None, - out_channels=1, - use_sam_stats=False, - use_mae_stats=False, - encoder_checkpoint_path=None, - final_activation=None, + img_size: int = 1024, + backbone: str = "sam", + encoder: str = "vit_b", + decoder: Optional[nn.Module] = None, + out_channels: int = 1, + use_sam_stats: bool = False, + use_mae_stats: bool = False, + encoder_checkpoint: Optional[Union[str, OrderedDict]] = None, + final_activation: Optional[Union[str, nn.Module]] = None, ) -> None: super().__init__() @@ -60,9 +77,9 @@ def __init__( self.use_mae_stats = use_mae_stats print(f"Using {encoder} from {backbone.upper()}") - self.encoder = get_vision_transformer(backbone=backbone, model=encoder) - if encoder_checkpoint_path is not None: - self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint_path) + self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder) + if encoder_checkpoint is not None: + self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint) # parameters for the decoder network depth = 3 diff --git a/torch_em/model/vit.py b/torch_em/model/vit.py index 2f5a9bbc..2e67755e 100644 --- a/torch_em/model/vit.py +++ b/torch_em/model/vit.py @@ -123,7 +123,7 @@ def forward(self, x): return x, list_from_encoder -def get_vision_transformer(backbone: str, model: str): +def get_vision_transformer(backbone: str, model: str, img_size: int = 1024): if backbone == "sam": if model == "vit_b": encoder = ViT_Sam( @@ -155,18 +155,18 @@ def get_vision_transformer(backbone: str, model: str): elif backbone == "mae": if model == "vit_b": encoder = ViT_MAE( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) elif model == "vit_l": encoder = ViT_MAE( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) elif model == "vit_h": encoder = ViT_MAE( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6) + img_size=img_size, patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) else: raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.") diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 63dfdf0b..45d2f5a5 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import skimage.measure import skimage.segmentation @@ -192,16 +194,25 @@ def __call__(self, labels): class DistanceTransform: - """Compute distances to foreground. + """Compute distances to foreground in the labels. + + Args: + distances: Whether to compute the absolute distances. + directed_distances: Whether to compute the directed distances (vector distances). + normalize: Whether to normalize the computed distances. + max_distance: Maximal distance at which to threshold the distances. + foreground_id: Label id to which the distance is compute. + invert Whether to invert the distances: + func: Normalization function for the distances. """ eps = 1e-7 def __init__( self, - distances=True, - directed_distances=False, - normalize=True, - max_distance=None, + distances: bool = True, + directed_distances: bool = False, + normalize: bool = True, + max_distance: Optional[float] = None, foreground_id=1, invert=False, func=None @@ -272,6 +283,16 @@ def __call__(self, labels): class PerObjectDistanceTransform: """Compute normalized distances per object in a segmentation. + + Args: + distances: Whether to compute the undirected distances. + boundary_distances: Whether to compute the distances to the object boundaries. + directed_distances: Whether to compute the directed distances (vector distances). + foreground: Whether to return a foreground channel. + apply_label: Whether to apply connected components to the labels before computing distances. + correct_centers: Whether to correct centers that are not in the objects. + min_size: Minimal size of objects for distance calculdation. + distance_fill_value: Fill value for the distances outside of objects. """ eps = 1e-7