Skip to content

Commit

Permalink
Update evaluation in unetr scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 22, 2023
1 parent 35bd684 commit ee309eb
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Tuple, List

import imageio.v2 as imageio
from elf.evaluation import dice_score
from skimage.segmentation import find_boundaries
from elf.evaluation import dice_score, mean_segmentation_accuracy

import torch
import torch_em
Expand Down Expand Up @@ -166,6 +166,7 @@ def do_unetr_evaluation(
source_choice: str
):
fg_list, bd_list = [], []
ws1_msa_list, ws2_msa_list, ws1_sa50_list, ws2_sa50_list = [], [], [], []

for c1 in cell_types:
_save_dir = os.path.join(root_save_dir, f"src-{c1}")
Expand All @@ -174,18 +175,24 @@ def do_unetr_evaluation(
continue

fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}
ws1_msa_set, ws2_msa_set, ws1_sa50_set, ws2_sa50_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}, {"CELL TYPE": c1}
for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"):
fg_dir = os.path.join(_save_dir, "foreground")
bd_dir = os.path.join(_save_dir, "boundary")
ws1_dir = os.path.join(_save_dir, "watershed1")
ws2_dir = os.path.join(_save_dir, "watershed2")

gt_dir = os.path.join(input_path, "annotations", "livecell_test_images", c2, "*")
cwise_fg, cwise_bd = [], []
cwise_ws1_msa, cwise_ws2_msa, cwise_ws1_sa50, cwise_ws2_sa50 = [], [], [], []
for gt_path in glob(gt_dir):
fname = os.path.split(gt_path)[-1]

gt = imageio.imread(gt_path)
fg = imageio.imread(os.path.join(fg_dir, fname))
bd = imageio.imread(os.path.join(bd_dir, fname))
ws1 = imageio.imread(os.path.join(ws1_dir, fname))
ws2 = imageio.imread(os.path.join(ws2_dir, fname))

true_bd = find_boundaries(gt)

Expand All @@ -201,21 +208,45 @@ def do_unetr_evaluation(
# For the ground-truth we have already a binary label, so we don't need to threshold it again.
cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None))

msa1, sa_acc1 = mean_segmentation_accuracy(ws1, gt, return_accuracies=True) # type: ignore
msa2, sa_acc2 = mean_segmentation_accuracy(ws2, gt, return_accuracies=True) # type: ignore

cwise_ws1_msa.append(msa1)
cwise_ws2_msa.append(msa2)
cwise_ws1_sa50.append(sa_acc1[0])
cwise_ws2_sa50.append(sa_acc2[0])

fg_set[c2] = np.mean(cwise_fg)
bd_set[c2] = np.mean(cwise_bd)
ws1_msa_set[c2] = np.mean(cwise_ws1_msa)
ws2_msa_set[c2] = np.mean(cwise_ws2_msa)
ws1_sa50_set[c2] = np.mean(cwise_ws1_sa50)
ws2_sa50_set[c2] = np.mean(cwise_ws2_sa50)

fg_list.append(pd.DataFrame.from_dict([fg_set])) # type: ignore
bd_list.append(pd.DataFrame.from_dict([bd_set])) # type: ignore
ws1_msa_list.append(pd.DataFrame.from_dict([ws1_msa_set])) # type: ignore
ws2_msa_list.append(pd.DataFrame.from_dict([ws2_msa_set])) # type: ignore
ws1_sa50_list.append(pd.DataFrame.from_dict([ws1_sa50_set])) # type: ignore
ws2_sa50_list.append(pd.DataFrame.from_dict([ws2_sa50_set])) # type: ignore

f_df_fg = pd.concat(fg_list, ignore_index=True)
f_df_bd = pd.concat(bd_list, ignore_index=True)
f_df_ws1_msa = pd.concat(ws1_msa_list, ignore_index=True)
f_df_ws2_msa = pd.concat(ws2_msa_list, ignore_index=True)
f_df_ws1_sa50 = pd.concat(ws1_sa50_list, ignore_index=True)
f_df_ws2_sa50 = pd.concat(ws2_sa50_list, ignore_index=True)

csv_save_dir = "./results/"
tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
csv_save_dir = f"./results/{tmp_csv_name}"
os.makedirs(csv_save_dir, exist_ok=True)

tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv"))
f_df_fg.to_csv(os.path.join(csv_save_dir, "foreground-dice.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, "boundary-dice.csv"))
f_df_ws1_msa.to_csv(os.path.join(csv_save_dir, "watershed1-msa.csv"))
f_df_ws2_msa.to_csv(os.path.join(csv_save_dir, "watershed2-msa.csv"))
f_df_ws1_sa50.to_csv(os.path.join(csv_save_dir, "watershed1-sa50.csv"))
f_df_ws2_sa50.to_csv(os.path.join(csv_save_dir, "watershed2-sa50.csv"))


def main(args):
Expand Down

0 comments on commit ee309eb

Please sign in to comment.