In [1]:
# -*- coding: utf-8 -*-
"""
Attention U-Net - Full runnable semantic segmentation code (no extra libs)

Dataset:
- Train images: X_train_uDRk9z9/images (well1-6)
- Train labels: Y_train_T9NrBYo.csv (flatten + -1 padding)
- Test images:  X_test_xNbnvIa/images  (well7-11)

Split:
- Train: well1-5
- Val:   well6

Output:
- submission.csv (each row = one patch name, flattened, padded to 160*272 with -1)

Notes:
- Input patches are (160,160) or (160,272). We pad to (160,272).
- Class ids: 0/1/2; padding is -1 (ignored in loss).
"""

import re
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset


# =========================
# 0) Paths & Hyperparameters
# =========================
DATA_ROOT = Path(r"C:\Users\lenovo\Desktop\deep_datachallenge")  # change to your path

TRAIN_IMAGES_DIR = DATA_ROOT / "X_train_uDRk9z9" / "images"
TEST_IMAGES_DIR = DATA_ROOT / "X_test_xNbnvIa" / "images"
Y_TRAIN_CSV = DATA_ROOT / "Y_train_T9NrBYo.csv"

H = 160
W = 272
NUM_CLASSES = 3
IGNORE_INDEX = -1

BATCH_SIZE = 8          # Attention U-Net is lighter than Swin/Mask2Former; 8 is often OK on 4060(8GB)
LR = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# 1) Utils
# =========================
def parse_well_id(name: str) -> int:
    """Extract well id from: well_1_section_0_patch_0 -> 1"""
    m = re.search(r"well_(\d+)_", name)
    return int(m.group(1)) if m else -1


def minmax_normalize(x: np.ndarray) -> np.ndarray:
    """Min-max normalize; replace NaN/inf with 0."""
    x = x.astype(np.float32)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    x_min = float(x.min())
    x_max = float(x.max())
    if x_max - x_min < 1e-6:
        return np.zeros_like(x, dtype=np.float32)
    return (x - x_min) / (x_max - x_min)


def pad_to_160x272(img: np.ndarray, fill_value: float = 0.0) -> np.ndarray:
    """Pad (160,160) or (160,272) to (160,272)."""
    h, w = img.shape
    assert h == H, f"Expected height {H}, got {h}"
    if w == W:
        return img
    if w < W:
        out = np.full((H, W), fill_value, dtype=img.dtype)
        out[:, :w] = img
        return out
    return img[:, :W]


def decode_mask_from_csv_row(row_values: np.ndarray) -> np.ndarray:
    """
    Decode one CSV row -> (160,w) semantic mask
    - row_values: flattened mask with -1 padding
    """
    valid = row_values[row_values != IGNORE_INDEX]
    assert len(valid) % H == 0, f"Valid mask length {len(valid)} not divisible by 160"
    ww = len(valid) // H
    return valid.reshape(H, ww).astype(np.int64)


def pad_mask_to_160x272(mask: np.ndarray) -> np.ndarray:
    """Pad (160,w) -> (160,272) using -1 for padding."""
    h, w = mask.shape
    assert h == H
    if w == W:
        return mask
    out = np.full((H, W), IGNORE_INDEX, dtype=np.int64)
    out[:, :w] = mask
    return out


# =========================
# 2) Dataset
# =========================
class WellSegDataset(Dataset):
    def __init__(self, images_dir: Path, y_csv_path: Path = None, wells=None):
        """
        y_csv_path=None => test mode.
        wells: optional set of well ids to filter.
        """
        self.images_dir = images_dir
        self.has_label = y_csv_path is not None
        self.wells = set(wells) if wells is not None else None

        all_paths = sorted(images_dir.glob("*.npy"))
        all_names = [p.stem for p in all_paths]

        if self.wells is None:
            self.image_paths = all_paths
            self.names = all_names
        else:
            keep = []
            keep_names = []
            for p, n in zip(all_paths, all_names):
                if parse_well_id(n) in self.wells:
                    keep.append(p)
                    keep_names.append(n)
            self.image_paths = keep
            self.names = keep_names

        self.y_df = pd.read_csv(y_csv_path, index_col=0) if self.has_label else None

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        name = self.names[idx]
        img_path = self.image_paths[idx]

        img = np.load(img_path)       # (160,160) or (160,272)
        raw_w = int(img.shape[1])
        img = minmax_normalize(img)
        img = pad_to_160x272(img, fill_value=0.0)
        x = torch.from_numpy(img).unsqueeze(0).float()  # (1,160,272)

        if not self.has_label:
            return {"name": name, "image": x, "raw_w": raw_w}

        row = self.y_df.loc[name].values.astype(np.int64)
        mask = decode_mask_from_csv_row(row)   # (160,w)
        mask = pad_mask_to_160x272(mask)       # (160,272)
        y = torch.from_numpy(mask).long()      # (160,272)

        return {"name": name, "image": x, "mask": y, "raw_w": raw_w}


# =========================
# 3) Attention U-Net building blocks
# =========================
class DoubleConv(nn.Module):
    """(Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        return self.net(x)


class Up(nn.Module):
    """Upscaling then double conv (we use bilinear upsample)"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x, skip):
        x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


class AttentionGate(nn.Module):
    """
    Attention Gate (AG) from Attention U-Net:
    - g: gating signal from decoder
    - x: skip connection from encoder
    output: attended skip features
    """
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=False),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=False),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # g and x may have different spatial sizes; resize g to x size
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        if g1.shape[-2:] != x1.shape[-2:]:
            g1 = F.interpolate(g1, size=x1.shape[-2:], mode="bilinear", align_corners=False)

        psi = self.relu(g1 + x1)
        psi = self.psi(psi)  # (B,1,H,W) attention mask in [0,1]
        return x * psi


class AttentionUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=3, base=32):
        super().__init__()
        # Encoder
        self.inc = DoubleConv(in_channels, base)         # 32
        self.down1 = Down(base, base*2)                  # 64
        self.down2 = Down(base*2, base*4)                # 128
        self.down3 = Down(base*4, base*8)                # 256
        self.down4 = Down(base*8, base*16)               # 512

        # Decoder + Attention gates
        self.att4 = AttentionGate(F_g=base*16, F_l=base*8,  F_int=base*4)
        self.up4 = Up(in_ch=base*16 + base*8, out_ch=base*8)

        self.att3 = AttentionGate(F_g=base*8,  F_l=base*4,  F_int=base*2)
        self.up3 = Up(in_ch=base*8 + base*4,  out_ch=base*4)

        self.att2 = AttentionGate(F_g=base*4,  F_l=base*2,  F_int=base)
        self.up2 = Up(in_ch=base*4 + base*2,  out_ch=base*2)

        self.att1 = AttentionGate(F_g=base*2,  F_l=base,    F_int=base//2)
        self.up1 = Up(in_ch=base*2 + base,    out_ch=base)

        self.outc = nn.Conv2d(base, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)       # (B,base,   H,   W)
        x2 = self.down1(x1)    # (B,base*2, H/2, W/2)
        x3 = self.down2(x2)    # (B,base*4, H/4, W/4)
        x4 = self.down3(x3)    # (B,base*8, H/8, W/8)
        x5 = self.down4(x4)    # (B,base*16,H/16,W/16)

        # Decoder with attention on skip connections
        s4 = self.att4(g=x5, x=x4)
        d4 = self.up4(x5, s4)

        s3 = self.att3(g=d4, x=x3)
        d3 = self.up3(d4, s3)

        s2 = self.att2(g=d3, x=x2)
        d2 = self.up2(d3, s2)

        s1 = self.att1(g=d2, x=x1)
        d1 = self.up1(d2, s1)

        logits = self.outc(d1)  # (B,C,H,W)
        return logits


# =========================
# 4) Train / Validate
# =========================
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0

    for batch in loader:
        x = batch["image"].to(DEVICE)      # (B,1,160,272)
        y = batch["mask"].to(DEVICE)       # (B,160,272)

        logits = model(x)                  # (B,C,160,272)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss.item()) * x.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def valid_one_epoch(model, loader):
    model.eval()
    total_loss = 0.0

    for batch in loader:
        x = batch["image"].to(DEVICE)
        y = batch["mask"].to(DEVICE)

        logits = model(x)
        loss = F.cross_entropy(logits, y, ignore_index=IGNORE_INDEX)
        total_loss += float(loss.item()) * x.size(0)

    return total_loss / len(loader.dataset)


# =========================
# 5) Inference & submission
# =========================
@torch.no_grad()
def predict_and_make_submission(model, test_images_dir: Path, out_csv_path: Path):
    """
    Predict all test patches and write submission.csv:
    - argmax over classes
    - crop to raw width
    - pad to 160*272 using -1
    """
    model.eval()

    test_ds = WellSegDataset(test_images_dir, y_csv_path=None)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    preds_dict = {}

    for batch in test_loader:
        name = batch["name"][0]
        raw_w = int(batch["raw_w"][0])
        x = batch["image"].to(DEVICE)

        logits = model(x)
        pred_full = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype(np.int64)  # (160,272)

        pred = pred_full[:, :raw_w]  # crop to original width

        if raw_w < W:
            padded = np.full((H * W,), IGNORE_INDEX, dtype=np.int64)
            padded[: H * raw_w] = pred.flatten()
            preds_dict[name] = padded
        else:
            preds_dict[name] = pred.flatten()

    sub = pd.DataFrame(preds_dict, dtype="int64").T
    sub.to_csv(out_csv_path)
    print(f"[OK] submission saved to: {out_csv_path}")


# =========================
# 6) Main
# =========================
def main():
    print(f"DEVICE: {DEVICE}")
    print(f"Train dir: {TRAIN_IMAGES_DIR}")
    print(f"Test dir:  {TEST_IMAGES_DIR}")

    # Load all train data (well1-6)
    train_ds_all = WellSegDataset(TRAIN_IMAGES_DIR, Y_TRAIN_CSV)

    # Split by well: well6 as validation
    VAL_WELLS = {6}
    train_indices, val_indices = [], []
    for i, name in enumerate(train_ds_all.names):
        w = parse_well_id(name)
        if w in VAL_WELLS:
            val_indices.append(i)
        else:
            train_indices.append(i)

    train_ds = Subset(train_ds_all, train_indices)  # well1-5
    val_ds = Subset(train_ds_all, val_indices)      # well6

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | val_wells={VAL_WELLS}")

    model = AttentionUNet(in_channels=1, num_classes=NUM_CLASSES, base=32).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e9
    best_path = DATA_ROOT / "best_attention_unet.pth"

    for epoch in range(1, EPOCHS + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer)
        va_loss = valid_one_epoch(model, val_loader)

        print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={tr_loss:.4f} | val_loss={va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            torch.save(model.state_dict(), best_path)
            print(f"  -> Best model saved: {best_path}")

    # Predict test and write submission
    out_csv = DATA_ROOT / "submission.csv"
    state_dict = torch.load(best_path, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    predict_and_make_submission(model, TEST_IMAGES_DIR, out_csv)


if __name__ == "__main__":
    main()


DEVICE: cuda
Train dir: C:\Users\lenovo\Desktop\deep_datachallenge\X_train_uDRk9z9\images
Test dir:  C:\Users\lenovo\Desktop\deep_datachallenge\X_test_xNbnvIa\images
Train samples: 2790 | Val samples: 1620 | val_wells={6}
Epoch 01/20 | train_loss=0.2101 | val_loss=0.1172
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_attention_unet.pth
Epoch 02/20 | train_loss=0.0944 | val_loss=0.1888
Epoch 03/20 | train_loss=0.0832 | val_loss=0.1081
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_attention_unet.pth
Epoch 04/20 | train_loss=0.0848 | val_loss=0.1979
Epoch 05/20 | train_loss=0.0776 | val_loss=0.0860
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_attention_unet.pth
Epoch 06/20 | train_loss=0.0754 | val_loss=0.1289
Epoch 07/20 | train_loss=0.0760 | val_loss=0.0803
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_attention_unet.pth
Epoch 08/20 | train_loss=0.0732 | val_loss=0.0971
Epoch 09/20 | trai