In [None]:
"""
UPerNet (ResNet50 backbone, torchvision pretrained) - Full runnable version (simple implementation)

- Training images: X_train_uDRk9z9/images (well1–6)
- Test images: X_test_xNbnvIa/images (well7–11)
- Training labels: Y_train_T9NrBYo.csv (flattened with -1 padding)
- Validation split: by well (e.g. well6 as validation, others as training)
- Output: submission.csv (one row per patch, flattened, padded to 160*272 with -1)

You only need to check / modify:
1) DATA_ROOT path
2) EPOCHS / BATCH_SIZE (set smaller if training is slow)
"""

import re
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from torchvision.models import resnet50, ResNet50_Weights


# =========================
# 0. Hyperparameters & Paths
# =========================
DATA_ROOT = Path(r"C:\Users\lenovo\Desktop\deep_datachallenge")  # change to your actual path

TRAIN_IMAGES_DIR = DATA_ROOT / "X_train_uDRk9z9" / "images"
TEST_IMAGES_DIR = DATA_ROOT / "X_test_xNbnvIa" / "images"
Y_TRAIN_CSV = DATA_ROOT / "Y_train_T9NrBYo.csv"

TARGET_H = 160
TARGET_W = 272

NUM_CLASSES = 3          # only classes 0/1/2 in CSV
IGNORE_INDEX = -1        # padding value in CSV

BATCH_SIZE = 4           # UPerNet is memory-heavy; for RTX 4060 (8GB) start with 2–4
LR = 1e-4                # smaller LR is usually more stable with pretrained backbones
WEIGHT_DECAY = 1e-4
EPOCHS = 20              # reduce for quick debugging

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# 1. Utility Functions
# =========================
def parse_well_id(name: str) -> int:
    """Extract well id from name like: well_1_section_0_patch_0 -> 1"""
    m = re.search(r"well_(\d+)_", name)
    return int(m.group(1)) if m else -1


def minmax_normalize(x: np.ndarray) -> np.ndarray:
    """Min-max normalization; replace NaN/Inf with 0."""
    x = x.astype(np.float32)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    x_min = float(x.min())
    x_max = float(x.max())
    if x_max - x_min < 1e-6:
        return np.zeros_like(x, dtype=np.float32)
    return (x - x_min) / (x_max - x_min)


def pad_to_160x272(img: np.ndarray, fill_value: float = 0.0) -> np.ndarray:
    """Pad (160,160) or (160,272) image to (160,272)."""
    h, w = img.shape
    assert h == TARGET_H, f"Expected height {TARGET_H}, got {h}"
    if w == TARGET_W:
        return img
    if w < TARGET_W:
        out = np.full((TARGET_H, TARGET_W), fill_value, dtype=img.dtype)
        out[:, :w] = img
        return out
    return img[:, :TARGET_W]


def decode_mask_from_csv_row(row_values: np.ndarray) -> np.ndarray:
    """
    Restore mask from one CSV row:
    - row_values: flattened mask with -1 padding
    - remove -1 and reshape to (160, w)
    """
    valid = row_values[row_values != IGNORE_INDEX]
    assert len(valid) % TARGET_H == 0, f"Valid mask length {len(valid)} not divisible by 160"
    w = len(valid) // TARGET_H
    return valid.reshape(TARGET_H, w).astype(np.int64)


def pad_mask_to_160x272(mask: np.ndarray) -> np.ndarray:
    """Pad (160, w) mask to (160,272) using -1 as ignore index."""
    h, w = mask.shape
    assert h == TARGET_H
    if w == TARGET_W:
        return mask
    out = np.full((TARGET_H, TARGET_W), IGNORE_INDEX, dtype=np.int64)
    out[:, :w] = mask
    return out


# =========================
# 2. Dataset (shared for train/test)
# =========================
class WellSegDataset(Dataset):
    def __init__(self, images_dir: Path, y_csv_path: Path = None):
        """
        y_csv_path=None indicates unlabeled data (test set).
        """
        self.images_dir = images_dir
        self.has_label = y_csv_path is not None

        self.image_paths = sorted(images_dir.glob("*.npy"))
        self.names = [p.stem for p in self.image_paths]

        if self.has_label:
            self.y_df = pd.read_csv(y_csv_path, index_col=0)
        else:
            self.y_df = None

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        name = self.names[idx]
        img_path = self.image_paths[idx]

        img = np.load(img_path)        # (160,160) or (160,272)
        raw_w = img.shape[1]           # used to crop back during inference
        img = minmax_normalize(img)
        img = pad_to_160x272(img, fill_value=0.0)
        img_t = torch.from_numpy(img).unsqueeze(0).float()  # (1,160,272)

        if not self.has_label:
            return {"name": name, "image": img_t, "raw_w": raw_w}

        row = self.y_df.loc[name].values.astype(np.int64)
        mask = decode_mask_from_csv_row(row)     # (160,w)
        mask = pad_mask_to_160x272(mask)         # (160,272)
        mask_t = torch.from_numpy(mask).long()

        return {"name": name, "image": img_t, "mask": mask_t, "raw_w": raw_w}


# =========================
# 3. UPerNet Head (simplified)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class PSPModule(nn.Module):
    """
    PSP module: multi-scale pooling on the highest-level feature.
    """
    def __init__(self, in_ch, out_ch=256, pool_sizes=(1, 2, 3, 6)):
        super().__init__()
        self.stages = nn.ModuleList()
        for ps in pool_sizes:
            self.stages.append(nn.Sequential(
                nn.AdaptiveAvgPool2d((ps, ps)),
                nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            ))
        self.bottleneck = ConvBNReLU(in_ch + len(pool_sizes) * out_ch, out_ch, k=3, p=1)

    def forward(self, x):
        h, w = x.shape[-2:]
        priors = [x]
        for stage in self.stages:
            y = stage(x)
            y = F.interpolate(y, size=(h, w), mode="bilinear", align_corners=False)
            priors.append(y)
        x = torch.cat(priors, dim=1)
        return self.bottleneck(x)


class UPerHead(nn.Module):
    """
    UPerNet Head = PSP + FPN
    Input features: c2(1/4), c3(1/8), c4(1/16), c5(1/32)
    """
    def __init__(self, in_channels=(256, 512, 1024, 2048), fpn_dim=256, num_classes=3):
        super().__init__()

        c2, c3, c4, c5 = in_channels

        self.psp = PSPModule(c5, out_ch=fpn_dim)

        self.lateral_c2 = nn.Conv2d(c2, fpn_dim, kernel_size=1, bias=False)
        self.lateral_c3 = nn.Conv2d(c3, fpn_dim, kernel_size=1, bias=False)
        self.lateral_c4 = nn.Conv2d(c4, fpn_dim, kernel_size=1, bias=False)

        self.fpn_c2 = ConvBNReLU(fpn_dim, fpn_dim)
        self.fpn_c3 = ConvBNReLU(fpn_dim, fpn_dim)
        self.fpn_c4 = ConvBNReLU(fpn_dim, fpn_dim)
        self.fpn_c5 = ConvBNReLU(fpn_dim, fpn_dim)

        self.fuse = ConvBNReLU(fpn_dim * 4, fpn_dim)
        self.cls = nn.Conv2d(fpn_dim, num_classes, kernel_size=1)

    def forward(self, c2, c3, c4, c5):
        p5 = self.psp(c5)

        p4 = self.lateral_c4(c4)
        p3 = self.lateral_c3(c3)
        p2 = self.lateral_c2(c2)

        p4 = p4 + F.interpolate(p5, size=p4.shape[-2:], mode="bilinear", align_corners=False)
        p3 = p3 + F.interpolate(p4, size=p3.shape[-2:], mode="bilinear", align_corners=False)
        p2 = p2 + F.interpolate(p3, size=p2.shape[-2:], mode="bilinear", align_corners=False)

        p5 = self.fpn_c5(p5)
        p4 = self.fpn_c4(p4)
        p3 = self.fpn_c3(p3)
        p2 = self.fpn_c2(p2)

        p5_u = F.interpolate(p5, size=p2.shape[-2:], mode="bilinear", align_corners=False)
        p4_u = F.interpolate(p4, size=p2.shape[-2:], mode="bilinear", align_corners=False)
        p3_u = F.interpolate(p3, size=p2.shape[-2:], mode="bilinear", align_corners=False)

        x = torch.cat([p2, p3_u, p4_u, p5_u], dim=1)
        x = self.fuse(x)
        logits = self.cls(x)
        return logits


# =========================
# 4. UPerNet (ResNet50 backbone)
# =========================
class UPerNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

        old_conv1 = backbone.conv1
        new_conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=old_conv1.out_channels,
            kernel_size=old_conv1.kernel_size,
            stride=old_conv1.stride,
            padding=old_conv1.padding,
            bias=False,
        )
        with torch.no_grad():
            new_conv1.weight[:] = old_conv1.weight.mean(dim=1, keepdim=True)
        backbone.conv1 = new_conv1

        self.backbone = backbone
        self.head = UPerHead(
            in_channels=(256, 512, 1024, 2048),
            fpn_dim=256,
            num_classes=num_classes,
        )

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        c2 = self.backbone.layer1(x)
        c3 = self.backbone.layer2(c2)
        c4 = self.backbone.layer3(c3)
        c5 = self.backbone.layer4(c4)

        logits_1_4 = self.head(c2, c3, c4, c5)
        logits = F.interpolate(logits_1_4, size=(TARGET_H, TARGET_W),
                               mode="bilinear", align_corners=False)
        return logits


# =========================
# 5. Training & Validation
# =========================
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    for batch in loader:
        x = batch["image"].to(DEVICE)
        y = batch["mask"].to(DEVICE)

        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def valid_one_epoch(model, loader):
    model.eval()
    total_loss = 0.0
    for batch in loader:
        x = batch["image"].to(DEVICE)
        y = batch["mask"].to(DEVICE)

        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)
        total_loss += loss.item() * x.size(0)

    return total_loss / len(loader.dataset)


# =========================
# 6. Inference & Submission
# =========================
@torch.no_grad()
def predict_and_make_submission(model, test_images_dir: Path, out_csv_path: Path):
    """
    Run inference on all npy files and generate submission CSV:
    - one row per patch
    - length = 160*272
    - pad with -1 if original width < 272
    """
    model.eval()

    test_ds = WellSegDataset(test_images_dir, y_csv_path=None)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    preds_dict = {}

    for batch in test_loader:
        name = batch["name"][0]
        raw_w = int(batch["raw_w"][0])
        x = batch["image"].to(DEVICE)

        logits = model(x)
        pred_full = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.int64)

        pred = pred_full[:, :raw_w]
        if raw_w < TARGET_W:
            padded = np.full((TARGET_H * TARGET_W,), IGNORE_INDEX, dtype=np.int64)
            padded[: TARGET_H * raw_w] = pred.flatten()
            preds_dict[name] = padded
        else:
            preds_dict[name] = pred.flatten()

    sub = pd.DataFrame(preds_dict, dtype="int64").T
    sub.to_csv(out_csv_path)
    print(f"[OK] Submission saved to: {out_csv_path}")


# =========================
# 7. Main
# =========================
def main():
    train_ds_all = WellSegDataset(TRAIN_IMAGES_DIR, Y_TRAIN_CSV)

    VAL_WELLS = {6}
    train_indices, val_indices = [], []
    for i, name in enumerate(train_ds_all.names):
        w = parse_well_id(name)
        if w in VAL_WELLS:
            val_indices.append(i)
        else:
            train_indices.append(i)

    train_ds = Subset(train_ds_all, train_indices)
    val_ds = Subset(train_ds_all, val_indices)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Training samples: {len(train_ds)} | Validation samples: {len(val_ds)} | val_wells={VAL_WELLS}")
    print(f"DEVICE: {DEVICE}")

    model = UPerNet(num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e9
    best_path = DATA_ROOT / "best_upernet.pth"

    for epoch in range(1, EPOCHS + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer)
        va_loss = valid_one_epoch(model, val_loader)
        print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={tr_loss:.4f} | val_loss={va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            torch.save(model.state_dict(), best_path)
            print(f"  -> Best model saved: {best_path}")

    out_csv = DATA_ROOT / "submission.csv"
    state_dict = torch.load(best_path, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    predict_and_make_submission(model, TEST_IMAGES_DIR, out_csv)


if __name__ == "__main__":
    main()


训练样本数: 2790 | 验证样本数: 1620 | val_wells={6}
DEVICE: cuda


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\lenovo/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:09<00:00, 10.3MB/s]


Epoch 01/20 | train_loss=0.1318 | val_loss=0.1120
  -> 保存最优模型: C:\Users\lenovo\Desktop\deep_datachallenge\best_upernet.pth
Epoch 02/20 | train_loss=0.0745 | val_loss=0.0900
  -> 保存最优模型: C:\Users\lenovo\Desktop\deep_datachallenge\best_upernet.pth
Epoch 03/20 | train_loss=0.0664 | val_loss=0.0940
Epoch 04/20 | train_loss=0.0616 | val_loss=0.1032
Epoch 05/20 | train_loss=0.0576 | val_loss=0.1283
Epoch 06/20 | train_loss=0.0555 | val_loss=0.0903
Epoch 07/20 | train_loss=0.0538 | val_loss=0.0936
Epoch 08/20 | train_loss=0.0509 | val_loss=0.1156
Epoch 09/20 | train_loss=0.0481 | val_loss=0.0915
Epoch 10/20 | train_loss=0.0446 | val_loss=0.0991
Epoch 11/20 | train_loss=0.0449 | val_loss=0.1008
Epoch 12/20 | train_loss=0.0416 | val_loss=0.1002
Epoch 13/20 | train_loss=0.0402 | val_loss=0.1123
Epoch 14/20 | train_loss=0.0419 | val_loss=0.1042
Epoch 15/20 | train_loss=0.0380 | val_loss=0.1052
Epoch 16/20 | train_loss=0.0379 | val_loss=0.1320
Epoch 17/20 | train_loss=0.0359 | val_loss=0.1031
Epoc