In [None]:
import os
import SimpleITK as sitk
from tqdm import tqdm
from scipy.ndimage import zoom
from skimage import transform

from metric.tree_parse import *
from metric.evaluation_atm_22 import *

BAS_PRED_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_results/Dataset011_BAS/nnUNetTrainer__nnUNetPlans__3d_fullres/test_results/'
BAS_GT_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_raw/Dataset011_BAS/labelsTs/'

ATM_PRED_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_results/Dataset012_ATM/nnUNetTrainer__nnUNetPlans__3d_fullres/test_results/'
ATM_GT_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_raw/Dataset012_ATM/labelsTs/'

PARSE_PRED_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_results/Dataset013_PARSE/nnUNetTrainer__nnUNetPlans__3d_fullres/test_results/'
PARSE_GT_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_raw/Dataset013_PARSE/labelsTs/'

IMAGECAS_PRED_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_results/Dataset014_ImageCAS/nnUNetTrainer__nnUNetPlans__3d_fullres/test_results/'
IMAGECAS_GT_PATH = '/data/dengxiaolong/airway/nnunet/nnUNet_raw/Dataset014_ImageCAS/labelsTs/'

def resample_array(array, original_spacing, new_spacing=(1.0, 1.0, 1.0)):
    """
    将给定的NumPy数组根据原始间距和新间距进行重采样。
    
    :param array: 输入的三维NumPy数组。
    :param original_spacing: 原始的体素间距，形式为(z, y, x)。
    :param new_spacing: 目标的体素间距，形式为(z, y, x)，默认为(1.0, 1.0, 1.0)。
    :return: 重采样后的NumPy数组。
    """
    # 计算重采样的比率
    resize_factor = np.array(original_spacing) / np.array(new_spacing)
    # 计算新的数组尺寸
    new_shape = np.round(array.shape * resize_factor)
    # 计算实际的缩放比率s
    real_resize_factor = new_shape / array.shape
    # 使用scipy的zoom函数进行重采样
    resampled_array = zoom(array, real_resize_factor, mode='nearest')

    # resampled_array = transform.resize(array, new_shape, order=1, mode='reflect', cval=0, clip=True, preserve_range=False, anti_aliasing=True, anti_aliasing_sigma=None)

    return resampled_array


In [None]:
# evaluate
# set the path of the prediction and ground truth
gt_path = '/data/dengxiaolong/airway/BAS/Data3/labels'
pred_path = '/home/dengxiaolong/code/3DSAM-adapter/3DSAM-adapter/logs/bas/preds'

pred_file = os.listdir(pred_path)
# find the end of '.nii.gz'
pred_file = [i for i in pred_file if i.endswith('.nii.gz')]

tds = []
bds = []
dices = []
pres = []
sens = []
spes = []

for pred in tqdm(pred_file):
    # print(pred)
    pred_image = sitk.ReadImage(os.path.join(pred_path, pred))
    pred_array = sitk.GetArrayFromImage(pred_image)    

    gt_image = sitk.ReadImage(os.path.join(gt_path, pred))
    gt_array = sitk.GetArrayFromImage(gt_image)
    
    assert gt_array.shape == pred_array.shape, 'Shape Mismatch!'

    # resample 
    # gt_array = resample_array(gt_array, original_spacing=gt_image.GetSpacing())
    # pred_array = resample_array(pred_array, original_spacing=gt_image.GetSpacing())

    # evaluate
    gt_parsing, gt_skeleton = get_parsing_and_skeleton(gt_array)
    
    tds.append(tree_length_calculation(pred_array, gt_skeleton))
    bds.append(branch_detected_calculation(pred_array, gt_parsing, gt_skeleton)[-1])
    dices.append(dice_coefficient_score_calculation(pred_array, gt_array))
    pres.append(precision_calculation(pred_array, gt_array))
    sens.append(sensitivity_calculation(pred_array, gt_array))
    spes.append(specificity_calculation(pred_array, gt_array))
    # print("TD: %0.4f, BD: %0.4f, DICE: %0.4f, PRE: %0.4f, SEN: %0.4f, SPE: %0.4f" %
    #     (tds[0], bds[0], dices[0], pres[0], sens[0], spes[0]))


td_mean = np.mean(tds)
bd_mean = np.mean(bds)
dice_mean = np.mean(dices)
pre_mean = np.mean(pres)
sen_mean = np.mean(sens)
spe_mean = np.mean(spes)

print("TD: %0.4f, BD: %0.4f, DICE: %0.4f, PRE: %0.4f, SEN: %0.4f, SPE: %0.4f" %
        (td_mean, bd_mean, dice_mean, pre_mean, sen_mean, spe_mean))


In [None]:
# test gt parsing and skeleton
gt_file = os.listdir(gt_path)
# find the end of '.nii.gz'
gt_file = [i for i in gt_file if i.endswith('.nii.gz')]

for gt in tqdm(gt_file):
    print(gt)
    gt_image = sitk.ReadImage(os.path.join(gt_path, gt))
    gt_array = sitk.GetArrayFromImage(gt_image)

    gt_parsing, gt_skeleton = get_parsing_and_skeleton(gt_array)