In [86]:
import monai
import os
import numpy as np
import nibabel as nib
import torch
import json
from tqdm import tqdm
import pandas as pd

In [2]:
dice_metric = monai.metrics.DiceMetric(include_background=False, reduction="none", get_not_nans=False)
post_label = monai.transforms.AsDiscrete(to_onehot=13)
post_pred = monai.transforms.AsDiscrete(to_onehot=13)

In [87]:
data_classes = ['spleen', 'kidney_right', 'kidney_left', 'gallbladder', 'esophagus', 'liver', 'stomach', 'aorta', 'inferior_vena_cava', 'pancreas', 'adrenal_gland_right', 'adrenal_gland_left']

In [88]:
def nnunet_dice(mask_pred: np.ndarray, mask_ref: np.ndarray):
    dice_scores = []
    for organ_class in range(1, 13):
        gt = (mask_ref == organ_class)
        pred = (mask_pred == organ_class)
        use_mask = np.ones_like(gt, dtype=bool)
        
        tp = np.sum((gt & pred) & use_mask)
        fp = np.sum(((~gt) & pred) & use_mask)
        fn = np.sum((gt & (~pred)) & use_mask)
        
        if tp + fp + fn == 0:
            dice_scores.append(np.nan)
        else:
            dice_scores.append(2 * tp / (2 * tp + fp + fn))
    
    return dice_scores

In [90]:
old_meta = pd.read_csv("/data_analysis/merged_metadata.csv")['id']

accepted_names = []

for name in old_meta:
    splits = name.split('_')
    if len(splits) > 1 and int(splits[1]) <= 500:
        accepted_names.append(name)
    elif len(splits) == 1:
        accepted_names.append(name)

for scan_name in os.listdir("/nnUNet/preprocessed_data/Dataset060_Merged_Def/gt_segmentations"):
    if scan_name.startswith('img'):
        accepted_names.append(scan_name.split('.')[0])

In [92]:
len(accepted_names)

680

In [89]:
def per_dataset(dataset):
    datasets = ['btcv', 'amos', 'totalseg']
    if dataset not in datasets:
        raise ValueError("Invalid dataset name. Expected one of: %s" % datasets)

    inference_folder = os.path.join("/nnUNet/inference/")


    #splits = json.load(open('/nnUNet/preprocessed_data/Dataset053_MergedWithPathologies/splits_final.json'))
    load_image = monai.transforms.LoadImage()

    pattern = 0
    if dataset == 'btcv':  pattern = 'img'
    if dataset == 'amos':  pattern = 'amos'
    if dataset == 'totalseg':  pattern = 's'

    num_folds = 5
    class_res = []
    for fold in tqdm(range(num_folds)):
        fold_folder = os.path.join(inference_folder, 'fold_' + str(fold))
        input_folder = os.path.join(fold_folder, 'preds')
        gt_folder = os.path.join(fold_folder, 'gt')
        pred_names = sorted(file for file in os.listdir(input_folder) if file.endswith(".nii.gz"))
        gt_names = sorted(os.listdir(gt_folder))
        for pred_name, gt_name in tqdm(zip(pred_names, gt_names)):
            if pred_name.startswith(pattern) and pred_name.split('.')[0] in accepted_names:
                dice_scores = nnunet_dice(load_image(os.path.join(input_folder, pred_name)), load_image(os.path.join(gt_folder, gt_name)))

                class_res.append(dice_scores)

    mean_class_res = np.nanmean(class_res, axis=0)
    return np.nanmean(mean_class_res), mean_class_res

In [93]:
per_dataset('btcv')

170it [00:24,  7.07it/s]:00<?, ?it/s]
170it [00:36,  4.72it/s]:24<01:36, 24.07s/it]
169it [00:41,  4.03it/s]:00<01:33, 31.11s/it]
169it [00:23,  7.24it/s]:42<01:12, 36.07s/it]
169it [00:17,  9.62it/s]:05<00:31, 31.05s/it]
100%|██████████| 5/5 [02:23<00:00, 28.60s/it]


(0.8597102980832645,
 array([0.9554186586875473, 0.9164737610884095, 0.9416819174270082,
        0.7150160930331665, 0.7921410762610425, 0.9681061325682294,
        0.9343983383768842, 0.9236569465386706, 0.8814241177421734,
        0.8338661348142795, 0.7338441621916609, 0.720496238270102 ]))

In [44]:
per_dataset('amos')

100%|██████████| 5/5 [32:53<00:00, 394.65s/it]


(0.903191627615454,
 array([0.96931269, 0.96362673, 0.95894844, 0.87630084, 0.8558492 ,
        0.97628755, 0.92722287, 0.95438065, 0.90723227, 0.86089942,
        0.78775857, 0.8004803 ]))

In [61]:
per_dataset('totalseg')

170it [02:00,  1.42it/s]:00<?, ?it/s]
170it [02:28,  1.14it/s]:00<08:00, 120.04s/it]
169it [02:29,  1.13it/s]:28<06:50, 137.00s/it]
169it [02:16,  1.24it/s]:58<04:45, 142.79s/it]
169it [02:05,  1.34it/s]:15<02:20, 140.42s/it]
100%|██████████| 5/5 [11:21<00:00, 136.28s/it]


(0.9425105963400348,
 array([0.97797404, 0.970631  , 0.96993605, 0.87453245, 0.9219088 ,
        0.98790338, 0.95326074, 0.97767958, 0.95144403, 0.9218948 ,
        0.9026587 , 0.90030359]))