In [1]:
import os
import glob
import torch
import time
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd, Spacingd,
    RandSpatialCropd, RandFlipd, NormalizeIntensityd, RandScaleIntensityd, RandShiftIntensityd,
    Activations, AsDiscrete, MapTransform, SpatialPadd
)
from monai.networks.nets import SwinUNETR
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from torch.nn import functional as F

In [2]:
class CustomCTDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        """加载CT图像和对应的标签数据"""
        data = []
        # 获取所有患者ID
        patient_ids = set()
        for filename in os.listdir(self.image_dir):
            if filename.endswith(".nii.gz"):
                patient_id = filename.split(".nii.gz")[0]
                patient_ids.add(patient_id)
        
        
        # 匹配图像和标签
        for patient_id in patient_ids:
            image_path = os.path.join(self.image_dir, f"{patient_id}.nii.gz")
            label_path = os.path.join(self.label_dir, f"{patient_id}.nii.gz")
            
            if os.path.exists(image_path) and os.path.exists(label_path):
                data.append({
                    "image": image_path,
                    "label": label_path
                })
            else:
                print(f"Warning: Missing data for patient {patient_id}")
                print(image_path, label_path)
                
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]
        if self.transform:
            data = self.transform(data)
        return data

In [None]:
CustomCTDataset("/workspace/Task04_Hippocampus/imagesTr", "/workspace/Task04_Hippocampus/labelsTr")

In [3]:
class ConvertToMultiChannel5Classesd(MapTransform):
    """将5亚型标签转换为多通道格式"""
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # 为每个亚型创建一个通道
            result = [
                d[key] == 1,  # 亚型1
                d[key] == 2,  # 亚型2
                d[key] == 3,  # 亚型3
                d[key] == 4,  # 亚型4
                d[key] == 5   # 亚型5
            ]
            d[key] = torch.stack(result, axis=0).float()
        return d

In [14]:


# 定义数据变换
train_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),  # 单模态数据，添加通道维度
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannel5Classesd(keys="label"),  # 转换为5通道标签
        SpatialPadd(
            keys=["image", "label"],
            spatial_size=(64, 64, 64),  # (D, H, W) — divisible by 32
            method="end",
        ),
        #rientationd(keys=["image", "label"], axcodes="RAS"),  # 调整方向
        #pacingd(
        #   keys=["image", "label"],
        #   pixdim=(1.0, 1.0, 1.0),  # 重采样到1mm×1mm×1mm
        #   mode=("bilinear", "nearest"),  # 图像双线性插值，标签最近邻插值
        #,
        #andSpatialCropd(keys=["image", "label"], roi_size=[192, 192, 192], random_size=False),  # 随机裁剪
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),  # 随机翻转
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),  # 归一化
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),  # 随机缩放强度
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),  # 随机偏移强度
    ]
)

val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannel5Classesd(keys="label"),
        SpatialPadd(
            keys=["image", "label"],
            spatial_size=(64, 64, 64),  # (D, H, W) — divisible by 32
            method="end",
        ),
        #rientationd(keys=["image", "label"], axcodes="RAS"),
        #pacingd(
        #   keys=["image", "label"],
        #   pixdim=(1.0, 1.0, 1.0),
        #   mode=("bilinear", "nearest"),
        #,
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

In [15]:
# 数据路径
image_dir = "/workspace/Task04_Hippocampus/imagesTr"
label_dir = "/workspace/Task04_Hippocampus/labelsTr"

# 创建数据集
full_dataset = CustomCTDataset(
    image_dir=image_dir,
    label_dir=label_dir,
    transform=None
)

# 划分训练集和验证集
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# 应用变换
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [16]:
sample = train_dataset[0]
print(sample["image"].shape)

torch.Size([1, 64, 64, 64])


In [17]:
batch = next(iter(train_loader))
print(batch["image"].shape)

torch.Size([1, 1, 64, 64, 64])


In [19]:
# 模型配置
max_epochs = 50
val_interval = 1
VAL_AMP = True

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型
model = SwinUNETR(
    #img_size=(192, 192, 192),  # 输入图像大小
    in_channels=1,             # 单模态CT
    out_channels=5,            # 5个亚型
    feature_size=48,
    use_checkpoint=True
).to(device)

# 损失函数和优化器
loss_function = DiceLoss(
    smooth_nr=0, 
    smooth_dr=1e-5, 
    squared_pred=True, 
    to_onehot_y=False, 
    sigmoid=True
)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

# 评估指标
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

# 后处理
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# 推理函数
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(64, 64, 32),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


# 训练过程
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
# 5个亚型的指标存储
metric_values_class1 = []
metric_values_class2 = []
metric_values_class3 = []
metric_values_class4 = []
metric_values_class5 = []

total_start = time.time()

# 创建模型保存目录
os.makedirs("ct_models", exist_ok=True)

for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs = batch_data["image"].to(device)
        labels = batch_data["label"].to(device)

        # 根据需要调整pad大小以匹配模型输入
        pad_sizes = (0, 0, 0, 0, 0, 0)  # 可根据实际数据尺寸调整
        # if any(pad_sizes):
        #     inputs = F.pad(inputs, pad_sizes, mode="constant", value=0).to(device)

        # 前向传播
        outputs = model(inputs)
        
        # 如果有padding，裁剪回原始大小
        # if any(pad_sizes):
        #     outputs = outputs[:, :, pad_sizes[0]:-pad_sizes[1], 
        #                      pad_sizes[2]:-pad_sizes[3], 
        #                      pad_sizes[4]:-pad_sizes[5]]

        # 计算损失
        loss = loss_function(outputs, labels)
        epoch_loss += loss.item()

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print(
        #     f"{step}/{len(train_dataset) // train_loader.batch_size}"
        #     f", train_loss: {loss.item():.4f}"
        #     f", step time: {(time.time() - step_start):.4f}"
        # )

    # 学习率调整
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    # 验证过程
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs = val_data["image"].to(device)
                val_labels = val_data["label"].to(device)

                # 处理padding
                # if any(pad_sizes):
                #     val_inputs = F.pad(val_inputs, pad_sizes, mode="constant", value=0).to(device)

                # 推理
                val_outputs = inference(val_inputs)
                
                # 裁剪回原始大小
                if any(pad_sizes):
                    val_outputs = val_outputs[:, :, pad_sizes[0]:-pad_sizes[1], 
                                            pad_sizes[2]:-pad_sizes[3], 
                                            pad_sizes[4]:-pad_sizes[5]]

                # 后处理
                val_outputs = [post_trans(i) for i in val_outputs]
                
                # 计算指标
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)

            # 聚合指标
            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()

            # 每个类别的Dice
            metric_class1 = metric_batch[0].item()
            metric_class2 = metric_batch[1].item()
            metric_class3 = metric_batch[2].item()
            metric_class4 = metric_batch[3].item()
            metric_class5 = metric_batch[4].item()

            metric_values_class1.append(metric_class1)
            metric_values_class2.append(metric_class2)
            metric_values_class3.append(metric_class3)
            metric_values_class4.append(metric_class4)
            metric_values_class5.append(metric_class5)

            # 重置指标
            dice_metric.reset()
            dice_metric_batch.reset()

            # 保存最佳模型
            current_time = time.localtime()
            formatted_time = time.strftime("%Y-%m-%d %H:%M:%S", current_time)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(
                    model.state_dict(),
                    os.path.join("ct_models", f"best_model_{epoch + 1}.pth"),
                )
                print("Saved new best metric model")

            # 打印和记录指标
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nclass1: {metric_class1:.4f} class2: {metric_class2:.4f} "
                f"class3: {metric_class3:.4f} class4: {metric_class4:.4f} class5: {metric_class5:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

            # 写入日志文件
            with open('ct_training_log.txt', 'a+') as f:
                f.write(
                    f"epoch: {epoch + 1} time: {formatted_time}\n"
                    f"mean dice: {metric:.4f}\n"
                    f"class1: {metric_class1:.4f} class2: {metric_class2:.4f} "
                    f"class3: {metric_class3:.4f} class4: {metric_class4:.4f} class5: {metric_class5:.4f}\n"
                    f"best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}\n\n"
                )

    print(f"epoch {epoch + 1} time: {(time.time() - epoch_start):.4f}")

total_time = time.time() - total_start
print(f"Training completed! Best metric: {best_metric:.4f} at epoch {best_metric_epoch}, total time: {total_time:.4f}s")

# 保存最终结果
with open('ct_training_final.txt', 'w') as f:
    f.write(f"Training completed! Best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}\n")
    f.write(f"Total training time: {total_time:.4f} seconds\n")
    f.write(f"Class 1 best dice: {max(metric_values_class1):.4f}\n")
    f.write(f"Class 2 best dice: {max(metric_values_class2):.4f}\n")
    f.write(f"Class 3 best dice: {max(metric_values_class3):.4f}\n")
    f.write(f"Class 4 best dice: {max(metric_values_class4):.4f}\n")
    f.write(f"Class 5 best dice: {max(metric_values_class5):.4f}\n")

----------
epoch 1/50
epoch 1 average loss: 0.9782


  with torch.cuda.amp.autocast():
  out[idx_zm] += p


Saved new best metric model
current epoch: 1 current mean dice: 0.1094
class1: 0.0640 class2: 0.1548 class3: 0.0000 class4: 0.0000 class5: 0.0000
best mean dice: 0.1094 at epoch: 1
epoch 1 time: 6.1932
----------
epoch 2/50
epoch 2 average loss: 0.9553
Saved new best metric model
current epoch: 2 current mean dice: 0.1650
class1: 0.0456 class2: 0.2845 class3: 0.0000 class4: 0.0000 class5: 0.0000
best mean dice: 0.1650 at epoch: 2
epoch 2 time: 6.1523
----------
epoch 3/50
epoch 3 average loss: 0.9339
Saved new best metric model
current epoch: 3 current mean dice: 0.2875
class1: 0.0455 class2: 0.5295 class3: 0.0000 class4: 0.0000 class5: 0.0000
best mean dice: 0.2875 at epoch: 3
epoch 3 time: 6.1318
----------
epoch 4/50
epoch 4 average loss: 0.9157
Saved new best metric model
current epoch: 4 current mean dice: 0.3407
class1: 0.0445 class2: 0.6369 class3: 0.0000 class4: 0.0000 class5: 0.0000
best mean dice: 0.3407 at epoch: 4
epoch 4 time: 6.1392
----------
epoch 5/50
epoch 5 average l