In [None]:
# --- Standalone cell: ORIGINAL | GT (colored) | OVERLAY as three separate images ---

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

# Import your project pieces
from mask2former import (
    SegmentationDataModule, 
    DATASET_DIR, 
    BATCH_SIZE, 
    NUM_WORKERS, 
    ID2LABEL,
)

# ---------- helpers ----------

def _to_numpy_image(img_t):
    """
    Accepts CHW or HWC torch/numpy; returns uint8 HWC numpy.
    Handles [0,1] floats or [0,255] uint8.
    """
    if isinstance(img_t, torch.Tensor):
        img = img_t.detach().cpu().numpy()
    else:
        img = np.asarray(img_t)

    # CHW -> HWC if needed
    if img.ndim == 3 and img.shape[0] in (1, 3):
        img = np.moveaxis(img, 0, -1)

    # If single channel, stack to 3 channels for nicer visualization
    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)

    # normalize to 0..255 uint8
    img_float = img.astype(np.float32)
    if img_float.max() <= 1.0:
        img_float = img_float * 255.0
    img_uint8 = np.clip(img_float, 0, 255).astype(np.uint8)
    return img_uint8

def _make_palette(id2label, seed=13):
    """
    Deterministic color palette for classes: returns ListedColormap + id->color map.
    Background (0) set to black for contrast (if present).
    """
    rng = np.random.default_rng(seed)
    max_id = max(id2label.keys()) if len(id2label) else 0
    num_classes = max_id + 1
    colors = rng.uniform(0, 1, size=(num_classes, 3))
    if 0 in id2label:
        colors[0] = np.array([0.0, 0.0, 0.0])
    cmap = ListedColormap(colors)
    id2color = {i: colors[i] for i in range(num_classes)}
    return cmap, id2color

def _legend_handles(mask, id2label, id2color):
    present_ids = np.unique(mask)
    handles = []
    for i in present_ids:
        i = int(i)
        label = id2label.get(i, f"id_{i}")
        col = id2color.get(i, (0.5, 0.5, 0.5))
        handles.append(Patch(facecolor=col, edgecolor='none', label=f"{i}: {label}"))
    if not handles:
        handles = [Patch(facecolor='gray', edgecolor='none', label="(no labels)")]
    return handles

def visualize_image_gt_overlay(batch, id2label, sample_index=0, overlay_alpha=0.5, figsize=(18,6)):
    """
    Shows three separate panels:
      1) Original image
      2) Ground-truth (colored) + legend
      3) Overlay: original image with GT mask blended on top
    Expects batch keys: "original_images", "original_segmentation_maps"
    """
    orig_imgs = batch["original_images"]
    gts = batch["original_segmentation_maps"]

    img = _to_numpy_image(orig_imgs[sample_index])
    gt = gts[sample_index]
    if isinstance(gt, torch.Tensor):
        gt = gt.detach().cpu().numpy()
    else:
        gt = np.asarray(gt)

    cmap, id2color = _make_palette(id2label)
    vmax = max(id2label.keys()) if len(id2label) else 0

    # Plot as three separate subplots
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    # 1) Original
    ax = axes[0]
    ax.imshow(img)
    ax.set_title("Original")
    ax.axis("off")

    # 2) GT (colored)
    ax = axes[1]
    ax.imshow(gt, cmap=cmap, vmin=0, vmax=vmax, interpolation='nearest')
    ax.set_title("Ground Truth")
    ax.axis("off")
    handles = _legend_handles(gt, id2label, id2color)
    ax.legend(handles=handles, loc="lower right", fontsize=8, frameon=True, ncol=1)

    # 3) Overlay (image + GT)
    ax = axes[2]
    ax.imshow(img)
    ax.imshow(gt, cmap=cmap, vmin=0, vmax=vmax, alpha=overlay_alpha, interpolation='nearest')
    ax.set_title("Overlay (Image + GT)")
    ax.axis("off")

    plt.tight_layout()
    plt.show()

# ---------- build a batch (standalone) ----------

# Prefer test if available, else val, else train
dm = SegmentationDataModule(dataset_dir=DATASET_DIR, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

dl = None
try:
    dm.setup("test")
    dl = dm.test_dataloader()
except Exception:
    dl = None

if dl is None:
    try:
        dm.setup("validate")
        dl = dm.val_dataloader()
    except Exception:
        dl = None

if dl is None:
    dm.setup("fit")
    dl = dm.train_dataloader()

batch = next(iter(dl))

# ---------- visualize ----------
visualize_image_gt_overlay(batch, ID2LABEL, sample_index=0, overlay_alpha=0.5, figsize=(18,6))

In [None]:
# --- Standalone cell: PROGRAMMABLE N SAMPLES (rows) -> [OG | OG+GT | OG+PRED] ---

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Project imports (adapt to your package layout if needed)
from mask2former import (
    Mask2FormerFinetuner,
    SegmentationDataModule,
    DATASET_DIR,
    BATCH_SIZE,
    NUM_WORKERS,
    ID2LABEL,
    LEARNING_RATE,
)

# ----------------- CONFIG -----------------
CKPT_PATH = "/home/erik/Documents/Finetune-Mask2Former/outputs/lightning_logs_csv/version_4/checkpoints/epoch=6-step=6699.ckpt"   # change if your checkpoint is elsewhere
N_SAMPLES  = 50                  # <--- number of rows to show
OVERLAY_ALPHA = 0.5
FIG_W = 18
ROW_H = 3.0                      # figure height per row; total height = ROW_H * N_SAMPLES
SEED_PALETTE = 13
# ------------------------------------------

def _to_numpy_image(img_t):
    """Accepts CHW or HWC torch/numpy; returns uint8 HWC numpy. Handles [0,1] floats."""
    if isinstance(img_t, torch.Tensor):
        img = img_t.detach().cpu().numpy()
    else:
        img = np.asarray(img_t)

    if img.ndim == 3 and img.shape[0] in (1, 3):  # CHW -> HWC
        img = np.moveaxis(img, 0, -1)
    if img.ndim == 2:                             # grayscale -> 3ch
        img = np.stack([img]*3, axis=-1)

    img = img.astype(np.float32)
    if img.max() <= 1.0:
        img = img * 255.0
    return np.clip(img, 0, 255).astype(np.uint8)

def _make_palette(id2label, seed=SEED_PALETTE):
    """Deterministic palette for 0..max_id."""
    rng = np.random.default_rng(seed)
    max_id = max(id2label.keys()) if len(id2label) else 0
    num_classes = max_id + 1
    colors = rng.uniform(0, 1, size=(num_classes, 3))
    if 0 in id2label:
        colors[0] = np.array([0.0, 0.0, 0.0])
    return ListedColormap(colors), max_id

# --------- load data & model ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dm = SegmentationDataModule(dataset_dir=DATASET_DIR, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
# Prefer test -> val -> train
dl = None
try:
    dm.setup("test")
    dl = dm.test_dataloader()
    split_name = "test"
except Exception:
    dl = None

if dl is None:
    try:
        dm.setup("validate")
        dl = dm.val_dataloader()
        split_name = "val"
    except Exception:
        dl = None


if dl is None:
    raise RuntimeError("No test or val dataloader available — cannot visualize samples.")

print(f"Showing samples from the **{split_name} split**")

model = Mask2FormerFinetuner.load_from_checkpoint(
    CKPT_PATH,
    id2label=ID2LABEL,
    lr=LEARNING_RATE,
).eval().to(device)

# --------- collect up to N_SAMPLES ----------
samples = []  # list of dicts: {img, gt, pred}
cmap, vmax = _make_palette(ID2LABEL)

with torch.no_grad():
    for batch in dl:
        # batch keys expected: original_images, original_segmentation_maps, pixel_values
        orig_imgs = batch["original_images"]
        gts       = batch["original_segmentation_maps"]
        pvals     = batch["pixel_values"].to(device)

        # forward entire batch at once
        outputs = model.model(pixel_values=pvals)

        # resize predictions to each original image size
        target_sizes = []
        for img in orig_imgs:
            # support torch or numpy shapes [C,H,W] or [H,W,C]
            if isinstance(img, torch.Tensor):
                arr = img.detach().cpu().numpy()
            else:
                arr = np.asarray(img)
            if arr.ndim == 3 and arr.shape[0] in (1,3):  # CHW
                H, W = arr.shape[1], arr.shape[2]
            else:                                        # HWC
                H, W = arr.shape[0], arr.shape[1]
            target_sizes.append((H, W))

        pred_list = model.processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

        # stash results
        for i in range(len(orig_imgs)):
            img_np = _to_numpy_image(orig_imgs[i])
            gt_np  = gts[i].detach().cpu().numpy() if isinstance(gts[i], torch.Tensor) else np.asarray(gts[i])
            pred_np = pred_list[i].detach().cpu().numpy()
            samples.append({"img": img_np, "gt": gt_np, "pred": pred_np})
            if len(samples) >= N_SAMPLES:
                break
        if len(samples) >= N_SAMPLES:
            break

if len(samples) == 0:
    raise RuntimeError("No samples found. Check your dataloaders / dataset paths.")

# --------- plot grid: N_SAMPLES rows x 3 cols ----------
fig, axes = plt.subplots(
    nrows=len(samples),
    ncols=3,
    figsize=(FIG_W, ROW_H * len(samples)),
)

# if only 1 sample, axes shape comes as (3,), normalize to 2D indexing
if len(samples) == 1:
    axes = np.expand_dims(axes, axis=0)

for r, s in enumerate(samples):
    # 1) Original
    ax = axes[r, 0]
    ax.imshow(s["img"])
    ax.set_title("Original")
    ax.axis("off")

    # 2) Original + GT
    ax = axes[r, 1]
    ax.imshow(s["img"])
    ax.imshow(s["gt"], cmap=cmap, vmin=0, vmax=vmax, alpha=OVERLAY_ALPHA, interpolation='nearest')
    ax.set_title("Original + GT")
    ax.axis("off")

    # 3) Original + Prediction
    ax = axes[r, 2]
    ax.imshow(s["img"])
    ax.imshow(s["pred"], cmap=cmap, vmin=0, vmax=vmax, alpha=OVERLAY_ALPHA, interpolation='nearest')
    ax.set_title("Original + Prediction")
    ax.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# --- Standalone cell: Inference on ALL images in a folder -> [Original | Original + Prediction] ---

import os
from pathlib import Path
from typing import List, Tuple

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Project imports (adjust if your package path differs)
from mask2former import (
    Mask2FormerFinetuner,
    ID2LABEL,
    LEARNING_RATE,
)

# ===================== USER CONFIG =====================
CKPT_PATH   = "/home/erik/Documents/Finetune-Mask2Former/outputs/lightning_logs_csv/version_4/checkpoints/epoch=6-step=6699.ckpt"     # <-- change to your .ckpt
IMAGES_DIR  = "/home/erik/Documents/Finetune-Mask2Former/data/rs19/images/custom_test"   # <-- folder with input images
IMAGE_EXTS  = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

BATCH_SIZE  = 8            # batching for speed/memory
OVERLAY_ALPHA = 0.5
SEED_PALETTE  = 14

# Display options:
SHOW_MAX = -1              # -1 = show ALL; otherwise show only first N images in the notebook
FIG_W    = 14              # width of the full figure
ROW_H    = 3.2             # height per image row (figure height = ROW_H * num_rows)
# =======================================================


# ---------- helpers ----------
def list_images(folder: str, exts: set) -> List[Path]:
    p = Path(folder)
    files = [f for f in sorted(p.iterdir()) if f.suffix.lower() in exts and f.is_file()]
    if not files:
        raise FileNotFoundError(f"No images found in {folder} with extensions {exts}")
    return files

def load_rgb(path: Path) -> Image.Image:
    img = Image.open(path).convert("RGB")
    return img

def _make_palette(id2label, seed=SEED_PALETTE) -> Tuple[ListedColormap, int]:
    rng = np.random.default_rng(seed)
    max_id = max(id2label.keys()) if len(id2label) else 0
    num_classes = max_id + 1
    colors = rng.uniform(0, 1, size=(num_classes, 3))
    if 0 in id2label:
        colors[0] = np.array([0.0, 0.0, 0.0])  # background black for contrast
    return ListedColormap(colors), max_id

def overlay_image_with_mask(img_np: np.ndarray, mask: np.ndarray, cmap: ListedColormap, vmax: int, alpha: float) -> np.ndarray:
    """
    Returns an RGBA overlay image drawn with matplotlib, then captured as numpy.
    (We also handle visualization with plt directly below; this is for saving to disk.)
    """
    import matplotlib.pyplot as plt
    from io import BytesIO
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(img_np)
    ax.imshow(mask, cmap=cmap, vmin=0, vmax=vmax, alpha=alpha, interpolation='nearest')
    ax.axis("off")
    buf = BytesIO()
    plt.tight_layout(pad=0)
    fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
    plt.close(fig)
    buf.seek(0)
    out = Image.open(buf).convert("RGB")
    return np.array(out)

def pil_to_numpy_uint8(img_pil: Image.Image) -> np.ndarray:
    return np.array(img_pil, dtype=np.uint8)

# ---------- load model ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Mask2FormerFinetuner.load_from_checkpoint(
    CKPT_PATH,
    id2label=ID2LABEL,
    lr=LEARNING_RATE,
).eval().to(device)

processor = model.processor  # AutoImageProcessor from your finetuner
cmap, vmax = _make_palette(ID2LABEL)

# ---------- gather images ----------
paths = list_images(IMAGES_DIR, IMAGE_EXTS)
num_imgs = len(paths)
print(f"Found {num_imgs} images in: {IMAGES_DIR}")

# ---------- batched inference ----------
all_originals = []   # list of np.uint8 HxWx3
all_preds     = []   # list of int HxW

with torch.no_grad():
    # we will accumulate batches of PIL images and their target sizes
    batch_imgs_pil = []
    batch_target_sizes = []
    batch_indices = []

    def flush_batch():
        if not batch_imgs_pil:
            return
        # preprocess as a batch
        inputs = processor(images=batch_imgs_pil, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(device)
        outputs = model.model(pixel_values=pixel_values)
        pred_list = processor.post_process_semantic_segmentation(outputs, target_sizes=batch_target_sizes)

        # store results
        for idx_in_batch, pred in enumerate(pred_list):
            pred_np = pred.detach().cpu().numpy()
            img_np  = pil_to_numpy_uint8(batch_imgs_pil[idx_in_batch])
            # Keep in arrays
            all_originals.append(img_np)
            all_preds.append(pred_np)

        # clear batch
        batch_imgs_pil.clear()
        batch_target_sizes.clear()
        batch_indices.clear()

    for idx, p in enumerate(paths):
        img_pil = load_rgb(p)
        H, W = img_pil.size[1], img_pil.size[0]  # PIL: size = (W, H)
        batch_imgs_pil.append(img_pil)
        batch_target_sizes.append((H, W))
        batch_indices.append(idx)

        if len(batch_imgs_pil) == BATCH_SIZE:
            flush_batch()

    # leftover
    flush_batch()

# ---------- visualize ALL: 2 columns (Original | Original + Prediction) ----------
to_show = len(all_originals) if SHOW_MAX < 0 else min(SHOW_MAX, len(all_originals))
fig, axes = plt.subplots(
    nrows=to_show,
    ncols=2,
    figsize=(FIG_W, ROW_H * to_show),
)

# normalize axes to 2D even if single image
if to_show == 1:
    axes = np.expand_dims(axes, axis=0)

for r in range(to_show):
    img_np = all_originals[r]
    pred   = all_preds[r]

    # 1) Original
    ax = axes[r, 0]
    ax.imshow(img_np)
    ax.set_title(f"Original ({Path(paths[r]).name})")
    ax.axis("off")

    # 2) Original + Prediction
    ax = axes[r, 1]
    ax.imshow(img_np)
    ax.imshow(pred, cmap=cmap, vmin=0, vmax=vmax, alpha=OVERLAY_ALPHA, interpolation='nearest')
    ax.set_title("Original + Prediction")
    ax.axis("off")

plt.tight_layout()
plt.show()


In [None]:
import os
from pathlib import Path
from typing import List, Tuple
from tqdm import tqdm

import torch
import numpy as np
from PIL import Image
from matplotlib.colors import ListedColormap

# Project imports (adjust if your package path differs)
from mask2former import (
    Mask2FormerFinetuner,
    ID2LABEL,
    LEARNING_RATE,
)

# ===================== USER CONFIG =====================
CKPT_PATH    = "/home/erik/Documents/Finetune-Mask2Former/outputs/lightning_logs_csv/version_4/checkpoints/epoch=6-step=6699.ckpt"
IMAGES_DIR   = "//home/erik/Desktop/test_imgs/imgs_26"                                     
OUTPUT_DIR   = "/home/erik/Desktop/test_imgs/workspace_26/segmentation/mask2former"                   
IMAGE_EXTS   = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

BATCH_SIZE     = 64           # batching for speed/memory
OVERLAY_ALPHA  = 0.5         # transparency of the segmentation overlay
SEED_PALETTE   = 14          # to get consistent colors across runs

# Optional: also save raw label mask as PNG (each pixel = class id)
SAVE_RAW_MASKS = False
RAW_MASKS_DIR  = "/home/erik/Documents/Finetune-Mask2Former/outputs/masks"
# =======================================================

# ---------- helpers ----------
def list_images(folder: str, exts: set) -> List[Path]:
    p = Path(folder)
    files = [f for f in sorted(p.iterdir()) if f.suffix.lower() in exts and f.is_file()]
    if not files:
        raise FileNotFoundError(f"No images found in {folder} with extensions {exts}")
    return files

def load_rgb(path: Path) -> Image.Image:
    return Image.open(path).convert("RGB")

def _make_palette(id2label, seed=SEED_PALETTE) -> Tuple[ListedColormap, int]:
    rng = np.random.default_rng(seed)
    max_id = max(id2label.keys()) if len(id2label) else 0
    num_classes = max_id + 1
    colors = rng.uniform(0, 1, size=(num_classes, 3))
    if 0 in id2label:
        colors[0] = np.array([0.0, 0.0, 0.0])  # background black
    return ListedColormap(colors), max_id

def overlay_image_with_mask(img_np: np.ndarray, mask: np.ndarray, cmap: ListedColormap, vmax: int, alpha: float) -> np.ndarray:
    """
    Overlay segmentation mask directly onto the original RGB image (keeps original resolution).
    """
    # colorize mask
    mask_rgba = cmap(mask / vmax)  # normalized to [0,1]
    mask_rgb = (mask_rgba[:, :, :3] * 255).astype(np.uint8)

    # blend: img * (1-alpha) + mask * alpha (where mask != background)
    overlay = img_np.copy()
    non_bg = mask != 0
    overlay[non_bg] = ((1 - alpha) * img_np[non_bg] + alpha * mask_rgb[non_bg]).astype(np.uint8)

    return overlay


# ---------- prep IO ----------
paths = list_images(IMAGES_DIR, IMAGE_EXTS)
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
if SAVE_RAW_MASKS:
    Path(RAW_MASKS_DIR).mkdir(parents=True, exist_ok=True)

print(f"Found {len(paths)} images in: {IMAGES_DIR}")
print(f"Saving overlays to: {OUTPUT_DIR}")

# ---------- load model ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Mask2FormerFinetuner.load_from_checkpoint(
    CKPT_PATH,
    id2label=ID2LABEL,
    lr=LEARNING_RATE,
).eval().to(device)

processor = model.processor
cmap, vmax = _make_palette(ID2LABEL)

# ---------- batched inference and saving ----------
with torch.no_grad():
    batch_imgs_pil = []
    batch_sizes = []
    batch_paths = []

    def flush_batch():
        if not batch_imgs_pil:
            return
        inputs = processor(images=batch_imgs_pil, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(device)
        outputs = model.model(pixel_values=pixel_values)
        pred_list = processor.post_process_semantic_segmentation(outputs, target_sizes=batch_sizes)

        for img_pil, mask_t, p in zip(batch_imgs_pil, pred_list, batch_paths):
            mask_np = mask_t.detach().cpu().numpy().astype(np.int32)
            img_np = np.array(img_pil, dtype=np.uint8)
            overlay_np = overlay_image_with_mask(img_np, mask_np, cmap, vmax, OVERLAY_ALPHA)

            # save overlay
            out_path = Path(OUTPUT_DIR) / p.name
            Image.fromarray(overlay_np).save(out_path)

            if SAVE_RAW_MASKS:
                mask_path = Path(RAW_MASKS_DIR) / (p.stem + ".png")
                Image.fromarray(mask_np.astype(np.uint16), mode="I;16").save(mask_path)

        batch_imgs_pil.clear()
        batch_sizes.clear()
        batch_paths.clear()

    for p in tqdm(paths, desc="Processing images"):
        img_pil = load_rgb(p)
        H, W = img_pil.size[1], img_pil.size[0]  # (H, W)
        batch_imgs_pil.append(img_pil)
        batch_sizes.append((H, W))
        batch_paths.append(p)

        if len(batch_imgs_pil) == BATCH_SIZE:
            flush_batch()

    flush_batch()

print("Done.")