In [None]:
import os, glob, json, random
import numpy as np
import cv2
import openslide
from tqdm import tqdm

# =========================
# CONFIG
# =========================
KPI_DIR = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/KPI_challenge_wsi_data"   # contains normal_F*_wsi.tiff and normal_F*_mask.tiff
OUT_DIR = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/medsam_adaptor_train_data"

PATCH_SIZE = 1024
N_AUG_PER_INSTANCE = 4          # number of augmented versions per glomerulus crop
JITTER_FRAC = 0.25              # jitter up to 25% of patch size in x/y
MIN_MASK_PIXELS = 5000          # drop crops with too little glomerulus mask
MASK_THRESH = 0                 # mask > thresh => foreground

# speed: find bboxes at a coarse level
FIND_LEVEL = None               # None = use highest level automatically; or set int like 3

SEED = 0
random.seed(SEED)
np.random.seed(SEED)

os.makedirs(OUT_DIR, exist_ok=True)
IMG_OUT = os.path.join(OUT_DIR, "images")
MSK_OUT = os.path.join(OUT_DIR, "masks")
os.makedirs(IMG_OUT, exist_ok=True)
os.makedirs(MSK_OUT, exist_ok=True)


# =========================
# HELPERS: OpenSlide reading
# =========================
def read_region_rgb(slide, x, y, size, level=0):
    """Read RGB patch from OpenSlide."""
    im = slide.read_region((int(x), int(y)), level, (int(size), int(size))).convert("RGB")
    return np.array(im, dtype=np.uint8)

def read_region_mask(mask_slide, x, y, size, level=0):
    """
    Read mask patch. KPI mask TIFFs are often single-channel or RGB.
    We convert to uint8 single-channel and threshold.
    """
    im = mask_slide.read_region((int(x), int(y)), level, (int(size), int(size))).convert("RGB")
    arr = np.array(im, dtype=np.uint8)
    # convert to grayscale channel
    g = arr[..., 0]  # often identical across channels
    m = (g > MASK_THRESH).astype(np.uint8)
    return m

def clamp_xy(x, y, W, H, size):
    x = max(0, min(int(x), int(W - size)))
    y = max(0, min(int(y), int(H - size)))
    return x, y


# =========================
# FIND GLOMERULI BBOXES FROM MASK (FAST)
# =========================
def find_glomeruli_bboxes_from_mask(mask_slide, find_level=None, min_area_lvl=50):
    """
    Returns list of bboxes in LEVEL-0 coordinates: (x0,y0,x1,y1)
    Approach:
      - read mask at coarse level
      - threshold
      - connected components
      - map bbox back to level-0 using downsample factor
    """
    if find_level is None:
        lvl = mask_slide.level_count - 1  # coarsest
    else:
        lvl = int(find_level)

    Wl, Hl = mask_slide.level_dimensions[lvl]
    down = mask_slide.level_downsamples[lvl]

    # read whole mask at coarse level (small)
    m_rgb = mask_slide.read_region((0, 0), lvl, (Wl, Hl)).convert("RGB")
    m_np = np.array(m_rgb, dtype=np.uint8)[..., 0]
    m_bin = (m_np > MASK_THRESH).astype(np.uint8)

    # clean a bit (optional)
    m_bin = cv2.morphologyEx(m_bin * 255, cv2.MORPH_OPEN,
                            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)),
                            iterations=1)
    m_bin = (m_bin > 0).astype(np.uint8)

    num, lab, stats, _ = cv2.connectedComponentsWithStats(m_bin, connectivity=8)

    bboxes0 = []
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        if area < min_area_lvl:
            continue

        # map bbox to level-0 coordinates
        x0 = int(x * down)
        y0 = int(y * down)
        x1 = int((x + w) * down)
        y1 = int((y + h) * down)
        bboxes0.append((x0, y0, x1, y1))

    return bboxes0


# =========================
# AUGMENTATION (SAM-adaptor friendly)
# =========================
def aug_geometric(img, msk):
    """Random flips + 90-degree rotations."""
    # flip
    if random.random() < 0.5:
        img = img[:, ::-1].copy()
        msk = msk[:, ::-1].copy()
    if random.random() < 0.5:
        img = img[::-1, :].copy()
        msk = msk[::-1, :].copy()

    # rotate k*90
    k = random.randint(0, 3)
    if k:
        img = np.rot90(img, k).copy()
        msk = np.rot90(msk, k).copy()
    return img, msk

def aug_color_noise(img):
    """
    Strong appearance augmentations to learn stain-robust structure:
      - HSV jitter
      - gamma
      - brightness/contrast
      - blur
      - gaussian noise
    """
    out = img.astype(np.float32)

    # HSV jitter
    if random.random() < 0.9:
        hsv = cv2.cvtColor(out.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
        h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
        h = (h + random.uniform(-12, 12)) % 180
        s = np.clip(s * random.uniform(0.6, 1.6), 0, 255)
        v = np.clip(v * random.uniform(0.7, 1.4), 0, 255)
        hsv = np.stack([h, s, v], axis=-1).astype(np.uint8)
        out = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float32)

    # brightness/contrast
    if random.random() < 0.9:
        alpha = random.uniform(0.7, 1.5)     # contrast
        beta = random.uniform(-25, 25)       # brightness
        out = np.clip(alpha * out + beta, 0, 255)

    # gamma
    if random.random() < 0.7:
        gamma = random.uniform(0.7, 1.6)
        out = 255.0 * np.power(out / 255.0, gamma)
        out = np.clip(out, 0, 255)

    # blur
    if random.random() < 0.3:
        k = random.choice([3, 5])
        out = cv2.GaussianBlur(out, (k, k), 0)

    # gaussian noise
    if random.random() < 0.5:
        sigma = random.uniform(2, 10)
        noise = np.random.normal(0, sigma, size=out.shape).astype(np.float32)
        out = np.clip(out + noise, 0, 255)

    return out.astype(np.uint8)

def augment_pair(img, msk):
    img2, msk2 = aug_geometric(img, msk)
    img2 = aug_color_noise(img2)
    return img2, msk2


# =========================
# MAIN EXTRACTION
# =========================
def paired_paths_from_kpi_dir(kpi_dir):
    wsi_paths = sorted(glob.glob(os.path.join(kpi_dir, "*_wsi.tiff")))
    pairs = []
    for wsi in wsi_paths:
        mask = wsi.replace("_wsi.tiff", "_mask.tiff")
        if os.path.exists(mask):
            pairs.append((wsi, mask))
    return pairs

def extract_kpi_patches():
    pairs = paired_paths_from_kpi_dir(KPI_DIR)
    assert pairs, f"No KPI (*_wsi.tiff + *_mask.tiff) pairs found in: {KPI_DIR}"

    meta = []
    idx = 0

    for wsi_path, mask_path in pairs:
        base = os.path.basename(wsi_path).replace("_wsi.tiff", "")
        print(f"\n=== {base} ===")
        wsi = openslide.OpenSlide(wsi_path)
        msk_slide = openslide.OpenSlide(mask_path)

        W0, H0 = wsi.level_dimensions[0]
        bboxes = find_glomeruli_bboxes_from_mask(msk_slide, find_level=FIND_LEVEL, min_area_lvl=50)
        print(f"Found {len(bboxes)} glomerulus components (coarse bboxes)")

        for gi, (x0, y0, x1, y1) in enumerate(tqdm(bboxes, desc=f"{base} gloms")):
            # center of bbox in level-0
            cx = 0.5 * (x0 + x1)
            cy = 0.5 * (y0 + y1)

            # random jitter so glomerulus not always centered
            j = int(JITTER_FRAC * PATCH_SIZE)
            jx = random.randint(-j, j)
            jy = random.randint(-j, j)

            px = int(cx - PATCH_SIZE / 2 + jx)
            py = int(cy - PATCH_SIZE / 2 + jy)
            px, py = clamp_xy(px, py, W0, H0, PATCH_SIZE)

            # read paired crop
            img = read_region_rgb(wsi, px, py, PATCH_SIZE, level=0)
            msk = read_region_mask(msk_slide, px, py, PATCH_SIZE, level=0)

            # filter out empties
            if int(msk.sum()) < MIN_MASK_PIXELS:
                continue

            # save ORIGINAL
            name = f"{base}_g{gi:05d}_i{idx:07d}_orig"
            img_path = os.path.join(IMG_OUT, name + ".png")
            msk_path = os.path.join(MSK_OUT, name + ".png")

            cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            cv2.imwrite(msk_path, (msk * 255).astype(np.uint8))

            meta.append({
                "id": name,
                "slide": base,
                "wsi_path": wsi_path,
                "mask_path": mask_path,
                "glm_index": gi,
                "x": px, "y": py,
                "size": PATCH_SIZE,
                "aug": "orig",
                "mask_pixels": int(msk.sum()),
                "img_path": img_path,
                "msk_path": msk_path,
            })
            idx += 1

            # save AUGMENTED versions
            for ai in range(N_AUG_PER_INSTANCE):
                img_a, msk_a = augment_pair(img, msk)

                name = f"{base}_g{gi:05d}_i{idx:07d}_aug{ai}"
                img_path = os.path.join(IMG_OUT, name + ".png")
                msk_path = os.path.join(MSK_OUT, name + ".png")

                cv2.imwrite(img_path, cv2.cvtColor(img_a, cv2.COLOR_RGB2BGR))
                cv2.imwrite(msk_path, (msk_a * 255).astype(np.uint8))

                meta.append({
                    "id": name,
                    "slide": base,
                    "wsi_path": wsi_path,
                    "mask_path": mask_path,
                    "glm_index": gi,
                    "x": px, "y": py,
                    "size": PATCH_SIZE,
                    "aug": f"aug{ai}",
                    "mask_pixels": int(msk_a.sum()),
                    "img_path": img_path,
                    "msk_path": msk_path,
                })
                idx += 1

        wsi.close()
        msk_slide.close()

    # write metadata
    meta_path = os.path.join(OUT_DIR, "meta.json")
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)
    print(f"\nSaved {len(meta)} patches (+augs) to: {OUT_DIR}")
    print(f"Metadata: {meta_path}")


if __name__ == "__main__":
    extract_kpi_patches()


In [4]:
import os, json, glob
import numpy as np
import cv2
import openslide

# ---------- helpers ----------
def eq_diameter_from_area(area_px):
    # diameter of a circle with same area
    return float(np.sqrt(4.0 * area_px / np.pi))

def kpi_glom_diams_px(mask_path, find_level=None, min_area_lvl=50):
    ms = openslide.OpenSlide(mask_path)
    if find_level is None:
        lvl = ms.level_count - 1
    else:
        lvl = int(find_level)

    Wl, Hl = ms.level_dimensions[lvl]
    down = float(ms.level_downsamples[lvl])

    m = np.array(ms.read_region((0,0), lvl, (Wl,Hl)).convert("RGB"), dtype=np.uint8)[...,0]
    mb = (m > 0).astype(np.uint8)

    num, lab, stats, _ = cv2.connectedComponentsWithStats(mb, connectivity=8)

    diams = []
    for i in range(1, num):
        area_lvl = int(stats[i, cv2.CC_STAT_AREA])
        if area_lvl < min_area_lvl:
            continue
        # area at level-0 scales by down^2
        area0 = area_lvl * (down * down)
        diams.append(eq_diameter_from_area(area0))

    ms.close()
    return diams

def ndpi_glom_diams_px_from_geojson(geojson_path):
    with open(geojson_path, "r") as f:
        geo = json.load(f)
    diams = []
    for feat in geo.get("features", []):
        coords = np.array(feat["geometry"]["coordinates"][0], dtype=np.float32)
        # quick polygon area via contourArea
        c = coords.reshape(-1,1,2).astype(np.float32)
        area = float(cv2.contourArea(c))
        if area <= 0:
            continue
        diams.append(eq_diameter_from_area(area))
    return diams

# ---------- compute scale ----------
KPI_DIR = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/KPI_challenge_wsi_data"
NDPI_GEOJSON = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/final_glomeruli_segmentation_pipeline/out/sam_all_features_merged.geojson"

kpi_masks = sorted(glob.glob(os.path.join(KPI_DIR, "*_mask.tiff")))
kpi_d = []
for mp in kpi_masks:
    kpi_d.extend(kpi_glom_diams_px(mp, find_level=None))

ndpi_d = ndpi_glom_diams_px_from_geojson(NDPI_GEOJSON)

kpi_med = np.median(kpi_d)
ndpi_med = np.median(ndpi_d)

scale = ndpi_med / kpi_med
print("KPI median diam(px):", kpi_med)
print("NDPI median diam(px):", ndpi_med)
print("Suggested KPI -> NDPI scale factor:", scale)


KPI median diam(px): 472.6992991577287
NDPI median diam(px): 293.4284882788751
Suggested KPI -> NDPI scale factor: 0.6207508426640693


In [1]:
import os, glob, json, random
import numpy as np
import cv2
import openslide
from tqdm import tqdm

# =========================
# CONFIG
# =========================
KPI_DIR = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/KPI_challenge_wsi_data"
OUTDIR  = "/Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/kpi_glom_patches_for_sam_adaptor"

PATCH_SIZE = 1024                 # final saved size
N_AUG_PER_GLOM = 7                # how many augmented variants per glomerulus (in addition to base)
SCALE_MATCH = 0.6207508426640693  # KPI -> NDPI scale factor you computed

# Crop sampling
PAD_FRAC = 0.35                   # extra context around bbox (as fraction of max(bw,bh))
JITTER_FRAC = 0.25                # random center jitter relative to crop size
MIN_AREA0 = 2000                  # min glomerulus area (px at level0) to keep

# Mask reading level (use a downsampled level for speed)
MASK_LEVEL = None                 # None => use coarsest level

# =========================
# IO helpers
# =========================
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

img_dir = os.path.join(OUTDIR, "images")
msk_dir = os.path.join(OUTDIR, "masks")
ensure_dir(img_dir); ensure_dir(msk_dir)

# =========================
# Core utilities
# =========================
def read_slide_region_rgb(slide, x, y, w, h, level=0):
    """Read RGB patch from openslide at given level."""
    patch = slide.read_region((int(x), int(y)), int(level), (int(w), int(h))).convert("RGB")
    return np.asarray(patch, dtype=np.uint8)

def read_mask_level(mask_slide, level=None):
    """Read full mask at given OpenSlide level as uint8 (single channel)."""
    if level is None:
        level = mask_slide.level_count - 1
    W, H = mask_slide.level_dimensions[level]
    m = mask_slide.read_region((0, 0), level, (W, H)).convert("RGB")
    m = np.asarray(m, dtype=np.uint8)[..., 0]
    return m, level

def connected_components_from_mask_level0(mask_slide, min_area0=2000, level=None):
    """
    Find connected components in mask at a chosen level, but return bboxes in level-0 coords.
    """
    mL, lvl = read_mask_level(mask_slide, level=level)
    down = float(mask_slide.level_downsamples[lvl])

    mb = (mL > 0).astype(np.uint8)
    num, lab, stats, _ = cv2.connectedComponentsWithStats(mb, connectivity=8)

    comps = []
    for i in range(1, num):
        x, y, w, h, areaL = stats[i]
        if areaL <= 0:
            continue
        area0 = areaL * (down * down)
        if area0 < min_area0:
            continue
        # bbox in level-0 coords
        x0 = int(round(x * down))
        y0 = int(round(y * down))
        x1 = int(round((x + w) * down))
        y1 = int(round((y + h) * down))
        comps.append({"bbox0": (x0, y0, x1, y1), "area0": float(area0)})
    return comps

def crop_box_with_jitter(bbox0, slide_w, slide_h, patch_size, pad_frac=0.35, jitter_frac=0.25):
    """
    Build a square crop around bbox0 with padding + random center jitter.
    Returns (x0,y0,size,size) in level-0 coords.
    """
    x0, y0, x1, y1 = bbox0
    bw = max(1, x1 - x0)
    bh = max(1, y1 - y0)
    base = int(round(max(bw, bh) * (1.0 + pad_frac)))
    base = max(base, int(patch_size / SCALE_MATCH))  # ensure enough pixels before downscaling

    cx = (x0 + x1) / 2.0
    cy = (y0 + y1) / 2.0

    jitter = jitter_frac * base
    cx += random.uniform(-jitter, jitter)
    cy += random.uniform(-jitter, jitter)

    # top-left
    xs = int(round(cx - base / 2))
    ys = int(round(cy - base / 2))

    # clamp
    xs = max(0, min(xs, slide_w - base))
    ys = max(0, min(ys, slide_h - base))

    return xs, ys, base, base

def rescale_to_match(img_rgb_u8, mask_u8, scale, out_size=1024):
    """
    Rescale img+mask by `scale` then center crop/pad to out_size.
    """
    H, W = img_rgb_u8.shape[:2]
    newW = max(2, int(round(W * scale)))
    newH = max(2, int(round(H * scale)))

    img2 = cv2.resize(img_rgb_u8, (newW, newH), interpolation=cv2.INTER_LINEAR)
    m2   = cv2.resize(mask_u8.astype(np.uint8), (newW, newH), interpolation=cv2.INTER_NEAREST)

    def center_crop_or_pad(arr, out_size, is_mask=False):
        h, w = arr.shape[:2]
        if h >= out_size and w >= out_size:
            y0 = (h - out_size) // 2
            x0 = (w - out_size) // 2
            return arr[y0:y0+out_size, x0:x0+out_size]
        top = max(0, (out_size - h)//2)
        bottom = out_size - h - top
        left = max(0, (out_size - w)//2)
        right = out_size - w - left
        val = 0 if is_mask else (0,0,0)
        return cv2.copyMakeBorder(arr, top, bottom, left, right, cv2.BORDER_CONSTANT, value=val)

    img3 = center_crop_or_pad(img2, out_size, is_mask=False)
    m3   = center_crop_or_pad(m2, out_size, is_mask=True)
    m3   = (m3 > 0).astype(np.uint8)
    return img3, m3

# =========================
# Augmentations (stain-robust, PAS-friendly)
# =========================
def random_color_aug(rgb):
    x = rgb.astype(np.float32)

    # brightness/contrast
    if random.random() < 0.9:
        a = random.uniform(0.75, 1.35)   # contrast
        b = random.uniform(-25, 25)      # brightness
        x = x * a + b

    # gamma
    if random.random() < 0.7:
        g = random.uniform(0.75, 1.45)
        x = 255.0 * np.power(np.clip(x/255.0, 0, 1), g)

    # HSV jitter (stronger for stain robustness)
    if random.random() < 0.8:
        hsv = cv2.cvtColor(np.clip(x,0,255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
        hsv[...,0] = (hsv[...,0] + random.uniform(-12, 12)) % 180
        hsv[...,1] *= random.uniform(0.70, 1.40)
        hsv[...,2] *= random.uniform(0.80, 1.30)
        x = cv2.cvtColor(np.clip(hsv,0,255).astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32)

    # channel scaling
    if random.random() < 0.8:
        scales = np.array([random.uniform(0.8, 1.25),
                           random.uniform(0.8, 1.25),
                           random.uniform(0.8, 1.25)], dtype=np.float32)
        x *= scales[None,None,:]

    # slight blur or noise
    if random.random() < 0.35:
        k = random.choice([3,5])
        x = cv2.GaussianBlur(np.clip(x,0,255).astype(np.uint8), (k,k), 0).astype(np.float32)
    if random.random() < 0.35:
        n = np.random.normal(0, random.uniform(2, 8), size=x.shape).astype(np.float32)
        x = x + n

    return np.clip(x, 0, 255).astype(np.uint8)

def random_geom_aug(rgb, mask):
    # flips
    if random.random() < 0.5:
        rgb = np.flip(rgb, axis=1).copy()
        mask = np.flip(mask, axis=1).copy()
    if random.random() < 0.5:
        rgb = np.flip(rgb, axis=0).copy()
        mask = np.flip(mask, axis=0).copy()

    # rotate 0/90/180/270
    k = random.randint(0, 3)
    if k:
        rgb = np.rot90(rgb, k).copy()
        mask = np.rot90(mask, k).copy()

    # mild random scale (multi-scale robustness)
    if random.random() < 0.8:
        s = random.uniform(0.85, 1.20)
        H, W = rgb.shape[:2]
        newW = max(2, int(round(W * s)))
        newH = max(2, int(round(H * s)))
        rgb2 = cv2.resize(rgb, (newW, newH), interpolation=cv2.INTER_LINEAR)
        m2   = cv2.resize(mask, (newW, newH), interpolation=cv2.INTER_NEAREST)

        # center crop/pad back
        rgb, mask = rescale_to_match(rgb2, m2, scale=1.0, out_size=H)

    return rgb, mask

# =========================
# Main extraction
# =========================
def extract_kpi_glomeruli_patches():
    wsi_paths = sorted(glob.glob(os.path.join(KPI_DIR, "*_wsi.tiff")))
    assert wsi_paths, f"No *_wsi.tiff found in {KPI_DIR}"

    meta_path = os.path.join(OUTDIR, "meta.jsonl")
    meta_f = open(meta_path, "w")

    out_count = 0

    for wsi_path in wsi_paths:
        base = os.path.basename(wsi_path).replace("_wsi.tiff", "")
        mask_path = os.path.join(KPI_DIR, f"{base}_mask.tiff")
        if not os.path.exists(mask_path):
            print("Missing mask for", wsi_path)
            continue

        slide = openslide.OpenSlide(wsi_path)
        mslide = openslide.OpenSlide(mask_path)

        sw, sh = slide.level_dimensions[0]
        comps = connected_components_from_mask_level0(mslide, min_area0=MIN_AREA0, level=MASK_LEVEL)

        print(f"{base}: {len(comps)} glomeruli comps")

        for gi, comp in enumerate(comps):
            bbox0 = comp["bbox0"]

            # crop a big square at level0 (pre-scale)
            xs, ys, ww, hh = crop_box_with_jitter(
                bbox0, sw, sh,
                patch_size=PATCH_SIZE,
                pad_frac=PAD_FRAC,
                jitter_frac=JITTER_FRAC,
            )

            img0 = read_slide_region_rgb(slide, xs, ys, ww, hh, level=0)

            # mask patch (read from mask slide level0)
            mw, mh = mslide.level_dimensions[0]
            # ensure mask slide matches wsi dims; if not, we still read at 0 and resize
            m0 = np.asarray(mslide.read_region((xs, ys), 0, (ww, hh)).convert("RGB"), dtype=np.uint8)[...,0]
            m0 = (m0 > 0).astype(np.uint8)

            # scale-match KPI -> NDPI zoom and bring to PATCH_SIZE
            img, m = rescale_to_match(img0, m0, scale=SCALE_MATCH, out_size=PATCH_SIZE)

            # base save
            uid = f"{base}_g{gi:05d}_a000"
            img_path = os.path.join(img_dir, uid + ".png")
            msk_path = os.path.join(msk_dir, uid + ".png")
            cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            cv2.imwrite(msk_path, (m * 255).astype(np.uint8))

            meta = {
                "id": uid,
                "wsi": os.path.basename(wsi_path),
                "mask": os.path.basename(mask_path),
                "bbox0": [int(v) for v in bbox0],
                "crop0_xywh": [int(xs), int(ys), int(ww), int(hh)],
                "scale_match": float(SCALE_MATCH),
                "aug": "none",
            }
            meta_f.write(json.dumps(meta) + "\n")
            out_count += 1

            # augmented variants
            for ai in range(1, N_AUG_PER_GLOM + 1):
                aug_img, aug_m = img.copy(), m.copy()
                aug_img, aug_m = random_geom_aug(aug_img, aug_m)
                aug_img = random_color_aug(aug_img)

                uid = f"{base}_g{gi:05d}_a{ai:03d}"
                img_path = os.path.join(img_dir, uid + ".png")
                msk_path = os.path.join(msk_dir, uid + ".png")
                cv2.imwrite(img_path, cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR))
                cv2.imwrite(msk_path, (aug_m * 255).astype(np.uint8))

                meta = {
                    "id": uid,
                    "wsi": os.path.basename(wsi_path),
                    "mask": os.path.basename(mask_path),
                    "bbox0": [int(v) for v in bbox0],
                    "crop0_xywh": [int(xs), int(ys), int(ww), int(hh)],
                    "scale_match": float(SCALE_MATCH),
                    "aug": "geom+color",
                }
                meta_f.write(json.dumps(meta) + "\n")
                out_count += 1

        slide.close()
        mslide.close()

    meta_f.close()
    print("\nDone. Wrote:", out_count, "pairs to", OUTDIR)
    print("Images:", img_dir)
    print("Masks :", msk_dir)
    print("Meta  :", meta_path)

# Run
if __name__ == "__main__":
    extract_kpi_glomeruli_patches()


normal_F1576: 386 glomeruli comps
normal_F1: 357 glomeruli comps
normal_F2: 270 glomeruli comps
normal_F3: 331 glomeruli comps
normal_F4: 444 glomeruli comps

Done. Wrote: 14304 pairs to /Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/kpi_glom_patches_for_sam_adaptor
Images: /Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/kpi_glom_patches_for_sam_adaptor/images
Masks : /Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/kpi_glom_patches_for_sam_adaptor/masks
Meta  : /Users/edmundtsou/Desktop/JEFworks/jefworks-structure_segmentation/data/kpi_glom_patches_for_sam_adaptor/meta.jsonl
