In [10]:
from PIL import Image
import numpy as np, glob, os, shutil

from torchvision import transforms

def add_noise(img, sigma):
    arr = np.array(img).astype(np.float32)
    noisy = np.clip(arr + np.random.randn(*arr.shape)*sigma, 0, 255).astype(np.uint8)
    return Image.fromarray(noisy)

def add_salt_and_pepper(img, amount=0.05):
    arr = np.array(img)
    mask = np.random.rand(*arr.shape[:2]) < amount
    arr[mask] = 0  # salt = 0, pepper = 255
    mask2 = np.random.rand(*arr.shape[:2]) < amount
    arr[mask2] = 255
    return Image.fromarray(arr)

def add_color_jitter(img, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1):
    transform = transforms.ColorJitter(brightness=brightness, contrast=contrast,
                                       saturation=saturation, hue=hue)
    return transform(img)

def add_masking(img, mask_fraction=0.2):
    arr = np.array(img)
    h, w = arr.shape[:2]

    mask_h = int(h * mask_fraction)
    mask_w = int(w * mask_fraction)

    top = np.random.randint(0, h - mask_h + 1)
    left = np.random.randint(0, w - mask_w + 1)

    arr[top:top+mask_h, left:left+mask_w] = 0  # black pixels
    return Image.fromarray(arr)

def make_noisy_dataset(src_root, dst_root, corruption, **kwargs):
    for split in ["train", "test"]:
        src_dir = os.path.join(src_root, split)
        dst_dir = os.path.join(dst_root, split)
        os.makedirs(dst_dir, exist_ok=True)

        for img_path in sorted(glob.glob(os.path.join(src_dir, "*.png"))):
            img = Image.open(img_path)

            if corruption == "gaussian":
                img = add_noise(img, **kwargs)
            elif corruption == "s&p":
                img = add_salt_and_pepper(img, **kwargs)
            elif corruption == "color_jitter":
                img = add_salt_and_pepper(img, **kwargs)
            elif corruption == "masking":
                img = add_masking(img, **kwargs)
            else:
                raise ValueError("Unsupported corruption")

            img.save(os.path.join(dst_dir, os.path.basename(img_path)))

            label_path = os.path.splitext(img_path)[0] + ".npy"
            if os.path.exists(label_path):
                shutil.copy(label_path, os.path.join(dst_dir, os.path.basename(label_path)))

    print(f"Created dataset '{dst_root}' with corruption '{corruption}'")

base = "ddpm-segmentation/datasets"
src_root = os.path.join(base, "horse_21/real")

dst_gaussian = os.path.join(base, "horse_21_gaussian")
dst_saltpepper = os.path.join(base, "horse_21_salt_and_pepper")
dst_color_jitter = os.path.join(base, "horse_21_color_jitter")
dst_masked = os.path.join(base, "horse_21_masked")


make_noisy_dataset(src_root, dst_gaussian, corruption="gaussian", sigma=25)
make_noisy_dataset(src_root, dst_saltpepper, corruption="s&p", amount=0.05)
make_noisy_dataset(src_root, dst_color_jitter, corruption="color_jitter")
make_noisy_dataset(src_root, dst_masked, corruption="masking")



Created dataset 'ddpm-segmentation/datasets\horse_21_gaussian' with corruption 'gaussian'
Created dataset 'ddpm-segmentation/datasets\horse_21_salt_and_pepper' with corruption 's&p'
Created dataset 'ddpm-segmentation/datasets\horse_21_color_jitter' with corruption 'color_jitter'
Created dataset 'ddpm-segmentation/datasets\horse_21_masked' with corruption 'masking'
