diff --git a/experiments/vision-transformer/unetr/livecell_unetr.py b/experiments/vision-transformer/unetr/livecell_unetr.py index 250ff5c7..0ff9cde2 100644 --- a/experiments/vision-transformer/unetr/livecell_unetr.py +++ b/experiments/vision-transformer/unetr/livecell_unetr.py @@ -7,8 +7,8 @@ from typing import Tuple, List import imageio.v2 as imageio -from elf.evaluation import dice_score from skimage.segmentation import find_boundaries +from elf.evaluation import dice_score, mean_segmentation_accuracy import torch import torch_em @@ -166,6 +166,7 @@ def do_unetr_evaluation( source_choice: str ): fg_list, bd_list = [], [] + ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], [] for c1 in cell_types: _save_dir = os.path.join(root_save_dir, f"src-{c1}") @@ -174,18 +175,24 @@ def do_unetr_evaluation( continue fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1} + ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1} for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"): fg_dir = os.path.join(_save_dir, "foreground") bd_dir = os.path.join(_save_dir, "boundary") + ws1_dir = os.path.join(_save_dir, "watershed1") + ws2_dir = os.path.join(_save_dir, "watershed2") gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*") cwise_fg, cwise_bd = [], [] + cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], [] for gt_path in glob(gt_dir): fname = os.path.split(gt_path)[-1] gt = imageio.imread(gt_path) fg = imageio.imread(os.path.join(fg_dir, fname)) bd = imageio.imread(os.path.join(bd_dir, fname)) + ws1 = imageio.imread(os.path.join(ws1_dir, fname)) + ws2 = imageio.imread(os.path.join(ws2_dir, fname)) true_bd = find_boundaries(gt) @@ -201,21 +208,45 @@ def do_unetr_evaluation( # For the ground-truth we have already a binary label, so we don't need to threshold it again. cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None)) + msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore + msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore + + cwise_ws1_msa.append(msa1) + cwise_ws2_msa.append(msa2) + cwise_ws1_sa50.append(sa_acc1[0]) + cwise_ws2_sa50.append(sa_acc2[0]) + fg_set[c2] = np.mean(cwise_fg) bd_set[c2] = np.mean(cwise_bd) + ws1_msa_set[c2] = np.mean(cwise_ws1_msa) + ws2_msa_set[c2] = np.mean(cwise_ws2_msa) + ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50) + ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50) fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore + ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore + ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore + ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore + ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore f_df_fg = pd.concat(fg_list, ignore_index=True) f_df_bd = pd.concat(bd_list, ignore_index=True) + f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True) + f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True) + f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True) + f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True) - csv_save_dir = "./results/" + tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" + csv_save_dir = f"./results/{tmp_csv_name}" os.makedirs(csv_save_dir, exist_ok=True) - tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch" - f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv")) - f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv")) + f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv")) + f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv")) + f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv")) + f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv")) + f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv")) + f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv")) def main(args):