diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 3e4f01e3..6ad6ce40 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -13,7 +13,6 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) -from .peft_sam import PEFT_Sam from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -87,21 +86,18 @@ def get_trainable_sam_model( # (for e.g. encoder blocks to "image_encoder") if freeze is not None: for name, param in sam.named_parameters(): - if isinstance(freeze, list): - # we would want to "freeze" all the components in the model if passed a list of parts - for l_item in freeze: - if name.startswith(f"{l_item}"): - param.requires_grad = False - else: + if not isinstance(freeze, list): # we "freeze" only for one specific component when passed a "particular" part - if name.startswith(f"{freeze}"): - param.requires_grad = False + freeze = [freeze] + + # we would want to "freeze" all the components in the model if passed a list of parts + for l_item in freeze: + # in case LoRA is switched on, we cannot freeze the image encoder + if use_lora and (l_item == "image_encoder"): + raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") - # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything - if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers - if rank is None: - rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them - sam = PEFT_Sam(sam, rank=rank).sam + if name.startswith(f"{l_item}"): + param.requires_grad = False # convert to trainable sam trainable_sam = TrainableSAM(sam)