From 39d37ea0fcb809b00a4dc106e5cd4127b8dfb95f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 16 Oct 2023 17:34:27 +0200 Subject: [PATCH 1/7] Add monai unetr for livecell --- .../unetr/livecell_unetr.py | 129 +++++++++++------- .../boundary-unetr-monai-scratch-results.csv | 3 + ...oundary-unetr-torch-em-scratch-results.csv | 3 + ...foreground-unetr-monai-scratch-results.csv | 3 + ...eground-unetr-torch-em-scratch-results.csv | 3 + .../unetr/submit_training.py | 26 ++-- 6 files changed, 110 insertions(+), 57 deletions(-) create mode 100644 experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv create mode 100644 experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv create mode 100644 experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv create mode 100644 experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 0aff767f..bb951c56 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -12,25 +12,55 @@ import torch import torch_em -from torch_em.model import UNETR from torch_em.transform.raw import standardize from torch_em.transform.label import labels_to_binary from torch_em.data.datasets import get_livecell_loader from torch_em.util.prediction import predict_with_halo +def get_unetr_model( + model_name: str, + source_choice: str, + patch_shape: Tuple[int, int], + sam_initialization: bool, + output_channels: int +): + """Returns the expected UNETR model + """ + if source_choice == "torch-em": + from torch_em import model as torch_em_models + model = torch_em_models.UNETR( + encoder=model_name, out_channels=output_channels, + encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None + ) + + elif source_choice == "monai": + from monai.networks import nets as monai_models + model = monai_models.unetr.UNETR( + in_channels=1, + out_channels=output_channels, + img_size=patch_shape, + spatial_dims=2 + ) + model.out_channels = 2 # type: ignore + + else: + raise ValueError("The available UNETR models are either from \"torch-em\" or \"monai\", choose from them") + + return model + + def do_unetr_training( input_path: str, - model: UNETR, - model_name: str, + model, cell_types: List[str], patch_shape: Tuple[int, int], device: torch.device, save_root: str, iterations: int, - sam_initialization: bool + sam_initialization: bool, + source_choice: str ): - os.makedirs(input_path, exist_ok=True) train_loader = get_livecell_loader( path=input_path, split="train", @@ -53,14 +83,13 @@ def do_unetr_training( num_workers=8 ) - _name = "livecell-unetr" if cell_types is None else f"livecell-{cell_types}-unetr" - _save_root = os.path.join( - save_root, f"sam-{model_name}" if sam_initialization else "scratch" + save_root, + f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" ) if save_root is not None else save_root trainer = torch_em.default_segmentation_trainer( - name=_name, + name=f"livecell-{cell_types}", model=model, train_loader=train_loader, val_loader=val_loader, @@ -78,24 +107,19 @@ def do_unetr_training( def do_unetr_inference( input_path: str, device: torch.device, - model: UNETR, + model, cell_types: List[str], - save_dir: str, + root_save_dir: str, sam_initialization: bool, - model_name: str, - save_root: str + save_root: str, + source_choice: str ): - _save_dir = os.path.join( - save_dir, - f"unetr-torch-em-sam-{model_name}" if sam_initialization else f"unetr-torch-em-scratch-{model_name}" - ) - for ctype in cell_types: test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*") model_ckpt = os.path.join(save_root, - f"sam-{model_name}" if sam_initialization else "scratch", - "checkpoints", f"livecell-{ctype}-unetr", "best.pt") + f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch", + "checkpoints", f"livecell-{ctype}", "best.pt") assert os.path.exists(model_ckpt) model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"]) @@ -112,8 +136,8 @@ def do_unetr_inference( fg, bd = outputs[0, :, :], outputs[1, :, :] - fg_save_dir = os.path.join(_save_dir, f"src-{ctype}", "foreground") - bd_save_dir = os.path.join(_save_dir, f"src-{ctype}", "boundary") + fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground") + bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary") os.makedirs(fg_save_dir, exist_ok=True) os.makedirs(bd_save_dir, exist_ok=True) @@ -125,14 +149,10 @@ def do_unetr_inference( def do_unetr_evaluation( input_path: str, cell_types: List[str], - save_dir: str, - model_name: str, - sam_initialization: bool + root_save_dir: str, + sam_initialization: bool, + source_choice: str ): - root_save_dir = os.path.join( - save_dir, - f"unetr-torch-em-sam-{model_name}" if sam_initialization else f"unetr-torch-em-scratch-{model_name}" - ) fg_list, bd_list = [], [] for c1 in cell_types: @@ -170,36 +190,46 @@ def do_unetr_evaluation( csv_save_dir = "./results/" os.makedirs(csv_save_dir, exist_ok=True) - tmp_csv_name = f"sam-{model_name}" if sam_initialization else "scratch" - f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-torch-em-unetr-{tmp_csv_name}-results.csv")) - f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-torch-em-unetr-{tmp_csv_name}-results.csv")) + tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" + f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) + f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - n_channels = 2 - model = UNETR( - encoder=args.model_name, out_channels=n_channels, - encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if args.do_sam_ini else None) - model.to(device) + patch_shape = (512, 512) + output_channels = 2 all_cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + model = get_unetr_model( + model_name=args.model_name, + source_choice=args.source_choice, + patch_shape=patch_shape, + sam_initialization=args.do_sam_ini, + output_channels=output_channels + ) + model.to(device) + if args.train: print("2d UNETR training on LiveCell dataset") do_unetr_training( input_path=args.input, model=model, - model_name=args.model_name, cell_types=args.cell_type, - patch_shape=(512, 512), + patch_shape=patch_shape, device=device, save_root=args.save_root, iterations=args.iterations, - sam_initialization=args.do_sam_ini + sam_initialization=args.do_sam_ini, + source_choice=args.source_choice ) + root_save_dir = os.path.join( + args.save_dir, + f"unetr-torch-em-sam-{args.model_name}" if args.do_sam_ini else f"unetr-torch-em-scratch-{args.model_name}" + ) + if args.predict: print("2d UNETR inference on LiveCell dataset") do_unetr_inference( @@ -207,19 +237,19 @@ def main(args): device=device, model=model, cell_types=all_cell_types, - save_dir=args.save_dir, + root_save_dir=root_save_dir, sam_initialization=args.do_sam_ini, - model_name=args.model_name, - save_root=args.save_root + save_root=args.save_root, + source_choice=args.source_choice ) if args.evaluate: print("2d UNETR evaluation on LiveCell dataset") do_unetr_evaluation( input_path=args.input, cell_types=all_cell_types, - save_dir=args.save_dir, - model_name=args.model_name, - sam_initialization=args.do_sam_ini + root_save_dir=root_save_dir, + sam_initialization=args.do_sam_ini, + source_choice=args.source_choice ) @@ -234,6 +264,9 @@ def main(args): parser.add_argument("--evaluate", action='store_true', help="Enables UNETR evaluation on LiveCell dataset") + parser.add_argument("--source_choice", type=str, default="torch-em", + help="The source where the model comes from, i.e. either torch-em / monai") + parser.add_argument("-m", "--model_name", type=str, default="vit_b", help="Name of the ViT to use as the encoder in UNETR") @@ -246,7 +279,7 @@ def main(args): parser.add_argument("-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell", help="Path where the dataset already exists/will be downloaded by the dataloader") - parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/torch-em/", + parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/", help="Path where checkpoints and logs will be saved") parser.add_argument("--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", diff --git a/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv b/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv new file mode 100644 index 00000000..24ed9619 --- /dev/null +++ b/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv @@ -0,0 +1,3 @@ +,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 +0,A172,0.281576376639954,0.22489535508532418,0.3360854423939335,0.10534368422737121,0.34645559398247966,0.41675285531129036,0.3801486359254411,0.23643396618484006 +1,BT474,0.24166370863434725,0.1776787222538735,0.2932050490579082,0.09689262054684385,0.27523996350005353,0.3649604124636659,0.33871591704503684,0.21826910822283302 diff --git a/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv b/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv new file mode 100644 index 00000000..dc2dafb0 --- /dev/null +++ b/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv @@ -0,0 +1,3 @@ +,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 +0,A172,0.1850161673724403,0.11711088107632658,0.15398026900541587,0.05887791133257723,0.19190126929509818,0.2730259209313969,0.1751961078140095,0.17032320672603893 +1,BT474,0.1850161673724403,0.11711088107632658,0.15398026900541587,0.05887791133257723,0.19190126929509818,0.2730259209313969,0.1751961078140095,0.17032320672603893 diff --git a/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv b/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv new file mode 100644 index 00000000..80bd6652 --- /dev/null +++ b/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv @@ -0,0 +1,3 @@ +,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 +0,A172,0.6788825592108618,0.39103204419936854,0.309162440868368,0.32081925317932786,0.506118775058835,0.49141294985601713,0.49570195994360167,0.7406890203296562 +1,BT474,0.7359839660952532,0.5335655032637894,0.3564281709945332,0.36817940042036595,0.6213287057364375,0.5324793139057482,0.5691675983852051,0.742574351912725 diff --git a/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv b/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv new file mode 100644 index 00000000..0359f011 --- /dev/null +++ b/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv @@ -0,0 +1,3 @@ +,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 +0,A172,0.6786707807600854,0.3910081446358929,0.30896940855913735,0.3208134930311804,0.5058459829124426,0.491059230290008,0.49560698899107275,0.7406748006271707 +1,BT474,0.6786707807600854,0.3910081446358929,0.30896940855913735,0.3208134930311804,0.5058459829124426,0.491059230290008,0.49560698899107275,0.7406748006271707 diff --git a/experiments/vision-transformer/unetr/submit_training.py b/experiments/vision-transformer/unetr/submit_training.py index fc17e299..24de0aa0 100644 --- a/experiments/vision-transformer/unetr/submit_training.py +++ b/experiments/vision-transformer/unetr/submit_training.py @@ -7,8 +7,15 @@ from datetime import datetime -def write_batch_script(out_path, ini_sam=False): - cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] +def write_batch_script(out_path, ini_sam=False, source_choice="torch-em"): + """ + inputs: + source_choice:str - [torch_em / monai] source of the unetr model coming from + ini_sam: bool - initialize torch-em's unetr implementation with sam encoder weights + """ + # ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + cell_types = ["A172", "BT474"] + for i, ctype in enumerate(cell_types): batch_script = """#!/bin/bash #SBATCH -t 2-00:00:00 @@ -20,22 +27,23 @@ def write_batch_script(out_path, ini_sam=False): #SBATCH -A gzz0001 """ if ini_sam: - batch_script += f"#SBATCH --job-name=unetr-sam-torch-em-{ctype}" + batch_script += f"#SBATCH --job-name=unetr-sam-{source_choice}-{ctype}" else: - batch_script += f"#SBATCH --job-name=unetr-torch-em-{ctype}" + batch_script += f"#SBATCH --job-name=unetr-{source_choice}-{ctype}" + + env_name = "monai2" if source_choice == "monai" else "sam" - batch_script += """ + batch_script += f""" source ~/.bashrc -mamba activate sam +mamba activate {env_name} python livecell_unetr.py --train """ add_ctype = f"-c {ctype} " - add_input_path = "-i /scratch/usr/nimanwai/data/livecell/ " - add_save_root = "-s /scratch/usr/nimanwai/models/unetr/torch-em/ " add_sam_ini = "--do_sam_ini " + add_source_choice = f"--source_choice {source_choice}" - batch_script += add_ctype + add_input_path + add_save_root + batch_script += add_ctype + add_source_choice if ini_sam: batch_script += add_sam_ini From 89b5a5c25e796c7f6692dc283ef87a0c343994c9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 17 Oct 2023 16:06:48 +0200 Subject: [PATCH 2/7] Some clean up in livecell-unetr --- .../vision-transformer/unetr/livecell_unetr.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index bb951c56..11d38c12 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -61,6 +61,9 @@ def do_unetr_training( sam_initialization: bool, source_choice: str ): + print("!!! Run training for cell types:") + print(cell_types) + print("!!!!!!!!!!!!!!!!!!!!!!!!!1") train_loader = get_livecell_loader( path=input_path, split="train", @@ -127,12 +130,14 @@ def do_unetr_inference( model.eval() with torch.no_grad(): - for img_path in glob(test_img_dir): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for {ctype} with model {model_ckpt}"): fname = os.path.split(img_path)[-1] input_img = imageio.imread(img_path) input_img = standardize(input_img) - outputs = predict_with_halo(input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64]) + outputs = predict_with_halo( + input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True + ) fg, bd = outputs[0, :, :], outputs[1, :, :] @@ -157,6 +162,9 @@ def do_unetr_evaluation( for c1 in cell_types: _save_dir = os.path.join(root_save_dir, f"src-{c1}") + if not os.path.exists(_save_dir): + print("Skipping", _save_dir) + continue fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models"): @@ -193,6 +201,9 @@ def do_unetr_evaluation( tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) + print(csv_save_dir) + print(f_df_fg) + print(f_df_bd) def main(args): @@ -225,6 +236,7 @@ def main(args): source_choice=args.source_choice ) + # FIXME this is wrong for the MONAI models root_save_dir = os.path.join( args.save_dir, f"unetr-torch-em-sam-{args.model_name}" if args.do_sam_ini else f"unetr-torch-em-scratch-{args.model_name}" From e2dde0ae6aa9cf61837e1b4a62f56b5e8f3d30af Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 17 Oct 2023 16:53:54 +0200 Subject: [PATCH 3/7] Update save directory for monai inference --- experiments/vision-transformer/unetr/livecell_unetr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 11d38c12..234740f5 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -130,7 +130,7 @@ def do_unetr_inference( model.eval() with torch.no_grad(): - for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for {ctype} with model {model_ckpt}"): + for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"): fname = os.path.split(img_path)[-1] input_img = imageio.imread(img_path) @@ -236,10 +236,9 @@ def main(args): source_choice=args.source_choice ) - # FIXME this is wrong for the MONAI models root_save_dir = os.path.join( args.save_dir, - f"unetr-torch-em-sam-{args.model_name}" if args.do_sam_ini else f"unetr-torch-em-scratch-{args.model_name}" + f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" ) if args.predict: From cf76cae74270731f06cdd8d827055806dc665672 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 17 Oct 2023 18:49:43 +0200 Subject: [PATCH 4/7] Fix issue in metric application for unetr experiments --- .../unetr/livecell_unetr.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 234740f5..e53a8b98 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -13,7 +13,6 @@ import torch import torch_em from torch_em.transform.raw import standardize -from torch_em.transform.label import labels_to_binary from torch_em.data.datasets import get_livecell_loader from torch_em.util.prediction import predict_with_halo @@ -167,7 +166,7 @@ def do_unetr_evaluation( continue fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} - for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models"): + for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): fg_dir = os.path.join(_save_dir, "foreground") bd_dir = os.path.join(_save_dir, "boundary") @@ -180,11 +179,19 @@ def do_unetr_evaluation( fg = imageio.imread(os.path.join(fg_dir, fname)) bd = imageio.imread(os.path.join(bd_dir, fname)) - true_fg = labels_to_binary(gt) true_bd = find_boundaries(gt) - cwise_fg.append(dice_score(fg, true_fg, threshold_gt=0)) - cwise_bd.append(dice_score(bd, true_bd, threshold_gt=0)) + # Compare the foreground prediction to the ground-truth. + # Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to + # foreground in the dice function, since we have a comparision > 0 in there, and everything in the + # binary prediction evaluates to true. + # For the GT we can set the threshold to 0, because this will map to the correct binary mask. + cwise_fg.append(dice_score(fg, gt, threshold_gt=0, threshold_seg=None)) + + # Compare the background prediction to the ground-truth. + # Here, we don't need any thresholds: for the prediction the same holds as before. + # For the ground-truth we have already a binary label, so we don't need to threshold it again. + cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None)) fg_set[c2] = np.mean(cwise_fg) bd_set[c2] = np.mean(cwise_bd) @@ -201,9 +208,7 @@ def do_unetr_evaluation( tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) - print(csv_save_dir) print(f_df_fg) - print(f_df_bd) def main(args): @@ -240,6 +245,7 @@ def main(args): args.save_dir, f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch" ) + print("Predictions are saved in", root_save_dir) if args.predict: print("2d UNETR inference on LiveCell dataset") @@ -253,6 +259,7 @@ def main(args): save_root=args.save_root, source_choice=args.source_choice ) + if args.evaluate: print("2d UNETR evaluation on LiveCell dataset") do_unetr_evaluation( From 3dc4a939f8cc3e7efe22d35b165dba81b348b51d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 17 Oct 2023 18:51:41 +0200 Subject: [PATCH 5/7] Remove unnecessary prints --- experiments/vision-transformer/unetr/livecell_unetr.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index e53a8b98..60367a9f 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -60,9 +60,7 @@ def do_unetr_training( sam_initialization: bool, source_choice: str ): - print("!!! Run training for cell types:") - print(cell_types) - print("!!!!!!!!!!!!!!!!!!!!!!!!!1") + print("Run training for cell types:", cell_types) train_loader = get_livecell_loader( path=input_path, split="train", @@ -208,7 +206,6 @@ def do_unetr_evaluation( tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) - print(f_df_fg) def main(args): From c136ac064f1fd9027c7d484afb4f987b12d0e08e Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 17 Oct 2023 23:34:10 +0200 Subject: [PATCH 6/7] Remove evaluation results --- .../unetr/results/boundary-unetr-monai-scratch-results.csv | 3 --- .../unetr/results/boundary-unetr-torch-em-scratch-results.csv | 3 --- .../unetr/results/foreground-unetr-monai-scratch-results.csv | 3 --- .../results/foreground-unetr-torch-em-scratch-results.csv | 3 --- 4 files changed, 12 deletions(-) delete mode 100644 experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv delete mode 100644 experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv delete mode 100644 experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv delete mode 100644 experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv diff --git a/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv b/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv deleted file mode 100644 index 24ed9619..00000000 --- a/experiments/vision-transformer/unetr/results/boundary-unetr-monai-scratch-results.csv +++ /dev/null @@ -1,3 +0,0 @@ -,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 -0,A172,0.281576376639954,0.22489535508532418,0.3360854423939335,0.10534368422737121,0.34645559398247966,0.41675285531129036,0.3801486359254411,0.23643396618484006 -1,BT474,0.24166370863434725,0.1776787222538735,0.2932050490579082,0.09689262054684385,0.27523996350005353,0.3649604124636659,0.33871591704503684,0.21826910822283302 diff --git a/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv b/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv deleted file mode 100644 index dc2dafb0..00000000 --- a/experiments/vision-transformer/unetr/results/boundary-unetr-torch-em-scratch-results.csv +++ /dev/null @@ -1,3 +0,0 @@ -,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 -0,A172,0.1850161673724403,0.11711088107632658,0.15398026900541587,0.05887791133257723,0.19190126929509818,0.2730259209313969,0.1751961078140095,0.17032320672603893 -1,BT474,0.1850161673724403,0.11711088107632658,0.15398026900541587,0.05887791133257723,0.19190126929509818,0.2730259209313969,0.1751961078140095,0.17032320672603893 diff --git a/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv b/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv deleted file mode 100644 index 80bd6652..00000000 --- a/experiments/vision-transformer/unetr/results/foreground-unetr-monai-scratch-results.csv +++ /dev/null @@ -1,3 +0,0 @@ -,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 -0,A172,0.6788825592108618,0.39103204419936854,0.309162440868368,0.32081925317932786,0.506118775058835,0.49141294985601713,0.49570195994360167,0.7406890203296562 -1,BT474,0.7359839660952532,0.5335655032637894,0.3564281709945332,0.36817940042036595,0.6213287057364375,0.5324793139057482,0.5691675983852051,0.742574351912725 diff --git a/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv b/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv deleted file mode 100644 index 0359f011..00000000 --- a/experiments/vision-transformer/unetr/results/foreground-unetr-torch-em-scratch-results.csv +++ /dev/null @@ -1,3 +0,0 @@ -,CELL TYPE,A172,BT474,BV2,Huh7,MCF7,SHSY5Y,SkBr3,SKOV3 -0,A172,0.6786707807600854,0.3910081446358929,0.30896940855913735,0.3208134930311804,0.5058459829124426,0.491059230290008,0.49560698899107275,0.7406748006271707 -1,BT474,0.6786707807600854,0.3910081446358929,0.30896940855913735,0.3208134930311804,0.5058459829124426,0.491059230290008,0.49560698899107275,0.7406748006271707 From c2578d89154bc2389381424a520279d36aebf540 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 18 Oct 2023 10:42:13 +0200 Subject: [PATCH 7/7] Update training scripts --- experiments/vision-transformer/unetr/livecell_unetr.py | 4 ++-- .../vision-transformer/unetr/submit_training.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 60367a9f..fd883c81 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -44,7 +44,7 @@ def get_unetr_model( model.out_channels = 2 # type: ignore else: - raise ValueError("The available UNETR models are either from \"torch-em\" or \"monai\", choose from them") + raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}") return model @@ -300,7 +300,7 @@ def main(args): parser.add_argument("--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr", help="Path to save predictions from UNETR model") - parser.add_argument("--iterations", type=int, default=10000) + parser.add_argument("--iterations", type=int, default=100000) args = parser.parse_args() main(args) diff --git a/experiments/vision-transformer/unetr/submit_training.py b/experiments/vision-transformer/unetr/submit_training.py index 24de0aa0..b7042f8a 100644 --- a/experiments/vision-transformer/unetr/submit_training.py +++ b/experiments/vision-transformer/unetr/submit_training.py @@ -7,14 +7,13 @@ from datetime import datetime -def write_batch_script(out_path, ini_sam=False, source_choice="torch-em"): +def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"): """ inputs: source_choice:str - [torch_em / monai] source of the unetr model coming from ini_sam: bool - initialize torch-em's unetr implementation with sam encoder weights """ - # ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] - cell_types = ["A172", "BT474"] + cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] for i, ctype in enumerate(cell_types): batch_script = """#!/bin/bash @@ -40,11 +39,10 @@ def write_batch_script(out_path, ini_sam=False, source_choice="torch-em"): python livecell_unetr.py --train """ add_ctype = f"-c {ctype} " - add_sam_ini = "--do_sam_ini " - add_source_choice = f"--source_choice {source_choice}" - + add_source_choice = f"--source_choice {source_choice} " batch_script += add_ctype + add_source_choice + add_sam_ini = "--do_sam_ini " if ini_sam: batch_script += add_sam_ini