In [1]:
import os
import concurrent.futures
from PIL import Image
from tqdm.auto import tqdm
from dev import *
from distortions import *

tasks = [
    ("distortion_single_rotation", "3"),
    ("distortion_single_rotation", "1"),
    ("distortion_single_resizedcrop", "0.85"),
    ("distortion_single_resizedcrop", "0.95"),
    ("distortion_combo_degradation", "0.02"),
]


def distortion(attack_name, attack_strength, original_image):
    if attack_name == "distortion_single_rotation":
        distortion_types = ["rotation"]
        relative_strength = False
    elif attack_name == "distortion_single_resizedcrop":
        distortion_types = ["resizedcrop"]
        relative_strength = False
    elif attack_name == "distortion_combo_degradation":
        distortion_types = ["blurring", "noise", "compression"]
        relative_strength = True
    else:
        assert False
    image = original_image
    for distortion_type in distortion_types:
        image = apply_distortion(
            [image],
            distortion_type=distortion_type,
            strength=float(attack_strength),
            distortion_seed=0,
            same_operation=False,
            relative_strength=relative_strength,
            return_image=True,
        )[0]
    return image


def process_image(index, attack_name, attack_strength, original_path, attacked_path):
    original_image = Image.open(f"{original_path}/{index}.png")
    attacked_image = distortion(attack_name, attack_strength, original_image)
    attacked_image.save(f"{attacked_path}/{index}.png")


def process_dataset(attack_name, attack_strength, dataset_name, source_name, pbar):
    original_path = (
        f"/fs/nexus-projects/HuangWM/datasets/main/{dataset_name}/{source_name}"
    )
    attacked_path = f"/fs/nexus-projects/HuangWM/datasets/attacked/{dataset_name}/{attack_name}-{attack_strength}-{source_name}"
    if not os.path.exists(attacked_path):
        os.makedirs(attacked_path)

    with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
        futures = [
            executor.submit(
                process_image,
                index,
                attack_name,
                attack_strength,
                original_path,
                attacked_path,
            )
            for index in range(5000)
        ]
        for _ in concurrent.futures.as_completed(futures):
            pbar.update(1)


total_operations = len(DATASET_NAMES) * (len(WATERMARK_METHODS) + 1) * len(tasks) * 5000

with tqdm(total=total_operations) as pbar:
    for attack_name, attack_strength in tasks:
        for dataset_name in DATASET_NAMES.keys():
            for source_name in ["real"] + list(WATERMARK_METHODS.keys()):
                process_dataset(
                    attack_name, attack_strength, dataset_name, source_name, pbar
                )
                attacked_path = f"/fs/nexus-projects/HuangWM/datasets/attacked/{dataset_name}/{attack_name}-{attack_strength}-{source_name}"
                print(attacked_path)
                assert len(os.listdir(attacked_path)) == 5000

  0%|          | 0/300000 [00:00<?, ?it/s]

/fs/nexus-projects/HuangWM/datasets/attacked/diffusiondb/distortion_single_rotation-3-real
/fs/nexus-projects/HuangWM/datasets/attacked/diffusiondb/distortion_single_rotation-3-tree_ring


KeyboardInterrupt: 