In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Single image: Boundary IoU (Euclidean distance, no tolerance by default) + output only the "unmatched near-boundary bands" difference map (red=FP, blue=FN)
import os
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

try:
    from google.colab import drive
    if not os.path.is_mount('/content/drive'):
        drive.mount('/content/drive', force_remount=False)
    print("Attempted to mount Google Drive (ignored if not in Colab).")
except Exception as e:
    print("Non-Colab environment; skipping mount.", e)

PRED_PATH = "/content/drive/MyDrive/inferencee/png/geb15_1000_sdxl_binary.png"
GT_PATH   = "/content/drive/MyDrive/inferencee/png/geb15_clip_4270_target.png"
OUT_PATH  = "/content/drive/MyDrive/inferencee/boundary_diff_only_biou_15k_sdxl.png"

INFORMAT  = "probs"   # "probs" for probabilities/mask in [0,1], "logit" for logits
BIN_THR   = 0.5

THETA0 = 3

TOL_RADIUS_PX = 28

BASE_FOR_CONTEXT = "gt"  # "gt", "pred", or "none"
BASE_FADE = 0.6
VIS_THICKEN_PX = 2

def load_gray_01(path, resize_hw=None):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Failed to read image: {path}")
    if resize_hw is not None:
        img = cv2.resize(img, resize_hw, interpolation=cv2.INTER_NEAREST)
    return (img.astype(np.float32) / 255.0)

def to_tensor01(a01_hw):
    return torch.from_numpy(a01_hw).unsqueeze(0).unsqueeze(0)

def extract_boundary(bin_mask, k):
    return F.max_pool2d(1 - bin_mask, kernel_size=k, stride=1,
                        padding=(k - 1)//2) - (1 - bin_mask)

def boundary_iou(pred_bin, gt_bin, theta0=3, tol_px=0):
    with torch.no_grad():
        P = (pred_bin[0,0].cpu().numpy() > 0.5).astype(np.uint8)
        G = (gt_bin  [0,0].cpu().numpy() > 0.5).astype(np.uint8)

        def _edge_from_mask(mask01, k):
            t = torch.from_numpy(mask01.astype(np.float32)).unsqueeze(0).unsqueeze(0)
            e = extract_boundary(t, k)[0,0].cpu().numpy()
            return (e > 0).astype(np.uint8)

        eP = _edge_from_mask(P, theta0)
        eG = _edge_from_mask(G, theta0)

        srcP = np.where(eP > 0, 0, 1).astype(np.uint8)
        srcG = np.where(eG > 0, 0, 1).astype(np.uint8)
        distP = cv2.distanceTransform(srcP, cv2.DIST_L2, 5)
        distG = cv2.distanceTransform(srcG, cv2.DIST_L2, 5)

        d = int(tol_px)
        bandP = ((distP <= d) & (P > 0)).astype(np.uint8)
        bandG = ((distG <= d) & (G > 0)).astype(np.uint8)

        inter = np.logical_and(bandP > 0, bandG > 0).sum()
        union = np.logical_or (bandP > 0, bandG > 0).sum()
        iou = 1.0 if union == 0 else inter / union

        stats = {
            "inter": int(inter),
            "union": int(union),
            "bandP": int(bandP.sum()),
            "bandG": int(bandG.sum()),
        }
        return iou, stats, eP, eG, bandP, bandG

def thicken(mask_bool, iters=2):
    if iters is None or iters <= 0:
        return mask_bool
    k = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
    return cv2.dilate(mask_bool.astype(np.uint8)*255, k, iterations=int(iters)) > 0

assert os.path.exists(PRED_PATH), f"Prediction image not found: {PRED_PATH}"
assert os.path.exists(GT_PATH),   f"GT image not found: {GT_PATH}"

gt01 = load_gray_01(GT_PATH)
H, W = gt01.shape
pred_raw = cv2.imread(PRED_PATH, cv2.IMREAD_GRAYSCALE)
if pred_raw is None:
    raise ValueError(f"Failed to read image: {PRED_PATH}")
if (pred_raw.shape[1], pred_raw.shape[0]) != (W, H):
    print(f"Prediction size differs from GT; resizing to {W}x{H}.")
    pred01 = load_gray_01(PRED_PATH, resize_hw=(W, H))
else:
    pred01 = pred_raw.astype(np.float32) / 255.0

gt_t   = to_tensor01(gt01)
pred_t = to_tensor01(pred01)
if INFORMAT == "logit":
    pred_t = torch.sigmoid(pred_t)

gt_bin   = (gt_t   > BIN_THR).float()
pred_bin = (pred_t > BIN_THR).float()

t = int(TOL_RADIUS_PX)
biou, stats, eP, eG, bandP, bandG = boundary_iou(pred_bin, gt_bin, theta0=THETA0, tol_px=t)

print(f"Params: theta0={THETA0} (boundary kernel), d={t}px (Euclidean tolerance)")
print(f"Boundary IoU (Euclidean d): {biou:.4f}")
print(f"counts: |band_pred|={stats['bandP']}, |band_gt|={stats['bandG']}, "
      f"|inter|={stats['inter']}, |union|={stats['union']}")

bfp = (bandP > 0) & (bandG == 0)
bfn = (bandG > 0) & (bandP == 0)

bfp_viz = thicken(bfp, VIS_THICKEN_PX)
bfn_viz = thicken(bfn, VIS_THICKEN_PX)

if BASE_FOR_CONTEXT == "gt":
    base = (gt01 * 255.0).astype(np.uint8)
elif BASE_FOR_CONTEXT == "pred":
    base = (pred01 * 255.0).astype(np.uint8)
else:
    base = np.full((H, W), 230, dtype=np.uint8)

FADE = float(np.clip(BASE_FADE, 0.0, 1.0))
base_faded = np.clip(base*(1-FADE) + 255*FADE, 0, 255).astype(np.uint8)
overlay = cv2.cvtColor(base_faded, cv2.COLOR_GRAY2BGR)

overlay[bfp_viz] = [0, 0, 255]
overlay[bfn_viz] = [255, 0, 0]

os.makedirs(os.path.dirname(OUT_PATH) or ".", exist_ok=True)
cv2.imwrite(OUT_PATH, overlay)
print(f"Saved difference map of unmatched near-boundary bands (Euclidean BIoU definition): {OUT_PATH}")

plt.figure(figsize=(12,7))
plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB), interpolation="nearest")
title_tol = "(strict d=0)" if t==0 else f"(d={t}px)"
plt.title(f"Boundary Non-overlap (Euclidean) {title_tol}  |  BIoU={biou:.4f}")
plt.axis("off")
plt.tight_layout()
plt.show()
