Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fix to trainable sam model functionality #646

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
get_centers_and_bounding_boxes, get_sam_model, get_device,
segmentation_to_one_hot, _DEFAULT_MODEL,
)
from .peft_sam import PEFT_Sam
from .trainable_sam import TrainableSAM

from torch_em.transform.label import PerObjectDistanceTransform
Expand Down Expand Up @@ -87,21 +86,18 @@ def get_trainable_sam_model(
# (for e.g. encoder blocks to "image_encoder")
if freeze is not None:
for name, param in sam.named_parameters():
if isinstance(freeze, list):
# we would want to "freeze" all the components in the model if passed a list of parts
for l_item in freeze:
if name.startswith(f"{l_item}"):
param.requires_grad = False
else:
if not isinstance(freeze, list):
# we "freeze" only for one specific component when passed a "particular" part
if name.startswith(f"{freeze}"):
param.requires_grad = False
freeze = [freeze]

# we would want to "freeze" all the components in the model if passed a list of parts
for l_item in freeze:
# in case LoRA is switched on, we cannot freeze the image encoder
if use_lora and (l_item == "image_encoder"):
raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.")

# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything
if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers
if rank is None:
rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them
sam = PEFT_Sam(sam, rank=rank).sam
if name.startswith(f"{l_item}"):
param.requires_grad = False

# convert to trainable sam
trainable_sam = TrainableSAM(sam)
Expand Down
Loading