Skip to content

Commit

Permalink
include lora in the evaluation and allow to load models without givin…
Browse files Browse the repository at this point in the history
…g a checkpoint
  • Loading branch information
caroteu committed Jul 15, 2024
1 parent 7c6e1a4 commit ed4d5c9
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 22 deletions.
11 changes: 4 additions & 7 deletions finetuning/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS


def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank=None):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_amg(
Expand All @@ -16,7 +16,8 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
lora_rank=lora_rank,
)
return prediction_folder

Expand All @@ -32,12 +33,8 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):

def main():
args = get_default_arguments()
if args.checkpoint is None:
ckpt = VANILLA_MODELS[args.model]
else:
ckpt = args.checkpoint

prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder)
prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank)
eval_amg(args.dataset, prediction_folder, args.experiment_folder)


Expand Down
7 changes: 4 additions & 3 deletions finetuning/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from util import get_pred_paths, get_default_arguments


def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_instance_segmentation_with_decoder(
Expand All @@ -16,7 +16,8 @@ def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, c
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
lora_rank=lora_rank,
)
return prediction_folder

Expand All @@ -34,7 +35,7 @@ def main():
args = get_default_arguments()

prediction_folder = run_instance_segmentation_with_decoder_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder
args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank
)
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)

Expand Down
3 changes: 2 additions & 1 deletion finetuning/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data
from util import get_model, get_default_arguments
from micro_sam.util import get_sam_model


def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks):
Expand Down Expand Up @@ -42,7 +43,7 @@ def main():
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point

# get the predictor to perform inference
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank)

prediction_root = _run_iterative_prompting(
args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks
Expand Down
3 changes: 2 additions & 1 deletion finetuning/evaluation/precompute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from util import get_paths # comment this and create a custom function with the same name to execute on your data
from util import get_model, get_default_arguments
from micro_sam.util import get_sam_model


def main():
args = get_default_arguments()

predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank)
embedding_dir = os.path.join(args.experiment_folder, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)

Expand Down
5 changes: 4 additions & 1 deletion finetuning/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,16 @@ def get_default_arguments():
parser.add_argument(
"-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=none_or_str, required=True, default=None)
parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None)
parser.add_argument("-e", "--experiment_folder", type=str, required=True)
parser.add_argument("-d", "--dataset", type=str, default=None)
parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box")
parser.add_argument(
"--use_masks", action="store_true", help="To use logits masks for iterative prompting."
)
parser.add_argument(
"--lora_rank", type=int, default=None, help="The rank of the LoRA if LoRA model is used for inference."
)
args = parser.parse_args()
return args

Expand Down
18 changes: 14 additions & 4 deletions finetuning/specialists/resource-efficient/covid_if_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def finetune_covid_if(args):
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
checkpoint_name = f"{args.model_type}/covid_if_sam"

checkpoint_name = f"{model_type}/{args.checkpoint_name}"
# all stuff we need for training
train_loader, val_loader = get_dataloaders(
patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images
)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
optimizer_class = torch.optim.AdamW

# Run training
sam_training.train_sam(
Expand All @@ -99,12 +99,13 @@ def finetune_covid_if(args):
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
device=device,
lr=1e-5,
lr=args.lr,
n_epochs=args.epochs,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,

optimizer_class=optimizer_class,
lora_rank=args.lora_rank
)


Expand Down Expand Up @@ -148,6 +149,15 @@ def main():
parser.add_argument(
"--n_images", type=int, default=None, help="The number of images used for finetuning."
)
parser.add_argument(
"--lora_rank", type=int, default=None, help="The rank of the LoRA model."
)
parser.add_argument(
"--lr", type=float, default=5e-5, help="The learning rate for the optimizer. Default is 5e-5."
)
parser.add_argument(
"--checkpoint_name", type=str, default="covid_if_sam",
)
args = parser.parse_args()
finetune_covid_if(args)

Expand Down
7 changes: 5 additions & 2 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,12 @@ def run_amg(
test_image_paths: List[Union[str, os.PathLike]],
iou_thresh_values: Optional[List[float]] = None,
stability_score_values: Optional[List[float]] = None,
lora_rank: Optional[int] = None,
) -> str:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)

predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank)
amg = AutomaticMaskGenerator(predictor)
amg_prefix = "amg"

Expand Down Expand Up @@ -588,11 +589,13 @@ def run_instance_segmentation_with_decoder(
val_image_paths: List[Union[str, os.PathLike]],
val_gt_paths: List[Union[str, os.PathLike]],
test_image_paths: List[Union[str, os.PathLike]],
lora_rank: Optional[int] = None,
) -> str:

embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)

predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint)
predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank)
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
seg_prefix = "instance_segmentation_with_decoder"

Expand Down
3 changes: 2 additions & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def get_predictor_and_decoder(
model_type: str,
checkpoint_path: Union[str, os.PathLike],
device: Optional[Union[str, torch.device]] = None,
lora_rank: Optional[int] = None,
) -> Tuple[SamPredictor, DecoderAdapter]:
"""Load the SAM model (predictor) and instance segmentation decoder.
Expand All @@ -816,7 +817,7 @@ def get_predictor_and_decoder(
device = util.get_device(device)
predictor, state = util.get_sam_model(
model_type=model_type, checkpoint_path=checkpoint_path,
device=device, return_state=True
device=device, return_state=True, lora_rank=lora_rank
)
if "decoder_state" not in state:
raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state")
Expand Down
5 changes: 3 additions & 2 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def train_sam(
save_every_kth_epoch: Optional[int] = None,
pbar_signals: Optional[QObject] = None,
optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
lora_rank: Optional[int] = None,
**model_kwargs,
) -> None:
"""Run training for a SAM model.
Expand Down Expand Up @@ -183,22 +184,22 @@ def train_sam(
If passed None, the chosen default parameters are used in ReduceLROnPlateau.
save_every_kth_epoch: Save checkpoints after every kth epoch separately.
pbar_signals: Controls for napari progress bar.
lora_rank: The rank of the LoRA Training
"""
_check_loader(train_loader, with_segmentation_decoder)
_check_loader(val_loader, with_segmentation_decoder)

device = get_device(device)

# Get the trainable segment anything model.
model, state = get_trainable_sam_model(
model_type=model_type,
device=device,
freeze=freeze,
checkpoint_path=checkpoint_path,
return_state=True,
lora_rank=lora_rank,
**model_kwargs
)

# This class creates all the training data for a batch (inputs, prompts and labels).
convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

Expand Down

0 comments on commit ed4d5c9

Please sign in to comment.