## Mask2former

In [None]:
"""
Mask2Former (HuggingFace Transformers) - Full runnable semantic segmentation code

What this script does:
- Reads .npy ultrasound patches
- Trains on well1-5, validates on well6 (from X_train_uDRk9z9/images)
- Predicts on X_test_xNbnvIa/images (well7-11)
- Writes submission.csv with the SAME format as before:
  - each row = one patch name
  - flattened mask
  - padded to 160*272 with -1

Why we resize:
- Mask2Former backbones are usually trained on larger resolutions.
- To keep it simple and fit RTX 4060 (8GB), we resize inputs to 224x224 during training/inference,
  then upsample predictions back to (160,272) for submission.

Install (in your CUDA environment):
    pip install transformers accelerate

Notes:
- If your machine cannot download pretrained weights (no internet), set PRETRAINED=None and it will start from scratch.
- Mask2Former expects instance-style labels: a set of binary masks + class ids per image.
  We convert your (H,W) semantic mask into that format automatically.
"""

import re
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from transformers import (
    AutoImageProcessor,
    Mask2FormerForUniversalSegmentation,
)


# =========================
# 0. Paths & Hyperparameters
# =========================
DATA_ROOT = Path(r"C:\Users\lenovo\Desktop\deep_datachallenge")  # change to your path

TRAIN_IMAGES_DIR = DATA_ROOT / "X_train_uDRk9z9" / "images"
TEST_IMAGES_DIR = DATA_ROOT / "X_test_xNbnvIa" / "images"
Y_TRAIN_CSV = DATA_ROOT / "Y_train_T9NrBYo.csv"

# Original submission size
TARGET_H = 160
TARGET_W = 272

# Model input size (keep small for 4060)
MODEL_H = 224
MODEL_W = 224

NUM_CLASSES = 3
IGNORE_INDEX = -1

BATCH_SIZE = 2          # Mask2Former is heavy; start with 1~2 on 4060 8GB
LR = 5e-5
WEIGHT_DECAY = 1e-4
EPOCHS = 20

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# A strong semantic pretrained checkpoint (Swin-T backbone)
# If you have no internet, set PRETRAINED = None
PRETRAINED = "facebook/mask2former-swin-tiny-ade-semantic"


# =========================
# 1. Utils
# =========================
def parse_well_id(name: str) -> int:
    """Extract well id from: well_1_section_0_patch_0 -> 1"""
    m = re.search(r"well_(\d+)_", name)
    return int(m.group(1)) if m else -1


def minmax_normalize(x: np.ndarray) -> np.ndarray:
    """Min-max normalize; replace NaN/inf with 0."""
    x = x.astype(np.float32)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    x_min = float(x.min())
    x_max = float(x.max())
    if x_max - x_min < 1e-6:
        return np.zeros_like(x, dtype=np.float32)
    return (x - x_min) / (x_max - x_min)


def pad_to_160x272(img: np.ndarray, fill_value: float = 0.0) -> np.ndarray:
    """Pad (160,160) or (160,272) to (160,272)."""
    h, w = img.shape
    assert h == TARGET_H, f"Expected height {TARGET_H}, got {h}"
    if w == TARGET_W:
        return img
    if w < TARGET_W:
        out = np.full((TARGET_H, TARGET_W), fill_value, dtype=img.dtype)
        out[:, :w] = img
        return out
    return img[:, :TARGET_W]


def decode_mask_from_csv_row(row_values: np.ndarray) -> np.ndarray:
    """
    Decode one CSV row -> (160,w) semantic mask
    - row_values: flattened mask with -1 padding
    """
    valid = row_values[row_values != IGNORE_INDEX]
    assert len(valid) % TARGET_H == 0, f"Valid mask length {len(valid)} not divisible by 160"
    w = len(valid) // TARGET_H
    return valid.reshape(TARGET_H, w).astype(np.int64)


def pad_mask_to_160x272(mask: np.ndarray) -> np.ndarray:
    """Pad (160,w) -> (160,272) using -1 for padding."""
    h, w = mask.shape
    assert h == TARGET_H
    if w == TARGET_W:
        return mask
    out = np.full((TARGET_H, TARGET_W), IGNORE_INDEX, dtype=np.int64)
    out[:, :w] = mask
    return out


def resize_image_torch(img_1hw: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """Resize image tensor (1,H,W) -> (1,h,w) (bilinear)."""
    x = img_1hw.unsqueeze(0)  # (1,1,H,W)
    x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
    return x.squeeze(0)       # (1,h,w)


def resize_mask_torch(mask_hw: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """Resize mask tensor (H,W) -> (h,w) (nearest)."""
    y = mask_hw.unsqueeze(0).unsqueeze(0).float()  # (1,1,H,W)
    y = F.interpolate(y, size=(h, w), mode="nearest")
    return y.squeeze(0).squeeze(0).long()


def semantic_to_mask2former_targets(
    semantic_mask: torch.Tensor,
    num_classes: int,
    ignore_index: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Convert a semantic mask (H,W) into Mask2Former targets:
    - class_labels: (N,) long
    - mask_labels:  (N,H,W) float (0/1)

    We create one binary mask per class present in the image (excluding ignore_index).
    """
    # semantic_mask: (H,W)
    valid = semantic_mask != ignore_index
    if valid.sum() == 0:
        # If everything is ignore, create a dummy empty target (rare).
        # Use background class 0 with an all-zero mask.
        class_labels = torch.tensor([0], dtype=torch.long)
        mask_labels = torch.zeros((1, semantic_mask.shape[0], semantic_mask.shape[1]), dtype=torch.float32)
        return class_labels, mask_labels

    present_classes = torch.unique(semantic_mask[valid]).tolist()
    present_classes = [int(c) for c in present_classes if 0 <= int(c) < num_classes]

    if len(present_classes) == 0:
        class_labels = torch.tensor([0], dtype=torch.long)
        mask_labels = torch.zeros((1, semantic_mask.shape[0], semantic_mask.shape[1]), dtype=torch.float32)
        return class_labels, mask_labels

    masks = []
    classes = []
    for c in present_classes:
        m = (semantic_mask == c) & valid
        if m.sum() == 0:
            continue
        masks.append(m.float())
        classes.append(c)

    if len(classes) == 0:
        class_labels = torch.tensor([0], dtype=torch.long)
        mask_labels = torch.zeros((1, semantic_mask.shape[0], semantic_mask.shape[1]), dtype=torch.float32)
        return class_labels, mask_labels

    class_labels = torch.tensor(classes, dtype=torch.long)
    mask_labels = torch.stack(masks, dim=0).float()  # (N,H,W)
    return class_labels, mask_labels


# =========================
# 2. Dataset
# =========================
class WellSegDataset(Dataset):
    def __init__(self, images_dir: Path, y_csv_path: Path = None):
        self.images_dir = images_dir
        self.has_label = y_csv_path is not None

        self.image_paths = sorted(images_dir.glob("*.npy"))
        self.names = [p.stem for p in self.image_paths]

        self.y_df = pd.read_csv(y_csv_path, index_col=0) if self.has_label else None

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        name = self.names[idx]
        img_path = self.image_paths[idx]

        img = np.load(img_path)                 # (160,160) or (160,272)
        raw_w = int(img.shape[1])               # used to crop back for submission
        img = minmax_normalize(img)
        img = pad_to_160x272(img, fill_value=0.0)

        img_t = torch.from_numpy(img).unsqueeze(0).float()      # (1,160,272)
        img_t = resize_image_torch(img_t, MODEL_H, MODEL_W)     # (1,224,224)

        if not self.has_label:
            return {"name": name, "image": img_t, "raw_w": raw_w}

        row = self.y_df.loc[name].values.astype(np.int64)
        mask = decode_mask_from_csv_row(row)                    # (160,w)
        mask = pad_mask_to_160x272(mask)                        # (160,272)
        mask_t = torch.from_numpy(mask).long()                  # (160,272)
        mask_t = resize_mask_torch(mask_t, MODEL_H, MODEL_W)    # (224,224)

        return {"name": name, "image": img_t, "mask": mask_t, "raw_w": raw_w}


# =========================
# 3. Collate for Mask2Former
# =========================
def collate_mask2former(batch: List[Dict]) -> Dict:
    """
    Build a batch dict for Mask2Former:
    - pixel_values: (B,3,224,224) float
    - pixel_mask:   (B,224,224)   bool/long (1=valid)
    - mask_labels:  list of (Ni,224,224) float
    - class_labels: list of (Ni,) long
    """
    names = [b["name"] for b in batch]
    raw_ws = torch.tensor([b["raw_w"] for b in batch], dtype=torch.long)

    # image: (1,224,224) -> (3,224,224) by repeating channel
    imgs_1 = torch.stack([b["image"] for b in batch], dim=0)  # (B,1,224,224)
    pixel_values = imgs_1.repeat(1, 3, 1, 1)                  # (B,3,224,224)

    pixel_mask = torch.ones((pixel_values.shape[0], MODEL_H, MODEL_W), dtype=torch.long)

    out = {
        "names": names,
        "raw_ws": raw_ws,
        "pixel_values": pixel_values,
        "pixel_mask": pixel_mask,
    }

    if "mask" in batch[0]:
        class_labels_list = []
        mask_labels_list = []
        for b in batch:
            y = b["mask"]  # (224,224)
            cls, msk = semantic_to_mask2former_targets(y, NUM_CLASSES, IGNORE_INDEX)
            class_labels_list.append(cls)
            mask_labels_list.append(msk)

        out["class_labels"] = class_labels_list
        out["mask_labels"] = mask_labels_list

    return out


# =========================
# 4. Model builder
# =========================
def build_model_and_processor(num_classes: int):
    id2label = {0: "class0", 1: "class1", 2: "class2"}
    label2id = {v: k for k, v in id2label.items()}

    if PRETRAINED is None:
        # Train from scratch
        processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-ade-semantic")
        model = Mask2FormerForUniversalSegmentation.from_pretrained(
            "facebook/mask2former-swin-tiny-ade-semantic",
            ignore_mismatched_sizes=True,
            id2label=id2label,
            label2id=label2id,
            num_labels=num_classes,
            use_safetensors=True,
        )
    else:
        processor = AutoImageProcessor.from_pretrained(PRETRAINED)
        model = Mask2FormerForUniversalSegmentation.from_pretrained(
            PRETRAINED,
            ignore_mismatched_sizes=True,   # allow changing num_labels
            id2label=id2label,
            label2id=label2id,
            num_labels=num_classes,
        )

    return model, processor


# =========================
# 5. Train / Validate
# =========================
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0

    for batch in loader:
        pixel_values = batch["pixel_values"].to(DEVICE)  # (B,3,224,224)
        pixel_mask = batch["pixel_mask"].to(DEVICE)      # (B,224,224)

        # Mask2Former expects lists for labels (length B)
        class_labels = [x.to(DEVICE) for x in batch["class_labels"]]
        mask_labels = [x.to(DEVICE) for x in batch["mask_labels"]]

        outputs = model(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            class_labels=class_labels,
            mask_labels=mask_labels,
        )

        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += float(loss.item()) * pixel_values.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def valid_one_epoch(model, loader):
    model.eval()
    total_loss = 0.0

    for batch in loader:
        pixel_values = batch["pixel_values"].to(DEVICE)
        pixel_mask = batch["pixel_mask"].to(DEVICE)
        class_labels = [x.to(DEVICE) for x in batch["class_labels"]]
        mask_labels = [x.to(DEVICE) for x in batch["mask_labels"]]

        outputs = model(
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            class_labels=class_labels,
            mask_labels=mask_labels,
        )
        loss = outputs.loss
        total_loss += float(loss.item()) * pixel_values.size(0)

    return total_loss / len(loader.dataset)


# =========================
# 6. Inference & submission
# =========================
@torch.no_grad()
def predict_and_make_submission(model, processor, test_images_dir: Path, out_csv_path: Path):
    """
    Predict all test patches and write submission.csv.
    Steps:
    - model predicts at 224x224
    - we use processor.post_process_semantic_segmentation to get semantic map
    - upsample semantic map to (160,272)
    - crop to raw width and pad to 160*272 with -1
    """
    model.eval()

    test_ds = WellSegDataset(test_images_dir, y_csv_path=None)
    test_loader = DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_mask2former,
    )

    preds_dict = {}

    for batch in test_loader:
        name = batch["names"][0]
        raw_w = int(batch["raw_ws"][0].item())

        pixel_values = batch["pixel_values"].to(DEVICE)  # (1,3,224,224)
        pixel_mask = batch["pixel_mask"].to(DEVICE)

        outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

        # Post-process semantic segmentation
        target_sizes = [(MODEL_H, MODEL_W)]
        seg_list = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
        seg_224 = seg_list[0].to(torch.int64)  # (224,224)

        # Upsample to (160,272) using nearest
        seg_224 = seg_224.unsqueeze(0).unsqueeze(0).float()  # (1,1,224,224)
        seg_160_272 = F.interpolate(seg_224, size=(TARGET_H, TARGET_W), mode="nearest").squeeze(0).squeeze(0)
        seg_160_272 = seg_160_272.cpu().numpy().astype(np.int64)  # (160,272)

        # Crop back to original width
        pred = seg_160_272[:, :raw_w]

        if raw_w < TARGET_W:
            padded = np.full((TARGET_H * TARGET_W,), IGNORE_INDEX, dtype=np.int64)
            padded[: TARGET_H * raw_w] = pred.flatten()
            preds_dict[name] = padded
        else:
            preds_dict[name] = pred.flatten()

    sub = pd.DataFrame(preds_dict, dtype="int64").T
    sub.to_csv(out_csv_path)
    print(f"[OK] submission saved to: {out_csv_path}")


# =========================
# 7. Main
# =========================
def main():
    print(f"DEVICE: {DEVICE}")
    print(f"Train dir: {TRAIN_IMAGES_DIR}")
    print(f"Test dir:  {TEST_IMAGES_DIR}")
    print(f"Model input size: {MODEL_H}x{MODEL_W}")
    print(f"Pretrained: {PRETRAINED}")

    # Load all train data (well1-6)
    train_ds_all = WellSegDataset(TRAIN_IMAGES_DIR, Y_TRAIN_CSV)

    # Split by well: well6 as validation
    VAL_WELLS = {6}
    train_indices, val_indices = [], []
    for i, name in enumerate(train_ds_all.names):
        w = parse_well_id(name)
        if w in VAL_WELLS:
            val_indices.append(i)
        else:
            train_indices.append(i)

    train_ds = Subset(train_ds_all, train_indices)  # well1-5
    val_ds = Subset(train_ds_all, val_indices)      # well6

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_mask2former,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_mask2former,
    )

    print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | val_wells={VAL_WELLS}")

    model, processor = build_model_and_processor(NUM_CLASSES)
    model = model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e9
    best_path = DATA_ROOT / "best_mask2former.pth"

    for epoch in range(1, EPOCHS + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer)
        va_loss = valid_one_epoch(model, val_loader)
        print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={tr_loss:.4f} | val_loss={va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            torch.save(model.state_dict(), best_path)
            print(f"  -> Best model saved: {best_path}")

    # Predict test and write submission
    out_csv = DATA_ROOT / "submission.csv"
    state_dict = torch.load(best_path, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    predict_and_make_submission(model, processor, TEST_IMAGES_DIR, out_csv)


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


DEVICE: cuda
Train dir: C:\Users\lenovo\Desktop\deep_datachallenge\X_train_uDRk9z9\images
Test dir:  C:\Users\lenovo\Desktop\deep_datachallenge\X_test_xNbnvIa\images
Model input size: 224x224
Pretrained: facebook/mask2former-swin-tiny-ade-semantic
Train samples: 2790 | Val samples: 1620 | val_wells={6}


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  image_processor = cls(**image_processor_dict)
Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-tiny-ade-semantic and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([4, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([4]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream 

Epoch 01/20 | train_loss=17.3520 | val_loss=15.0105
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_mask2former.pth
Epoch 02/20 | train_loss=13.7775 | val_loss=15.2349
Epoch 03/20 | train_loss=12.6401 | val_loss=14.5714
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_mask2former.pth
Epoch 04/20 | train_loss=12.0788 | val_loss=14.2691
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_mask2former.pth
Epoch 05/20 | train_loss=11.6676 | val_loss=15.0381
Epoch 06/20 | train_loss=11.4770 | val_loss=14.3574
Epoch 07/20 | train_loss=10.9957 | val_loss=14.1694
  -> Best model saved: C:\Users\lenovo\Desktop\deep_datachallenge\best_mask2former.pth
Epoch 08/20 | train_loss=10.6843 | val_loss=14.9301
Epoch 09/20 | train_loss=10.3577 | val_loss=15.2518
Epoch 10/20 | train_loss=10.2286 | val_loss=15.0672
Epoch 11/20 | train_loss=9.6430 | val_loss=15.5822
Epoch 12/20 | train_loss=9.4476 | val_loss=14.6725
Epoch 13/20 | train_loss=9.092