From b258f1a560312624692cb4ef690e27e1359dfb92 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 8 Dec 2023 15:43:23 +0100 Subject: [PATCH] Fix issues in ViT initialization and update UNETR state loading --- torch_em/model/unetr.py | 37 +++++++++++++++++++++++++------------ torch_em/model/vit.py | 6 +++++- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 906e801b..b09de6b9 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -19,6 +19,30 @@ class UNETR(nn.Module): + + def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint): + + if backbone == "sam": + # If we have a SAM encoder, then we first try to load the full SAM Model + # (using micro_sam) and otherwise fall back on directly loading the encoder state + # from the checkpoint + try: + _, model = get_sam_model( + model_type=encoder, + checkpoint_path=checkpoint, + return_sam=True + ) + encoder_state = model.image_encoder.state_dict() + except Exception: + # If we have a MAE encoder, then we directly load the encoder state + # from the checkpoint. + encoder_state = torch.load(checkpoint) + + elif backbone == "mae": + encoder_state = torch.load(checkpoint) + + self.encoder.load_state_dict(encoder_state) + def __init__( self, backbone="sam", @@ -36,20 +60,9 @@ def __init__( self.use_mae_stats = use_mae_stats print(f"Using {encoder} from {backbone.upper()}") - self.encoder = get_vision_transformer(backbone=backbone, model=encoder) - if encoder_checkpoint_path is not None: - if backbone == "sam": - _, model = get_sam_model( - model_type=encoder, - checkpoint_path=encoder_checkpoint_path, - return_sam=True - ) - for param1, param2 in zip(model.parameters(), self.encoder.parameters()): - param2.data = param1 - elif backbone == "mae": - raise NotImplementedError + self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint_path) # parameters for the decoder network depth = 3 diff --git a/torch_em/model/vit.py b/torch_em/model/vit.py index 0ae3fc23..2f5a9bbc 100644 --- a/torch_em/model/vit.py +++ b/torch_em/model/vit.py @@ -39,7 +39,11 @@ def __init__( "and then rerun your code." ) - super().__init__(embed_dim=embed_dim, **kwargs) + super().__init__( + embed_dim=embed_dim, + global_attn_indexes=global_attn_indexes, + **kwargs, + ) self.chunks_for_projection = global_attn_indexes self.in_chans = in_chans self.embed_dim = embed_dim