# 跑冒滴漏异常检测（条件VAE-UNet）

- 数据来自课程要求的 DATASET2（train 2836 + test 120），标签使用 YOLO 格式，仅保留前三类（Oil_accumulation / Oil_seepage / Standing_water）。
- 参考 PPT 提示，本文实现了一个条件 VAE 版本的 UNet，用生成式建模的方式学习异常区域，再通过后处理还原成提交所需的检测框。

## 流程概览
1. 解析 train 标签并构建热力图 mask；划分 train/val。
2. 构建条件 VAE-UNet（image encoder + latent branch + decoder）并以 mask 重建损失训练。
3. 使用验证集监控 IoU，保存最优 checkpoint。
4. 推理阶段得到每类概率图，阈值+NMS 转为 YOLO 坐标，生成 120 个结果文件。

In [59]:
import json
import math
import os
import random
import shutil
from pathlib import Path

import cv2
import matplotlib; matplotlib.use('Agg');
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

plt.style.use('seaborn-v0_8')

DATA_ROOT = Path('d:/BaiduNetdiskDownload/dataset_work_final')
TRAIN_IMG_DIR = DATA_ROOT / 'images' / 'train'
TEST_IMG_DIR = DATA_ROOT / 'images' / 'test'
LABEL_DIR = DATA_ROOT / 'labels' / 'train'
CHECKPOINT_DIR = DATA_ROOT / 'artifacts'
RESULT_DIR = DATA_ROOT / 'results_cvae'

for path in [CHECKPOINT_DIR, RESULT_DIR]:
    path.mkdir(parents=True, exist_ok=True)

CLASS_NAMES = ['Oil_accumulation', 'Oil_seepage', 'Standing_water']
NUM_CLASSES = len(CLASS_NAMES)
IMAGE_SIZE = (288, 512)  # (H, W), divisible by 16 for UNet
SEED = 7
FAST_DEBUG = False  # quick smoke test; set True for mini-run
FAST_TRAIN_LIMIT = 256
FAST_VAL_LIMIT = 64
VAL_FRACTION = 0.1
BASE_BATCH_SIZE = 2
BASE_EPOCHS = 5
BATCH_SIZE = 1 if FAST_DEBUG else BASE_BATCH_SIZE
EPOCHS = 1 if FAST_DEBUG else BASE_EPOCHS
LEARNING_RATE = 3e-4
LATENT_DIM = 64
KL_WEIGHT = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IS_WINDOWS = os.name == 'nt'
MAX_WORKERS = max(0, min(4, (os.cpu_count() or 0) // 2))
NUM_WORKERS = 0 if FAST_DEBUG or IS_WINDOWS else MAX_WORKERS
PREFETCH_FACTOR = 2
PIN_MEMORY = DEVICE == 'cuda'

def seed_everything(seed: int = 7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = True

seed_everything(SEED)
print(f'device: {DEVICE}')
print(f'train images: {len(list(TRAIN_IMG_DIR.glob("*.jpg")))}, test images: {len(list(TEST_IMG_DIR.glob("*.jpg")))}')


device: cuda
train images: 2836, test images: 120


In [60]:
def load_yolo_boxes(label_path: Path):
    boxes = []
    if not label_path.exists():
        return boxes
    with label_path.open() as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 5:
                continue
            cls = int(float(parts[0]))
            if cls >= NUM_CLASSES:
                continue
            cx, cy, w, h = map(float, parts[1:])
            boxes.append((cls, cx, cy, w, h))
    return boxes


def boxes_to_mask(boxes, out_h, out_w, num_classes=NUM_CLASSES):
    mask = np.zeros((out_h, out_w, num_classes), dtype=np.float32)
    for cls, cx, cy, bw, bh in boxes:
        if cls >= num_classes:
            continue
        x1 = (cx - bw / 2.0) * out_w
        x2 = (cx + bw / 2.0) * out_w
        y1 = (cy - bh / 2.0) * out_h
        y2 = (cy + bh / 2.0) * out_h
        x1 = np.clip(x1, 0, out_w - 1)
        x2 = np.clip(x2, 0, out_w - 1)
        y1 = np.clip(y1, 0, out_h - 1)
        y2 = np.clip(y2, 0, out_h - 1)
        if x2 <= x1 or y2 <= y1:
            continue
        x1i, x2i = int(np.floor(x1)), int(np.ceil(x2))
        y1i, y2i = int(np.floor(y1)), int(np.ceil(y2))
        mask[y1i:y2i, x1i:x2i, cls] = 1.0
    return mask


def apply_augmentations(image, mask):
    if random.random() < 0.5:
        image = image[:, ::-1].copy()
        mask = mask[:, ::-1].copy()
    if random.random() < 0.3:
        alpha = 1.0 + 0.3 * (random.random() * 2 - 1)
        beta = 15.0 * (random.random() * 2 - 1)
        image = np.clip(image * alpha + beta, 0, 255).astype(np.float32)
    if random.random() < 0.2:
        noise = np.random.normal(0, 5, size=image.shape).astype(np.float32)
        image = np.clip(image + noise, 0, 255)
    return image, mask


def draw_boxes(image, boxes):
    vis = image.copy()
    h, w = vis.shape[:2]
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
    for cls, cx, cy, bw, bh in boxes:
        x1 = int((cx - bw / 2) * w)
        y1 = int((cy - bh / 2) * h)
        x2 = int((cx + bw / 2) * w)
        y2 = int((cy + bh / 2) * h)
        color = colors[cls % len(colors)]
        cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
        cv2.putText(vis, CLASS_NAMES[cls], (x1, max(0, y1 - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
    return vis


def mask_overlay(image_tensor, mask_tensor):
    image = image_tensor.numpy().transpose(1, 2, 0)
    image = np.clip(image * 255.0, 0, 255).astype(np.uint8)
    mask = mask_tensor.numpy().transpose(1, 2, 0)
    colors = np.array([(255, 0, 0), (0, 255, 0), (0, 0, 255)], dtype=np.float32)
    overlay = image.copy().astype(np.float32)
    for idx in range(min(mask.shape[-1], len(colors))):
        cls_mask = mask[..., idx] > 0.5
        if cls_mask.sum() == 0:
            continue
        overlay[cls_mask] = 0.6 * overlay[cls_mask] + 0.4 * colors[idx]
    return overlay.astype(np.uint8)


In [61]:
all_image_ids = sorted([p.stem for p in TRAIN_IMG_DIR.glob('*.jpg')])
random.Random(SEED).shuffle(all_image_ids)
val_count = max(1, int(len(all_image_ids) * VAL_FRACTION))
val_ids = sorted(all_image_ids[:val_count])
train_ids = sorted(all_image_ids[val_count:])

if FAST_DEBUG:
    original_train = len(train_ids)
    original_val = len(val_ids)
    train_ids = train_ids[:min(original_train, FAST_TRAIN_LIMIT)]
    val_ids = val_ids[:min(original_val, FAST_VAL_LIMIT)]
    print(f'[FAST_DEBUG] using {len(train_ids)}/{original_train} train and {len(val_ids)}/{original_val} val samples')

print(f'train samples: {len(train_ids)}, val samples: {len(val_ids)}')
with (CHECKPOINT_DIR / 'split.json').open('w', encoding='utf-8') as f:
    json.dump({'train_ids': train_ids, 'val_ids': val_ids}, f, ensure_ascii=False, indent=2)


[FAST_DEBUG] using 256/2553 train and 64/283 val samples
train samples: 256, val samples: 64


In [62]:
class LeakDataset(Dataset):
    def __init__(self, image_ids, image_dir, label_dir, image_size=IMAGE_SIZE, augment=False):
        self.image_ids = image_ids
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_size = image_size
        self.augment = augment

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = self.image_dir / f'{image_id}.jpg'
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)

        boxes = load_yolo_boxes(self.label_dir / f'{image_id}.txt')
        mask = boxes_to_mask(boxes, self.image_size[0], self.image_size[1])

        if self.augment:
            image, mask = apply_augmentations(image, mask)

        image = image.astype(np.float32) / 255.0
        mask = mask.astype(np.float32)
        image = torch.from_numpy(image.transpose(2, 0, 1))
        mask = torch.from_numpy(mask.transpose(2, 0, 1))
        return image, mask


class LeakInferenceDataset(Dataset):
    def __init__(self, image_paths, image_size=IMAGE_SIZE):
        self.image_paths = sorted(image_paths)
        self.image_size = image_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        image = cv2.imread(str(path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = image.shape[:2]
        image_resized = cv2.resize(image, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_LINEAR)
        tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float() / 255.0
        return tensor, path.stem, (orig_h, orig_w)


In [63]:
preview_ds = LeakDataset(train_ids[:4], TRAIN_IMG_DIR, LABEL_DIR, augment=False)
fig, axes = plt.subplots(1, len(preview_ds), figsize=(18, 4))
for ax, sample in zip(axes, preview_ds):
    img, mask = sample
    overlay = mask_overlay(img, mask)
    ax.imshow(overlay)
    ax.axis('off')
plt.suptitle('训练样本可视化（mask 覆盖在缩放后图片上）')
plt.show()


  plt.show()


## 条件 VAE-UNet 结构
- `image encoder` 提取逐级特征；`posterior encoder` 在训练时接收 (image, mask) 以推断潜变量；`prior encoder` 仅依赖 image。
- 将采样得到的 latent `z` 拼接到 bottleneck，再经过带 skip 的 decoder 输出每类热力图。
- 损失函数 = BCE + Dice + β·KL。β 设为 1e-4，用于稳定训练又保持生成式约束。

In [64]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
        )

    def forward(self, x):
        return self.block(x)


class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_ch, out_ch)
        )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        diff_y = skip.size(2) - x.size(2)
        diff_x = skip.size(3) - x.size(3)
        if diff_y != 0 or diff_x != 0:
            x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


class Encoder(nn.Module):
    def __init__(self, in_ch, base_ch):
        super().__init__()
        self.inc = ConvBlock(in_ch, base_ch)
        self.down1 = DownBlock(base_ch, base_ch * 2)
        self.down2 = DownBlock(base_ch * 2, base_ch * 4)
        self.down3 = DownBlock(base_ch * 4, base_ch * 8)
        self.down4 = DownBlock(base_ch * 8, base_ch * 16)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        return [x1, x2, x3, x4, x5]


class PosteriorEncoder(nn.Module):
    def __init__(self, in_ch, base_ch):
        super().__init__()
        self.enc = Encoder(in_ch, base_ch)

    def forward(self, x):
        return self.enc(x)[-1]


class LatentHead(nn.Module):
    def __init__(self, in_ch, latent_dim):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.mu = nn.Linear(in_ch, latent_dim)
        self.logvar = nn.Linear(in_ch, latent_dim)

    def forward(self, feat):
        pooled = self.pool(feat).view(feat.size(0), -1)
        return self.mu(pooled), self.logvar(pooled)


class ConditionalVAEUNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=NUM_CLASSES, base_ch=32, latent_dim=LATENT_DIM):
        super().__init__()
        self.encoder = Encoder(in_ch, base_ch)
        self.posterior_encoder = PosteriorEncoder(in_ch + num_classes, base_ch)
        bottleneck_ch = base_ch * 16
        self.prior_head = LatentHead(bottleneck_ch, latent_dim)
        self.posterior_head = LatentHead(bottleneck_ch, latent_dim)
        self.up1 = UpBlock(bottleneck_ch + latent_dim, base_ch * 8, base_ch * 8)
        self.up2 = UpBlock(base_ch * 8, base_ch * 4, base_ch * 4)
        self.up3 = UpBlock(base_ch * 4, base_ch * 2, base_ch * 2)
        self.up4 = UpBlock(base_ch * 2, base_ch, base_ch)
        self.out_conv = nn.Conv2d(base_ch, num_classes, kernel_size=1)
        self.latent_dim = latent_dim

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, feats, z):
        x1, x2, x3, x4, x5 = feats
        b, _, h, w = x5.shape
        z_map = z.view(b, self.latent_dim, 1, 1).expand(-1, -1, h, w)
        x = torch.cat([x5, z_map], dim=1)
        x = self.up1(x, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.out_conv(x)

    def forward(self, x, target_mask=None):
        feats = self.encoder(x)
        bottleneck = feats[-1]
        prior_mu, prior_logvar = self.prior_head(bottleneck)
        if self.training and target_mask is not None:
            posterior_input = torch.cat([x, target_mask], dim=1)
            posterior_feat = self.posterior_encoder(posterior_input)
            post_mu, post_logvar = self.posterior_head(posterior_feat)
            z = self.reparameterize(post_mu, post_logvar)
            kl = 0.5 * torch.sum(
                torch.exp(post_logvar - prior_logvar)
                + ((prior_mu - post_mu) ** 2) / torch.exp(prior_logvar)
                - 1.0 + (prior_logvar - post_logvar),
                dim=1
            )
            kl = kl.mean()
        else:
            z = self.reparameterize(prior_mu, prior_logvar)
            kl = torch.tensor(0.0, device=x.device)
        logits = self.decode(feats, z)
        return logits, kl


In [65]:
def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    dims = (0, 2, 3)
    intersection = (probs * targets).sum(dim=dims)
    denom = probs.sum(dim=dims) + targets.sum(dim=dims)
    dice = (2 * intersection + eps) / (denom + eps)
    return 1 - dice.mean()


def compute_iou(logits, targets, threshold=0.5, eps=1e-6):
    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()
    dims = (0, 2, 3)
    intersection = (preds * targets).sum(dim=dims)
    union = preds.sum(dim=dims) + targets.sum(dim=dims) - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.mean().item()


def train_one_epoch(model, loader, optimizer, scaler):
    model.train()
    running = {'loss': 0.0, 'bce': 0.0, 'dice': 0.0, 'kl': 0.0}
    for images, masks in tqdm(loader, total=len(loader)):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(DEVICE == 'cuda')):
            logits, kl = model(images, masks)
            bce = F.binary_cross_entropy_with_logits(logits, masks)
            dice = dice_loss(logits, masks)
            loss = bce + dice + KL_WEIGHT * kl
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running['loss'] += loss.item()
        running['bce'] += bce.item()
        running['dice'] += dice.item()
        running['kl'] += kl.item()
    for k in running:
        running[k] /= max(1, len(loader))
    return running


def evaluate(model, loader):
    model.eval()
    metrics = {'loss': 0.0, 'bce': 0.0, 'dice': 0.0, 'iou': 0.0}
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            logits, _ = model(images, masks)
            bce = F.binary_cross_entropy_with_logits(logits, masks)
            dice = dice_loss(logits, masks)
            loss = bce + dice
            metrics['loss'] += loss.item()
            metrics['bce'] += bce.item()
            metrics['dice'] += dice.item()
            metrics['iou'] += compute_iou(logits, masks)
    for k in metrics:
        metrics[k] /= max(1, len(loader))
    return metrics


In [66]:
train_dataset = LeakDataset(train_ids, TRAIN_IMG_DIR, LABEL_DIR, augment=True)
val_dataset = LeakDataset(val_ids, TRAIN_IMG_DIR, LABEL_DIR, augment=False)
loader_kwargs = {
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY,
}
if NUM_WORKERS > 0:
    loader_kwargs['persistent_workers'] = True
    loader_kwargs['prefetch_factor'] = PREFETCH_FACTOR

train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
print(f'train batches: {len(train_loader)}, val batches: {len(val_loader)}')


train batches: 256, val batches: 64


In [67]:
model = ConditionalVAEUNet().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == 'cuda'))

history = []
best_iou = 0.0
best_path = CHECKPOINT_DIR / 'cvae_detector.pt'

for epoch in range(1, EPOCHS + 1):
    print(f'Epoch {epoch}/{EPOCHS}')
    train_stats = train_one_epoch(model, train_loader, optimizer, scaler)
    val_stats = evaluate(model, val_loader)
    scheduler.step()
    record = {
        'epoch': epoch,
        'train_loss': train_stats['loss'],
        'train_dice': 1 - train_stats['dice'],
        'val_loss': val_stats['loss'],
        'val_dice': 1 - val_stats['dice'],
        'val_iou': val_stats['iou'],
        'lr': scheduler.get_last_lr()[0],
    }
    history.append(record)
    print(record)
    if record['val_iou'] > best_iou:
        best_iou = record['val_iou']
        torch.save({'model_state': model.state_dict(), 'config': {'image_size': IMAGE_SIZE, 'latent_dim': LATENT_DIM}}, best_path)
        print(f'>> saved new best model to {best_path}, IoU={best_iou:.4f}')

history


Epoch 1/1


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == 'cuda'))


  0%|          | 0/256 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == 'cuda')):


{'epoch': 1, 'train_loss': 1.4616852323524654, 'train_dice': 0.016534110298380256, 'val_loss': 1.38950889185071, 'val_dice': 0.020868970081210136, 'val_iou': 0.08487947774271319, 'lr': 0.0}
>> saved new best model to d:\BaiduNetdiskDownload\dataset_work_final\artifacts\cvae_detector.pt, IoU=0.0849


[{'epoch': 1,
  'train_loss': 1.4616852323524654,
  'train_dice': 0.016534110298380256,
  'val_loss': 1.38950889185071,
  'val_dice': 0.020868970081210136,
  'val_iou': 0.08487947774271319,
  'lr': 0.0}]

In [68]:
if history:
    epochs = [h['epoch'] for h in history]
    val_iou = [h['val_iou'] for h in history]
    train_loss = [h['train_loss'] for h in history]
    val_loss = [h['val_loss'] for h in history]
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(epochs, train_loss, label='train loss')
    axes[0].plot(epochs, val_loss, label='val loss')
    axes[0].legend()
    axes[0].set_title('Loss curves')
    axes[1].plot(epochs, val_iou, label='val IoU', color='purple')
    axes[1].set_title('Validation IoU')
    axes[1].legend()
    plt.show()


  plt.show()


In [69]:
best_checkpoint = CHECKPOINT_DIR / 'cvae_detector.pt'
assert best_checkpoint.exists(), '请先完成训练再继续'
state = torch.load(best_checkpoint, map_location=DEVICE)
model.load_state_dict(state['model_state'])
model.eval()
print('Loaded checkpoint from', best_checkpoint)

val_samples = min(4, len(val_dataset))
fig, axes = plt.subplots(val_samples, 3, figsize=(12, 3 * val_samples))
for row in range(val_samples):
    img, mask = val_dataset[row]
    with torch.no_grad():
        logits, _ = model(img.unsqueeze(0).to(DEVICE), mask.unsqueeze(0).to(DEVICE))
        probs = torch.sigmoid(logits)[0].cpu()
    overlay_gt = mask_overlay(img, mask)
    overlay_pred = mask_overlay(img, probs)
    axes[row, 0].imshow(img.numpy().transpose(1, 2, 0))
    axes[row, 0].set_title('Input')
    axes[row, 1].imshow(overlay_gt)
    axes[row, 1].set_title('GT mask')
    axes[row, 2].imshow(overlay_pred)
    axes[row, 2].set_title('Pred mask')
    for col in range(3):
        axes[row, col].axis('off')
plt.tight_layout()
plt.show()


  state = torch.load(best_checkpoint, map_location=DEVICE)


Loaded checkpoint from d:\BaiduNetdiskDownload\dataset_work_final\artifacts\cvae_detector.pt


  plt.show()


In [70]:
def non_max_suppression(boxes, iou_thresh=0.4):
    if not boxes:
        return []
    boxes = sorted(boxes, key=lambda x: x[1], reverse=True)
    keep = []
    while boxes:
        current = boxes.pop(0)
        keep.append(current)
        boxes = [b for b in boxes if box_iou(current, b) < iou_thresh]
    return keep


def box_iou(box_a, box_b):
    xa1, ya1 = box_a[2], box_a[3]
    xa2, ya2 = xa1 + box_a[4], ya1 + box_a[5]
    xb1, yb1 = box_b[2], box_b[3]
    xb2, yb2 = xb1 + box_b[4], yb1 + box_b[5]
    inter_x1 = max(xa1, xb1)
    inter_y1 = max(ya1, yb1)
    inter_x2 = min(xa2, xb2)
    inter_y2 = min(ya2, yb2)
    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
    area_a = box_a[4] * box_a[5]
    area_b = box_b[4] * box_b[5]
    union = area_a + area_b - inter_area + 1e-6
    return inter_area / union


def probmap_to_boxes(prob_map, min_area_ratio=1e-4, threshold=0.45, min_score=0.3, nms_thresh=0.4):
    boxes = []
    channels, h, w = prob_map.shape
    for cls_idx in range(channels):
        cls_map = prob_map[cls_idx]
        heat = cv2.GaussianBlur(cls_map, (5, 5), 0)
        mask = (heat >= threshold).astype(np.uint8)
        if mask.sum() == 0:
            continue
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            x, y, bw, bh = cv2.boundingRect(cnt)
            if bw * bh < min_area_ratio * h * w:
                continue
            score = float(heat[y:y+bh, x:x+bw].mean())
            if score < min_score:
                continue
            boxes.append([cls_idx, score, x, y, bw, bh])
    boxes = non_max_suppression(boxes, iou_thresh=nms_thresh)
    normalized = []
    for cls_idx, score, x, y, bw, bh in boxes:
        cx = (x + bw / 2) / w
        cy = (y + bh / 2) / h
        nw = bw / w
        nh = bh / h
        normalized.append((cls_idx, score, cx, cy, nw, nh))
    return normalized


def write_prediction_file(image_id, boxes, output_dir=RESULT_DIR):
    output_dir.mkdir(parents=True, exist_ok=True)
    path = output_dir / f'{image_id}.txt'
    with path.open('w') as f:
        for cls, score, cx, cy, w, h in boxes:
            f.write(f'{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n')
    return path


In [71]:

# 推理 test 并写入120个结果文件
for old in RESULT_DIR.glob('*.txt'):
    old.unlink()

test_paths = sorted(TEST_IMG_DIR.glob('*.jpg'))
test_dataset = LeakInferenceDataset(test_paths)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

def _to_list(value):
    if isinstance(value, torch.Tensor):
        return value.tolist()
    if isinstance(value, np.ndarray):
        return value.tolist()
    if isinstance(value, (list, tuple)):
        return list(value)
    return [value]


def normalize_original_sizes(batch_sizes):
    if isinstance(batch_sizes, torch.Tensor):
        values = _to_list(batch_sizes)
        if isinstance(values[0], (list, tuple)):
            return [(int(h), int(w)) for h, w in values]
        raise ValueError(f'unexpected tensor shape for original size: {batch_sizes.shape}')

    if isinstance(batch_sizes, (list, tuple)):
        if len(batch_sizes) == 2 and all(len(_to_list(bs)) == len(_to_list(batch_sizes[0])) for bs in batch_sizes):
            heights = _to_list(batch_sizes[0])
            widths = _to_list(batch_sizes[1])
            return [(int(h), int(w)) for h, w in zip(heights, widths)]
        normalized = []
        for size in batch_sizes:
            values = _to_list(size)
            if len(values) == 1 and isinstance(values[0], (list, tuple)):
                values = values[0]
            if len(values) != 2:
                raise ValueError(f'unexpected original size entry: {values}')
            normalized.append((int(values[0]), int(values[1])))
        return normalized

    values = _to_list(batch_sizes)
    if len(values) != 2:
        raise ValueError(f'unexpected original size: {values}')
    return [(int(values[0]), int(values[1]))]


model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader, total=len(test_loader)):
        images, image_ids, original_sizes = batch
        if isinstance(image_ids, str):
            image_ids = [image_ids]
        else:
            image_ids = list(image_ids)
        original_sizes = normalize_original_sizes(original_sizes)
        images = images.to(DEVICE)
        logits, _ = model(images)
        probs = torch.sigmoid(logits).cpu().numpy()
        for prob, image_id, (orig_h, orig_w) in zip(probs, image_ids, original_sizes):
            resized_map = np.stack([
                cv2.resize(prob[c], (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
                for c in range(prob.shape[0])
            ])
            boxes = probmap_to_boxes(resized_map)
            write_prediction_file(image_id, boxes)

print('result files:', len(list(RESULT_DIR.glob('*.txt'))))


  0%|          | 0/120 [00:00<?, ?it/s]

result files: 120


In [72]:
sample_files = sorted(RESULT_DIR.glob('*.txt'))[:5]
for path in sample_files:
    print(path.name)
    with path.open() as f:
        print(f.read().strip())


0014.txt
1 0.541146 0.069444 0.083333 0.118519
1 0.360677 0.143519 0.080729 0.159259
1 0.635417 0.141667 0.036458 0.074074
2 0.540365 0.081481 0.053646 0.003704
2 0.538281 0.088426 0.050521 0.004630
2 0.547917 0.067593 0.046875 0.003704
2 0.531771 0.102315 0.045833 0.004630
2 0.544010 0.074537 0.053646 0.004630
2 0.550260 0.060648 0.042188 0.002778
2 0.534115 0.095833 0.050521 0.004630
2 0.360156 0.143981 0.038021 0.002778
2 0.530208 0.109722 0.037500 0.004630
2 0.371615 0.200000 0.031771 0.003704
2 0.371094 0.192593 0.039062 0.003704
2 0.366667 0.178704 0.042708 0.003704
2 0.364063 0.158333 0.038542 0.003704
2 0.364844 0.172222 0.046354 0.003704
2 0.366146 0.164815 0.034375 0.003704
2 0.362240 0.150926 0.033854 0.003704
2 0.368229 0.185648 0.044792 0.004630
0 0.481250 0.420833 0.043750 0.145370
0 0.972917 0.589815 0.050000 0.064815
0 0.768750 0.125463 0.030208 0.073148
0 0.138802 0.960648 0.077604 0.073148
0 0.736458 0.186574 0.045833 0.136111
0 0.733333 0.515741 0.025000 0.109259
0 0