In [2]:
import os
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from tqdm import tqdm

In [3]:


# 配置参数
TEMPLATE_IDS = list(range(1001, 1010))  # 模板编号1001-1009
TARGET_IDS = [1010]                     # 目标数据编号
LABEL_CATEGORIES = ['Ao', 'Heart', 'LA', 'LV', 'LV_Myo', 'PA', 'RA', 'RV']
OUTPUT_DIR = 'results'
VISUALIZATION_DIR = os.path.join(OUTPUT_DIR, 'visualization')

In [4]:


def create_directories():
    """创建输出目录结构"""
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(VISUALIZATION_DIR, exist_ok=True)

def load_nii(path):
    """加载NIfTI文件并转换为SimpleITK图像"""
    img = sitk.ReadImage(path)
    return img

def register_images(fixed_img, moving_img):
    """执行弹性配准"""
    elastix = sitk.ElastixImageFilter()
    elastix.SetFixedImage(fixed_img)
    elastix.SetMovingImage(moving_img)
    
    # 配准参数设置（基于心脏CTA特点优化）
    parameter_map = sitk.GetDefaultParameterMap('affine')
    parameter_map['NumberOfResolutions'] = ['4']
    parameter_map['MaximumNumberOfIterations'] = ['2000']
    parameter_map['FinalGridSpacingInPhysicalUnits'] = ['10']
    
    elastix.SetParameterMap(parameter_map)
    elastix.Execute()
    
    return elastix.GetResultImage(), elastix.GetTransformParameterMap()

def visualize_registration(fixed, moving, registered, target_id, template_id):
    """可视化配准结果"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 显示中间切片
    slice_idx = fixed.GetSize()[2] // 2
    
    # 固定图像
    axes[0].imshow(sitk.GetArrayFromImage(fixed)[slice_idx], cmap='gray')
    axes[0].set_title(f'Target {target_id}')
    
    # 移动图像
    axes[1].imshow(sitk.GetArrayFromImage(moving)[slice_idx], cmap='gray')
    axes[1].set_title(f'Template {template_id}')
    
    # 配准结果
    axes[2].imshow(sitk.GetArrayFromImage(registered)[slice_idx], cmap='gray')
    axes[2].set_title(f'Registered {template_id}')
    
    plt.savefig(os.path.join(VISUALIZATION_DIR, 
              f'registration_{target_id}_template_{template_id}.png'))
    plt.close()

def process_target(target_id):
    """处理单个目标图像"""
    print(f"\n正在处理目标 {target_id}...")
    
    # 加载目标图像
    fixed_img = load_nii(f'ct_train_{target_id}_imageROI.nii')
    
    # 存储所有变形后的标签
    deformed_labels = {category: [] for category in LABEL_CATEGORIES}
    
    # 获取一个切片索引
    slice_idx = fixed_img.GetSize()[2] // 2  # 使用固定图像的中心切片
    
    # 遍历模板图像
    for template_id in tqdm(TEMPLATE_IDS, desc="正在处理模板"):
        # 加载模板图像
        moving_img = load_nii(f'ct_train_{template_id}_imageROI.nii')
        
        # 执行配准
        registered_img, transform_params = register_images(fixed_img, moving_img)
        
        # 可视化配准结果
        visualize_registration(fixed_img, moving_img, registered_img, 
                              target_id, template_id)
        
        # 处理每个标签
        for category in LABEL_CATEGORIES:
            # 加载模板标签
            label_img = load_nii(f'ct_train_{template_id}_{category}.nii')
            
            # 应用变换
            warped_label = sitk.Transformix(label_img, transform_params)
            
            # 转换为数组
            label_array = sitk.GetArrayFromImage(warped_label)
            deformed_labels[category].append(label_array)
    
    # 融合标签（多数投票）
    final_segmentation = {}
    for category in LABEL_CATEGORIES:
        stacked = np.stack(deformed_labels[category], axis=0)
        votes = np.sum(stacked, axis=0)
        final_segmentation[category] = (votes > len(TEMPLATE_IDS)//2).astype(np.uint8)
        
        # 可视化分割结果
        plt.figure(figsize=(8, 6))
        plt.imshow(final_segmentation[category][:, :, slice_idx], cmap='jet')
        plt.title(f'{category} result')
        plt.colorbar()
        plt.savefig(os.path.join(VISUALIZATION_DIR,
                  f'seg_{target_id}_{category}.png'))
        plt.close()
    
    # 保存结果
    for category in LABEL_CATEGORIES:
        output_img = sitk.GetImageFromArray(final_segmentation[category])
        output_img.CopyInformation(fixed_img)
        sitk.WriteImage(output_img, 
                       os.path.join(OUTPUT_DIR, f'ct_train_{target_id}_{category}_seg.nii'))

def main():
    create_directories()
    for target_id in TARGET_IDS:
        process_target(target_id)

if __name__ == '__main__':
    main()


正在处理目标 1010...


正在处理模板: 100%|██████████| 9/9 [03:31<00:00, 23.46s/it]
