In [None]:
from google.colab import drive
drive.mount('/content/drive')



In [None]:
import os
import numpy as np
import cv2
from scipy.ndimage import distance_transform_edt
import matplotlib.pyplot as plt

PRED_PATH = "/content/drive/MyDrive/inferencee/png/geb25_1000_sdxl_binary.png"
GT_PATH   = "/content/drive/MyDrive/inferencee/png/geb25_clip_4270_target.png"
OUT_DIR   = "/content/drive/MyDrive/inferencee/Bdis_25k_sdxl"
IMG_NAME  = "example25"

USE_POSITIVE_AS_ONE = False
BIN_THR = 0.5

def distance_field(binary_mask_u8):
    """
    Compute unsigned Euclidean distance to the GT boundary:
    - Input a binary mask (0/1 or 0/255), convert to bool.
    - Use edt(inv_mask) to get distance to GT interior/boundary.
    """
    binary_mask = (binary_mask_u8 > 0)
    inv_mask = ~binary_mask
    dist_map = distance_transform_edt(inv_mask)
    return dist_map.astype(np.float32)

def to_binary_u8(img_gray):
    if USE_POSITIVE_AS_ONE:
        return (img_gray > 0).astype(np.uint8)
    else:
        if img_gray.max() > 1.0:
            x = img_gray.astype(np.float32) / 255.0
        else:
            x = img_gray.astype(np.float32)
        return (x > BIN_THR).astype(np.uint8)

def compute_bnddis_single(pred_img, gt_img, return_intermediates=False):
    """
    Compute BndDis for a single image pair:
      diff = pred XOR gt
      dist = edt(~gt) → np.where(dist>0, 1/dist, 0)
      bdis = mean( diff * dist )
    """
    pred_mask = to_binary_u8(pred_img)
    gt_mask   = to_binary_u8(gt_img)

    if pred_mask.sum() == 0 or gt_mask.sum() == 0:
        return None if not return_intermediates else (None, pred_mask, gt_mask, None, None)

    dist_map = distance_field(gt_mask)
    recip_map = np.where(dist_map > 0, 1.0 / (dist_map + 1e-12), 0.0)

    diff = (pred_mask ^ gt_mask).astype(np.float32)

    bdis = float((diff * recip_map).mean())

    if return_intermediates:
        return bdis, pred_mask, gt_mask, diff, recip_map
    return bdis

def visualize_single(pred_mask, gt_mask, diff, recip_map, title, save_dir, save_name):
    os.makedirs(save_dir, exist_ok=True)
    fig, axs = plt.subplots(2, 2, figsize=(8, 8))
    axs[0,0].imshow(pred_mask, cmap='gray'); axs[0,0].set_title('Prediction'); axs[0,0].axis('off')
    axs[0,1].imshow(gt_mask,   cmap='gray'); axs[0,1].set_title('Ground Truth'); axs[0,1].axis('off')
    axs[1,0].imshow(diff,      cmap='gray'); axs[1,0].set_title('Difference (XOR)'); axs[1,0].axis('off')
    axs[1,1].imshow(recip_map, cmap='jet');  axs[1,1].set_title('Reciprocal Dist to GT boundary'); axs[1,1].axis('off')

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    out_path = os.path.join(save_dir, f"{save_name}.png")
    plt.savefig(out_path, dpi=150)
    plt.show()
    print(f"Visualization saved: {out_path}")

if __name__ == "__main__":
    assert os.path.exists(PRED_PATH), f"Prediction image does not exist: {PRED_PATH}"
    assert os.path.exists(GT_PATH),   f"GT image does not exist: {GT_PATH}"

    gt_img = cv2.imread(GT_PATH, cv2.IMREAD_GRAYSCALE)
    pred_img_raw = cv2.imread(PRED_PATH, cv2.IMREAD_GRAYSCALE)
    if gt_img is None or pred_img_raw is None:
        raise RuntimeError("Failed to read images. Please check the paths.")

    H, W = gt_img.shape
    if pred_img_raw.shape != gt_img.shape:
        print(f"ℹ️ Size mismatch: pred={pred_img_raw.shape[::-1]}, gt={gt_img.shape[::-1]}. Resizing pred to {W}x{H}")
        pred_img = cv2.resize(pred_img_raw, (W, H), interpolation=cv2.INTER_NEAREST)
    else:
        pred_img = pred_img_raw

    out = compute_bnddis_single(pred_img, gt_img, return_intermediates=True)
    bdis, pred_mask, gt_mask, diff, recip_map = out

    if bdis is None:
        print("Prediction or GT mask is empty; cannot compute BndDis.")
    else:
        print(f"BndDis (XOR + reciprocal distance to GT) = {bdis:.6f}")
        visualize_single(pred_mask, gt_mask, diff, recip_map,
                         title=f"{IMG_NAME} | BndDis={bdis:.6f}",
                         save_dir=OUT_DIR,
                         save_name=f"{IMG_NAME}_bdis_{bdis:.6f}")
