# Evaluation of Combined Natural Distortions for Attacking Tree-Ring Watermarks


In [1]:
# Package imports
import wandb, os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
from IPython.display import display, Image

# Relative imports
from metrics import *
from tree_ring import *
from guided_diffusion import *
from distortions import *
from utils import *

# Experiment name
experiment_name = "Eval_Comb_Distort_Tree_Ring"

# Experiment parameters
image_size = 64
dataset_name = "Tiny-ImageNet"
dataset_template = dataset_name
num_sample_per_class = 20  # So that we can have 20*200 = 4000 images for Tiny-ImageNet


tree_ring_paras = dict(
    w_channel=2,
    w_pattern="ring",
    w_mask_shape="circle",
    w_radius=10,
    w_measurement="l1_complex",
    w_injection="complex",
)
# Seeds
sampling_seed = 0
distortion_seed = 0

# Wandb and device setup
os.environ["WANDB_DIR"] = f"results/{experiment_name}/"
os.environ["WANDB_MODE"] = "dryrun"
os.environ["WANDB_SILENT"] = "true"
wandb.init(
    project=experiment_name,
    name=f"",
    config={},
    save_code=False,
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Evaluate the Image Quality Metrics for Tree-Ring Watermarks


In [2]:
# Load the original imagenet subset
dataset_org, class_names_org = load_imagenet_subset(
    dataset_name, convert_to_tensor=False
)

# Load generated images without watermarks
dataset_wo, class_names_wo = load_imagenet_guided(
    image_size, dataset_template, convert_to_tensor=False
)

# Load generated images with watermarks
dataset_w, class_names_w, keys, messages = load_tree_ring_guided(
    image_size,
    dataset_template,
    num_key_seeds=1,
    num_message_seeds=1,
    convert_to_tensor=False,
)

# Sample the images, class evenly distributed
assert class_names_org == class_names_wo == class_names_w
images_org, labels_org = sample_images_by_label_set(
    dataset_org, num_sample_per_class, sampling_seed=sampling_seed
)
images_wo, labels_wo = sample_images_by_label_set(
    dataset_wo, num_sample_per_class, sampling_seed=sampling_seed
)
images_w, labels_w = sample_images_by_label_set(
    dataset_w, num_sample_per_class, sampling_seed=sampling_seed
)
assert labels_org == labels_wo == [label[0] for label in labels_w]

# Metrics to calculate
metrics_funcs = dict(
    FID=lambda images1, images2: compute_fid_repeated(
        images1,
        images2,
        num_repeats=3,
        sample_size=2048,
        pairwise=True,
        sampling_seed=sampling_seed,
    ),
)

# Calculate and print results
for metric_name, metric_func in metrics_funcs.items():
    means, stds = tuples_to_lists(
        [
            metric_func(images_org, images_wo),
            metric_func(images_org, images_w),
            metric_func(images_wo, images_w),
        ]
    )
    fmt_strings = format_mean_and_std_list(means, stds, style="latex")
    wandb.log(
        {
            "metric": metric_name,
            "dist_org_wo": means[0],
            "dist_org_w": means[1],
            "dist_wo_w": means[2],
        }
    )
    print(
        "\n".join(
            [
                f"{metric_name} scores between:",
                f"  (original, wo/ watermark): {fmt_strings[0]}",
                f"  (original, w/ watermark): {fmt_strings[1]}",
                f"  (wo/ watermark, w/ watermark): {fmt_strings[2]}",
            ]
        )
    )

FID scores between:
  (original, wo/ watermark): ($430.5 \pm 3.0$) $\times 10^{-1}$
  (original, w/ watermark): ($531.7 \pm 3.4$) $\times 10^{-1}$
  (wo/ watermark, w/ watermark): ($376.6 \pm 2.6$) $\times 10^{-1}$
