In [None]:
"""
UPerNet + Swin Transformer backbone
FINAL VERSION (NO cv2, PyTorch interpolate only)

Key points:
- Swin requires 224x224 input
- Original data: (160,160) or (160,272)
- Strategy:
    (160,w) -> pad to (160,272) -> resize to (224,224) -> model
    output (224,224) -> resize back to (160,272) -> crop raw_w
- Submission remains 160*272 with -1 padding
"""

import re
from pathlib import Path
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import segmentation_models_pytorch as smp


# =========================
# 0) Config
# =========================
DATA_ROOT = Path(r"C:\Users\lenovo\Desktop\deep_datachallenge")

TRAIN_IMG_DIR = DATA_ROOT / "X_train_uDRk9z9" / "images"
TEST_IMG_DIR  = DATA_ROOT / "X_test_xNbnvIa" / "images"
Y_TRAIN_CSV   = DATA_ROOT / "Y_train_T9NrBYo.csv"

# submission size (fixed by challenge)
H_SUB, W_SUB = 160, 272

# model input size (Swin requirement)
H_MODEL, W_MODEL = 224, 224

NUM_CLASSES = 3
IGNORE_INDEX = -1

BATCH_SIZE = 4        # RTX 4060 (8GB): 2~4 recommended
LR = 1e-4
WEIGHT_DECAY = 1e-4
EPOCHS = 20

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# 1) Utils
# =========================
def parse_well_id(name: str) -> int:
    m = re.search(r"well_(\d+)_", name)
    return int(m.group(1)) if m else -1


def minmax_norm(x: np.ndarray) -> np.ndarray:
    x = x.astype(np.float32)
    x = np.nan_to_num(x)
    mn, mx = x.min(), x.max()
    if mx - mn < 1e-6:
        return np.zeros_like(x)
    return (x - mn) / (mx - mn)


def resize_image_to_224(img: np.ndarray) -> np.ndarray:
    """
    img: (160, w)
    -> pad to (160,272)
    -> resize to (224,224) using bilinear
    """
    padded = np.zeros((H_SUB, W_SUB), dtype=img.dtype)
    h, w = img.shape
    padded[:, :w] = img

    x = torch.from_numpy(padded).unsqueeze(0).unsqueeze(0).float()  # (1,1,160,272)
    x = F.interpolate(x, size=(H_MODEL, W_MODEL), mode="bilinear", align_corners=False)
    return x.squeeze(0).squeeze(0).numpy()


def decode_mask(row: np.ndarray) -> np.ndarray:
    valid = row[row != IGNORE_INDEX]
    w = len(valid) // H_SUB
    return valid.reshape(H_SUB, w).astype(np.int64)


def resize_mask_to_224(mask: np.ndarray) -> np.ndarray:
    """
    mask: (160, w)
    -> pad to (160,272) with -1
    -> resize to (224,224) using nearest
    """
    padded = np.full((H_SUB, W_SUB), IGNORE_INDEX, dtype=np.int64)
    h, w = mask.shape
    padded[:, :w] = mask

    x = torch.from_numpy(padded).unsqueeze(0).unsqueeze(0).float()
    x = F.interpolate(x, size=(H_MODEL, W_MODEL), mode="nearest")
    return x.squeeze(0).squeeze(0).long().numpy()


# =========================
# 2) Dataset
# =========================
class WellDataset(Dataset):
    def __init__(self, img_dir: Path, y_csv: Path = None):
        self.paths = sorted(img_dir.glob("*.npy"))
        self.names = [p.stem for p in self.paths]
        self.has_label = y_csv is not None
        self.y_df = pd.read_csv(y_csv, index_col=0) if self.has_label else None

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        name = self.names[idx]
        img = np.load(self.paths[idx])          # (160, w)
        raw_w = img.shape[1]

        img = minmax_norm(img)
        img224 = resize_image_to_224(img)
        x = torch.from_numpy(img224).unsqueeze(0).float()  # (1,224,224)

        if not self.has_label:
            return {"name": name, "image": x, "raw_w": raw_w}

        row = self.y_df.loc[name].values.astype(np.int64)
        mask = decode_mask(row)
        mask224 = resize_mask_to_224(mask)
        y = torch.from_numpy(mask224).long()  # (224,224)

        return {"name": name, "image": x, "mask": y, "raw_w": raw_w}


# =========================
# 3) Model
# =========================
def build_model():
    return smp.UPerNet(
        encoder_name="tu-swin_small_patch4_window7_224",
        encoder_weights="imagenet",
        in_channels=1,
        classes=NUM_CLASSES,
        activation=None,
    )


# =========================
# 4) Train / Validate
# =========================
def train_one_epoch(model, loader, optimizer):
    model.train()
    total = 0.0

    for b in loader:
        x = b["image"].to(DEVICE)
        y = b["mask"].to(DEVICE)

        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item() * x.size(0)

    return total / len(loader.dataset)


@torch.no_grad()
def eval_one_epoch(model, loader):
    model.eval()
    total = 0.0

    for b in loader:
        x = b["image"].to(DEVICE)
        y = b["mask"].to(DEVICE)
        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)
        total += loss.item() * x.size(0)

    return total / len(loader.dataset)


# =========================
# 5) Inference & Submission
# =========================
@torch.no_grad()
def predict_and_submit(model, out_csv: Path):
    model.eval()
    test_ds = WellDataset(TEST_IMG_DIR)
    loader = DataLoader(test_ds, batch_size=1, shuffle=False)

    preds = {}
    for b in loader:
        name = b["name"][0]
        raw_w = int(b["raw_w"][0])
        x = b["image"].to(DEVICE)

        logits = model(x)
        pred224 = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()

        # resize back to (160,272)
        x2 = torch.from_numpy(pred224).unsqueeze(0).unsqueeze(0).float()
        x2 = F.interpolate(x2, size=(H_SUB, W_SUB), mode="nearest")
        pred160 = x2.squeeze(0).squeeze(0).long().numpy()

        pred160 = pred160[:, :raw_w]

        flat = np.full((H_SUB * W_SUB,), IGNORE_INDEX, dtype=np.int64)
        flat[: H_SUB * raw_w] = pred160.flatten()
        preds[name] = flat

    pd.DataFrame(preds, dtype="int64").T.to_csv(out_csv)
    print(f"[OK] submission saved to {out_csv}")


# =========================
# 6) Main
# =========================
def main():
    print("DEVICE:", DEVICE)

    full_ds = WellDataset(TRAIN_IMG_DIR, Y_TRAIN_CSV)

    train_idx, val_idx = [], []
    for i, n in enumerate(full_ds.names):
        if parse_well_id(n) == 6:
            val_idx.append(i)
        else:
            train_idx.append(i)

    train_ds = Subset(full_ds, train_idx)  # well1â€“5
    val_ds   = Subset(full_ds, val_idx)    # well6

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE)

    model = build_model().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e9
    best_path = DATA_ROOT / "best_upernet_swin.pth"

    for epoch in range(1, EPOCHS + 1):
        tr = train_one_epoch(model, train_loader, optimizer)
        va = eval_one_epoch(model, val_loader)
        print(f"Epoch {epoch:02d}/{EPOCHS} | train={tr:.4f} | val={va:.4f}")

        if va < best_val:
            best_val = va
            torch.save(model.state_dict(), best_path)
            print("  -> best model saved")

    model.load_state_dict(torch.load(best_path, map_location=DEVICE, weights_only=True))
    predict_and_submit(model, DATA_ROOT / "submission.csv")


if __name__ == "__main__":
    main()


DEVICE: cuda
Epoch 01/20 | train=0.1124 | val=0.0792
  -> best model saved
Epoch 02/20 | train=0.0740 | val=0.0748
  -> best model saved
Epoch 03/20 | train=0.0674 | val=0.0782
Epoch 04/20 | train=0.0624 | val=0.0876
Epoch 05/20 | train=0.0597 | val=0.0759
Epoch 06/20 | train=0.0535 | val=0.0836
Epoch 07/20 | train=0.0491 | val=0.0806
Epoch 08/20 | train=0.0463 | val=0.0891
Epoch 09/20 | train=0.0423 | val=0.0952
Epoch 10/20 | train=0.0395 | val=0.1013
Epoch 11/20 | train=0.0362 | val=0.0972
Epoch 12/20 | train=0.0334 | val=0.1049
Epoch 13/20 | train=0.0334 | val=0.1016
Epoch 14/20 | train=0.0295 | val=0.1128
Epoch 15/20 | train=0.0279 | val=0.1291
Epoch 16/20 | train=0.0273 | val=0.1160
Epoch 17/20 | train=0.0250 | val=0.1186
Epoch 18/20 | train=0.0289 | val=0.1225
Epoch 19/20 | train=0.0227 | val=0.1314
Epoch 20/20 | train=0.0234 | val=0.1459
[OK] submission saved to C:\Users\lenovo\Desktop\deep_datachallenge\submission.csv
