From 79bd2fad6d763a502e62e7621e9331540da3993c Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Wed, 19 Jun 2024 15:43:14 +0200 Subject: [PATCH 01/12] include lora in livecell training logic --- .../evaluation/submit_all_evaluation.py | 31 ++++++++++++++---- finetuning/livecell/lora/train_livecell.py | 18 ++++++++--- finetuning/run_all_finetuning.py | 32 ++++++++++--------- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index b64549de..a557d49e 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -6,6 +6,7 @@ from pathlib import Path from datetime import datetime +ROOT = "/scratch/usr/nimcarot/sam/experiments/lora" ALL_SCRIPTS = [ "precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation" @@ -14,7 +15,8 @@ def write_batch_script( env_name, out_path, inference_setup, checkpoint, model_type, - experiment_folder, dataset_name, delay=None, use_masks=False + experiment_folder, dataset_name, delay=None, use_masks=False, + use_lora=False, lora_rank=None ): "Writing scripts with different fold-trainings for micro-sam evaluation" batch_script = f"""#!/bin/bash @@ -23,7 +25,7 @@ def write_batch_script( #SBATCH -t 4-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -A gzz0001 +#SBATCH -A nim00007 #SBATCH --constraint=80gb #SBATCH --qos=96h #SBATCH --job-name={inference_setup} @@ -56,6 +58,9 @@ def write_batch_script( if inference_setup == "iterative_prompting" and use_masks: python_script += "--use_masks " + # use lora if requested + if use_lora: + python_script += f"--use_lora --lora_rank {lora_rank}" # let's add the python script to the bash script batch_script += python_script @@ -84,7 +89,7 @@ def get_batch_script_names(tmp_folder): return batch_script -def get_checkpoint_path(experiment_set, dataset_name, model_type, region): +def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=False, rank=None): # let's set the experiment type - either using the generalist or just using vanilla model if experiment_set == "generalist": checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/" @@ -112,7 +117,11 @@ def get_checkpoint_path(experiment_set, dataset_name, model_type, region): if dataset_name.startswith("tissuenet"): dataset_name = "tissuenet" - checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/{dataset_name}_sam/best.pt" + if lora: + assert rank is not None, "Provide the rank for LoRA finetuning." + checkpoint_name = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_lora_rank_{rank}/best.pt" + else: + checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_sam/best.pt" elif experiment_set == "vanilla": checkpoint = None @@ -144,8 +153,9 @@ def submit_slurm(args): if args.experiment_path is None: modality = region if region == "lm" else "em" - experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/v3/" - experiment_folder += f"{experiment_set}/{modality}/{dataset_name}/{model_type}/" + # get the correct naming if lora finetuning was used + experiment_name = f"{experiment_set}_lora_{args.lora_rank}" if args.use_lora else experiment_set + experiment_folder = f"{ROOT}/{experiment_name}/{modality}/{dataset_name}/{model_type}/" else: experiment_folder = args.experiment_path @@ -176,6 +186,8 @@ def submit_slurm(args): dataset_name=dataset_name, delay=None if current_setup == "precompute_embeddings" else make_delay, use_masks=args.use_masks + use_lora=args.use_lora, + lora_rank=args.lora_rank ) # the logic below automates the process of first running the precomputation of embeddings, and only then inference. @@ -220,5 +232,12 @@ def main(args): # ask for a specific experiment parser.add_argument("-s", "--specific_experiment", type=str, default=None) + # LoRA specific arguments + parser.add_argument("--lora_rank", type=int, default=4) + parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.") + args = parser.parse_args() main(args) + + +# python ~/micro-sam/finetuning/evaluation/submit_all_evaluation.py -d livecell -m vit_b -e specialist -r lm \ No newline at end of file diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index fa887437..65571109 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -74,15 +74,14 @@ def finetune_livecell(args): patch_shape = (520, 704) # the patch shape for training n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - rank = 4 # the rank - + rank = args.lora_rank # the rank # get the trainable segment anything model model = sam_training.get_trainable_sam_model( model_type=model_type, device=device, checkpoint_path=checkpoint_path, freeze=freeze_parts, - use_lora=True, + use_lora=args.use_lora, rank=rank, ) model.to(device) @@ -116,9 +115,14 @@ def finetune_livecell(args): # this class creates all the training data for a batch (inputs, prompts and labels) convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + name = ( + f"{args.model_type}/livecell_" + f"{f'lora_rank_{args.lora_rank}' if args.use_lora else 'sam'}" + ) + trainer = sam_training.JointSamTrainer( - name="livecell_lora", + name=name, save_root=args.save_root, train_loader=train_loader, val_loader=val_loader, @@ -176,6 +180,12 @@ def main(): "--freeze", type=str, nargs="+", default=None, help="Which parts of the model to freeze for finetuning." ) + parser.add_argument( + "--use_lora", action="store_true", help="Whether to use LoRA for finetuning." + ) + parser.add_argument( + "--lora_rank", type=int, default=4, help="Pass the rank for LoRA." + ) args = parser.parse_args() finetune_livecell(args) diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index 7562e374..c95e97b1 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -3,19 +3,14 @@ import subprocess from datetime import datetime +ROOT = "~/micro-sam" -N_OBJECTS = { - "vit_t": 50, - "vit_b": 40, - "vit_l": 30, - "vit_h": 25 -} +MODELS = ["vit_t", "vit_b", "vit_l", "vit_h"] - -def write_batch_script(out_path, _name, env_name, model_type, save_root): +def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lora=False, lora_rank=4): "Writing scripts with different micro-sam finetunings." batch_script = f"""#!/bin/bash -#SBATCH -t 14-00:00:00 +#SBATCH -t 4-00:00:00 #SBATCH --mem 64G #SBATCH --nodes=1 #SBATCH --ntasks=1 @@ -23,7 +18,7 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root): #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH -c 16 -#SBATCH --qos=14d +#SBATCH --qos=96h #SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} @@ -38,8 +33,8 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root): # name of the model configuration python_script += f"-m {model_type} " - # choice of the number of objects - python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " + if use_lora: + python_script += f"--use_lora --lora_rank {lora_rank} " # let's add the python script to the bash script batch_script += python_script @@ -70,7 +65,7 @@ def submit_slurm(args): tmp_folder = "./gpu_jobs" script_combinations = { - "livecell_specialist": "livecell_finetuning", + "livecell_specialist": f"{ROOT}/finetuning/livecell/lora/train_livecell", "deepbacs_specialist": "specialists/training/light_microscopy/deepbacs_finetuning", "tissuenet_specialist": "specialists/training/light_microscopy/tissuenet_finetuning", "plantseg_root_specialist": "specialists/training/light_microscopy/plantseg_root_finetuning", @@ -90,7 +85,7 @@ def submit_slurm(args): experiments = [args.experiment_name] if args.model_type is None: - models = list(N_OBJECTS.keys()) + models = MODELS else: models = [args.model_type] @@ -103,7 +98,9 @@ def submit_slurm(args): _name=script_name, env_name="mobilesam" if model_type == "vit_t" else "sam", model_type=model_type, - save_root=args.save_root + save_root=args.save_root, + use_lora=args.use_lora, + lora_rank=args.lora_rank ) @@ -122,5 +119,10 @@ def main(args): parser.add_argument("-e", "--experiment_name", type=str, default=None) parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/") parser.add_argument("-m", "--model_type", type=str, default=None) + parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.") + parser.add_argument("--lora_rank", type=int, default=4, help="Pass the rank for LoRA") args = parser.parse_args() main(args) + + +# python ~/micro-sam/finetuning/run_all_finetuning.py -e livecell_specialist -m vit_b -s /scratch/usr/nimcarot/lora/ --use_lora --lora_rank 4 \ No newline at end of file From 111a9940147eb58066f2d3d59723c8afdedf97d5 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Wed, 19 Jun 2024 16:41:47 +0200 Subject: [PATCH 02/12] implemented training scripts for mouse embryo and covid if (non-stable) --- finetuning/run_all_finetuning.py | 2 + finetuning/specialists/lora/train_covid_if.py | 216 ++++++++++++++++++ .../specialists/lora/train_mouse_embryo.py | 210 +++++++++++++++++ 3 files changed, 428 insertions(+) create mode 100644 finetuning/specialists/lora/train_covid_if.py create mode 100644 finetuning/specialists/lora/train_mouse_embryo.py diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index c95e97b1..2038453f 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -72,6 +72,8 @@ def submit_slurm(args): "neurips_cellseg_specialist": "specialists/training/light_microscopy/neurips_cellseg_finetuning", "dynamicnuclearnet_specialist": "specialists/training/light_microscopy/dynamicnuclearnet_finetuning", "lm_generalist": "generalists/training/light_microscopy/train_lm_generalist", + "covid_if_generalist": f"{ROOT}/finetuning/specialists/lora/train_covid_if", + "mouse_embryo_generalist": f"{ROOT}/finetuning/specialists/lora/train_mouse_embryo", "cremi_specialist": "specialists/training/electron_microscopy/boundaries/cremi_finetuning", "asem_specialist": "specialists/training/electron_microscopy/organelles/asem_finetuning", "em_mito_nuc_generalist": "generalists/training/electron_microscopy/mito_nuc/train_mito_nuc_em_generalist", diff --git a/finetuning/specialists/lora/train_covid_if.py b/finetuning/specialists/lora/train_covid_if.py new file mode 100644 index 00000000..d51cc571 --- /dev/null +++ b/finetuning/specialists/lora/train_covid_if.py @@ -0,0 +1,216 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_covid_if_loader +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data import MinInstanceSampler + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + num_workers = 8 if torch.cuda.is_available() else 0 + + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + sampler = MinInstanceSampler() + + train_volumes = (None, 10) + val_volumes = (10, 13) + + # let's estimate the total number of patches + train_loader = get_covid_if_loader( + path=data_path, patch_shape=patch_shape, batch_size=1, target="cells", + download=True, sampler=sampler, sample_range=train_volumes + ) + + print( + f"Found {len(train_loader)} samples for training.", + "Hence, we will use {0} samples for training.".format(50 if len(train_loader) < 50 else len(train_loader)) + ) + + # now, let's get the training and validation dataloaders + train_loader = get_covid_if_loader( + path=data_path, patch_shape=patch_shape, batch_size=1, target="cells", num_workers=num_workers, shuffle=True, + raw_transform=raw_transform, sampler=sampler, label_transform=label_transform, label_dtype=torch.float32, + sample_range=train_volumes, n_samples=50 if len(train_loader) < 50 else None, + ) + + val_loader = get_covid_if_loader( + path=data_path, patch_shape=patch_shape, batch_size=1, target="cells", download=True, num_workers=num_workers, + raw_transform=raw_transform, sampler=sampler, label_transform=label_transform, label_dtype=torch.float32, + sample_range=val_volumes, n_samples=5, + ) + + return train_loader, val_loader + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + params = params / 1e6 + return f"The number of trainable parameters for the provided model is {round(params, 2)}M" + + +def finetune_livecell(args): + """Code for finetuning SAM (using LoRA) on LIVECell + + Initial observations: There's no real memory advantage actually unless it's "truly" scaled up + # vit_b + # SAM: 93M (takes ~50GB) + # SAM-LoRA: 4.2M (takes ~49GB) + + # vit_l + # SAM: 312M (takes ~63GB) + # SAM-LoRA: 4.4M (takes ~61GB) + + # vit_h + # SAM: 641M (takes ~73GB) + # SAM-LoRA: 4.7M (takes ~67GB) + + # Q: Would quantization lead to better results? (eg. QLoRA / DoRA) + """ + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint # the patch shape for training + n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled + patch_shape = (512,512) + freeze_parts = args.freeze # override this to freeze different parts of the model + rank = args.lora_rank # the rank + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + use_lora=args.use_lora, + rank=rank, + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's check the total number of trainable parameters + print(count_parameters(model)) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = model.parameters() + + joint_model_params = [params for params in joint_model_params] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + name = ( + f"{args.model_type}/covid_if_" + f"{f'lora_rank_{args.lora_rank}' if args.use_lora else 'sam'}" + ) + + + trainer = sam_training.JointSamTrainer( + name=name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True), + early_stopping=10 + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/covid_if/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", default=None, + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e4), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--use_lora", action="store_true", help="Whether to use LoRA for finetuning." + ) + parser.add_argument( + "--lora_rank", type=int, default=4, help="Pass the rank for LoRA." + ) + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() diff --git a/finetuning/specialists/lora/train_mouse_embryo.py b/finetuning/specialists/lora/train_mouse_embryo.py new file mode 100644 index 00000000..d1d61d25 --- /dev/null +++ b/finetuning/specialists/lora/train_mouse_embryo.py @@ -0,0 +1,210 @@ +import os +import argparse + +import torch +import numpy as np + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_mouse_embryo_loader +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data import MinInstanceSampler +from micro_sam.training.util import ResizeLabelTrafo, ResizeRawTrafo + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path): + # 3. Mouse Embryo + # the logic used here is: I use the first 100 slices per volume from the training split for training + # and the next ~20/30 slices per volume from the training split for validation + # and we use the whole volume from the val set for testing + train_rois = [np.s_[0:100, :, :], np.s_[0:100, :, :], np.s_[0:100, :, :], np.s_[0:100, :, :]] + val_rois = [np.s_[100:, :, :], np.s_[100:, :, :], np.s_[100:, :, :], np.s_[100:, :, :]] + + raw_transform = ResizeRawTrafo((1,512,512)) + label_transform = ResizeLabelTrafo((512,512)) + + train_loader = get_mouse_embryo_loader( + path=data_path, + name="membrane", + split="train", + patch_shape=(1, 512, 512), + batch_size=2, + download=True, + num_workers=16, + shuffle=True, + sampler=MinInstanceSampler(min_num_instances=3), + rois=train_rois, + raw_transform=raw_transform, + label_transform=label_transform + ) + val_loader = get_mouse_embryo_loader( + path=data_path, + name="membrane", + split="train", + patch_shape=(1, 512, 512), + batch_size=1, + download=True, + num_workers=16, + sampler=MinInstanceSampler(min_num_instances=3), + rois=val_rois, + raw_transform=raw_transform, + label_transform=label_transform + ) + + return train_loader, val_loader + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + params = params / 1e6 + return f"The number of trainable parameters for the provided model is {round(params, 2)}M" + + +def finetune_livecell(args): + """Code for finetuning SAM (using LoRA) on LIVECell + + Initial observations: There's no real memory advantage actually unless it's "truly" scaled up + # vit_b + # SAM: 93M (takes ~50GB) + # SAM-LoRA: 4.2M (takes ~49GB) + + # vit_l + # SAM: 312M (takes ~63GB) + # SAM-LoRA: 4.4M (takes ~61GB) + + # vit_h + # SAM: 641M (takes ~73GB) + # SAM-LoRA: 4.7M (takes ~67GB) + + # Q: Would quantization lead to better results? (eg. QLoRA / DoRA) + """ + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint # the patch shape for training + n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled + patch_shape = (512,512) + freeze_parts = args.freeze # override this to freeze different parts of the model + rank = args.lora_rank # the rank + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + use_lora=args.use_lora, + rank=rank, + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's check the total number of trainable parameters + print(count_parameters(model)) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = model.parameters() + + joint_model_params = [params for params in joint_model_params] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + name = ( + f"{args.model_type}/mouse_embryo_" + f"{f'lora_rank_{args.lora_rank}' if args.use_lora else 'sam'}" + ) + + + trainer = sam_training.JointSamTrainer( + name=name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True), + early_stopping=10 + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/mouse_embryo/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", default=None, + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e4), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--use_lora", action="store_true", help="Whether to use LoRA for finetuning." + ) + parser.add_argument( + "--lora_rank", type=int, default=4, help="Pass the rank for LoRA." + ) + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() From 6b435eb74118e179f2dc71d3e9c610e4d1b97878 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 20 Jun 2024 09:04:23 +0200 Subject: [PATCH 03/12] removed use_lora from evaluation scripts --- finetuning/evaluation/submit_all_evaluation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index a557d49e..0be05f82 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -58,9 +58,6 @@ def write_batch_script( if inference_setup == "iterative_prompting" and use_masks: python_script += "--use_masks " - # use lora if requested - if use_lora: - python_script += f"--use_lora --lora_rank {lora_rank}" # let's add the python script to the bash script batch_script += python_script @@ -185,7 +182,7 @@ def submit_slurm(args): experiment_folder=experiment_folder, dataset_name=dataset_name, delay=None if current_setup == "precompute_embeddings" else make_delay, - use_masks=args.use_masks + use_masks=args.use_masks, use_lora=args.use_lora, lora_rank=args.lora_rank ) From e4db246c5a15e894928006a3d769058be52c9347 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 20 Jun 2024 10:58:53 +0200 Subject: [PATCH 04/12] clean up for pr --- .../evaluation/submit_all_evaluation.py | 5 ++--- finetuning/run_all_finetuning.py | 21 +++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index 0be05f82..ca1cd198 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -6,7 +6,8 @@ from pathlib import Path from datetime import datetime -ROOT = "/scratch/usr/nimcarot/sam/experiments/lora" +# Replace with the path to the experiments folder +ROOT = "/scratch/projects/nim00007/sam/experiments/" ALL_SCRIPTS = [ "precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation" @@ -236,5 +237,3 @@ def main(args): args = parser.parse_args() main(args) - -# python ~/micro-sam/finetuning/evaluation/submit_all_evaluation.py -d livecell -m vit_b -e specialist -r lm \ No newline at end of file diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index c95e97b1..d88d522c 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -3,14 +3,17 @@ import subprocess from datetime import datetime -ROOT = "~/micro-sam" - -MODELS = ["vit_t", "vit_b", "vit_l", "vit_h"] +N_OBJECTS = { + "vit_t": 50, + "vit_b": 40, + "vit_l": 30, + "vit_h": 25 +} def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lora=False, lora_rank=4): "Writing scripts with different micro-sam finetunings." batch_script = f"""#!/bin/bash -#SBATCH -t 4-00:00:00 +#SBATCH -t 14-00:00:00 #SBATCH --mem 64G #SBATCH --nodes=1 #SBATCH --ntasks=1 @@ -18,7 +21,7 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lor #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH -c 16 -#SBATCH --qos=96h +#SBATCH --qos=14h #SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} @@ -35,6 +38,8 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lor if use_lora: python_script += f"--use_lora --lora_rank {lora_rank} " +# choice of the number of objects + python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} " # let's add the python script to the bash script batch_script += python_script @@ -65,7 +70,7 @@ def submit_slurm(args): tmp_folder = "./gpu_jobs" script_combinations = { - "livecell_specialist": f"{ROOT}/finetuning/livecell/lora/train_livecell", + "livecell_specialist": "livecell/lora/train_livecell", "deepbacs_specialist": "specialists/training/light_microscopy/deepbacs_finetuning", "tissuenet_specialist": "specialists/training/light_microscopy/tissuenet_finetuning", "plantseg_root_specialist": "specialists/training/light_microscopy/plantseg_root_finetuning", @@ -85,7 +90,7 @@ def submit_slurm(args): experiments = [args.experiment_name] if args.model_type is None: - models = MODELS + models = list(N_OBJECTS.keys()) else: models = [args.model_type] @@ -124,5 +129,3 @@ def main(args): args = parser.parse_args() main(args) - -# python ~/micro-sam/finetuning/run_all_finetuning.py -e livecell_specialist -m vit_b -s /scratch/usr/nimcarot/lora/ --use_lora --lora_rank 4 \ No newline at end of file From 2eaaf2609b69785de2e954d93169232559fbde7c Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 20 Jun 2024 11:00:45 +0200 Subject: [PATCH 05/12] changed qos specification in batch script --- finetuning/run_all_finetuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index d88d522c..a530a526 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -21,7 +21,7 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lor #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH -c 16 -#SBATCH --qos=14h +#SBATCH --qos=14d #SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} From 6dda34c5c78dd532d6d2e2188216d0af4c0b3405 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 20 Jun 2024 11:02:08 +0200 Subject: [PATCH 06/12] changed user in batch script --- finetuning/run_all_finetuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index a530a526..b269a719 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -19,7 +19,7 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lor #SBATCH --ntasks=1 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -A nim00007 +#SBATCH -A gzz0001 #SBATCH -c 16 #SBATCH --qos=14d #SBATCH --constraint=80gb From 0bcd7df41b8a785af592e4cf4e9c6f39436f35fb Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 20 Jun 2024 15:04:11 +0200 Subject: [PATCH 07/12] removed mistake in submit evaluation --- finetuning/evaluation/submit_all_evaluation.py | 4 ++-- finetuning/run_all_finetuning.py | 11 ++++++----- finetuning/specialists/lora/train_covid_if.py | 1 + 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index ca1cd198..20401f52 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -7,7 +7,7 @@ from datetime import datetime # Replace with the path to the experiments folder -ROOT = "/scratch/projects/nim00007/sam/experiments/" +ROOT = "/scratch/usr/nimcarot/sam/experiments/lora" ALL_SCRIPTS = [ "precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation" @@ -117,7 +117,7 @@ def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=F if lora: assert rank is not None, "Provide the rank for LoRA finetuning." - checkpoint_name = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_lora_rank_{rank}/best.pt" + checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_lora_rank_{rank}/best.pt" else: checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_sam/best.pt" diff --git a/finetuning/run_all_finetuning.py b/finetuning/run_all_finetuning.py index 75563343..b930194e 100644 --- a/finetuning/run_all_finetuning.py +++ b/finetuning/run_all_finetuning.py @@ -3,6 +3,8 @@ import subprocess from datetime import datetime +ROOT = "~/micro-sam/finetuing/" + N_OBJECTS = { "vit_t": 50, "vit_b": 40, @@ -13,20 +15,19 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lora=False, lora_rank=4): "Writing scripts with different micro-sam finetunings." batch_script = f"""#!/bin/bash -#SBATCH -t 14-00:00:00 +#SBATCH -t 4-00:00:00 #SBATCH --mem 64G #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -A gzz0001 +#SBATCH -A nim00007 #SBATCH -c 16 -#SBATCH --qos=14d +#SBATCH --qos=96h #SBATCH --constraint=80gb #SBATCH --job-name={os.path.split(_name)[-1]} source activate {env_name} \n""" - # python script python_script = f"python {_name}.py " @@ -70,7 +71,7 @@ def submit_slurm(args): tmp_folder = "./gpu_jobs" script_combinations = { - "livecell_specialist": "livecell/lora/train_livecell", + "livecell_specialist": f"{ROOT}livecell/lora/train_livecell", "deepbacs_specialist": "specialists/training/light_microscopy/deepbacs_finetuning", "tissuenet_specialist": "specialists/training/light_microscopy/tissuenet_finetuning", "plantseg_root_specialist": "specialists/training/light_microscopy/plantseg_root_finetuning", diff --git a/finetuning/specialists/lora/train_covid_if.py b/finetuning/specialists/lora/train_covid_if.py index d51cc571..fb39aa1c 100644 --- a/finetuning/specialists/lora/train_covid_if.py +++ b/finetuning/specialists/lora/train_covid_if.py @@ -47,6 +47,7 @@ def get_dataloaders(patch_shape, data_path): ) # now, let's get the training and validation dataloaders + train_loader = get_covid_if_loader( path=data_path, patch_shape=patch_shape, batch_size=1, target="cells", num_workers=num_workers, shuffle=True, raw_transform=raw_transform, sampler=sampler, label_transform=label_transform, label_dtype=torch.float32, From df6a130e6b6f9417296176a67c778592e3caf24b Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Fri, 21 Jun 2024 00:58:23 +0200 Subject: [PATCH 08/12] lora implementation in evaluation scripts --- finetuning/evaluation/evaluate_amg.py | 8 +++++--- .../evaluation/evaluate_instance_segmentation.py | 8 +++++--- finetuning/evaluation/iterative_prompting.py | 2 +- finetuning/evaluation/precompute_embeddings.py | 2 +- finetuning/evaluation/submit_all_evaluation.py | 11 ++++++++--- finetuning/evaluation/util.py | 6 ++++-- micro_sam/evaluation/inference.py | 8 ++++++-- micro_sam/instance_segmentation.py | 4 +++- micro_sam/util.py | 1 + 9 files changed, 34 insertions(+), 16 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 69ec63ef..8ddbe941 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -7,7 +7,7 @@ from util import get_pred_paths, get_default_arguments, VANILLA_MODELS -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, use_lora=False, rank=None): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( @@ -16,7 +16,9 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder): experiment_folder, val_image_paths, val_gt_paths, - test_image_paths + test_image_paths, + use_lora=use_lora, + rank=rank ) return prediction_folder @@ -37,7 +39,7 @@ def main(): else: ckpt = args.checkpoint - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder) + prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index 70da7635..45bd8e8e 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -7,7 +7,7 @@ from util import get_pred_paths, get_default_arguments -def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder, use_lora=False, rank=None): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_instance_segmentation_with_decoder( @@ -16,7 +16,9 @@ def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, c experiment_folder, val_image_paths, val_gt_paths, - test_image_paths + test_image_paths, + use_lora=use_lora, + rank=rank ) return prediction_folder @@ -34,7 +36,7 @@ def main(): args = get_default_arguments() prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder + args.dataset, args.model, args.checkpoint, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index 08c0cf3b..05b9abac 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -42,7 +42,7 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_model(model_type=args.model, ckpt=args.checkpoint, use_lora=args.use_lora, rank=args.lora_rank) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 438cba59..c670b839 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -9,7 +9,7 @@ def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_model(model_type=args.model, ckpt=args.checkpoint, use_lora=args.use_lora, rank=args.lora_rank) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index 20401f52..4a35a3fc 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -7,7 +7,7 @@ from datetime import datetime # Replace with the path to the experiments folder -ROOT = "/scratch/usr/nimcarot/sam/experiments/lora" +ROOT = "/scratch/usr/nimcarot/sam/experiments/dummy_directory" ALL_SCRIPTS = [ "precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation" @@ -30,6 +30,7 @@ def write_batch_script( #SBATCH --constraint=80gb #SBATCH --qos=96h #SBATCH --job-name={inference_setup} +#SBATCH -x ggpu139 source ~/.bashrc mamba activate {env_name} \n""" @@ -58,6 +59,10 @@ def write_batch_script( # use logits for iterative prompting if inference_setup == "iterative_prompting" and use_masks: python_script += "--use_masks " + + if use_lora: + python_script += "--use_lora " + python_script += f"--lora_rank {lora_rank} " # let's add the python script to the bash script batch_script += python_script @@ -72,7 +77,7 @@ def write_batch_script( new_path = out_path[:-3] + f"_{inference_setup}_box.sh" with open(new_path, "w") as f: f.write(batch_script) - + print(batch_script) def get_batch_script_names(tmp_folder): tmp_folder = os.path.expanduser(tmp_folder) @@ -145,7 +150,7 @@ def submit_slurm(args): make_delay = "10s" # wait for precomputing the embeddings and later run inference scripts if args.checkpoint_path is None: - checkpoint = get_checkpoint_path(experiment_set, dataset_name, model_type, region) + checkpoint = get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=args.use_lora, rank=args.lora_rank) else: checkpoint = args.checkpoint_path diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 8b1716e8..ed75dc5a 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -80,10 +80,10 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt): +def get_model(model_type, ckpt, use_lora=False, rank=None): if ckpt is None: ckpt = VANILLA_MODELS[model_type] - predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt) + predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt, use_lora=use_lora, rank=rank) return predictor @@ -226,6 +226,8 @@ def get_default_arguments(): parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) + parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.") + parser.add_argument("--lora_rank", type=int, default=4) args = parser.parse_args() return args diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 1905fc77..e8dccdf9 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,11 +547,13 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, + use_lora: bool = False, + rank: Optional[int] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, use_lora=use_lora, rank=rank) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -588,11 +590,13 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], + use_lora: bool = False, + rank: Optional[int] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint) + predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint, use_lora=use_lora, rank=rank) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 23d666b9..148ed639 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -798,6 +798,8 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, + use_lora: bool = False, + rank: Optional[int] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -816,7 +818,7 @@ def get_predictor_and_decoder( device = util.get_device(device) predictor, state = util.get_sam_model( model_type=model_type, checkpoint_path=checkpoint_path, - device=device, return_state=True + device=device, return_state=True, use_lora=use_lora, rank=rank, ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") diff --git a/micro_sam/util.py b/micro_sam/util.py index e61a28f7..30c97f72 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -354,6 +354,7 @@ def get_sam_model( if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers from micro_sam.training.peft_sam import PEFT_Sam + print(use_lora, rank) if rank is None: rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them sam = PEFT_Sam(sam, rank=rank).sam From bd0061179f3d163c9a543e39c98defc3157e501d Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Sat, 22 Jun 2024 06:25:11 +0200 Subject: [PATCH 09/12] changed checkpoint handling in evaluation --- finetuning/evaluation/evaluate_amg.py | 4 ---- .../evaluation/submit_all_evaluation.py | 21 ++++++++++--------- finetuning/evaluation/util.py | 2 -- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 8ddbe941..09e879ce 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -34,10 +34,6 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - if args.checkpoint is None: - ckpt = VANILLA_MODELS[args.model] - else: - ckpt = args.checkpoint prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index 4a35a3fc..65878708 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -95,16 +95,17 @@ def get_batch_script_names(tmp_folder): def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=False, rank=None): # let's set the experiment type - either using the generalist or just using vanilla model if experiment_set == "generalist": - checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/" - - if region == "organelles": - checkpoint += "mito_nuc_em_generalist_sam/best.pt" - elif region == "boundaries": - checkpoint += "boundaries_em_generalist_sam/best.pt" - elif region == "lm": - checkpoint += "lm_generalist_sam/best.pt" - else: - raise ValueError("Choose `region` from lm / organelles / boundaries") + #checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/" + #if region == "organelles": + # checkpoint += "mito_nuc_em_generalist_sam/best.pt" + #elif region == "boundaries": + # checkpoint += "boundaries_em_generalist_sam/best.pt" + #elif region == "lm": + # checkpoint += "lm_generalist_sam/best.pt" + #else: + # raise ValueError("Choose `region` from lm / organelles / boundaries") + + checkpoint = None elif experiment_set == "specialist": _split = dataset_name.split("/") diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index ed75dc5a..543d8117 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -81,8 +81,6 @@ def get_dataset_paths(dataset_name, split_choice): def get_model(model_type, ckpt, use_lora=False, rank=None): - if ckpt is None: - ckpt = VANILLA_MODELS[model_type] predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt, use_lora=use_lora, rank=rank) return predictor From 313c0ed59c7a33aedb3bc436e0322eb1a5af6790 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Sat, 22 Jun 2024 06:59:32 +0200 Subject: [PATCH 10/12] corrected checkpoint argument in evaluate amg --- finetuning/evaluation/evaluate_amg.py | 2 +- finetuning/livecell/lora/train_livecell.py | 2 +- finetuning/specialists/lora/train_covid_if.py | 17 ++++++++++++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 09e879ce..a50f58b1 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -35,7 +35,7 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank) + prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index 65571109..fa379b73 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -109,7 +109,7 @@ def finetune_livecell(args): if not name.startswith("encoder"): joint_model_params.append(params) - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + optimizer = torch.optim.AdamW(joint_model_params, lr=5e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) diff --git a/finetuning/specialists/lora/train_covid_if.py b/finetuning/specialists/lora/train_covid_if.py index fb39aa1c..2c3a5023 100644 --- a/finetuning/specialists/lora/train_covid_if.py +++ b/finetuning/specialists/lora/train_covid_if.py @@ -118,6 +118,21 @@ def finetune_livecell(args): use_skip_connection=False, resize_input=True, ) + + if checkpoint_path is not None: + import pickle + from micro_sam.util import _CustomUnpickler + custom_unpickle = pickle + custom_unpickle.Unpickler = _CustomUnpickler + + decoder_state = torch.load( + checkpoint_path, map_location="cpu", pickle_module=custom_unpickle + )["decoder_state"] + unetr_state_dict = unetr.state_dict() + for k, v in unetr_state_dict.items(): + if not k.startswith("encoder"): + unetr_state_dict[k] = decoder_state[k] + unetr.load_state_dict(unetr_state_dict) unetr.to(device) # let's check the total number of trainable parameters @@ -131,7 +146,7 @@ def finetune_livecell(args): if not name.startswith("encoder"): joint_model_params.append(params) - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + optimizer = torch.optim.AdamW(joint_model_params, lr=5e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) From cbf31ab0f01a45399c741db5708a3ec672bface1 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Sat, 22 Jun 2024 07:01:02 +0200 Subject: [PATCH 11/12] removed decoder initialization from covid_if training --- finetuning/specialists/lora/train_covid_if.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/finetuning/specialists/lora/train_covid_if.py b/finetuning/specialists/lora/train_covid_if.py index 2c3a5023..71470f66 100644 --- a/finetuning/specialists/lora/train_covid_if.py +++ b/finetuning/specialists/lora/train_covid_if.py @@ -119,22 +119,6 @@ def finetune_livecell(args): resize_input=True, ) - if checkpoint_path is not None: - import pickle - from micro_sam.util import _CustomUnpickler - custom_unpickle = pickle - custom_unpickle.Unpickler = _CustomUnpickler - - decoder_state = torch.load( - checkpoint_path, map_location="cpu", pickle_module=custom_unpickle - )["decoder_state"] - unetr_state_dict = unetr.state_dict() - for k, v in unetr_state_dict.items(): - if not k.startswith("encoder"): - unetr_state_dict[k] = decoder_state[k] - unetr.load_state_dict(unetr_state_dict) - unetr.to(device) - # let's check the total number of trainable parameters print(count_parameters(model)) From c2ef267e33206b71fe215fa0368421f4eb6a893c Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Mon, 24 Jun 2024 13:42:03 +0200 Subject: [PATCH 12/12] changed optimizer and naming of functions --- finetuning/specialists/lora/train_covid_if.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/finetuning/specialists/lora/train_covid_if.py b/finetuning/specialists/lora/train_covid_if.py index 71470f66..7eb0e775 100644 --- a/finetuning/specialists/lora/train_covid_if.py +++ b/finetuning/specialists/lora/train_covid_if.py @@ -69,8 +69,8 @@ def count_parameters(model): return f"The number of trainable parameters for the provided model is {round(params, 2)}M" -def finetune_livecell(args): - """Code for finetuning SAM (using LoRA) on LIVECell +def finetune_covid_if(args): + """Code for finetuning SAM (using LoRA) on Covid IF Initial observations: There's no real memory advantage actually unless it's "truly" scaled up # vit_b @@ -130,7 +130,7 @@ def finetune_livecell(args): if not name.startswith("encoder"): joint_model_params.append(params) - optimizer = torch.optim.AdamW(joint_model_params, lr=5e-5) + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) @@ -166,7 +166,7 @@ def finetune_livecell(args): ) trainer.fit(args.iterations) if args.export_path is not None: - checkpoint_path = os.path.join( + checkpoint_path = os.path.join(i "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" ) export_custom_sam_model( @@ -177,10 +177,10 @@ def finetune_livecell(args): def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the CovidIF dataset.") parser.add_argument( "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/covid_if/", - help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + help="The filepath to the CovidIF data. If the data does not exist yet it will be downloaded." ) parser.add_argument( "--model_type", "-m", default="vit_b", @@ -209,7 +209,7 @@ def main(): "--lora_rank", type=int, default=4, help="Pass the rank for LoRA." ) args = parser.parse_args() - finetune_livecell(args) + finetune_covid_if(args) if __name__ == "__main__":