In [3]:
import os
import glob
import random
from pathlib import Path
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split

# ---------------- CONFIG ----------------
DATA_DIR = "Normal_masks"   # your folder with *_img.png and *_mask.png
OUT_DIR = "dataset_splits (normal)"
IMG_SIZE = (256, 256)       # (H, W)
SEED = 42
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
# ----------------------------------------

random.seed(SEED)
np.random.seed(SEED)
os.makedirs(OUT_DIR, exist_ok=True)

# --- 1. Collect and check pairs ---
img_files = sorted(glob.glob(os.path.join(DATA_DIR, "*_img.png")))
mask_files = sorted(glob.glob(os.path.join(DATA_DIR, "*_mask.png")))

img_basenames = {Path(f).stem.replace("_img", "") for f in img_files}
mask_basenames = {Path(f).stem.replace("_mask", "") for f in mask_files}

missing_masks = img_basenames - mask_basenames
missing_imgs = mask_basenames - img_basenames

print("=== Dataset Check ===")
print(f"Total images found: {len(img_files)}")
print(f"Total masks found: {len(mask_files)}")

if missing_masks:
    print(f"⚠️ Missing masks for: {missing_masks}")
if missing_imgs:
    print(f"⚠️ Missing images for: {missing_imgs}")
if not missing_masks and not missing_imgs:
    print("✅ Every image has a corresponding mask.")

# --- 2. Build valid pairs list ---
pairs = []
for img in img_files:
    base = Path(img).stem.replace("_img", "")
    mask = os.path.join(DATA_DIR, base + "_mask.png")
    if os.path.exists(mask):
        pairs.append((img, mask))

print(f"Total valid pairs: {len(pairs)}")

# --- 3. Group by patient ID ---
def get_patient_id(path):
    name = Path(path).stem
    return "_".join(name.split("_")[:-2])  # remove slice info

patient_to_pairs = {}
for img, mask in pairs:
    pid = get_patient_id(img)
    patient_to_pairs.setdefault(pid, []).append((img, mask))

patients = list(patient_to_pairs.keys())
print(f"Unique patients: {len(patients)}")

# --- 4. Train/Val/Test split ---
train_pat, test_pat = train_test_split(patients, test_size=TEST_SPLIT, random_state=SEED)
train_pat, val_pat = train_test_split(train_pat, test_size=VAL_SPLIT/(1-TEST_SPLIT), random_state=SEED)

splits = {"train": train_pat, "val": val_pat, "test": test_pat}
for s in splits:
    os.makedirs(os.path.join(OUT_DIR, s, "images"), exist_ok=True)
    os.makedirs(os.path.join(OUT_DIR, s, "masks"), exist_ok=True)

print(f"Split: Train={len(train_pat)}, Val={len(val_pat)}, Test={len(test_pat)} patients")

# --- 5. Preprocess and save ---
def preprocess_and_save(img_path, mask_path, split):
    # --- Load image + mask ---
    img = Image.open(img_path).convert("L")   # grayscale
    mask = Image.open(mask_path).convert("L") # grayscale (but stores labels)

    # --- Resize (bilinear for image, nearest for mask) ---
    img = img.resize((IMG_SIZE[1], IMG_SIZE[0]), Image.BILINEAR)
    mask = mask.resize((IMG_SIZE[1], IMG_SIZE[0]), Image.NEAREST)

    # --- Normalize image to [0,1] ---
    img_np = np.array(img).astype(np.float32)
    img_norm = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
    img_out = Image.fromarray((img_norm * 255).astype(np.uint8))  # keep as PNG-friendly

    # --- Map multi-class mask pixel values to class indices (0..3) ---
    mask_np = np.array(mask).astype(np.uint8)
    mapping = {63: 0, 126: 1, 189: 2, 252: 3}
    mask_mapped = np.zeros_like(mask_np, dtype=np.uint8)
    for raw_val, class_idx in mapping.items():
        mask_mapped[mask_np == raw_val] = class_idx

    # --- Sanity check (skip if empty) ---
    unique_vals = np.unique(mask_mapped)
    if unique_vals.size <= 1:
        print(f"⚠️ Skipped empty/invalid mask {mask_path}, unique values: {unique_vals}")
        return False

    mask_out = Image.fromarray(mask_mapped)

    # --- Save original image + numeric mask ---
    base = Path(img_path).stem.replace("_img", "")
    img_out.save(os.path.join(OUT_DIR, split, "images", base + ".png"))
    mask_out.save(os.path.join(OUT_DIR, split, "masks", base + ".png"))

    # --- Extra: Save visualization mask ---
    colors = {
        0: (0, 0, 0),         # black (background)
        1: (255, 0, 0),       # red
        2: (0, 255, 0),       # green
        3: (0, 0, 255)        # blue
    }
    mask_vis = np.zeros((mask_mapped.shape[0], mask_mapped.shape[1], 3), dtype=np.uint8)
    for cls, col in colors.items():
        mask_vis[mask_mapped == cls] = col
    mask_vis_out = Image.fromarray(mask_vis)
    mask_vis_out.save(os.path.join(OUT_DIR, split, "masks", base + "_vis.png"))

    return True

# process all
for split, plist in splits.items():
    print(f"Processing {split}...")
    for pid in plist:
        for img_path, mask_path in patient_to_pairs[pid]:
            ok = preprocess_and_save(img_path, mask_path, split)
            if not ok:
                print(f"⚠️ Skipped empty mask: {mask_path}")

print("✅ Dataset preparation complete. Saved in:", OUT_DIR)

=== Dataset Check ===
Total images found: 8840
Total masks found: 8840
✅ Every image has a corresponding mask.
Total valid pairs: 8840
Unique patients: 115
Split: Train=79, Val=18, Test=18 patients
Processing train...
Processing val...
Processing test...
✅ Dataset preparation complete. Saved in: dataset_splits (normal)


In [4]:
import os
import glob
import random
from pathlib import Path
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split

# ---------------- CONFIG ----------------
DATA_DIR = "AMD_masks"   # your folder with *_img.png and *_mask.png
OUT_DIR = "dataset_splits (AMD)"
IMG_SIZE = (256, 256)       # (H, W)
SEED = 42
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
# ----------------------------------------

random.seed(SEED)
np.random.seed(SEED)
os.makedirs(OUT_DIR, exist_ok=True)

# --- 1. Collect and check pairs ---
img_files = sorted(glob.glob(os.path.join(DATA_DIR, "*_img.png")))
mask_files = sorted(glob.glob(os.path.join(DATA_DIR, "*_mask.png")))

img_basenames = {Path(f).stem.replace("_img", "") for f in img_files}
mask_basenames = {Path(f).stem.replace("_mask", "") for f in mask_files}

missing_masks = img_basenames - mask_basenames
missing_imgs = mask_basenames - img_basenames

print("=== Dataset Check ===")
print(f"Total images found: {len(img_files)}")
print(f"Total masks found: {len(mask_files)}")

if missing_masks:
    print(f"⚠️ Missing masks for: {missing_masks}")
if missing_imgs:
    print(f"⚠️ Missing images for: {missing_imgs}")
if not missing_masks and not missing_imgs:
    print("✅ Every image has a corresponding mask.")

# --- 2. Build valid pairs list ---
pairs = []
for img in img_files:
    base = Path(img).stem.replace("_img", "")
    mask = os.path.join(DATA_DIR, base + "_mask.png")
    if os.path.exists(mask):
        pairs.append((img, mask))

print(f"Total valid pairs: {len(pairs)}")

# --- 3. Group by patient ID ---
def get_patient_id(path):
    name = Path(path).stem
    return "_".join(name.split("_")[:-2])  # remove slice info

patient_to_pairs = {}
for img, mask in pairs:
    pid = get_patient_id(img)
    patient_to_pairs.setdefault(pid, []).append((img, mask))

patients = list(patient_to_pairs.keys())
print(f"Unique patients: {len(patients)}")

# --- 4. Train/Val/Test split ---
train_pat, test_pat = train_test_split(patients, test_size=TEST_SPLIT, random_state=SEED)
train_pat, val_pat = train_test_split(train_pat, test_size=VAL_SPLIT/(1-TEST_SPLIT), random_state=SEED)

splits = {"train": train_pat, "val": val_pat, "test": test_pat}
for s in splits:
    os.makedirs(os.path.join(OUT_DIR, s, "images"), exist_ok=True)
    os.makedirs(os.path.join(OUT_DIR, s, "masks"), exist_ok=True)

print(f"Split: Train={len(train_pat)}, Val={len(val_pat)}, Test={len(test_pat)} patients")

# --- 5. Preprocess and save ---
def preprocess_and_save(img_path, mask_path, split):
    # --- Load image + mask ---
    img = Image.open(img_path).convert("L")   # grayscale
    mask = Image.open(mask_path).convert("L") # grayscale (but stores labels)

    # --- Resize (bilinear for image, nearest for mask) ---
    img = img.resize((IMG_SIZE[1], IMG_SIZE[0]), Image.BILINEAR)
    mask = mask.resize((IMG_SIZE[1], IMG_SIZE[0]), Image.NEAREST)

    # --- Normalize image to [0,1] ---
    img_np = np.array(img).astype(np.float32)
    img_norm = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
    img_out = Image.fromarray((img_norm * 255).astype(np.uint8))  # keep as PNG-friendly

    # --- Map multi-class mask pixel values to class indices (0..3) ---
    mask_np = np.array(mask).astype(np.uint8)
    mapping = {63: 0, 126: 1, 189: 2, 252: 3}
    mask_mapped = np.zeros_like(mask_np, dtype=np.uint8)
    for raw_val, class_idx in mapping.items():
        mask_mapped[mask_np == raw_val] = class_idx

    # --- Sanity check (skip if empty) ---
    unique_vals = np.unique(mask_mapped)
    if unique_vals.size <= 1:
        print(f"⚠️ Skipped empty/invalid mask {mask_path}, unique values: {unique_vals}")
        return False

    mask_out = Image.fromarray(mask_mapped)

    # --- Save original image + numeric mask ---
    base = Path(img_path).stem.replace("_img", "")
    img_out.save(os.path.join(OUT_DIR, split, "images", base + ".png"))
    mask_out.save(os.path.join(OUT_DIR, split, "masks", base + ".png"))

    # --- Extra: Save visualization mask ---
    colors = {
        0: (0, 0, 0),         # black (background)
        1: (255, 0, 0),       # red
        2: (0, 255, 0),       # green
        3: (0, 0, 255)        # blue
    }
    mask_vis = np.zeros((mask_mapped.shape[0], mask_mapped.shape[1], 3), dtype=np.uint8)
    for cls, col in colors.items():
        mask_vis[mask_mapped == cls] = col
    mask_vis_out = Image.fromarray(mask_vis)
    mask_vis_out.save(os.path.join(OUT_DIR, split, "masks", base + "_vis.png"))

    return True

# process all
for split, plist in splits.items():
    print(f"Processing {split}...")
    for pid in plist:
        for img_path, mask_path in patient_to_pairs[pid]:
            ok = preprocess_and_save(img_path, mask_path, split)
            if not ok:
                print(f"⚠️ Skipped empty mask: {mask_path}")

print("✅ Dataset preparation complete. Saved in:", OUT_DIR)

=== Dataset Check ===
Total images found: 20514
Total masks found: 20514
✅ Every image has a corresponding mask.
Total valid pairs: 20514
Unique patients: 269
Split: Train=187, Val=41, Test=41 patients
Processing train...
Processing val...
Processing test...
✅ Dataset preparation complete. Saved in: dataset_splits (AMD)
