In [2]:
# -*- coding: utf-8 -*-
"""
Swin-UNet (Swin Transformer encoder + UNet decoder) - Full runnable code (224x224 input for Swin)

Pipeline:
- Train images:   X_train_uDRk9z9/images (well1-6)
- Test images:    X_test_xNbnvIa/images  (well7-11)
- Train labels:   Y_train_T9NrBYo.csv (flatten + -1 padding)
- Validation split: well6 as val, well1-5 as train
- Output: submission.csv (each row = one patch, flattened, padded to 160*272 with -1)

Key point:
- Swin encoders in timm/SMP often expect 224x224 inputs.
- We resize image/mask to 224x224 for training and inference.
- During inference we upsample logits back to (160,272), then crop to raw width and pad with -1 to match submission format.

Dependencies:
    pip install timm segmentation-models-pytorch
"""

import re
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import timm
import segmentation_models_pytorch as smp


# =========================
# 0. Paths & Hyperparameters
# =========================
DATA_ROOT = Path(r"C:\Users\lenovo\Desktop\deep_datachallenge")  # change to your path

TRAIN_IMAGES_DIR = DATA_ROOT / "X_train_uDRk9z9" / "images"
TEST_IMAGES_DIR = DATA_ROOT / "X_test_xNbnvIa" / "images"
Y_TRAIN_CSV = DATA_ROOT / "Y_train_T9NrBYo.csv"

# Original target size used by dataset/submission format
TARGET_H = 160
TARGET_W = 272

# Model input size for Swin
MODEL_H = 224
MODEL_W = 224

NUM_CLASSES = 3
IGNORE_INDEX = -1

BATCH_SIZE = 4          # for RTX 4060(8GB), start with 2~4
LR = 1e-4
WEIGHT_DECAY = 1e-4
EPOCHS = 20

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# 1. Utils
# =========================
def parse_well_id(name: str) -> int:
    """Extract well id from: well_1_section_0_patch_0 -> 1"""
    m = re.search(r"well_(\d+)_", name)
    return int(m.group(1)) if m else -1


def minmax_normalize(x: np.ndarray) -> np.ndarray:
    """Min-max normalize; replace NaN/inf with 0."""
    x = x.astype(np.float32)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    x_min = float(x.min())
    x_max = float(x.max())
    if x_max - x_min < 1e-6:
        return np.zeros_like(x, dtype=np.float32)
    return (x - x_min) / (x_max - x_min)


def pad_to_160x272(img: np.ndarray, fill_value: float = 0.0) -> np.ndarray:
    """Pad (160,160) or (160,272) to (160,272)."""
    h, w = img.shape
    assert h == TARGET_H, f"Expected height {TARGET_H}, got {h}"
    if w == TARGET_W:
        return img
    if w < TARGET_W:
        out = np.full((TARGET_H, TARGET_W), fill_value, dtype=img.dtype)
        out[:, :w] = img
        return out
    return img[:, :TARGET_W]


def decode_mask_from_csv_row(row_values: np.ndarray) -> np.ndarray:
    """
    Decode a mask from one CSV row:
    - row_values: flattened mask with -1 padding
    - remove -1 then reshape to (160, w)
    """
    valid = row_values[row_values != IGNORE_INDEX]
    assert len(valid) % TARGET_H == 0, f"Valid mask length {len(valid)} not divisible by 160"
    w = len(valid) // TARGET_H
    return valid.reshape(TARGET_H, w).astype(np.int64)


def pad_mask_to_160x272(mask: np.ndarray) -> np.ndarray:
    """Pad (160,w) to (160,272) using -1 for padding (ignore_index)."""
    h, w = mask.shape
    assert h == TARGET_H
    if w == TARGET_W:
        return mask
    out = np.full((TARGET_H, TARGET_W), IGNORE_INDEX, dtype=np.int64)
    out[:, :w] = mask
    return out


def resize_image_torch(img_t_1hw: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """
    Resize an image tensor (1,H,W) -> (1,h,w) using bilinear.
    """
    x = img_t_1hw.unsqueeze(0)  # (1,1,H,W)
    x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
    return x.squeeze(0)         # (1,h,w)


def resize_mask_torch(mask_t_hw: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """
    Resize a mask tensor (H,W) -> (h,w) using nearest.
    IGNORE_INDEX is kept as-is via nearest interpolation.
    """
    y = mask_t_hw.unsqueeze(0).unsqueeze(0).float()  # (1,1,H,W)
    y = F.interpolate(y, size=(h, w), mode="nearest")
    return y.squeeze(0).squeeze(0).long()            # (h,w)


# =========================
# 2. Dataset
# =========================
class WellSegDataset(Dataset):
    def __init__(self, images_dir: Path, y_csv_path: Path = None):
        """
        If y_csv_path is None => test mode (no labels).
        """
        self.images_dir = images_dir
        self.has_label = y_csv_path is not None

        self.image_paths = sorted(images_dir.glob("*.npy"))
        self.names = [p.stem for p in self.image_paths]

        if self.has_label:
            self.y_df = pd.read_csv(y_csv_path, index_col=0)
        else:
            self.y_df = None

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        name = self.names[idx]
        img_path = self.image_paths[idx]

        img = np.load(img_path)                 # (160,160) or (160,272)
        raw_w = img.shape[1]                    # used to crop back at submission
        img = minmax_normalize(img)
        img = pad_to_160x272(img, fill_value=0.0)

        # image: (1,160,272) -> resize to (1,224,224) for Swin
        img_t = torch.from_numpy(img).unsqueeze(0).float()  # (1,160,272)
        img_t_224 = resize_image_torch(img_t, MODEL_H, MODEL_W)  # (1,224,224)

        if not self.has_label:
            return {"name": name, "image": img_t_224, "raw_w": raw_w}

        # mask: decode -> pad to (160,272) -> resize to (224,224) for training
        row = self.y_df.loc[name].values.astype(np.int64)
        mask = decode_mask_from_csv_row(row)     # (160,w)
        mask = pad_mask_to_160x272(mask)         # (160,272)
        mask_t = torch.from_numpy(mask).long()   # (160,272)
        mask_t_224 = resize_mask_torch(mask_t, MODEL_H, MODEL_W)  # (224,224)

        return {"name": name, "image": img_t_224, "mask": mask_t_224, "raw_w": raw_w}


# =========================
# 3. Swin-UNet model (SMP + timm)
# =========================
def choose_swin_encoder_name() -> str:
    """
    Pick a Swin encoder name that exists in timm.
    SMP uses timm models (sometimes with 'tu-' prefix).
    """
    candidates = [
        "tu-swin_tiny_patch4_window7_224",
        "tu-swin_small_patch4_window7_224",
        "tu-swin_base_patch4_window7_224",
        "swin_tiny_patch4_window7_224",
        "swin_small_patch4_window7_224",
        "swin_base_patch4_window7_224",
    ]

    timm_models = set(timm.list_models())
    for name in candidates:
        raw = name.replace("tu-", "")
        if raw in timm_models:
            return name

    return "tu-swin_tiny_patch4_window7_224"


def build_swin_unet(num_classes: int) -> torch.nn.Module:
    encoder_name = choose_swin_encoder_name()

    # If your machine cannot download weights, set encoder_weights=None
    model = smp.Unet(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=1,
        classes=num_classes,
        activation=None,
    )
    return model


# =========================
# 4. Train / Validate
# =========================
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0

    for batch in loader:
        x = batch["image"].to(DEVICE)      # (B,1,224,224)
        y = batch["mask"].to(DEVICE)       # (B,224,224)

        logits = model(x)                  # (B,C,224,224)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def valid_one_epoch(model, loader):
    model.eval()
    total_loss = 0.0

    for batch in loader:
        x = batch["image"].to(DEVICE)
        y = batch["mask"].to(DEVICE)

        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)
        total_loss += loss.item() * x.size(0)

    return total_loss / len(loader.dataset)


# =========================
# 5. Inference & submission
# =========================
@torch.no_grad()
def predict_and_make_submission(model, test_images_dir: Path, out_csv_path: Path):
    """
    Predict all .npy in test_images_dir and write submission.csv
    - Model runs on 224x224 input
    - Logits are upsampled back to (160,272)
    - Then we crop to raw width and pad with -1 to match submission format
    """
    model.eval()

    test_ds = WellSegDataset(test_images_dir, y_csv_path=None)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    preds_dict = {}

    for batch in test_loader:
        name = batch["name"][0]
        raw_w = int(batch["raw_w"][0])
        x = batch["image"].to(DEVICE)  # (1,1,224,224)

        logits_224 = model(x)  # (1,C,224,224)
        logits = F.interpolate(logits_224, size=(TARGET_H, TARGET_W), mode="bilinear", align_corners=False)
        pred_full = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.int64)  # (160,272)

        pred = pred_full[:, :raw_w]  # crop back to original width

        if raw_w < TARGET_W:
            padded = np.full((TARGET_H * TARGET_W,), IGNORE_INDEX, dtype=np.int64)
            padded[: TARGET_H * raw_w] = pred.flatten()
            preds_dict[name] = padded
        else:
            preds_dict[name] = pred.flatten()

    sub = pd.DataFrame(preds_dict, dtype="int64").T
    sub.to_csv(out_csv_path)
    print(f"[OK] submission saved to: {out_csv_path}")


# =========================
# 6. Main
# =========================
def main():
    print(f"DEVICE: {DEVICE}")
    print(f"Train dir: {TRAIN_IMAGES_DIR}")
    print(f"Test dir:  {TEST_IMAGES_DIR}")
    print(f"Model input size: {MODEL_H}x{MODEL_W}")

    # (A) Load all train data (well1-6)
    train_ds_all = WellSegDataset(TRAIN_IMAGES_DIR, Y_TRAIN_CSV)

    # (B) Split by well: use well6 as validation
    VAL_WELLS = {6}
    train_indices, val_indices = [], []
    for i, name in enumerate(train_ds_all.names):
        w = parse_well_id(name)
        if w in VAL_WELLS:
            val_indices.append(i)
        else:
            train_indices.append(i)

    train_ds = Subset(train_ds_all, train_indices)  # well1-5
    val_ds = Subset(train_ds_all, val_indices)      # well6

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | val_wells={VAL_WELLS}")

    # (C) Model & optimizer
    model = build_swin_unet(num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    # (D) Train
    best_val = 1e9
    best_path = DATA_ROOT / "best_swin_unet.pth"

    for epoch in range(1, EPOCHS + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer)
        va_loss = valid_one_epoch(model, val_loader)

        print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={tr_loss:.4f} | val_loss={va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            torch.save(model.state_dict(), best_path)
            print(f"  -> Best model saved: {best_path}")

    # (E) Predict test and write submission
    out_csv = DATA_ROOT / "submission.csv"
    state_dict = torch.load(best_path, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    predict_and_make_submission(model, TEST_IMAGES_DIR, out_csv)


if __name__ == "__main__":
    main()


DEVICE: cuda
Train dir: C:\Users\lenovo\Desktop\deep_datachallenge\X_train_uDRk9z9\images
Test dir:  C:\Users\lenovo\Desktop\deep_datachallenge\X_test_xNbnvIa\images
Model input size: 224x224
Train samples: 2790 | Val samples: 1620 | val_wells={6}
Epoch 01/20 | train_loss=0.2381 | val_loss=0.0994
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_swin_unet.pth
Epoch 02/20 | train_loss=0.0835 | val_loss=0.0767
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_swin_unet.pth
Epoch 03/20 | train_loss=0.0714 | val_loss=0.0793
Epoch 04/20 | train_loss=0.0675 | val_loss=0.0746
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_swin_unet.pth
Epoch 05/20 | train_loss=0.0666 | val_loss=0.0761
Epoch 06/20 | train_loss=0.0645 | val_loss=0.0755
Epoch 07/20 | train_loss=0.0617 | val_loss=0.0773
Epoch 08/20 | train_loss=0.0612 | val_loss=0.0849
Epoch 09/20 | train_loss=0.0611 | val_loss=0.0784
Epoch 10/20 | train_loss=0.0592 | val_loss=0.07