Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora implementation in finetuning and evaluation #638

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
12 changes: 5 additions & 7 deletions finetuning/evaluation/evaluate_amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from util import get_pred_paths, get_default_arguments, VANILLA_MODELS


def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, use_lora=False, rank=None):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_amg(
Expand All @@ -16,7 +16,9 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder):
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
use_lora=use_lora,
rank=rank
)
return prediction_folder

Expand All @@ -32,12 +34,8 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder):

def main():
args = get_default_arguments()
if args.checkpoint is None:
ckpt = VANILLA_MODELS[args.model]
else:
ckpt = args.checkpoint

prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder)
prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank)
eval_amg(args.dataset, prediction_folder, args.experiment_folder)


Expand Down
8 changes: 5 additions & 3 deletions finetuning/evaluation/evaluate_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from util import get_pred_paths, get_default_arguments


def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder):
def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder, use_lora=False, rank=None):
val_image_paths, val_gt_paths = get_paths(dataset_name, split="val")
test_image_paths, _ = get_paths(dataset_name, split="test")
prediction_folder = run_instance_segmentation_with_decoder(
Expand All @@ -16,7 +16,9 @@ def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, c
experiment_folder,
val_image_paths,
val_gt_paths,
test_image_paths
test_image_paths,
use_lora=use_lora,
rank=rank
)
return prediction_folder

Expand All @@ -34,7 +36,7 @@ def main():
args = get_default_arguments()

prediction_folder = run_instance_segmentation_with_decoder_inference(
args.dataset, args.model, args.checkpoint, args.experiment_folder
args.dataset, args.model, args.checkpoint, args.experiment_folder, use_lora=args.use_lora, rank=args.lora_rank
)
eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder)

Expand Down
2 changes: 1 addition & 1 deletion finetuning/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point

# get the predictor to perform inference
predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_model(model_type=args.model, ckpt=args.checkpoint, use_lora=args.use_lora, rank=args.lora_rank)

prediction_root = _run_iterative_prompting(
args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks
Expand Down
2 changes: 1 addition & 1 deletion finetuning/evaluation/precompute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def main():
args = get_default_arguments()

predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
predictor = get_model(model_type=args.model, ckpt=args.checkpoint, use_lora=args.use_lora, rank=args.lora_rank)
embedding_dir = os.path.join(args.experiment_folder, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)

Expand Down
59 changes: 40 additions & 19 deletions finetuning/evaluation/submit_all_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from datetime import datetime

# Replace with the path to the experiments folder
ROOT = "/scratch/usr/nimcarot/sam/experiments/dummy_directory"

ALL_SCRIPTS = [
"precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation"
Expand All @@ -14,7 +16,8 @@

def write_batch_script(
env_name, out_path, inference_setup, checkpoint, model_type,
experiment_folder, dataset_name, delay=None, use_masks=False
experiment_folder, dataset_name, delay=None, use_masks=False,
use_lora=False, lora_rank=None
):
"Writing scripts with different fold-trainings for micro-sam evaluation"
batch_script = f"""#!/bin/bash
Expand All @@ -23,10 +26,11 @@ def write_batch_script(
#SBATCH -t 4-00:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A gzz0001
#SBATCH -A nim00007
#SBATCH --constraint=80gb
#SBATCH --qos=96h
#SBATCH --job-name={inference_setup}
#SBATCH -x ggpu139

source ~/.bashrc
mamba activate {env_name} \n"""
Expand Down Expand Up @@ -55,6 +59,10 @@ def write_batch_script(
# use logits for iterative prompting
if inference_setup == "iterative_prompting" and use_masks:
python_script += "--use_masks "

if use_lora:
python_script += "--use_lora "
python_script += f"--lora_rank {lora_rank} "

# let's add the python script to the bash script
batch_script += python_script
Expand All @@ -69,7 +77,7 @@ def write_batch_script(
new_path = out_path[:-3] + f"_{inference_setup}_box.sh"
with open(new_path, "w") as f:
f.write(batch_script)

print(batch_script)

def get_batch_script_names(tmp_folder):
tmp_folder = os.path.expanduser(tmp_folder)
Expand All @@ -84,19 +92,20 @@ def get_batch_script_names(tmp_folder):
return batch_script


def get_checkpoint_path(experiment_set, dataset_name, model_type, region):
def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=False, rank=None):
# let's set the experiment type - either using the generalist or just using vanilla model
if experiment_set == "generalist":
checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/"

if region == "organelles":
checkpoint += "mito_nuc_em_generalist_sam/best.pt"
elif region == "boundaries":
checkpoint += "boundaries_em_generalist_sam/best.pt"
elif region == "lm":
checkpoint += "lm_generalist_sam/best.pt"
else:
raise ValueError("Choose `region` from lm / organelles / boundaries")
#checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/"
#if region == "organelles":
# checkpoint += "mito_nuc_em_generalist_sam/best.pt"
#elif region == "boundaries":
# checkpoint += "boundaries_em_generalist_sam/best.pt"
#elif region == "lm":
# checkpoint += "lm_generalist_sam/best.pt"
#else:
# raise ValueError("Choose `region` from lm / organelles / boundaries")

checkpoint = None

elif experiment_set == "specialist":
_split = dataset_name.split("/")
Expand All @@ -112,7 +121,11 @@ def get_checkpoint_path(experiment_set, dataset_name, model_type, region):
if dataset_name.startswith("tissuenet"):
dataset_name = "tissuenet"

checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/{dataset_name}_sam/best.pt"
if lora:
assert rank is not None, "Provide the rank for LoRA finetuning."
checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_lora_rank_{rank}/best.pt"
else:
checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_sam/best.pt"

elif experiment_set == "vanilla":
checkpoint = None
Expand All @@ -138,14 +151,15 @@ def submit_slurm(args):
make_delay = "10s" # wait for precomputing the embeddings and later run inference scripts

if args.checkpoint_path is None:
checkpoint = get_checkpoint_path(experiment_set, dataset_name, model_type, region)
checkpoint = get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=args.use_lora, rank=args.lora_rank)
else:
checkpoint = args.checkpoint_path

if args.experiment_path is None:
modality = region if region == "lm" else "em"
experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/v3/"
experiment_folder += f"{experiment_set}/{modality}/{dataset_name}/{model_type}/"
# get the correct naming if lora finetuning was used
experiment_name = f"{experiment_set}_lora_{args.lora_rank}" if args.use_lora else experiment_set
experiment_folder = f"{ROOT}/{experiment_name}/{modality}/{dataset_name}/{model_type}/"
else:
experiment_folder = args.experiment_path

Expand Down Expand Up @@ -175,7 +189,9 @@ def submit_slurm(args):
experiment_folder=experiment_folder,
dataset_name=dataset_name,
delay=None if current_setup == "precompute_embeddings" else make_delay,
use_masks=args.use_masks
use_masks=args.use_masks,
use_lora=args.use_lora,
lora_rank=args.lora_rank
)

# the logic below automates the process of first running the precomputation of embeddings, and only then inference.
Expand Down Expand Up @@ -220,5 +236,10 @@ def main(args):
# ask for a specific experiment
parser.add_argument("-s", "--specific_experiment", type=str, default=None)

# LoRA specific arguments
parser.add_argument("--lora_rank", type=int, default=4)
parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.")

args = parser.parse_args()
main(args)

8 changes: 4 additions & 4 deletions finetuning/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def get_dataset_paths(dataset_name, split_choice):
return raw_dir, labels_dir


def get_model(model_type, ckpt):
if ckpt is None:
ckpt = VANILLA_MODELS[model_type]
predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt)
def get_model(model_type, ckpt, use_lora=False, rank=None):
predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt, use_lora=use_lora, rank=rank)
return predictor


Expand Down Expand Up @@ -226,6 +224,8 @@ def get_default_arguments():
parser.add_argument(
"--use_masks", action="store_true", help="To use logits masks for iterative prompting."
)
parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.")
parser.add_argument("--lora_rank", type=int, default=4)
args = parser.parse_args()
return args

Expand Down
20 changes: 15 additions & 5 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,14 @@ def finetune_livecell(args):
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = 4 # the rank

rank = args.lora_rank # the rank
# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=True,
use_lora=args.use_lora,
rank=rank,
)
model.to(device)
Expand Down Expand Up @@ -110,15 +109,20 @@ def finetune_livecell(args):
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
optimizer = torch.optim.AdamW(joint_model_params, lr=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
name = (
f"{args.model_type}/livecell_"
f"{f'lora_rank_{args.lora_rank}' if args.use_lora else 'sam'}"
)


trainer = sam_training.JointSamTrainer(
name="livecell_lora",
name=name,
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
Expand Down Expand Up @@ -176,6 +180,12 @@ def main():
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
parser.add_argument(
"--use_lora", action="store_true", help="Whether to use LoRA for finetuning."
)
parser.add_argument(
"--lora_rank", type=int, default=4, help="Pass the rank for LoRA."
)
args = parser.parse_args()
finetune_livecell(args)

Expand Down
24 changes: 16 additions & 8 deletions finetuning/run_all_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
from datetime import datetime

ROOT = "~/micro-sam/finetuing/"

N_OBJECTS = {
"vit_t": 50,
Expand All @@ -11,24 +12,22 @@
"vit_h": 25
}


def write_batch_script(out_path, _name, env_name, model_type, save_root):
def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lora=False, lora_rank=4):
"Writing scripts with different micro-sam finetunings."
batch_script = f"""#!/bin/bash
#SBATCH -t 14-00:00:00
#SBATCH -t 4-00:00:00
#SBATCH --mem 64G
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH -c 16
#SBATCH --qos=14d
#SBATCH --qos=96h
#SBATCH --constraint=80gb
#SBATCH --job-name={os.path.split(_name)[-1]}

source activate {env_name} \n"""

# python script
python_script = f"python {_name}.py "

Expand All @@ -38,7 +37,9 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root):
# name of the model configuration
python_script += f"-m {model_type} "

# choice of the number of objects
if use_lora:
python_script += f"--use_lora --lora_rank {lora_rank} "
# choice of the number of objects
python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} "

# let's add the python script to the bash script
Expand Down Expand Up @@ -70,13 +71,15 @@ def submit_slurm(args):
tmp_folder = "./gpu_jobs"

script_combinations = {
"livecell_specialist": "livecell_finetuning",
"livecell_specialist": f"{ROOT}livecell/lora/train_livecell",
"deepbacs_specialist": "specialists/training/light_microscopy/deepbacs_finetuning",
"tissuenet_specialist": "specialists/training/light_microscopy/tissuenet_finetuning",
"plantseg_root_specialist": "specialists/training/light_microscopy/plantseg_root_finetuning",
"neurips_cellseg_specialist": "specialists/training/light_microscopy/neurips_cellseg_finetuning",
"dynamicnuclearnet_specialist": "specialists/training/light_microscopy/dynamicnuclearnet_finetuning",
"lm_generalist": "generalists/training/light_microscopy/train_lm_generalist",
"covid_if_generalist": f"{ROOT}/finetuning/specialists/lora/train_covid_if",
"mouse_embryo_generalist": f"{ROOT}/finetuning/specialists/lora/train_mouse_embryo",
"cremi_specialist": "specialists/training/electron_microscopy/boundaries/cremi_finetuning",
"asem_specialist": "specialists/training/electron_microscopy/organelles/asem_finetuning",
"em_mito_nuc_generalist": "generalists/training/electron_microscopy/mito_nuc/train_mito_nuc_em_generalist",
Expand All @@ -103,7 +106,9 @@ def submit_slurm(args):
_name=script_name,
env_name="mobilesam" if model_type == "vit_t" else "sam",
model_type=model_type,
save_root=args.save_root
save_root=args.save_root,
use_lora=args.use_lora,
lora_rank=args.lora_rank
)


Expand All @@ -122,5 +127,8 @@ def main(args):
parser.add_argument("-e", "--experiment_name", type=str, default=None)
parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/")
parser.add_argument("-m", "--model_type", type=str, default=None)
parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.")
parser.add_argument("--lora_rank", type=int, default=4, help="Pass the rank for LoRA")
args = parser.parse_args()
main(args)

Loading
Loading