# Calculate 2D and 3D Dice for Any Model

In [22]:
import os
os.chdir('/workspace/noise_and_dice')

In [None]:
config = {}

config['images'] = '/processed/Public/TCIA'
config['dataset'] = 'tcia'
config['model'] = 'TS'
config['gt'] = '/processed/Public/TCIA/labels-vista3d'


#config['predictions'] = '/processed/Public/TCIA/predictions/vista3d'
#config['labels_json'] = 'labels-TCIA-pediatric.json'

config['predictions'] = '/processed/Public/TCIA/predictions/TotalSegmentator_mapped'
config['labels_json'] = 'labels-TCIA-pediatric.json'

In [24]:
#!pip install cupy
#!pip install xlsxwriter

In [25]:
import json
with open(config['labels_json'],'r') as jsonfile:
    config['labels'] = json.load(jsonfile)

KeyError: 'labels_json'

In [None]:
import os
import tqdm
import SimpleITK as sitk
import numpy as np
import cupy as cp
#import cudf
import pandas as pd

In [None]:
def sitk_to_numpy(image):
    """Convert SimpleITK image to NumPy array with Z-axis moved to index 2 (Y, X, Z)."""
    array = sitk.GetArrayFromImage(image)  # Original: (Z, Y, X)
    return np.transpose(array, (1, 2, 0))  # New shape: (Y, X, Z)

def dice_coefficient(true_array, pred_array, label):
    """Compute Dice coefficient for a given label.
    
    Args:
        true_array (cupy.ndarray): Ground truth binary mask.
        pred_array (cupy.ndarray): Prediction binary mask.
        label (int): Label for the organ.
        is_2d (bool): If True, empty masks return 1.0 (for slice-wise 2D Dice); 
                      If False, empty masks return 1.0 for fully missing organs in 3D.
    
    Returns:
        float: Dice score.
        int: Intersection count.
        int: Sum of sizes.
    """
    true_binary = (true_array == label).astype(cp.float32)
    pred_binary = (pred_array == label).astype(cp.float32)

    intersection = cp.sum(true_binary * pred_binary)
    sum_sizes = cp.sum(true_binary) + cp.sum(pred_binary)

    if sum_sizes == 0:
        return np.nan, intersection, sum_sizes  # Return 1.0 if GT & prediction are both empty

    if cp.sum(true_binary) == 0 and cp.sum(pred_binary) > 0:
        return 0.0, intersection, sum_sizes  # False positive: GT is missing, but prediction exists

    if cp.sum(true_binary) > 0 and cp.sum(pred_binary) == 0:
        return 0.0, intersection, sum_sizes  # False negative: GT exists, but prediction is missing

    dice = (2.0 * intersection) / (sum_sizes + 1e-6)
    return float(dice), intersection, sum_sizes

In [None]:
def multi_class_dice(true_array, pred_array, labels):
    """Compute organ-wise, macro-average, and micro-average Dice for multiple labels."""
    dice_scores = {}
    macro_dice_list = []
    total_intersection = 0
    total_sum_sizes = 0
    
    for key, label in labels.items():
        if cp.sum(true_array == label) > 0 or cp.sum(pred_array == label) > 0:
            dice, intersection, sum_sizes = dice_coefficient(true_array, pred_array, label)
            dice_scores[key] = dice
            if dice:
                macro_dice_list.append(dice)  #  Only include organs that exist
        else:
            dice_scores[key] = np.nan  # Still store the Dice score for reference, but don't include in macro
            intersection = 0
            sum_sizes = 0

        total_intersection += intersection
        total_sum_sizes += sum_sizes

    #  Compute Macro Dice over only PRESENT organs
    macro_average_dice = np.mean(macro_dice_list) if macro_dice_list else float('nan')

    # Compute Micro Dice (Weighted sum across all organs)
    micro_average_dice = (2.0 * total_intersection / total_sum_sizes) if total_sum_sizes > 0 else float('nan')

    return dice_scores, macro_average_dice, micro_average_dice

def compute_dice_2d(label_3d, pred_3d, labels):
    """Computes slice-wise 2D Dice scores using RAPIDS/cuDF.
    'overall' score is computed using micro-average (aggregated pixel-level Dice across classes)."""
    num_slices = label_3d.shape[2]
    record_2d = {'overall': {}}

    for key in labels.keys():
        record_2d[key] = {}

    for z in range(num_slices):
        label_slice = label_3d[:, :, z]
        pred_slice = pred_3d[:, :, z]

        total_tp = 0
        total_fp = 0
        total_fn = 0

        for key, label in labels.items():
            pred_mask = (pred_slice == label)
            gt_mask = (label_slice == label)

            tp = cp.sum(pred_mask & gt_mask)
            fp = cp.sum(pred_mask & ~gt_mask)
            fn = cp.sum(~pred_mask & gt_mask)

            total_tp += tp
            total_fp += fp
            total_fn += fn

            if tp + fp + fn > 0:
                organ_dice = (2 * tp) / (2 * tp + fp + fn)
                organ_dice = float(organ_dice)
            else:
                organ_dice = float('nan')

            record_2d[key][str(z + 1)] = organ_dice

        if total_tp + total_fp + total_fn > 0:
            overall_dice = (2 * total_tp) / (2 * total_tp + total_fp + total_fn)
            overall_dice = float(overall_dice)
        else:
            overall_dice = float('nan')

        record_2d['overall'][str(z + 1)] = overall_dice

    return record_2d


def process_ct_file(ct_name, config):
    """Processes a single CT file for 3D and 2D Dice computation."""
    try:
        label_path = os.path.join(config['gt'], ct_name)
        pred_path = os.path.join(config['predictions'], ct_name)

        label_image = sitk_to_numpy(sitk.ReadImage(label_path))
        pred_image = sitk_to_numpy(sitk.ReadImage(pred_path))
        
       # print('label_image shape: ', label_image.shape)
        
        label_gpu = cp.asarray(label_image)
        pred_gpu = cp.asarray(pred_image)

        # Compute 3D Dice
        dice_scores, macro_avg, micro_avg = multi_class_dice(label_gpu, pred_gpu, config['labels'])

        # Compute 2D Dice scores
        record_2d = compute_dice_2d(label_gpu, pred_gpu, config['labels'])

        return {
            'ct_name': ct_name,
            'macro_avg_dice': macro_avg,
            'micro_avg_dice': micro_avg,
            **{f'dice_{key}': dice_scores[key] for key in dice_scores}
        }, record_2d

    except Exception as e:
        print(f"Error processing {ct_name}: {e}")
        return None, None

def save_2d_dice_to_excel(records_2d, output_path):
    """Writes the 2D Dice records to an Excel file with each organ in its own sheet."""
    with pd.ExcelWriter(output_path, engine='xlsxwriter') as writer:
        for organ, ct_slices in records_2d.items():
            organ_data = []
            ct_names = []

            # Extract slice numbers
            slice_numbers = set()
            for ct_name, slices in ct_slices.items():
                slice_numbers.update(slices.keys())

            slice_numbers = sorted(slice_numbers, key=int)  # Sort numerically
            header = ['ct_name'] + [f'slice_{s}' for s in slice_numbers]

            # Populate data
            for ct_name, slices in ct_slices.items():
                ct_names.append(ct_name)
                organ_data.append([ct_name] + [slices.get(s, np.nan) for s in slice_numbers])

            df = pd.DataFrame(organ_data, columns=header)
            df.to_excel(writer, sheet_name=organ, index=False)

In [None]:
ct_images = [f for f in os.listdir(config['gt']) if f.endswith('.nii.gz')]

results = []
records_2d = {'overall': {}, **{key: {} for key in config['labels'].keys()}}

print('Starting DICE Calculation...')
for ct_name in tqdm.tqdm(ct_images):
    record, record_2d = process_ct_file(ct_name, config)
    if record:
        results.append(record)
        for organ in records_2d.keys():
            records_2d[organ][ct_name] = record_2d.get(organ, {})

Starting DICE Calculation...


100%|██████████| 358/358 [29:28<00:00,  4.94s/it]


In [None]:
#results_df = cudf.DataFrame(results)  # Store results in GPU memory

# Save the 2D Dice results to an Excel file
output_excel = f"{config['dataset']}_2d_dice_{config['model']}.xlsx"  # Specify desired output path
save_2d_dice_to_excel(records_2d, output_excel)
print(f"2D Results saved to {output_excel}")
# Create a DataFrame from results and save to CSV
df = pd.DataFrame(results)
output_csv_path = f"{config['dataset']}_3d_dice_{config['model']}.csv"  # Specify desired output path
df.to_csv(output_csv_path, index=False)

print(f"3D Results saved to {output_csv_path}")

2D Results saved to tcia_2d_dice_vista3d.xlsx
3D Results saved to tcia_3d_dice_vista3d.csv
