In [None]:
import matplotlib.pyplot as plt
import cv2
img = cv2.imread('VOCDatasets/SegmentationClassPNG/CHNCXR_0019_0.png',cv2.IMREAD_GRAYSCALE)
plt.figure(figsize=(10,10))
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.title('Image')
plt.show()

通过网盘分享的文件：LungDataset.zip
链接: https://pan.baidu.com/s/18Vy_cH0DfiXhjJLMC-YrJw?pwd=atqj 提取码: atqj 
--来自百度网盘超级会员v5的分享

In [None]:
## 导入相关库
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from PIL import Image


In [None]:
## 创建随机种子
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
# 配置参数
class Config:
    # 数据集路径
    DATA_ROOT = Path('LungDataset')  # 修改为处理后的数据集路径
    IMAGES_DIR = DATA_ROOT / 'images'
    MASKS_DIR = DATA_ROOT / 'masks'
    
    # 训练参数
    BATCH_SIZE = 4
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 50
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    IMG_SIZE = 512  # 调整图像大小为512x512
    SEED = 42
    VAL_SPLIT = 0.2  # 20%的数据用于验证
    
    # 模型参数
    IN_CHANNELS = 1  # 灰度图像
    OUT_CHANNELS = 1  # 二元分割
    FEATURES = [64, 128, 256, 512]  # UNet特征通道数
    
    # 保存路径
    CHECKPOINTS_DIR = Path('checkpoints')
    LOGS_DIR = Path('logs')
    RESULTS_DIR = Path('results')
    
    # 创建必要的目录
    def create_directories(self):
        os.makedirs(self.CHECKPOINTS_DIR, exist_ok=True)
        os.makedirs(self.LOGS_DIR, exist_ok=True)
        os.makedirs(self.RESULTS_DIR, exist_ok=True)

cfg = Config()
cfg.create_directories()
set_seed(cfg.SEED)

In [None]:
# 定义U-Net模型的双卷积块
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=None):
        super(UNet, self).__init__()
        if features is None:
            features = [64, 128, 256, 512]
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 下采样部分
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # 上采样部分
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature*2, feature))
        
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        skip_connections = []
        
        # 下采样路径
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]  # 反转列表
        
        # 上采样路径
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # 转置卷积
            skip_connection = skip_connections[idx//2]
            
            # 处理尺寸不匹配的情况
            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)  # 双卷积
        
        x = self.final_conv(x)
        x = self.sigmoid(x)
        return x

In [None]:
# 自定义数据增强转换
class RandomFlip:
    def __init__(self, p=0.5):
        self.p = p
    
    def __call__(self, img, mask):
        if random.random() < self.p:
            img = TF.hflip(img)
            mask = TF.hflip(mask)
        if random.random() < self.p:
            img = TF.vflip(img)
            mask = TF.vflip(mask)
        return img, mask

class RandomRotation:
    def __init__(self, p=0.5):
        self.p = p
        self.angles = [0, 90, 180, 270]
    
    def __call__(self, img, mask):
        if random.random() < self.p:
            angle = random.choice(self.angles)
            img = TF.rotate(img, angle)
            mask = TF.rotate(mask, angle)
        return img, mask

class RandomBrightnessContrast:
    def __init__(self, p=0.2, brightness=0.2, contrast=0.2):
        self.p = p
        self.brightness = brightness
        self.contrast = contrast
    
    def __call__(self, img, mask):
        if random.random() < self.p:
            brightness_factor = random.uniform(1-self.brightness, 1+self.brightness)
            img = TF.adjust_brightness(img, brightness_factor)
        
        if random.random() < self.p:
            contrast_factor = random.uniform(1-self.contrast, 1+self.contrast)
            img = TF.adjust_contrast(img, contrast_factor)
            
        return img, mask

In [None]:
class LungSegmentationDataset(Dataset):
    def __init__(self, images_paths, phase="train"):
        self.images_paths = images_paths
        self.phase = phase
        
        # 基本转换
        self.resize = transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), 
                                       interpolation=transforms.InterpolationMode.NEAREST)
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize((0.5,), (0.5,))
        
        # 训练时的数据增强
        if phase == "train":
            self.augmentations = [
                RandomFlip(p=0.5),
                RandomRotation(p=0.5),
                RandomBrightnessContrast(p=0.2)
            ]
        else:
            self.augmentations = []
    
    def __len__(self):
        return len(self.images_paths)
    
    def __getitem__(self, idx):
        # 获取图像路径
        img_path = self.images_paths[idx]
        
        # 构建对应的掩码路径（文件名相同，目录不同）
        mask_path = cfg.MASKS_DIR / img_path.name
        
        # 读取图像和掩码
        image = Image.open(str(img_path)).convert("L")  # 转换为灰度图
        mask = Image.open(str(mask_path)).convert("L")  # 转换为灰度图
        
        # 调整大小
        image = self.resize(image)
        mask = self.resize(mask)
        
        # 应用数据增强（如果是训练阶段）
        for aug in self.augmentations:
            image, mask = aug(image, mask)
        
        # 转换为张量
        image = self.to_tensor(image)
        
        # 归一化图像
        image = self.normalize(image)
        
        # 处理掩码（二值化并归一化）
        mask = torch.from_numpy(np.array(mask))
        mask = (mask > 127).float()
        mask = mask.unsqueeze(0)  # 添加通道维度
        
        return image, mask

In [None]:
def prepare_dataloaders():
    # 获取所有图像文件路径
    image_files = sorted(list(cfg.IMAGES_DIR.glob("*.png")))
    
    # 划分训练集和验证集
    train_files, val_files = train_test_split(
        image_files, test_size=cfg.VAL_SPLIT, random_state=cfg.SEED
    )
    
    # 创建数据集
    train_dataset = LungSegmentationDataset(
        train_files, phase="train"
    )
    
    val_dataset = LungSegmentationDataset(
        val_files, phase="val"
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, batch_size=cfg.BATCH_SIZE, 
        shuffle=True, num_workers=0, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=cfg.BATCH_SIZE, 
        shuffle=False, num_workers=0, pin_memory=True
    )
    
    return train_loader, val_loader

In [None]:
# 训练一个epoch
def train_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0
    
    for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="训练中")):
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        predictions = model(images)
        
        loss = loss_fn(predictions, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(loader)

In [None]:

# 验证一个epoch
def val_epoch(model, loader, loss_fn, device):
    model.eval()
    epoch_loss = 0
    dice_scores = []
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="验证中")):
            images = images.to(device)
            masks = masks.to(device)
            
            predictions = model(images)
            loss = loss_fn(predictions, masks)
            
            epoch_loss += loss.item()
            
            # 计算Dice系数
            preds = (predictions > 0.5).float()
            batch_dice = (2.0 * (preds * masks).sum()) / (preds.sum() + masks.sum() + 1e-8)
            dice_scores.append(batch_dice.item())
    
    return epoch_loss / len(loader), np.mean(dice_scores)

In [None]:
# Dice损失函数
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        # 展平预测和目标
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        # 计算交集
        intersection = (predictions * targets).sum()
        
        # 计算Dice系数: 2*|X∩Y| / (|X|+|Y|)
        dice_score = (2.0 * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth
        )
        
        # 返回Dice损失
        return 1.0 - dice_score

# 组合损失函数：BCE + Dice
class BCEDiceLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    
    def forward(self, predictions, targets):
        bce_loss = self.bce(predictions, targets)
        dice_loss = self.dice(predictions, targets)
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss

In [None]:
# 保存最佳模型和最后一个模型
def save_model(model, filename):
    filepath = cfg.CHECKPOINTS_DIR / filename
    torch.save(model.state_dict(), filepath)
    print(f"模型已保存至 {filepath}")

# 可视化预测结果
def visualize_predictions(model, loader, device, num_samples=4):
    model.eval()
    images, masks, preds = [], [], []
    
    with torch.no_grad():
        for batch_idx, (image, mask) in enumerate(loader):
            if batch_idx >= num_samples:
                break
                
            image = image.to(device)
            pred = model(image)
            
            # 转换为numpy数组以便可视化
            image = image.cpu().numpy()[0, 0]
            mask = mask.cpu().numpy()[0, 0]
            pred = (pred > 0.5).float().cpu().numpy()[0, 0]
            
            images.append(image)
            masks.append(mask)
            preds.append(pred)
    
    # 创建可视化结果
    plt.figure(figsize=(12, 4 * num_samples))
    for i in range(len(images)):
        # 原始图像
        plt.subplot(num_samples, 3, i*3 + 1)
        plt.imshow(images[i], cmap='gray')
        plt.title('原始图像')
        plt.axis('off')
        
        # 真实掩码
        plt.subplot(num_samples, 3, i*3 + 2)
        plt.imshow(masks[i], cmap='gray')
        plt.title('真实掩码')
        plt.axis('off')
        
        # 预测掩码
        plt.subplot(num_samples, 3, i*3 + 3)
        plt.imshow(preds[i], cmap='gray')
        plt.title('预测掩码')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(cfg.RESULTS_DIR / 'predictions.png')
    plt.close()

In [None]:
# 主训练函数
def train_model():
    # 准备数据加载器
    train_loader, val_loader = prepare_dataloaders()
    
    # 创建模型
    model = UNet(
        in_channels=cfg.IN_CHANNELS,
        out_channels=cfg.OUT_CHANNELS,
        features=cfg.FEATURES
    )
    model = model.to(cfg.DEVICE)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)
    loss_fn = BCEDiceLoss()
    
    # 学习率调度器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5, verbose=True
    )
    
    # TensorBoard写入器
    writer = SummaryWriter(log_dir=cfg.LOGS_DIR)
    
    # 训练循环
    best_val_loss = float('inf')
    for epoch in range(cfg.NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{cfg.NUM_EPOCHS}")
        
        # 训练
        train_loss = train_epoch(model, train_loader, optimizer, loss_fn, cfg.DEVICE)
        
        # 验证
        val_loss, val_dice = val_epoch(model, val_loader, loss_fn, cfg.DEVICE)
        
        # 学习率调度
        scheduler.step(val_loss)
        
        # 记录到TensorBoard
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Dice/val", val_dice, epoch)
        
        # 打印结果
        print(f"训练损失: {train_loss:.4f}, 验证损失: {val_loss:.4f}, 验证Dice: {val_dice:.4f}")
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, "best_model.pth")
        
        # 每10个epoch可视化一次预测结果
        if (epoch + 1) % 10 == 0:
            visualize_predictions(model, val_loader, cfg.DEVICE)
    
    # 保存最后一个模型
    save_model(model, "last_model.pth")
    writer.close()
    
    return model

# 测试函数
def test_model(model_path):
    # 加载最佳模型
    model = UNet(
        in_channels=cfg.IN_CHANNELS,
        out_channels=cfg.OUT_CHANNELS,
        features=cfg.FEATURES
    )
    model.load_state_dict(torch.load(model_path))
    model = model.to(cfg.DEVICE)
    
    # 准备数据加载器
    _, val_loader = prepare_dataloaders()
    
    # 评估模型
    loss_fn = BCEDiceLoss()
    val_loss, val_dice = val_epoch(model, val_loader, loss_fn, cfg.DEVICE)
    print(f"测试损失: {val_loss:.4f}, 测试Dice: {val_dice:.4f}")
    
    # 可视化预测结果
    visualize_predictions(model, val_loader, cfg.DEVICE, num_samples=8)

In [None]:
def main():
    # 检查处理后的数据集是否存在
    if not cfg.DATA_ROOT.exists() or not cfg.IMAGES_DIR.exists() or not cfg.MASKS_DIR.exists():
        print(f"警告: 处理后的数据集不存在，请先运行 preprocess_dataset.py 脚本处理数据集!")
        return
        
    print(f"使用设备: {cfg.DEVICE}")
    print(f"数据集路径: {cfg.DATA_ROOT}")
    print(f"图像数量: {len(list(cfg.IMAGES_DIR.glob('*.png')))}")
    print("开始训练...")
    
    # 训练模型
    model = train_model()
    
    # 测试最佳模型
    best_model_path = cfg.CHECKPOINTS_DIR / "best_model.pth"
    print("\n评估最佳模型...")
    test_model(best_model_path)
    
    print("完成!")

In [None]:
main()