Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues in ViT initialization and update UNETR state loading #180

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@


class UNETR(nn.Module):

def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

if backbone == "sam":
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
try:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=checkpoint,
return_sam=True
)
encoder_state = model.image_encoder.state_dict()
except Exception:
# If we have a MAE encoder, then we directly load the encoder state
# from the checkpoint.
encoder_state = torch.load(checkpoint)

elif backbone == "mae":
encoder_state = torch.load(checkpoint)

self.encoder.load_state_dict(encoder_state)

def __init__(
self,
backbone="sam",
Expand All @@ -36,20 +60,9 @@ def __init__(
self.use_mae_stats = use_mae_stats

print(f"Using {encoder} from {backbone.upper()}")

self.encoder = get_vision_transformer(backbone=backbone, model=encoder)

if encoder_checkpoint_path is not None:
if backbone == "sam":
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=encoder_checkpoint_path,
return_sam=True
)
for param1, param2 in zip(model.parameters(), self.encoder.parameters()):
param2.data = param1
elif backbone == "mae":
raise NotImplementedError
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint_path)

# parameters for the decoder network
depth = 3
Expand Down
6 changes: 5 additions & 1 deletion torch_em/model/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def __init__(
"and then rerun your code."
)

super().__init__(embed_dim=embed_dim, **kwargs)
super().__init__(
embed_dim=embed_dim,
global_attn_indexes=global_attn_indexes,
**kwargs,
)
self.chunks_for_projection = global_attn_indexes
self.in_chans = in_chans
self.embed_dim = embed_dim
Expand Down