In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import albumentations as A

In [2]:
IMG_IN_DIR  = r"ISIC2018/images"   
MASK_IN_DIR = r"ISIC2018/masks"    


IMG_OUT_DIR  = r"ISIC2018/images_aug_20000"
MASK_OUT_DIR = r"ISIC2018/masks_aug_20000"   


TARGET_COUNT = 20000


SEED = 42
rng = np.random.default_rng(SEED)


os.makedirs(IMG_OUT_DIR, exist_ok=True)
if MASK_IN_DIR is not None:
    os.makedirs(MASK_OUT_DIR, exist_ok=True)


In [3]:
def list_image_ids(folder):
    exts = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
    ids = []
    for f in os.listdir(folder):
        base, ext = os.path.splitext(f)
        if ext in exts:
            ids.append(base)
    return sorted(ids)

def find_existing_image_path(img_id, folder):
    for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
        p = os.path.join(folder, img_id + ext)
        if os.path.exists(p):
            return p
    return None

def mask_path_from_id(img_id):
   
    return os.path.join(MASK_IN_DIR, img_id + "_segmentation.png")

def read_image_bgr(path):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {path}")
    return img

def read_mask_gray(path):
    m = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if m is None:
        raise FileNotFoundError(f"Could not read mask: {path}")
    return m

def safe_imwrite(path, img):
    ok = cv2.imwrite(path, img)
    if not ok:
        raise RuntimeError(f"Failed to write: {path}")


In [4]:
offline_aug = A.Compose(
    [
        A.Resize(256, 256),

        # geometry
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.RandomRotate90(p=0.4),
        A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.15, rotate_limit=25,
                           border_mode=cv2.BORDER_REFLECT_101, p=0.7),

      
        A.ColorJitter(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.HueSaturationValue(p=0.35),

        
        A.GaussianBlur(p=0.15),
        A.GaussNoise(p=0.25),
        A.ISONoise(p=0.15),


        A.ImageCompression(quality_lower=50, quality_upper=95, p=0.25),
    ]
)


In [5]:
image_ids = list_image_ids(IMG_IN_DIR)
if len(image_ids) == 0:
    raise ValueError(f"No images found in: {IMG_IN_DIR}")

print("Original images:", len(image_ids))

for i, img_id in enumerate(image_ids):
    img_path = find_existing_image_path(img_id, IMG_IN_DIR)
    if img_path is None:
        continue

    img = read_image_bgr(img_path)

    mpath = mask_path_from_id(img_id)
    mask = read_mask_gray(mpath)

    
    safe_imwrite(os.path.join(IMG_OUT_DIR, img_id + ".jpg"), img)
    safe_imwrite(os.path.join(MASK_OUT_DIR, img_id + ".png"), mask)

print("✅ Copied originals to output.")


Original images: 2594
✅ Copied originals to output.


In [6]:
existing_out = list_image_ids(IMG_OUT_DIR)
current = len(existing_out)

need = max(0, TARGET_COUNT - current)
print(f"Current in output: {current} | Need to generate: {need}")

if need == 0:
    print("✅ Output already at or above target.")
else:
    for k in range(need):
        base_id = image_ids[int(rng.integers(0, len(image_ids)))]

        img_path = find_existing_image_path(base_id, IMG_IN_DIR)
        if img_path is None:
            continue

        img = read_image_bgr(img_path)

        mpath = mask_path_from_id(base_id)
        mask = read_mask_gray(mpath)

        augmented = offline_aug(image=img, mask=mask)
        aug_img = augmented["image"]
        aug_msk = augmented["mask"]

      
        _, aug_msk = cv2.threshold(aug_msk, 127, 255, cv2.THRESH_BINARY)

        new_id = f"{base_id}_aug_{k:06d}"

        safe_imwrite(os.path.join(IMG_OUT_DIR, new_id + ".jpg"), aug_img)
        safe_imwrite(os.path.join(MASK_OUT_DIR, new_id + ".png"), aug_msk)

print("✅ Done generating augmentations.")


Current in output: 2594 | Need to generate: 17406
✅ Done generating augmentations.


In [7]:
final_imgs = len(list_image_ids(IMG_OUT_DIR))
final_msks = len(list_image_ids(MASK_OUT_DIR))

print("Final images:", final_imgs)
print("Final masks :", final_msks)

if final_imgs != final_msks:
    print("⚠️ WARNING: image/mask count mismatch. Check failed writes or missing masks.")
else:
    print("✅ Image/mask counts match.")


Final images: 20000
Final masks : 20000
✅ Image/mask counts match.
