From a7bcc5bae81d11c02c7445f372e86c9a6e8c6bb4 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 5 Feb 2024 17:13:56 +0100 Subject: [PATCH 01/10] Fix deepbacs + add different UNETR generalist evaluation experiments --- .../run_updated_unetr_evaluations.py | 60 +++++++++++++++++++ finetuning/evaluation/preprocess_datasets.py | 6 +- 2 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 finetuning/evaluation/experiments/run_updated_unetr_evaluations.py diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py new file mode 100644 index 00000000..789379f8 --- /dev/null +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -0,0 +1,60 @@ +import os +import re +import subprocess +from glob import glob + + +CMD = "python ../submit_all_evaluation.py " +CHECKPOINT_ROOT = "/scratch/usr/nimanwai/experiments/micro-sam/unetr-decoder-updates/" +EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models/test/unetr-decoder-updates" + + +def run_eval_process(cmd): + proc = subprocess.Popen(cmd) + try: + outs, errs = proc.communicate(timeout=60) + except subprocess.TimeoutExpired: + proc.terminate() + outs, errs = proc.communicate() + + +def run_specific_experiment(dataset_name, model_type, setup): + all_checkpoint_dirs = sorted(glob(os.path.join(CHECKPOINT_ROOT, f"{setup}-*"))) + for checkpoint_dir in all_checkpoint_dirs: + checkpoint_path = os.path.join(checkpoint_dir, "checkpoints", model_type, "lm_generalist_sam", "best.pt") + + experiment_name = checkpoint_dir.split("/")[-1] + experiment_folder = os.path.join(EXPERIMENT_ROOT, experiment_name, dataset_name, model_type) + + cmd = CMD + f"-d {dataset_name} " + f"-m {model_type} " + "-e generalist " + cmd += f"--checkpoint_path {checkpoint_path} " + cmd += f"--experiment_path {experiment_folder}" + print(f"Running the command: {cmd} \n") + _cmd = re.split(r"\s", cmd) + # run_eval_process(_cmd) + + +def run_one_setup(all_dataset_list, all_model_list, setup): + for dataset_name in all_dataset_list: + for model_type in all_model_list: + run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup) + breakpoint() + + +def for_all_lm(setup): + assert setup in ["conv-transpose", "bilinear"] + + # let's run for in-domain + run_one_setup( + all_dataset_list=["tissuenet", "deepbacs", "plantseg_root", "livecell", "neurips-cell-seg"], + all_model_list=["vit_t", "vit_b", "vit_l", "vit_h"], + setup=setup + ) + + +def main(): + for_all_lm("conv-transpose") + + +if __name__ == "__main__": + main() diff --git a/finetuning/evaluation/preprocess_datasets.py b/finetuning/evaluation/preprocess_datasets.py index 6f0fcbce..1bcebaa1 100644 --- a/finetuning/evaluation/preprocess_datasets.py +++ b/finetuning/evaluation/preprocess_datasets.py @@ -752,8 +752,8 @@ def neurips_raw_trafo(raw): def for_deepbacs(save_dir): "Move the datasets from the internal split (provided by default in deepbacs) to our `slices` logic" for split in ["val", "test"]: - image_paths = os.path.join(ROOT, "deepbacs", "mixed", split, "source", "*") - label_paths = os.path.join(ROOT, "deepbacs", "mixed", split, "target", "*") + image_paths = sorted(glob(os.path.join(ROOT, "deepbacs", "mixed", split, "source", "*"))) + label_paths = sorted(glob(os.path.join(ROOT, "deepbacs", "mixed", split, "target", "*"))) os.makedirs(os.path.join(save_dir, split, "raw"), exist_ok=True) os.makedirs(os.path.join(save_dir, split, "labels"), exist_ok=True) @@ -801,7 +801,7 @@ def main(): # let's ensure all the data is downloaded download_all_datasets(ROOT) - # now let's save the slices as tif + # now let's save the slices as tif preprocess_lm_datasets() preprocess_em_datasets() From 061c2469f67a65403de0f1fb041e92bcdcea7dda Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 5 Feb 2024 19:20:11 +0100 Subject: [PATCH 02/10] Minor fixes --- .../experiments/run_updated_unetr_evaluations.py | 8 ++++---- finetuning/evaluation/run_all_evaluations.py | 4 ++-- finetuning/evaluation/submit_all_evaluation.py | 9 ++++++++- finetuning/evaluation/util.py | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index 789379f8..3f650d41 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -4,7 +4,7 @@ from glob import glob -CMD = "python ../submit_all_evaluation.py " +CMD = "python submit_all_evaluation.py " CHECKPOINT_ROOT = "/scratch/usr/nimanwai/experiments/micro-sam/unetr-decoder-updates/" EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models/test/unetr-decoder-updates" @@ -31,14 +31,13 @@ def run_specific_experiment(dataset_name, model_type, setup): cmd += f"--experiment_path {experiment_folder}" print(f"Running the command: {cmd} \n") _cmd = re.split(r"\s", cmd) - # run_eval_process(_cmd) + run_eval_process(_cmd) def run_one_setup(all_dataset_list, all_model_list, setup): for dataset_name in all_dataset_list: for model_type in all_model_list: run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup) - breakpoint() def for_all_lm(setup): @@ -46,13 +45,14 @@ def for_all_lm(setup): # let's run for in-domain run_one_setup( - all_dataset_list=["tissuenet", "deepbacs", "plantseg_root", "livecell", "neurips-cell-seg"], + all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"], all_model_list=["vit_t", "vit_b", "vit_l", "vit_h"], setup=setup ) def main(): + os.chdir("../") for_all_lm("conv-transpose") diff --git a/finetuning/evaluation/run_all_evaluations.py b/finetuning/evaluation/run_all_evaluations.py index 4fa5e1c1..184a7f14 100644 --- a/finetuning/evaluation/run_all_evaluations.py +++ b/finetuning/evaluation/run_all_evaluations.py @@ -31,7 +31,7 @@ def run_one_setup(all_dataset_list, all_model_list, all_experiment_set_list, roi def for_all_lm(): # let's run for in-domain run_one_setup( - all_dataset_list=["tissuenet", "deepbacs", "plantseg_root", "livecell"], + all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell"], all_model_list=["vit_b", "vit_h"], all_experiment_set_list=["vanilla", "generalist", "specialist"], roi="lm" @@ -39,7 +39,7 @@ def for_all_lm(): # next, let's run for out-of-domain run_one_setup( - all_dataset_list=["covid_if", "plantseg_ovules", "hpa", "lizard", "mouse-embryo", "ctc", "neurips-cell-seg"], + all_dataset_list=["covid_if", "plantseg/ovules", "hpa", "lizard", "mouse-embryo", "ctc", "neurips-cell-seg"], all_model_list=["vit_b", "vit_h"], all_experiment_set_list=["vanilla", "generalist"], roi="lm" diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index d10a53e5..4b788ed6 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -130,9 +130,16 @@ def submit_slurm(args): all_setups = ["precompute_embeddings", "evaluate_amg", "iterative_prompting"] else: all_setups = ["precompute_embeddings", "evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"] + + # env name + if model_type == "vit_t": + env_name = "mobilesam" + else: + env_name = "sam" + for current_setup in all_setups: write_batch_script( - env_name="sam", + env_name=env_name, out_path=get_batch_script_names(tmp_folder), inference_setup=current_setup, checkpoint=checkpoint, diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index a29d13f5..52ba7bc7 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -86,7 +86,7 @@ def get_model(model_type, ckpt): def get_paths(dataset_name, split): - assert dataset_name in DATASETS + assert dataset_name in DATASETS, dataset_name if dataset_name == "livecell": return _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split=split) From d9add24fb7c0c22140314fb59579c63dda6317d3 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 00:07:57 +0100 Subject: [PATCH 03/10] Fix livecell paths --- .../evaluation/experiments/run_updated_unetr_evaluations.py | 1 + finetuning/evaluation/util.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index 3f650d41..c46d304b 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -38,6 +38,7 @@ def run_one_setup(all_dataset_list, all_model_list, setup): for dataset_name in all_dataset_list: for model_type in all_model_list: run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup) + breakpoint() def for_all_lm(setup): diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 52ba7bc7..3243e87c 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -89,7 +89,8 @@ def get_paths(dataset_name, split): assert dataset_name in DATASETS, dataset_name if dataset_name == "livecell": - return _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split=split) + image_paths, gt_paths = _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split=split) + return sorted(image_paths), sorted(gt_paths) image_dir, gt_dir = get_dataset_paths(dataset_name, split) image_paths = sorted(glob(os.path.join(image_dir))) From fd394cf565e6ee3cd81dbd39d578059bf56af594 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 10:55:52 +0100 Subject: [PATCH 04/10] Add plotting scripts --- finetuning/evaluation/.gitignore | 2 +- .../run_updated_unetr_evaluations.py | 115 +++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/finetuning/evaluation/.gitignore b/finetuning/evaluation/.gitignore index 1046cc82..e33609d2 100644 --- a/finetuning/evaluation/.gitignore +++ b/finetuning/evaluation/.gitignore @@ -1 +1 @@ -figures/* +*.png diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index c46d304b..d1f364fd 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -2,6 +2,11 @@ import re import subprocess from glob import glob +from pathlib import Path + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt CMD = "python submit_all_evaluation.py " @@ -52,9 +57,115 @@ def for_all_lm(setup): ) -def main(): +def _run_evaluations(): os.chdir("../") - for_all_lm("conv-transpose") + # for_all_lm("conv-transpose") + for_all_lm("bilinear") + + +def _get_plots(dataset_name, model_type): + experiment_dirs = sorted(glob(os.path.join(EXPERIMENT_ROOT, "*"))) + + # adding a fixed color palette to each experiments, for consistency in plotting the legends + palette = {"amg": "C0", "ais": "C1", "box": "C2", "i_b": "C3", "point": "C4", "i_p": "C5"} + + fig, ax = plt.subplots(1, len(experiment_dirs), figsize=(20, 10), sharex="col", sharey="row") + + for idx, _experiment_dir in enumerate(experiment_dirs): + all_result_paths = sorted(glob(os.path.join(_experiment_dir, dataset_name, model_type, "results", "*"))) + res_list_per_experiment = [] + for i, result_path in enumerate(all_result_paths): + # avoid using the grid-search parameters' files + if os.path.split(result_path)[-1].startswith("grid_search_"): + continue + + res = pd.read_csv(result_path) + setting_name = Path(result_path).stem + if setting_name == "amg" or setting_name.startswith("instance"): # saving results from amg or ais + res_df = pd.DataFrame( + { + "name": model_type, + "type": Path(result_path).stem if len(setting_name) == 3 else "ais", + "results": res.iloc[0]["msa"] + }, index=[i] + ) + else: # saving results from iterative prompting + prompt_name = Path(result_path).stem.split("_")[-1] + res_df = pd.concat( + [ + pd.DataFrame( + { + "name": model_type, + "type": prompt_name, + "results": res.iloc[0]["msa"] + }, index=[i] + ), + pd.DataFrame( + { + "name": model_type, + "type": f"i_{prompt_name[0]}", + "results": res.iloc[-1]["msa"] + }, index=[i] + ) + ] + ) + res_list_per_experiment.append(res_df) + + res_df_per_experiment = pd.concat(res_list_per_experiment, ignore_index=True) + + container = sns.barplot( + x="name", y="results", hue="type", data=res_df_per_experiment, ax=ax[idx], palette=palette + ) + ax[idx].set(xlabel="Experiments", ylabel="Segmentation Quality") + ax[idx].legend(title="Settings", bbox_to_anchor=(1, 1)) + + # adding the numbers over the barplots + for j in container.containers: + container.bar_label(j, fmt='%.2f') + + # titles for each subplot + ax[idx].title.set_text(_experiment_dir.split("/")[-1]) + + # here, we remove the legends for each subplot, and get one common legend for all + all_lines, all_labels = [], [] + for ax in fig.axes: + lines, labels = ax.get_legend_handles_labels() + for line, label in zip(lines, labels): + if label not in all_labels: + all_lines.append(line) + all_labels.append(label) + ax.get_legend().remove() + + fig.legend(all_lines, all_labels) + plt.show() + plt.tight_layout() + plt.subplots_adjust(top=0.90, right=0.95) + fig.suptitle(dataset_name, fontsize=20) + + save_path = f"figures/{dataset_name}/{model_type}.png" + + try: + plt.savefig(save_path) + except FileNotFoundError: + os.makedirs(os.path.split(save_path)[0], exist_ok=True) + plt.savefig(save_path) + + plt.close() + print(f"Plot saved at {save_path}") + + +def _get_all_plots(): + all_datasets = ["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"] + all_models = ["vit_t", "vit_b", "vit_l", "vit_h"] + + for dataset_name in all_datasets: + for model_type in all_models: + _get_plots(dataset_name, model_type) + + +def main(): + # _run_evaluations() + _get_all_plots() if __name__ == "__main__": From f2baf5b5cd6f9b292d708479683fa7725bf9d2ba Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 11:01:17 +0100 Subject: [PATCH 05/10] Update slurm constraint --- finetuning/evaluation/submit_all_evaluation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index 4b788ed6..25cb18b6 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -17,6 +17,7 @@ def write_batch_script( #SBATCH -p grete:shared #SBATCH -G A100:1 #SBATCH -A gzz0001 +#SBATCH --constraint=80gb #SBATCH --job-name={inference_setup} source ~/.bashrc From dc9e5125a09c8382fc72c8aa99591d2a0695fda1 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 11:02:15 +0100 Subject: [PATCH 06/10] Refactor scripts --- .../run_updated_unetr_evaluations.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index d1f364fd..c25a4f0f 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -39,30 +39,6 @@ def run_specific_experiment(dataset_name, model_type, setup): run_eval_process(_cmd) -def run_one_setup(all_dataset_list, all_model_list, setup): - for dataset_name in all_dataset_list: - for model_type in all_model_list: - run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup) - breakpoint() - - -def for_all_lm(setup): - assert setup in ["conv-transpose", "bilinear"] - - # let's run for in-domain - run_one_setup( - all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"], - all_model_list=["vit_t", "vit_b", "vit_l", "vit_h"], - setup=setup - ) - - -def _run_evaluations(): - os.chdir("../") - # for_all_lm("conv-transpose") - for_all_lm("bilinear") - - def _get_plots(dataset_name, model_type): experiment_dirs = sorted(glob(os.path.join(EXPERIMENT_ROOT, "*"))) @@ -154,6 +130,30 @@ def _get_plots(dataset_name, model_type): print(f"Plot saved at {save_path}") +def run_one_setup(all_dataset_list, all_model_list, setup): + for dataset_name in all_dataset_list: + for model_type in all_model_list: + run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup) + breakpoint() + + +def for_all_lm(setup): + assert setup in ["conv-transpose", "bilinear"] + + # let's run for in-domain + run_one_setup( + all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"], + all_model_list=["vit_t", "vit_b", "vit_l", "vit_h"], + setup=setup + ) + + +def _run_evaluations(): + os.chdir("../") + # for_all_lm("conv-transpose") + for_all_lm("bilinear") + + def _get_all_plots(): all_datasets = ["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"] all_models = ["vit_t", "vit_b", "vit_l", "vit_h"] From 32de6a6fb030b15abc1cb2ab9f073e1430d6833e Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 11:10:00 +0100 Subject: [PATCH 07/10] Ignore amg figures --- .../evaluation/experiments/run_updated_unetr_evaluations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index c25a4f0f..947b7003 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -52,7 +52,8 @@ def _get_plots(dataset_name, model_type): res_list_per_experiment = [] for i, result_path in enumerate(all_result_paths): # avoid using the grid-search parameters' files - if os.path.split(result_path)[-1].startswith("grid_search_"): + _tmp_check = os.path.split(result_path)[-1] + if _tmp_check.startswith("grid_search_") or _tmp_check.startswith("amg"): continue res = pd.read_csv(result_path) From 170d6dd2850f569c1cb70ec0eedad4c5674d099d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 19:43:42 +0100 Subject: [PATCH 08/10] Add amg --- .../evaluation/experiments/run_updated_unetr_evaluations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py index 947b7003..4414aa05 100644 --- a/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py +++ b/finetuning/evaluation/experiments/run_updated_unetr_evaluations.py @@ -53,7 +53,7 @@ def _get_plots(dataset_name, model_type): for i, result_path in enumerate(all_result_paths): # avoid using the grid-search parameters' files _tmp_check = os.path.split(result_path)[-1] - if _tmp_check.startswith("grid_search_") or _tmp_check.startswith("amg"): + if _tmp_check.startswith("grid_search_"): continue res = pd.read_csv(result_path) From e038720e3aa5dbf82f41119a5fa0965a649cb865 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 20:11:45 +0100 Subject: [PATCH 09/10] Update lm generalist training --- .../training/light_microscopy/obtain_lm_datasets.py | 5 ++--- .../training/light_microscopy/train_lm_generalist.py | 4 +++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 474a913f..eff23d05 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -44,9 +44,8 @@ def get_concat_lm_datasets(input_path, patch_shape, split_choice): n_samples=1000 if split_choice == "train" else 100 ), datasets.get_livecell_dataset( - path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, - label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, raw_transform=identity, - n_samples=1000 if split_choice == "train" else 100, download=True + path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, download=True, + label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, raw_transform=identity ), datasets.get_deepbacs_dataset( path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape, diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 72f358be..4e43413e 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -41,7 +41,8 @@ def finetune_lm_generalist(args): use_sam_stats=True, final_activation="Sigmoid", use_skip_connection=False, - resize_input=True + resize_input=True, + use_conv_transpose=not args.use_bilinear ) unetr.to(device) @@ -121,6 +122,7 @@ def main(): "--save_every_kth_epoch", type=int, default=None, help="To save every kth epoch while fine-tuning. Expects an integer value." ) + parser.add_argument("--use_bilinear", action="store_true") args = parser.parse_args() finetune_lm_generalist(args) From 9c14e700fbd2b9c833ec8e2756ebdc97d29d20a2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 8 Feb 2024 18:43:56 +0100 Subject: [PATCH 10/10] Update defaults for using conv transpose --- .../training/light_microscopy/train_lm_generalist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py index 4e43413e..c693dd93 100644 --- a/finetuning/generalists/training/light_microscopy/train_lm_generalist.py +++ b/finetuning/generalists/training/light_microscopy/train_lm_generalist.py @@ -42,7 +42,7 @@ def finetune_lm_generalist(args): final_activation="Sigmoid", use_skip_connection=False, resize_input=True, - use_conv_transpose=not args.use_bilinear + use_conv_transpose=True ) unetr.to(device) @@ -122,7 +122,6 @@ def main(): "--save_every_kth_epoch", type=int, default=None, help="To save every kth epoch while fine-tuning. Expects an integer value." ) - parser.add_argument("--use_bilinear", action="store_true") args = parser.parse_args() finetune_lm_generalist(args)