In [2]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
import matplotlib.pyplot as plt

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader, random_split

# ====================== 数据增强 ======================
# 在训练集中，我们使用 Albumentations 做较强的增广
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.Perspective(scale=(0.05, 0.1), p=0.3),
    A.RandomBrightnessContrast(p=0.5),
    A.HueSaturationValue(p=0.5),
    A.CoarseDropout(max_holes=5, max_height=16, max_width=16,
                    fill_value=0, mask_fill_value=0, p=0.5),
    A.Resize(224, 224),           # 统一缩放到224×224
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# 验证/测试集中，我们通常只做最基本的resize + ToTensor
val_transform = A.Compose([
    A.Resize(224, 224),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# ====================== 自定义Dataset ======================
class TextSegDataset(Dataset):
    """
    读取 (image, mask) 对，并应用数据增强
    假设:
    - 原图目录: img_dir
    - 掩码目录: mask_dir
    - 每张原图对应一个同名 + '_mask.png' 的掩码文件
    """
    def __init__(self, img_dir, mask_dir, transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        
        self.img_files = sorted([f for f in os.listdir(img_dir)
                                 if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        base_name = img_name.rsplit('.', 1)[0]
        
        img_path  = os.path.join(self.img_dir,  img_name)
        mask_path = os.path.join(self.mask_dir, base_name + "_mask.png")
        
        # 读取图像(BGR->RGB)与掩码(灰度)
        img_bgr  = cv2.imread(img_path)
        mask_gray= cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        if img_bgr is None or mask_gray is None:
            raise FileNotFoundError(f"Cannot read image/mask: {img_path}, {mask_path}")
        
        # 转成 float32
        img_bgr = img_bgr.astype(np.float32)
        mask_gray = mask_gray.astype(np.float32)


        img_rgb  = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        
        # Albumentations需要 numpy array
        # mask必须是单通道
        if self.transform:
            # 传入transform时要指定 'image' 和 'mask'
            augmented = self.transform(image=img_rgb, mask=mask_gray)
            img_t  = augmented['image']   # tensor: [3,224,224]
            mask_t = augmented['mask']    # tensor: [224,224]
        else:
            # 如果不做增强，就手动ToTensor
            img_rgb  = cv2.resize(img_rgb, (224,224))
            mask_gray= cv2.resize(mask_gray, (224,224), interpolation=cv2.INTER_NEAREST)
            img_t  = torch.from_numpy(img_rgb.transpose(2,0,1)).float() / 255.0
            mask_t = torch.from_numpy(mask_gray).float()
        
        # 如果掩码是0/255，转成0/1
        mask_t = (mask_t > 127.5).float().unsqueeze(0)  # [1,224,224]
        
        return img_t, mask_t

# ====================== 定义一个简化的ViT模型 ======================
class SmallViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16,
                 embed_dim=128, num_heads=4, depth=4, num_classes=1):
        super().__init__()
        self.patch_size = patch_size
        # 1) 补丁嵌入
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2
        # 可学习位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        
        # 2) Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim*4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # 3) 解码头(上采样回原图尺寸)
        self.up = nn.ConvTranspose2d(embed_dim, num_classes,
                                     kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        B, C, H, W = x.shape
        # Patch Embedding
        x = self.patch_embed(x)  # [B,embed_dim,H/patch,W/patch]
        # flatten => [B,embed_dim, N], transpose => [B,N,embed_dim]
        x = x.flatten(2).transpose(1,2)
        
        # 加上位置编码
        x = x + self.pos_embed[:, :x.size(1), :]
        
        # Transformer
        x = self.transformer(x)  # [B,N,embed_dim]
        
        # reshape回 CNN 形式
        Hp = H // self.patch_size
        Wp = W // self.patch_size
        x = x.transpose(1,2).reshape(B, -1, Hp, Wp)  # [B,embed_dim,Hp,Wp]
        
        # 上采样回原图大小 => [B,num_classes,H,W]
        x = self.up(x)
        return x

# ====================== Dice Loss ======================
def dice_loss_fn(logits, targets):
    """
    logits: [B,1,H,W] raw output
    targets: [B,1,H,W] in {0,1}
    返回: 平均Dice损失 (越小越好)
    """
    probs = torch.sigmoid(logits)  # [B,1,H,W]
    num = (probs * targets).sum(dim=(1,2,3)) * 2.0
    den = (probs + targets).sum(dim=(1,2,3)) + 1e-6
    dice = num / den
    return 1 - dice.mean()

# ====================== 训练示例 ======================
def train_vit_seg():
    # 1) 数据准备
    img_dir = "data/annotation_images"
    mask_dir= "data/mask"
    
    full_dataset = TextSegDataset(img_dir, mask_dir, transform=train_transform)
    total_len = len(full_dataset)
    print(f"总共有 {total_len} 张图像+掩码.")
    
    # 若 total_len=18, 可以 12/3/3 划分
    train_size = 12
    val_size   = 3
    test_size  = total_len - train_size - val_size
    train_ds, val_ds, test_ds = random_split(
        full_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # 验证/测试集只做最基本的变换
    val_dataset = TextSegDataset(img_dir, mask_dir, transform=val_transform)
    # 这里需要注意：random_split只是拆分了索引，你得自己写法把train_ds,val_ds映射到新的transform
    
    # 简单做法：对train_ds用train_transform，对val_ds/test_ds用val_transform
    # 我们可以重写Dataset，或在Dataset里判断下标是否在train还是val
    # 为了演示，这里先只在train上做Albumentations(你也可以做更精细的写法)
    
    # 构造DataLoader
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=2, shuffle=False, num_workers=2)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 2) 模型
    model = SmallViT(
        image_size=224, patch_size=16,
        embed_dim=128, num_heads=4, depth=4, num_classes=1
    ).to(device)
    
    # 3) 损失函数 & 优化器
    bce_loss = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # 4) 训练循环
    epochs = 30
    for epoch in range(1, epochs+1):
        model.train()
        total_train_loss = 0
        for imgs, masks in train_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            optimizer.zero_grad()
            logits = model(imgs)  # [B,1,H,W]
            loss = bce_loss(logits, masks) + dice_loss_fn(logits, masks)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)
        
        # 验证
        model.eval()
        total_val_iou = 0
        total_val_dice= 0
        count = 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs, masks = imgs.to(device), masks.to(device)
                logits = model(imgs)
                probs  = torch.sigmoid(logits)
                # IoU
                pred_bin = (probs > 0.5).float()
                intersect = (pred_bin * masks).sum(dim=(1,2,3))
                union = ((pred_bin + masks) > 0).float().sum(dim=(1,2,3)) + 1e-6
                iou = (intersect / union).mean().item()
                # Dice
                dice_num = 2*intersect
                dice_den = pred_bin.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + 1e-6
                dice = (dice_num/dice_den).mean().item()
                
                total_val_iou += iou
                total_val_dice+= dice
                count += 1
        
        avg_val_iou = total_val_iou / count
        avg_val_dice= total_val_dice/ count
        
        print(f"Epoch [{epoch}/{epochs}] - "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val IoU: {avg_val_iou:.4f}, Val Dice: {avg_val_dice:.4f}")
    
    # 5) 测试
    model.eval()
    total_test_iou = 0
    total_test_dice= 0
    count = 0
    with torch.no_grad():
        for imgs, masks in test_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            probs  = torch.sigmoid(logits)
            pred_bin = (probs > 0.5).float()
            
            intersect = (pred_bin * masks).sum(dim=(1,2,3))
            union = ((pred_bin + masks) > 0).float().sum(dim=(1,2,3)) + 1e-6
            iou = (intersect/union).mean().item()
            
            dice_num = 2*intersect
            dice_den = pred_bin.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + 1e-6
            dice = (dice_num/dice_den).mean().item()
            
            total_test_iou  += iou
            total_test_dice += dice
            count += 1
    print(f"\n[测试结果] IoU={total_test_iou/count:.4f}, Dice={total_test_dice/count:.4f}")

if __name__ == "__main__":
    train_vit_seg()


总共有 18 张图像+掩码.


  A.CoarseDropout(max_holes=5, max_height=16, max_width=16,


Epoch [1/30] - Train Loss: 1.5883 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [2/30] - Train Loss: 1.3574 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [3/30] - Train Loss: 1.2079 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [4/30] - Train Loss: 1.1200 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [5/30] - Train Loss: 1.0723 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [6/30] - Train Loss: 1.0469 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [7/30] - Train Loss: 1.0329 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [8/30] - Train Loss: 1.0247 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [9/30] - Train Loss: 1.0195 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [10/30] - Train Loss: 1.0161 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [11/30] - Train Loss: 1.0136 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [12/30] - Train Loss: 1.0118 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [13/30] - Train Loss: 1.0104 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [14/30] - Train Loss: 1.0093 | Val IoU: 0.0000, Val Dice: 0.0000
Epoch [15/30] -

In [None]:
def visualize_single(model, dataset, idx=0, threshold=0.5):
    model.eval()
    with torch.no_grad():
        img_t, mask_t = dataset[idx]  # 一个 (image, mask)
        img_t = img_t.unsqueeze(0).cuda()  # 扩展batch维度
        logits = model(img_t)             # [1,1,H,W]
        probs  = torch.sigmoid(logits)[0] # [1,H,W]
        
        pred_bin = (probs > threshold).float().cpu().numpy()[0]
        img_np   = img_t[0].cpu().numpy().transpose(1,2,0)  # [H,W,3]
        mask_np  = mask_t[0].cpu().numpy()
        
        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1)
        plt.imshow(img_np)
        plt.title("Original")
        plt.axis('off')
        
        plt.subplot(1,3,2)
        plt.imshow(mask_np, cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')
        
        plt.subplot(1,3,3)
        plt.imshow(pred_bin, cmap='gray')
        plt.title("Predicted")
        plt.axis('off')
        plt.show()

# 假设 val_ds[0] 是一张图
visualize_single(model, val_ds, idx=0, threshold=0.5)
