Skip to content

Commit

Permalink
Revert "include lora in livecell training logic"
Browse files Browse the repository at this point in the history
This reverts commit fadf211.
  • Loading branch information
caroteu committed Jun 19, 2024
1 parent fadf211 commit 837b70b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 56 deletions.
31 changes: 6 additions & 25 deletions finetuning/evaluation/submit_all_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
from datetime import datetime

ROOT = "/scratch/usr/nimcarot/sam/experiments/lora"

ALL_SCRIPTS = [
"precompute_embeddings", "evaluate_amg", "iterative_prompting", "evaluate_instance_segmentation"
Expand All @@ -15,8 +14,7 @@

def write_batch_script(
env_name, out_path, inference_setup, checkpoint, model_type,
experiment_folder, dataset_name, delay=None, use_masks=False,
use_lora=False, lora_rank=None
experiment_folder, dataset_name, delay=None, use_masks=False
):
"Writing scripts with different fold-trainings for micro-sam evaluation"
batch_script = f"""#!/bin/bash
Expand All @@ -25,7 +23,7 @@ def write_batch_script(
#SBATCH -t 4-00:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH -A gzz0001
#SBATCH --constraint=80gb
#SBATCH --qos=96h
#SBATCH --job-name={inference_setup}
Expand Down Expand Up @@ -58,9 +56,6 @@ def write_batch_script(
if inference_setup == "iterative_prompting" and use_masks:
python_script += "--use_masks "

# use lora if requested
if use_lora:
python_script += f"--use_lora --lora_rank {lora_rank}"
# let's add the python script to the bash script
batch_script += python_script

Expand Down Expand Up @@ -89,7 +84,7 @@ def get_batch_script_names(tmp_folder):
return batch_script


def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=False, rank=None):
def get_checkpoint_path(experiment_set, dataset_name, model_type, region):
# let's set the experiment type - either using the generalist or just using vanilla model
if experiment_set == "generalist":
checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/"
Expand Down Expand Up @@ -117,11 +112,7 @@ def get_checkpoint_path(experiment_set, dataset_name, model_type, region, lora=F
if dataset_name.startswith("tissuenet"):
dataset_name = "tissuenet"

if lora:
assert rank is not None, "Provide the rank for LoRA finetuning."
checkpoint_name = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_lora_rank_{rank}/best.pt"
else:
checkpoint = f"{ROOT}/checkpoints/{model_type}/{dataset_name}_sam/best.pt"
checkpoint = f"/scratch/usr/nimanwai/micro-sam/checkpoints/{model_type}/{dataset_name}_sam/best.pt"

elif experiment_set == "vanilla":
checkpoint = None
Expand Down Expand Up @@ -153,9 +144,8 @@ def submit_slurm(args):

if args.experiment_path is None:
modality = region if region == "lm" else "em"
# get the correct naming if lora finetuning was used
experiment_name = f"{experiment_set}_lora_{args.lora_rank}" if args.use_lora else experiment_set
experiment_folder = f"{ROOT}/{experiment_name}/{modality}/{dataset_name}/{model_type}/"
experiment_folder = "/scratch/projects/nim00007/sam/experiments/new_models/v3/"
experiment_folder += f"{experiment_set}/{modality}/{dataset_name}/{model_type}/"
else:
experiment_folder = args.experiment_path

Expand Down Expand Up @@ -186,8 +176,6 @@ def submit_slurm(args):
dataset_name=dataset_name,
delay=None if current_setup == "precompute_embeddings" else make_delay,
use_masks=args.use_masks
use_lora=args.use_lora,
lora_rank=args.lora_rank
)

# the logic below automates the process of first running the precomputation of embeddings, and only then inference.
Expand Down Expand Up @@ -232,12 +220,5 @@ def main(args):
# ask for a specific experiment
parser.add_argument("-s", "--specific_experiment", type=str, default=None)

# LoRA specific arguments
parser.add_argument("--lora_rank", type=int, default=4)
parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.")

args = parser.parse_args()
main(args)


# python ~/micro-sam/finetuning/evaluation/submit_all_evaluation.py -d livecell -m vit_b -e specialist -r lm
18 changes: 4 additions & 14 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ def finetune_livecell(args):
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = args.lora_rank # the rank
rank = 4 # the rank

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=args.use_lora,
use_lora=True,
rank=rank,
)
model.to(device)
Expand Down Expand Up @@ -115,14 +116,9 @@ def finetune_livecell(args):

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
name = (
f"{args.model_type}/livecell_"
f"{f'lora_rank_{args.lora_rank}' if args.use_lora else 'sam'}"
)


trainer = sam_training.JointSamTrainer(
name=name,
name="livecell_lora",
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
Expand Down Expand Up @@ -180,12 +176,6 @@ def main():
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
parser.add_argument(
"--use_lora", action="store_true", help="Whether to use LoRA for finetuning."
)
parser.add_argument(
"--lora_rank", type=int, default=4, help="Pass the rank for LoRA."
)
args = parser.parse_args()
finetune_livecell(args)

Expand Down
32 changes: 15 additions & 17 deletions finetuning/run_all_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@
import subprocess
from datetime import datetime

ROOT = "~/micro-sam"

MODELS = ["vit_t", "vit_b", "vit_l", "vit_h"]
N_OBJECTS = {
"vit_t": 50,
"vit_b": 40,
"vit_l": 30,
"vit_h": 25
}

def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lora=False, lora_rank=4):

def write_batch_script(out_path, _name, env_name, model_type, save_root):
"Writing scripts with different micro-sam finetunings."
batch_script = f"""#!/bin/bash
#SBATCH -t 4-00:00:00
#SBATCH -t 14-00:00:00
#SBATCH --mem 64G
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH -c 16
#SBATCH --qos=96h
#SBATCH --qos=14d
#SBATCH --constraint=80gb
#SBATCH --job-name={os.path.split(_name)[-1]}
Expand All @@ -33,8 +38,8 @@ def write_batch_script(out_path, _name, env_name, model_type, save_root, use_lor
# name of the model configuration
python_script += f"-m {model_type} "

if use_lora:
python_script += f"--use_lora --lora_rank {lora_rank} "
# choice of the number of objects
python_script += f"--n_objects {N_OBJECTS[model_type[:5]]} "

# let's add the python script to the bash script
batch_script += python_script
Expand Down Expand Up @@ -65,7 +70,7 @@ def submit_slurm(args):
tmp_folder = "./gpu_jobs"

script_combinations = {
"livecell_specialist": f"{ROOT}/finetuning/livecell/lora/train_livecell",
"livecell_specialist": "livecell_finetuning",
"deepbacs_specialist": "specialists/training/light_microscopy/deepbacs_finetuning",
"tissuenet_specialist": "specialists/training/light_microscopy/tissuenet_finetuning",
"plantseg_root_specialist": "specialists/training/light_microscopy/plantseg_root_finetuning",
Expand All @@ -85,7 +90,7 @@ def submit_slurm(args):
experiments = [args.experiment_name]

if args.model_type is None:
models = MODELS
models = list(N_OBJECTS.keys())
else:
models = [args.model_type]

Expand All @@ -98,9 +103,7 @@ def submit_slurm(args):
_name=script_name,
env_name="mobilesam" if model_type == "vit_t" else "sam",
model_type=model_type,
save_root=args.save_root,
use_lora=args.use_lora,
lora_rank=args.lora_rank
save_root=args.save_root
)


Expand All @@ -119,10 +122,5 @@ def main(args):
parser.add_argument("-e", "--experiment_name", type=str, default=None)
parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/micro-sam/")
parser.add_argument("-m", "--model_type", type=str, default=None)
parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA for finetuning.")
parser.add_argument("--lora_rank", type=int, default=4, help="Pass the rank for LoRA")
args = parser.parse_args()
main(args)


# python ~/micro-sam/finetuning/run_all_finetuning.py -e livecell_specialist -m vit_b -s /scratch/usr/nimcarot/lora/ --use_lora --lora_rank 4

0 comments on commit 837b70b

Please sign in to comment.