In [6]:
#!/usr/bin/env python
# coding: utf-8

import os
import torch
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, NormalizeIntensityd, Activations, AsDiscrete, MapTransform, SpatialPadd
)
from monai.networks.nets import SwinUNETR
from monai.inferers import sliding_window_inference
import torch.nn.functional as F
from scipy.spatial.distance import directed_hausdorff

# 仅解决负号显示问题，无中文字体依赖
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'DejaVu Sans'  # 通用无衬线字体


class ConvertToMultiChannel5Classesd(MapTransform):
    """将5亚型标签转换为多通道格式"""
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = [
                d[key] == 1,  # 亚型1
                d[key] == 2,  # 亚型2
                d[key] == 3,  # 亚型3
                d[key] == 4,  # 亚型4
                d[key] == 5   # 亚型5
            ]
            d[key] = torch.stack(result, axis=0).float()
        return d


class CustomCTDataset(Dataset):
    """CT数据集类（支持全量加载）"""
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        # 初始化image_info（关键修复：确保属性始终存在）
        self.image_info = {}
        self.data = self._load_data()

    def _load_data(self):
        data = []
        patient_ids = set()
        
        # 提取所有有效患者ID
        for filename in os.listdir(self.image_dir):
            if filename.endswith("_head.nii.gz"):
                patient_id = filename.split("_head.nii.gz")[0]
                patient_ids.add(patient_id)
        
        # 匹配图像和标签
        for patient_id in sorted(patient_ids):  # 排序保证处理顺序固定
            image_path = os.path.join(self.image_dir, f"{patient_id}_head.nii.gz")
            label_path = os.path.join(self.label_dir, f"{patient_id}_merged.nii")
            
            if os.path.exists(image_path) and os.path.exists(label_path):
                try:
                    # 读取图像信息用于后续保存nii
                    image_itk = sitk.ReadImage(image_path)
                    self.image_info[patient_id] = {
                        'origin': image_itk.GetOrigin(),
                        'spacing': image_itk.GetSpacing(),
                        'direction': image_itk.GetDirection(),
                        'size': image_itk.GetSize()  # 新增：保存图像尺寸
                    }
                    data.append({
                        "image": image_path,
                        "label": label_path,
                        "patient_id": patient_id,
                        "original_image_path": image_path  # 新增：传递原始图像路径
                    })
                except Exception as e:
                    print(f"Warning: Failed to read image info for patient {patient_id}: {e}, skipped")
            else:
                print(f"Warning: Incomplete data for patient {patient_id}, skipped")
        
        if not data:
            print("Warning: No valid patients found in dataset!")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        if self.transform:
            data = self.transform(data)
        return data


class EvaluationMetrics:
    """评估指标计算类"""
    @staticmethod
    def dice(pred, target):
        """计算DICE系数"""
        intersection = np.sum(pred * target)
        union = np.sum(pred) + np.sum(target)
        return 2 * intersection / (union + 1e-8)

    @staticmethod
    def iou(pred, target):
        """计算交并比IoU"""
        intersection = np.sum(pred * target)
        union = np.sum(pred) + np.sum(target) - intersection
        return intersection / (union + 1e-8)

    @staticmethod
    def hd95(pred, target, spacing):
        """计算95%豪斯多夫距离(mm)"""
        if np.sum(pred) == 0 or np.sum(target) == 0:
            return np.nan
        
        pred_coords = np.argwhere(pred) * spacing
        target_coords = np.argwhere(target) * spacing
        
        d1 = directed_hausdorff(pred_coords, target_coords)[0]
        d2 = directed_hausdorff(target_coords, pred_coords)[0]
        hd = max(d1, d2)
        
        all_distances = []
        for p in pred_coords:
            min_dist = np.min(np.linalg.norm(p - target_coords, axis=1))
            all_distances.append(min_dist)
        for t in target_coords:
            min_dist = np.min(np.linalg.norm(t - pred_coords, axis=1))
            all_distances.append(min_dist)
        
        return np.percentile(all_distances, 95) if all_distances else np.nan

    @staticmethod
    def fp_volume(pred, target, spacing):
        """计算假阳性体积(mm³)"""
        fp = np.logical_and(pred > 0, target == 0).astype(np.float32)
        voxel_volume = np.prod(spacing)
        return np.sum(fp) * voxel_volume

    @staticmethod
    def calculate_all_metrics(pred, target, spacing):
        """计算所有亚型的完整指标（仅包含标签中存在的亚型）"""
        metrics = {
            "dice": [],
            "iou": [],
            "hd95": [],
            "fp_volume": EvaluationMetrics.fp_volume(pred, target, spacing),
            "present_subtypes": []  # 记录存在的亚型
        }
        
        # 找出标签中存在的亚型
        present_subtypes = np.unique(target)
        present_subtypes = [s for s in present_subtypes if s != 0]  # 排除背景
        
        for c in range(5):
            class_idx = c + 1
            # 只计算标签中存在的亚型
            if class_idx in present_subtypes:
                pred_class = (pred == class_idx).astype(np.float32)
                target_class = (target == class_idx).astype(np.float32)
                
                metrics["dice"].append(EvaluationMetrics.dice(pred_class, target_class))
                metrics["iou"].append(EvaluationMetrics.iou(pred_class, target_class))
                metrics["hd95"].append(EvaluationMetrics.hd95(pred_class, target_class, spacing))
                metrics["present_subtypes"].append(class_idx)
            else:
                # 对于不存在的亚型，不纳入计算（使用NaN标记）
                metrics["dice"].append(np.nan)
                metrics["iou"].append(np.nan)
                metrics["hd95"].append(np.nan)
        
        return metrics


def create_color_map():
    """创建5亚型的颜色映射（RGBA格式）"""
    color_map = np.zeros((6, 4))  # 0=背景(透明), 1-5=亚型
    color_map[0] = [0, 0, 0, 0]          # 背景
    color_map[1] = [1, 0, 0, 0.7]        # 亚型1：红色
    color_map[2] = [0, 1, 0, 0.7]        # 亚型2：绿色
    color_map[3] = [0, 0, 1, 0.7]        # 亚型3：蓝色
    color_map[4] = [1, 1, 0, 0.7]        # 亚型4：黄色
    color_map[5] = [1, 0, 1, 0.7]        # 亚型5：紫色
    return color_map


def generate_visualization(image_3d, target_3d, pred_3d, metrics, patient_id, save_dir):
    """生成单例患者的可视化结果（x-y切片 + 定量指标）"""
    # 1. 基础配置
    color_map = create_color_map()
    mid_slice = image_3d.shape[2] // 2  # x-y轴位切片（沿W轴中间层）
    
    # 2. 提取切片数据
    image_slice = image_3d[:, :, mid_slice]
    target_slice = target_3d[:, :, mid_slice]
    pred_slice = pred_3d[:, :, mid_slice]
    
    # 3. 生成彩色标签/预测图
    # 金标准彩色图
    target_color = np.zeros((target_slice.shape[0], target_slice.shape[1], 4))
    for i in range(6):
        target_color[target_slice == i] = color_map[i]
    
    # 预测结果彩色图
    pred_color = np.zeros((pred_slice.shape[0], pred_slice.shape[1], 4))
    for i in range(6):
        pred_color[pred_slice == i] = color_map[i]
    
    # 4. 创建可视化图像
    fig, axes = plt.subplots(1, 3, figsize=(24, 8))
    fig.suptitle(f"Patient ID: {patient_id} | Axial Slice (W-axis: {mid_slice})", fontsize=16, fontweight="bold")
    
    # 4.1 CT原图
    axes[0].imshow(image_slice, cmap="gray", aspect="auto")
    axes[0].set_title("CT Original Image", fontsize=14)
    axes[0].axis("off")
    
    # 4.2 金标准标签（多颜色）
    axes[1].imshow(image_slice, cmap="gray", aspect="auto", alpha=0.8)
    axes[1].imshow(target_color, aspect="auto")
    axes[1].set_title("Ground Truth (Color-Coded Subtypes)", fontsize=14)
    axes[1].axis("off")
    
    # 4.3 模型预测结果（多颜色）
    axes[2].imshow(image_slice, cmap="gray", aspect="auto", alpha=0.8)
    axes[2].imshow(pred_color, aspect="auto")
    axes[2].set_title("Model Prediction (Color-Coded Subtypes)", fontsize=14)
    axes[2].axis("off")
    
    # 5. 添加定量指标文本（图下方）
    metrics_text = f"Evaluation Metrics:\nFP Volume: {metrics['fp_volume']:.2f} mm³\n"
    # 只显示存在的亚型指标
    for i in metrics["present_subtypes"]:
        idx = i - 1
        metrics_text += f"Subtype {i} - DICE: {metrics['dice'][idx]:.4f} | IoU: {metrics['iou'][idx]:.4f} | HD95: {np.nan if np.isnan(metrics['hd95'][idx]) else metrics['hd95'][idx]:.2f} mm\n"
    
    fig.text(0.05, 0.02, metrics_text, fontsize=10, verticalalignment='bottom',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 6. 添加颜色图例
    handles = [plt.Rectangle((0, 0), 1, 1, facecolor=color_map[i][:3], alpha=color_map[i][3]) for i in range(1, 6)]
    labels = [f"Subtype {i}" for i in range(1, 6)]
    fig.legend(handles, labels, loc="upper right", bbox_to_anchor=(0.98, 0.95), ncol=5, fontsize=10)
    
    # 7. 保存图像
    plt.tight_layout(rect=[0.05, 0.1, 0.95, 0.95])  # 预留指标和图例空间
    save_path = os.path.join(save_dir, f"{patient_id}_visualization.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    return save_path


def save_patient_metrics(patient_id, metrics, save_dir):
    """保存单例患者的详细指标到文件"""
    metrics_path = os.path.join(save_dir, f"{patient_id}_metrics.txt")
    with open(metrics_path, "w", encoding="utf-8") as f:
        f.write(f"Patient ID: {patient_id}\n")
        f.write("="*80 + "\n")
        f.write(f"False Positive Volume: {metrics['fp_volume']:.2f} mm³\n\n")
        f.write("Present Subtypes in Ground Truth: " + ", ".join(map(str, metrics["present_subtypes"])) + "\n\n")
        f.write("Subtype-wise Metrics:\n")
        # 只保存存在的亚型指标
        for i in metrics["present_subtypes"]:
            idx = i - 1
            f.write(f"Subtype {i}:\n")
            f.write(f"  DICE: {metrics['dice'][idx]:.4f}\n")
            f.write(f"  IoU: {metrics['iou'][idx]:.4f}\n")
            f.write(f"  HD95: {np.nan if np.isnan(metrics['hd95'][idx]) else metrics['hd95'][idx]:.2f} mm\n\n")
    return metrics_path


def save_prediction_as_nii(pred_3d, patient_id, dataset, save_dir):
    """将预测结果保存为nii格式（增加容错处理）"""
    try:
        # 获取原始图像的空间信息
        if patient_id not in dataset.image_info:
            print(f"Warning: No image info for patient {patient_id}, using default spacing")
            # 使用默认空间信息（兜底方案）
            image_info = {
                'origin': (0.0, 0.0, 0.0),
                'spacing': (1.0, 1.0, 1.0),
                'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
            }
        else:
            image_info = dataset.image_info[patient_id]
        
        # 调整维度顺序：numpy (D,H,W) → SimpleITK (W,H,D)
        pred_3d_transposed = np.transpose(pred_3d, (2, 1, 0))
        
        # 转换数据类型为整数
        pred_itk = sitk.GetImageFromArray(pred_3d_transposed.astype(np.int16))
        
        # 设置空间信息（与原始图像保持一致）
        pred_itk.SetOrigin(image_info['origin'])
        pred_itk.SetSpacing(image_info['spacing'])
        pred_itk.SetDirection(image_info['direction'])
        
        # 保存为nii文件
        save_path = os.path.join(save_dir, f"{patient_id}_prediction.nii.gz")
        sitk.WriteImage(pred_itk, save_path, useCompression=True)
        print(f"  Successfully saved prediction to: {save_path}")
        return save_path
    except Exception as e:
        print(f"  Error saving nii for patient {patient_id}: {e}")
        return None


def update_summary_file(all_metrics, summary_path):
    """更新全局汇总文件"""
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write("Global Summary of All Patients\n")
        f.write("="*120 + "\n")
        # 表头
        headers = ["Patient ID", "FP_Volume(mm³)", 
                   "S1_DICE", "S1_IoU", "S1_HD95",
                   "S2_DICE", "S2_IoU", "S2_HD95",
                   "S3_DICE", "S3_IoU", "S3_HD95",
                   "S4_DICE", "S4_IoU", "S4_HD95",
                   "S5_DICE", "S5_IoU", "S5_HD95"]
        f.write("\t".join(headers) + "\n")
        
        # 逐行写入患者数据
        for patient_id, metrics in all_metrics.items():
            row = [
                patient_id,
                f"{metrics['fp_volume']:.2f}",
                # 亚型1
                f"{metrics['dice'][0]:.4f}" if not np.isnan(metrics['dice'][0]) else "N/A",
                f"{metrics['iou'][0]:.4f}" if not np.isnan(metrics['iou'][0]) else "N/A",
                f"{metrics['hd95'][0]:.2f}" if not np.isnan(metrics['hd95'][0]) else "N/A",
                # 亚型2
                f"{metrics['dice'][1]:.4f}" if not np.isnan(metrics['dice'][1]) else "N/A",
                f"{metrics['iou'][1]:.4f}" if not np.isnan(metrics['iou'][1]) else "N/A",
                f"{metrics['hd95'][1]:.2f}" if not np.isnan(metrics['hd95'][1]) else "N/A",
                # 亚型3
                f"{metrics['dice'][2]:.4f}" if not np.isnan(metrics['dice'][2]) else "N/A",
                f"{metrics['iou'][2]:.4f}" if not np.isnan(metrics['iou'][2]) else "N/A",
                f"{metrics['hd95'][2]:.2f}" if not np.isnan(metrics['hd95'][2]) else "N/A",
                # 亚型4
                f"{metrics['dice'][3]:.4f}" if not np.isnan(metrics['dice'][3]) else "N/A",
                f"{metrics['iou'][3]:.4f}" if not np.isnan(metrics['iou'][3]) else "N/A",
                f"{metrics['hd95'][3]:.2f}" if not np.isnan(metrics['hd95'][3]) else "N/A",
                # 亚型5
                f"{metrics['dice'][4]:.4f}" if not np.isnan(metrics['dice'][4]) else "N/A",
                f"{metrics['iou'][4]:.4f}" if not np.isnan(metrics['iou'][4]) else "N/A",
                f"{metrics['hd95'][4]:.2f}" if not np.isnan(metrics['hd95'][4]) else "N/A"
            ]
            f.write("\t".join(row) + "\n")
        
        # 计算并写入平均值（仅包含存在的亚型）
        f.write("\n" + "="*120 + "\n")
        f.write("Average Metrics (only for present subtypes):\n")
        fp_volumes = [m['fp_volume'] for m in all_metrics.values()]
        f.write(f"Average FP Volume: {np.mean(fp_volumes):.2f} mm³\n")
        
        for i in range(5):
            dice_list = [m['dice'][i] for m in all_metrics.values() if not np.isnan(m['dice'][i])]
            iou_list = [m['iou'][i] for m in all_metrics.values() if not np.isnan(m['iou'][i])]
            hd95_list = [m['hd95'][i] for m in all_metrics.values() if not np.isnan(m['hd95'][i])]
            
            f.write(f"Subtype {i+1}:\n")
            f.write(f"  Average DICE: {np.mean(dice_list):.4f} (n={len(dice_list)})\n" if dice_list else f"  Average DICE: N/A (n=0)\n")
            f.write(f"  Average IoU: {np.mean(iou_list):.4f} (n={len(iou_list)})\n" if iou_list else f"  Average IoU: N/A (n=0)\n")
            f.write(f"  Average HD95: {np.mean(hd95_list):.2f} mm (n={len(hd95_list)})\n" if hd95_list else f"  Average HD95: N/A (n=0)\n")


def main():
    # 1. 配置参数
    config = {
        "image_dir": "/workspace/Task04_Hippocampus/imagesTr",
        "label_dir": "/workspace/Task04_Hippocampus/labelsTr",
        "model_path": "/workspace/ct_models/best_model_27.pth",
        "root_output_dir": "model_results_test",  # 根输出目录
        "spacing": (0.5, 0.5, 5),          # 图像spacing
        "batch_size": 1,
        "num_workers": 4
    }
    
    # 2. 创建目录结构
    os.makedirs(config["root_output_dir"], exist_ok=True)
    summary_path = os.path.join(config["root_output_dir"], "global_summary.txt")
    
    # 3. 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 4. 数据变换
    val_transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            EnsureTyped(keys=["image", "label"]),
            ConvertToMultiChannel5Classesd(keys="label"),
            SpatialPadd(
                keys=["image", "label"],
                spatial_size=(64, 64, 64),  # (D, H, W) — divisible by 32
                method="end",
            ),
            #rientationd(keys=["image", "label"], axcodes="RAS"),
            #pacingd(
            #   keys=["image", "label"],
            #   pixdim=(1.0, 1.0, 1.0),
            #   mode=("bilinear", "nearest"),
            #,
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ]
    )
    
    # 5. 加载数据集
    dataset = CustomCTDataset(
        image_dir=config["image_dir"],
        label_dir=config["label_dir"],
        transform=val_transform
    )
    dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
        pin_memory=True if torch.cuda.is_available() else False
    )
    print(f"Total patients to process: {len(dataset)}")
    
    # 6. 初始化模型（增加容错：适配不同输入尺寸）
    try:
        model = SwinUNETR(
            #img_size=(192, 192, 192),  # 显式指定输入尺寸
            in_channels=1,
            out_channels=5,
            feature_size=48,
            use_checkpoint=True if torch.cuda.is_available() else False
        ).to(device)
    except Exception as e:
        print(f"Warning: Failed to initialize SwinUNETR with img_size=(192,192,192): {e}")
        print("Trying to initialize model without explicit img_size...")
        model = SwinUNETR(
            in_channels=1,
            out_channels=5,
            feature_size=48,
            use_checkpoint=True if torch.cuda.is_available() else False
        ).to(device)
    
    # 加载模型权重（增加容错）
    try:
        model.load_state_dict(torch.load(config["model_path"], map_location=device))
    except RuntimeError as e:
        print(f"Warning: Strict loading failed: {e}")
        print("Trying non-strict loading...")
        model.load_state_dict(torch.load(config["model_path"], map_location=device), strict=False)
    model.eval()
    print("Model loaded successfully")
    
    # 后处理
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    
    # 7. 推理+可视化+指标计算（逐例处理）
    all_metrics = {}  # 存储所有患者的指标 {patient_id: metrics}
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            patient_id = batch["patient_id"][0]
            print(f"\nProcessing patient {patient_id} ({idx+1}/{len(dataset)})")
            
            # 创建患者专属目录
            patient_dir = os.path.join(config["root_output_dir"], patient_id)
            os.makedirs(patient_dir, exist_ok=True)
            
            # 数据准备
            image = batch["image"].to(device)  # (1, 1, D, H, W)
            label = batch["label"].cpu().numpy()[0]  # (5, D, H, W)
            original_image = image.cpu().numpy()[0, 0]  # (D, H, W) - 原始CT图像
            
            # 模型推理（增加容错：适配不同ROI尺寸）
            try:
                output = sliding_window_inference(
                    inputs=image,
                    roi_size=(64, 64, 32),
                    sw_batch_size=1,
                    predictor=model,
                    overlap=0.5,
                )
            except Exception as e:
                print(f"Warning: Sliding window with roi_size=(192,192,192) failed: {e}")
                print("Trying smaller roi_size=(96,96,96)...")
                output = sliding_window_inference(
                    inputs=image,
                    roi_size=(96, 96, 32),
                    sw_batch_size=1,
                    predictor=model,
                    overlap=0.5,
                )
            
            output = post_trans(output).cpu().numpy()[0]  # (5, D, H, W)
            
            # 转换为单通道标签（1-5=亚型，0=背景）
            pred_3d = np.argmax(output, axis=0) + 1
            pred_3d[np.sum(output, axis=0) < 0.5] = 0
            target_3d = np.argmax(label, axis=0) + 1
            target_3d[np.sum(label, axis=0) < 0.5] = 0
            
            # 计算定量指标（仅包含标签中存在的亚型）
            metrics = EvaluationMetrics.calculate_all_metrics(
                pred=pred_3d,
                target=target_3d,
                spacing=config["spacing"]
            )
            
            all_metrics[patient_id] = metrics
            
            # 生成可视化
            vis_path = generate_visualization(
                image_3d=original_image,
                target_3d=target_3d,
                pred_3d=pred_3d,
                metrics=metrics,
                patient_id=patient_id,
                save_dir=patient_dir
            )
            
            # 保存患者详细指标
            metrics_path = save_patient_metrics(patient_id, metrics, patient_dir)
            
            # 保存预测结果为nii文件
            nii_path = save_prediction_as_nii(
                pred_3d=pred_3d,
                patient_id=patient_id,
                dataset=dataset,
                save_dir=patient_dir
            )
            
            print(f"  Visualization saved to: {vis_path}")
            print(f"  Metrics saved to: {metrics_path}")
            if nii_path:
                print(f"  Prediction saved to: {nii_path}")
    
    # 8. 生成全局汇总文件（增加容错：空数据处理）
    print(all_metrics)
    if all_metrics:
        update_summary_file(all_metrics, summary_path)
        print(f"\nGlobal summary saved to: {summary_path}")
    else:
        print("\nWarning: No metrics to summarize (all_metrics is empty)!")
    
    # 9. 输出最终统计
    print("\n=== Processing Complete ===")
    print(f"Total patients processed: {len(all_metrics)}")
    print(f"Results stored in: {config['root_output_dir']}")
    print(f"Each patient has:")
    print(f"  - Visualization image: [ID]/[ID]_visualization.png")
    print(f"  - Detailed metrics: [ID]/[ID]_metrics.txt")
    print(f"  - Prediction nii: [ID]/[ID]_prediction.nii.gz")
    if all_metrics:
        print(f"Global summary: {summary_path}")


if __name__ == "__main__":
    main()

Using device: cuda
Total patients to process: 0
Model loaded successfully
{}


=== Processing Complete ===
Total patients processed: 0
Results stored in: model_results_test
Each patient has:
  - Visualization image: [ID]/[ID]_visualization.png
  - Detailed metrics: [ID]/[ID]_metrics.txt
  - Prediction nii: [ID]/[ID]_prediction.nii.gz


In [7]:
#!/usr/bin/env python
# coding: utf-8

import os
import torch
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, NormalizeIntensityd, Activations, AsDiscrete, MapTransform, SpatialPadd
)
from monai.networks.nets import SwinUNETR
from monai.inferers import sliding_window_inference
import torch.nn.functional as F
from scipy.spatial.distance import directed_hausdorff

# Only solve minus sign display issue, no Chinese font dependency
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.family'] = 'DejaVu Sans'  # Universal sans-serif font


class ConvertToMultiChannel5Classesd(MapTransform):
    """Convert 5 subtype labels to multi-channel format"""
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = [
                d[key] == 1,  # Subtype 1
                d[key] == 2,  # Subtype 2
                d[key] == 3,  # Subtype 3
                d[key] == 4,  # Subtype 4
                d[key] == 5   # Subtype 5
            ]
            d[key] = torch.stack(result, axis=0).float()
        return d


class CustomCTDataset(Dataset):
    """CT Dataset class (supports full loading)"""
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_info = {}
        self.data = self._load_data()
        
        # Debug: Print directory contents
        print(f"Image directory contents: {os.listdir(self.image_dir)}")
        print(f"Label directory contents: {os.listdir(self.label_dir)}")

    def _load_data(self):
        data = []
        
        # Get all files in image directory
        image_files = [f for f in os.listdir(self.image_dir) 
                      if f.endswith('.nii') or f.endswith('.nii.gz')]
        
        # Get all files in label directory
        label_files = [f for f in os.listdir(self.label_dir) 
                      if f.endswith('.nii') or f.endswith('.nii.gz')]
        
        print(f"Found {len(image_files)} image files and {len(label_files)} label files")
        
        # Create mapping of patient IDs to files (more flexible approach)
        image_dict = {}
        for img_file in image_files:
            # Try to extract patient ID (remove common suffixes)
            patient_id = img_file
            for suffix in ['_head.nii.gz', '_head.nii', '.nii.gz', '.nii']:
                if patient_id.endswith(suffix):
                    patient_id = patient_id.replace(suffix, '')
                    break
            image_dict[patient_id] = img_file
        
        label_dict = {}
        for lbl_file in label_files:
            # Try to extract patient ID (remove common suffixes)
            patient_id = lbl_file
            for suffix in ['_merged.nii', '_merged.nii.gz', '_label.nii', 
                          '_label.nii.gz', '.nii.gz', '.nii']:
                if patient_id.endswith(suffix):
                    patient_id = patient_id.replace(suffix, '')
                    break
            label_dict[patient_id] = lbl_file
        
        print(f"Found {len(image_dict)} unique image patients and {len(label_dict)} unique label patients")
        
        # Find common patient IDs
        common_patients = set(image_dict.keys()) & set(label_dict.keys())
        print(f"Common patients found: {len(common_patients)}")
        
        if not common_patients:
            print("ERROR: No matching patients found between image and label directories!")
            print(f"Image patient IDs: {list(image_dict.keys())[:5]}...")
            print(f"Label patient IDs: {list(label_dict.keys())[:5]}...")
            return data
        
        # Process each common patient
        for patient_id in sorted(common_patients):
            image_path = os.path.join(self.image_dir, image_dict[patient_id])
            label_path = os.path.join(self.label_dir, label_dict[patient_id])
            
            try:
                # Read image for metadata
                image_itk = sitk.ReadImage(image_path)
                self.image_info[patient_id] = {
                    'origin': image_itk.GetOrigin(),
                    'spacing': image_itk.GetSpacing(),
                    'direction': image_itk.GetDirection(),
                    'size': image_itk.GetSize()
                }
                
                data.append({
                    "image": image_path,
                    "label": label_path,
                    "patient_id": patient_id,
                    "original_image_path": image_path
                })
                
                print(f"  Loaded: {patient_id} (image: {image_dict[patient_id]}, label: {label_dict[patient_id]})")
                
            except Exception as e:
                print(f"Warning: Failed to read image info for patient {patient_id}: {e}")
        
        print(f"Total valid patients loaded: {len(data)}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        if self.transform:
            data = self.transform(data)
        return data


class EvaluationMetrics:
    """Evaluation metrics calculation class"""
    @staticmethod
    def dice(pred, target):
        """Calculate DICE coefficient"""
        intersection = np.sum(pred * target)
        union = np.sum(pred) + np.sum(target)
        return 2 * intersection / (union + 1e-8)

    @staticmethod
    def iou(pred, target):
        """Calculate Intersection over Union (IoU)"""
        intersection = np.sum(pred * target)
        union = np.sum(pred) + np.sum(target) - intersection
        return intersection / (union + 1e-8)

    @staticmethod
    def hd95(pred, target, spacing):
        """Calculate 95% Hausdorff Distance (mm)"""
        if np.sum(pred) == 0 or np.sum(target) == 0:
            return np.nan
        
        pred_coords = np.argwhere(pred) * spacing
        target_coords = np.argwhere(target) * spacing
        
        d1 = directed_hausdorff(pred_coords, target_coords)[0]
        d2 = directed_hausdorff(target_coords, pred_coords)[0]
        hd = max(d1, d2)
        
        all_distances = []
        for p in pred_coords:
            min_dist = np.min(np.linalg.norm(p - target_coords, axis=1))
            all_distances.append(min_dist)
        for t in target_coords:
            min_dist = np.min(np.linalg.norm(t - pred_coords, axis=1))
            all_distances.append(min_dist)
        
        return np.percentile(all_distances, 95) if all_distances else np.nan

    @staticmethod
    def fp_volume(pred, target, spacing):
        """Calculate false positive volume (mm³)"""
        fp = np.logical_and(pred > 0, target == 0).astype(np.float32)
        voxel_volume = np.prod(spacing)
        return np.sum(fp) * voxel_volume

    @staticmethod
    def calculate_all_metrics(pred, target, spacing):
        """Calculate all metrics for all subtypes (only for subtypes present in labels)"""
        metrics = {
            "dice": [],
            "iou": [],
            "hd95": [],
            "fp_volume": EvaluationMetrics.fp_volume(pred, target, spacing),
            "present_subtypes": []  # Record existing subtypes
        }
        
        # Find subtypes present in labels
        present_subtypes = np.unique(target)
        present_subtypes = [s for s in present_subtypes if s != 0]  # Exclude background
        
        for c in range(5):
            class_idx = c + 1
            # Only calculate for subtypes present in labels
            if class_idx in present_subtypes:
                pred_class = (pred == class_idx).astype(np.float32)
                target_class = (target == class_idx).astype(np.float32)
                
                metrics["dice"].append(EvaluationMetrics.dice(pred_class, target_class))
                metrics["iou"].append(EvaluationMetrics.iou(pred_class, target_class))
                metrics["hd95"].append(EvaluationMetrics.hd95(pred_class, target_class, spacing))
                metrics["present_subtypes"].append(class_idx)
            else:
                # For non-existing subtypes, don't include in calculation (marked with NaN)
                metrics["dice"].append(np.nan)
                metrics["iou"].append(np.nan)
                metrics["hd95"].append(np.nan)
        
        return metrics


def create_color_map():
    """Create color map for 5 subtypes (RGBA format)"""
    color_map = np.zeros((6, 4))  # 0=background(transparent), 1-5=subtypes
    color_map[0] = [0, 0, 0, 0]          # Background
    color_map[1] = [1, 0, 0, 0.7]        # Subtype 1: Red
    color_map[2] = [0, 1, 0, 0.7]        # Subtype 2: Green
    color_map[3] = [0, 0, 1, 0.7]        # Subtype 3: Blue
    color_map[4] = [1, 1, 0, 0.7]        # Subtype 4: Yellow
    color_map[5] = [1, 0, 1, 0.7]        # Subtype 5: Purple
    return color_map


def generate_visualization(image_3d, target_3d, pred_3d, metrics, patient_id, save_dir):
    """Generate visualization results for single patient (x-y slice + quantitative metrics)"""
    # 1. Basic configuration
    color_map = create_color_map()
    mid_slice = image_3d.shape[2] // 2  # x-y axial slice (middle layer along W axis)
    
    # 2. Extract slice data
    image_slice = image_3d[:, :, mid_slice]
    target_slice = target_3d[:, :, mid_slice]
    pred_slice = pred_3d[:, :, mid_slice]
    
    # 3. Generate colored label/prediction maps
    # Ground truth color map
    target_color = np.zeros((target_slice.shape[0], target_slice.shape[1], 4))
    for i in range(6):
        target_color[target_slice == i] = color_map[i]
    
    # Prediction color map
    pred_color = np.zeros((pred_slice.shape[0], pred_slice.shape[1], 4))
    for i in range(6):
        pred_color[pred_slice == i] = color_map[i]
    
    # 4. Create visualization image
    fig, axes = plt.subplots(1, 3, figsize=(24, 8))
    fig.suptitle(f"Patient ID: {patient_id} | Axial Slice (W-axis: {mid_slice})", fontsize=16, fontweight="bold")
    
    # 4.1 Original CT image
    axes[0].imshow(image_slice, cmap="gray", aspect="auto")
    axes[0].set_title("CT Original Image", fontsize=14)
    axes[0].axis("off")
    
    # 4.2 Ground truth label (multi-color)
    axes[1].imshow(image_slice, cmap="gray", aspect="auto", alpha=0.8)
    axes[1].imshow(target_color, aspect="auto")
    axes[1].set_title("Ground Truth (Color-Coded Subtypes)", fontsize=14)
    axes[1].axis("off")
    
    # 4.3 Model prediction (multi-color)
    axes[2].imshow(image_slice, cmap="gray", aspect="auto", alpha=0.8)
    axes[2].imshow(pred_color, aspect="auto")
    axes[2].set_title("Model Prediction (Color-Coded Subtypes)", fontsize=14)
    axes[2].axis("off")
    
    # 5. Add quantitative metrics text (below image)
    metrics_text = f"Evaluation Metrics:\nFP Volume: {metrics['fp_volume']:.2f} mm³\n"
    # Only show metrics for existing subtypes
    for i in metrics["present_subtypes"]:
        idx = i - 1
        metrics_text += f"Subtype {i} - DICE: {metrics['dice'][idx]:.4f} | IoU: {metrics['iou'][idx]:.4f} | HD95: {np.nan if np.isnan(metrics['hd95'][idx]) else metrics['hd95'][idx]:.2f} mm\n"
    
    fig.text(0.05, 0.02, metrics_text, fontsize=10, verticalalignment='bottom',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 6. Add color legend
    handles = [plt.Rectangle((0, 0), 1, 1, facecolor=color_map[i][:3], alpha=color_map[i][3]) for i in range(1, 6)]
    labels = [f"Subtype {i}" for i in range(1, 6)]
    fig.legend(handles, labels, loc="upper right", bbox_to_anchor=(0.98, 0.95), ncol=5, fontsize=10)
    
    # 7. Save image
    plt.tight_layout(rect=[0.05, 0.1, 0.95, 0.95])  # Reserve space for metrics and legend
    save_path = os.path.join(save_dir, f"{patient_id}_visualization.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    return save_path


def save_patient_metrics(patient_id, metrics, save_dir):
    """Save detailed metrics for single patient to file"""
    metrics_path = os.path.join(save_dir, f"{patient_id}_metrics.txt")
    with open(metrics_path, "w", encoding="utf-8") as f:
        f.write(f"Patient ID: {patient_id}\n")
        f.write("="*80 + "\n")
        f.write(f"False Positive Volume: {metrics['fp_volume']:.2f} mm³\n\n")
        f.write("Present Subtypes in Ground Truth: " + ", ".join(map(str, metrics["present_subtypes"])) + "\n\n")
        f.write("Subtype-wise Metrics:\n")
        # Only save metrics for existing subtypes
        for i in metrics["present_subtypes"]:
            idx = i - 1
            f.write(f"Subtype {i}:\n")
            f.write(f"  DICE: {metrics['dice'][idx]:.4f}\n")
            f.write(f"  IoU: {metrics['iou'][idx]:.4f}\n")
            f.write(f"  HD95: {np.nan if np.isnan(metrics['hd95'][idx]) else metrics['hd95'][idx]:.2f} mm\n\n")
    return metrics_path


def save_prediction_as_nii(pred_3d, patient_id, dataset, save_dir):
    """Save prediction results as nii format (with error handling)"""
    try:
        # Get spatial information from original image
        if patient_id not in dataset.image_info:
            print(f"Warning: No image info for patient {patient_id}, using default spacing")
            # Use default spatial information (fallback)
            image_info = {
                'origin': (0.0, 0.0, 0.0),
                'spacing': (1.0, 1.0, 1.0),
                'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
            }
        else:
            image_info = dataset.image_info[patient_id]
        
        # Adjust dimension order: numpy (D,H,W) → SimpleITK (W,H,D)
        pred_3d_transposed = np.transpose(pred_3d, (2, 1, 0))
        
        # Convert data type to integer
        pred_itk = sitk.GetImageFromArray(pred_3d_transposed.astype(np.int16))
        
        # Set spatial information (consistent with original image)
        pred_itk.SetOrigin(image_info['origin'])
        pred_itk.SetSpacing(image_info['spacing'])
        pred_itk.SetDirection(image_info['direction'])
        
        # Save as nii file
        save_path = os.path.join(save_dir, f"{patient_id}_prediction.nii.gz")
        sitk.WriteImage(pred_itk, save_path, useCompression=True)
        print(f"  Successfully saved prediction to: {save_path}")
        return save_path
    except Exception as e:
        print(f"  Error saving nii for patient {patient_id}: {e}")
        return None


def update_summary_file(all_metrics, summary_path):
    """Update global summary file"""
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write("Global Summary of All Patients\n")
        f.write("="*120 + "\n")
        # Header
        headers = ["Patient ID", "FP_Volume(mm³)", 
                   "S1_DICE", "S1_IoU", "S1_HD95",
                   "S2_DICE", "S2_IoU", "S2_HD95",
                   "S3_DICE", "S3_IoU", "S3_HD95",
                   "S4_DICE", "S4_IoU", "S4_HD95",
                   "S5_DICE", "S5_IoU", "S5_HD95"]
        f.write("\t".join(headers) + "\n")
        
        # Write patient data row by row
        for patient_id, metrics in all_metrics.items():
            row = [
                patient_id,
                f"{metrics['fp_volume']:.2f}",
                # Subtype 1
                f"{metrics['dice'][0]:.4f}" if not np.isnan(metrics['dice'][0]) else "N/A",
                f"{metrics['iou'][0]:.4f}" if not np.isnan(metrics['iou'][0]) else "N/A",
                f"{metrics['hd95'][0]:.2f}" if not np.isnan(metrics['hd95'][0]) else "N/A",
                # Subtype 2
                f"{metrics['dice'][1]:.4f}" if not np.isnan(metrics['dice'][1]) else "N/A",
                f"{metrics['iou'][1]:.4f}" if not np.isnan(metrics['iou'][1]) else "N/A",
                f"{metrics['hd95'][1]:.2f}" if not np.isnan(metrics['hd95'][1]) else "N/A",
                # Subtype 3
                f"{metrics['dice'][2]:.4f}" if not np.isnan(metrics['dice'][2]) else "N/A",
                f"{metrics['iou'][2]:.4f}" if not np.isnan(metrics['iou'][2]) else "N/A",
                f"{metrics['hd95'][2]:.2f}" if not np.isnan(metrics['hd95'][2]) else "N/A",
                # Subtype 4
                f"{metrics['dice'][3]:.4f}" if not np.isnan(metrics['dice'][3]) else "N/A",
                f"{metrics['iou'][3]:.4f}" if not np.isnan(metrics['iou'][3]) else "N/A",
                f"{metrics['hd95'][3]:.2f}" if not np.isnan(metrics['hd95'][3]) else "N/A",
                # Subtype 5
                f"{metrics['dice'][4]:.4f}" if not np.isnan(metrics['dice'][4]) else "N/A",
                f"{metrics['iou'][4]:.4f}" if not np.isnan(metrics['iou'][4]) else "N/A",
                f"{metrics['hd95'][4]:.2f}" if not np.isnan(metrics['hd95'][4]) else "N/A"
            ]
            f.write("\t".join(row) + "\n")
        
        # Calculate and write averages (only for existing subtypes)
        f.write("\n" + "="*120 + "\n")
        f.write("Average Metrics (only for present subtypes):\n")
        fp_volumes = [m['fp_volume'] for m in all_metrics.values()]
        f.write(f"Average FP Volume: {np.mean(fp_volumes):.2f} mm³\n")
        
        for i in range(5):
            dice_list = [m['dice'][i] for m in all_metrics.values() if not np.isnan(m['dice'][i])]
            iou_list = [m['iou'][i] for m in all_metrics.values() if not np.isnan(m['iou'][i])]
            hd95_list = [m['hd95'][i] for m in all_metrics.values() if not np.isnan(m['hd95'][i])]
            
            f.write(f"Subtype {i+1}:\n")
            f.write(f"  Average DICE: {np.mean(dice_list):.4f} (n={len(dice_list)})\n" if dice_list else f"  Average DICE: N/A (n=0)\n")
            f.write(f"  Average IoU: {np.mean(iou_list):.4f} (n={len(iou_list)})\n" if iou_list else f"  Average IoU: N/A (n=0)\n")
            f.write(f"  Average HD95: {np.mean(hd95_list):.2f} mm (n={len(hd95_list)})\n" if hd95_list else f"  Average HD95: N/A (n=0)\n")


def main():
    # 1. Configuration parameters
    config = {
        "image_dir": "/workspace/Task04_Hippocampus/imagesTr",
        "label_dir": "/workspace/Task04_Hippocampus/labelsTr",
        "model_path": "/workspace/ct_models/best_model_27.pth",
        "root_output_dir": "model_results_test",  # Root output directory
        "spacing": (0.5, 0.5, 5),          # Image spacing
        "batch_size": 1,
        "num_workers": 4
    }
    
    # 2. Create directory structure
    os.makedirs(config["root_output_dir"], exist_ok=True)
    summary_path = os.path.join(config["root_output_dir"], "global_summary.txt")
    
    # 3. Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 4. Check if directories exist
    print(f"Checking if image directory exists: {os.path.exists(config['image_dir'])}")
    print(f"Checking if label directory exists: {os.path.exists(config['label_dir'])}")
    
    if not os.path.exists(config['image_dir']):
        print(f"ERROR: Image directory does not exist: {config['image_dir']}")
        return
    if not os.path.exists(config['label_dir']):
        print(f"ERROR: Label directory does not exist: {config['label_dir']}")
        return
    
    # 5. Data transformations
    val_transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            EnsureTyped(keys=["image", "label"]),
            ConvertToMultiChannel5Classesd(keys="label"),
            SpatialPadd(
                keys=["image", "label"],
                spatial_size=(64, 64, 64),  # (D, H, W) — divisible by 32
                method="end",
            ),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ]
    )
    
    # 6. Load dataset
    dataset = CustomCTDataset(
        image_dir=config["image_dir"],
        label_dir=config["label_dir"],
        transform=val_transform
    )
    
    if len(dataset) == 0:
        print("ERROR: No data loaded! Please check:")
        print(f"1. Image directory: {config['image_dir']}")
        print(f"2. Label directory: {config['label_dir']}")
        print("3. File naming patterns")
        print("4. File extensions (.nii or .nii.gz)")
        return
    
    dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
        pin_memory=True if torch.cuda.is_available() else False
    )
    print(f"Total patients to process: {len(dataset)}")
    
    # 7. Initialize model
    try:
        model = SwinUNETR(
            img_size=(192, 192, 192),  # Explicitly specify input size
            in_channels=1,
            out_channels=5,
            feature_size=48,
            use_checkpoint=True if torch.cuda.is_available() else False
        ).to(device)
    except Exception as e:
        print(f"Warning: Failed to initialize SwinUNETR with img_size=(192,192,192): {e}")
        print("Trying to initialize model without explicit img_size...")
        model = SwinUNETR(
            in_channels=1,
            out_channels=5,
            feature_size=48,
            use_checkpoint=True if torch.cuda.is_available() else False
        ).to(device)
    
    # Load model weights (with error handling)
    try:
        model.load_state_dict(torch.load(config["model_path"], map_location=device))
        print("Model loaded successfully with strict=True")
    except RuntimeError as e:
        print(f"Warning: Strict loading failed: {e}")
        print("Trying non-strict loading...")
        model.load_state_dict(torch.load(config["model_path"], map_location=device), strict=False)
        print("Model loaded successfully with strict=False")
    model.eval()
    
    # Post-processing
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    
    # 8. Inference + Visualization + Metrics calculation (process case by case)
    all_metrics = {}  # Store metrics for all patients {patient_id: metrics}
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            patient_id = batch["patient_id"][0]
            print(f"\nProcessing patient {patient_id} ({idx+1}/{len(dataset)})")
            
            # Create patient-specific directory
            patient_dir = os.path.join(config["root_output_dir"], patient_id)
            os.makedirs(patient_dir, exist_ok=True)
            
            # Data preparation
            image = batch["image"].to(device)  # (1, 1, D, H, W)
            label = batch["label"].cpu().numpy()[0]  # (5, D, H, W)
            original_image = image.cpu().numpy()[0, 0]  # (D, H, W) - original CT image
            
            # Model inference (with error handling for different ROI sizes)
            try:
                output = sliding_window_inference(
                    inputs=image,
                    roi_size=(64, 64, 64),
                    sw_batch_size=1,
                    predictor=model,
                    overlap=0.5,
                )
            except Exception as e:
                print(f"Warning: Sliding window with roi_size=(64,64,64) failed: {e}")
                print("Trying smaller roi_size=(32,32,32)...")
                try:
                    output = sliding_window_inference(
                        inputs=image,
                        roi_size=(32, 32, 32),
                        sw_batch_size=1,
                        predictor=model,
                        overlap=0.5,
                    )
                except Exception as e2:
                    print(f"Error: All sliding window attempts failed: {e2}")
                    continue
            
            output = post_trans(output).cpu().numpy()[0]  # (5, D, H, W)
            
            # Convert to single-channel labels (1-5=subtypes, 0=background)
            pred_3d = np.argmax(output, axis=0) + 1
            pred_3d[np.sum(output, axis=0) < 0.5] = 0
            target_3d = np.argmax(label, axis=0) + 1
            target_3d[np.sum(label, axis=0) < 0.5] = 0
            
            # Calculate quantitative metrics (only for subtypes present in labels)
            metrics = EvaluationMetrics.calculate_all_metrics(
                pred=pred_3d,
                target=target_3d,
                spacing=config["spacing"]
            )
            
            all_metrics[patient_id] = metrics
            
            # Generate visualization
            vis_path = generate_visualization(
                image_3d=original_image,
                target_3d=target_3d,
                pred_3d=pred_3d,
                metrics=metrics,
                patient_id=patient_id,
                save_dir=patient_dir
            )
            
            # Save detailed patient metrics
            metrics_path = save_patient_metrics(patient_id, metrics, patient_dir)
            
            # Save prediction as nii file
            nii_path = save_prediction_as_nii(
                pred_3d=pred_3d,
                patient_id=patient_id,
                dataset=dataset,
                save_dir=patient_dir
            )
            
            print(f"  Visualization saved to: {vis_path}")
            print(f"  Metrics saved to: {metrics_path}")
            if nii_path:
                print(f"  Prediction saved to: {nii_path}")
    
    # 9. Generate global summary file (with error handling for empty data)
    if all_metrics:
        update_summary_file(all_metrics, summary_path)
        print(f"\nGlobal summary saved to: {summary_path}")
    else:
        print("\nWarning: No metrics to summarize (all_metrics is empty)!")
    
    # 10. Output final statistics
    print("\n=== Processing Complete ===")
    print(f"Total patients processed: {len(all_metrics)}")
    print(f"Results stored in: {config['root_output_dir']}")
    if all_metrics:
        print(f"Each patient has:")
        print(f"  - Visualization image: [ID]/[ID]_visualization.png")
        print(f"  - Detailed metrics: [ID]/[ID]_metrics.txt")
        print(f"  - Prediction nii: [ID]/[ID]_prediction.nii.gz")
        print(f"Global summary: {summary_path}")


if __name__ == "__main__":
    main()

Using device: cuda
Checking if image directory exists: True
Checking if label directory exists: True
Found 24 image files and 24 label files
Found 24 unique image patients and 24 unique label patients
Common patients found: 24
  Loaded: hippocampus_001 (image: hippocampus_001.nii.gz, label: hippocampus_001.nii.gz)
  Loaded: hippocampus_003 (image: hippocampus_003.nii.gz, label: hippocampus_003.nii.gz)
  Loaded: hippocampus_004 (image: hippocampus_004.nii.gz, label: hippocampus_004.nii.gz)
  Loaded: hippocampus_006 (image: hippocampus_006.nii.gz, label: hippocampus_006.nii.gz)
  Loaded: hippocampus_007 (image: hippocampus_007.nii.gz, label: hippocampus_007.nii.gz)
  Loaded: hippocampus_008 (image: hippocampus_008.nii.gz, label: hippocampus_008.nii.gz)
  Loaded: hippocampus_011 (image: hippocampus_011.nii.gz, label: hippocampus_011.nii.gz)
  Loaded: hippocampus_014 (image: hippocampus_014.nii.gz, label: hippocampus_014.nii.gz)
  Loaded: hippocampus_015 (image: hippocampus_015.nii.gz, lab

  win_data = inputs[unravel_slice[0]].to(sw_device)
  out[idx_zm] += p


  Successfully saved prediction to: model_results_test/hippocampus_001/hippocampus_001_prediction.nii.gz
  Visualization saved to: model_results_test/hippocampus_001/hippocampus_001_visualization.png
  Metrics saved to: model_results_test/hippocampus_001/hippocampus_001_metrics.txt
  Prediction saved to: model_results_test/hippocampus_001/hippocampus_001_prediction.nii.gz

Processing patient hippocampus_003 (2/24)
  Successfully saved prediction to: model_results_test/hippocampus_003/hippocampus_003_prediction.nii.gz
  Visualization saved to: model_results_test/hippocampus_003/hippocampus_003_visualization.png
  Metrics saved to: model_results_test/hippocampus_003/hippocampus_003_metrics.txt
  Prediction saved to: model_results_test/hippocampus_003/hippocampus_003_prediction.nii.gz

Processing patient hippocampus_004 (3/24)
  Successfully saved prediction to: model_results_test/hippocampus_004/hippocampus_004_prediction.nii.gz
  Visualization saved to: model_results_test/hippocampus_00