-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes and Updates for UNETR #148
Conversation
And now the final fix is also in (other ViT models can be switched and used as well) ;) Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! Just a few minor comments.
@@ -1,9 +1,15 @@ | |||
import torch | |||
from torch_em.model.unetr import build_unetr_with_sam_intialization | |||
from torch_em.model.unetr import build_unetr_with_sam_initialization | |||
|
|||
# FIXME this doesn't work yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anwai98 did you check now that this works? If yes you can remove the FIXME.
torch_em/model/unetr.py
Outdated
use_sam_preprocessing=False, | ||
use_sam_preprocessing=True, | ||
initialize_from_sam=False, | ||
checkpoint_path=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we need both initialize_from_sam
and checkpoint_path
? From what I see it would be enough to have checkpoint_path
(see also comment below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I suggest we call it encoder_checkpoint_path
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we need both initialize_from_sam and checkpoint_path? From what I see it would be enough to have checkpoint_path (see also comment below).
Yes, this makes total sense. Fixed it under one argument - encoder_checkpoint_path
torch_em/model/unetr.py
Outdated
else: | ||
self.encoder = encoder | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should throw a ValueError
with a meaningful message, e.g. raise ValueError(f"{encoder} is not supported. Currently only vit_b, vit_l, vit_h are supported.")
.
torch_em/model/unetr.py
Outdated
self.encoder = encoder | ||
raise NotImplementedError | ||
|
||
if initialize_from_sam: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need initialize_from_sam
. Isn't it enough to have if checkpoint_path is not None
here? (And then you don't need to have the assertion.)
torch_em/model/unetr.py
Outdated
) | ||
predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) | ||
_image_encoder = predictor.model.image_encoder | ||
def build_unetr_with_sam_initialization(out_channels=1, model_type="vit_b", checkpoint_path=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get rid of these functions. The current design is simple enough and we don't need to wrap it in a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okie. In that case, I will document it in the directory the correct usage for different ViT models with and without initialization.
The test failure is unrelated. I will merge and fix that in a separate PR. |
I checked up for UNETR, and tried to fix the overalls (one final thing is still missing - the switch between ViT models)
To address some previous questions -
In general, UNETR models can now be called from two obvious functions super easily (
from torch_em.model import build_unetr_with_sam_initialization, build_unetr_without_sam_initialization
) (not 100% of the nomenclature though)