# E001: Visualization of Tree-Ring Watermark's Generation


In [1]:
# Package imports
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
from IPython.display import display, Image

# Relative imports
from tree_ring import *
from utils import *

# Device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Experiment parameters
image_size = 64
labels = [0]
tree_ring_paras = dict(
    w_channel=2,
    w_pattern="ring",
    w_mask_shape="circle",
    w_radius=10,
    w_measurement="l1_complex",
    w_injection="complex",
)

In [2]:
# TODO: Debug, reverse diffusion does not go back to the initial latent!!!

# Load guided diffusion models which are class-conditional diffusion models trained on ImageNet
model, diffusion = load_guided_diffusion_model(image_size, device)


# Generate images without watermark
images_wo = guided_diffusion_without_watermark(
    model, diffusion, labels, image_size, diffusion_seed=0
)

# Reverse diffusion on images with and without watermark
reversed_latents_wo = reverse_guided_diffusion(
    model,
    diffusion,
    images=images_wo,
    image_size=image_size,
    # For sanity check purpose, remove later
    default_labels=labels,
)

set_random_seed(0)
shape = (len(labels), 3, image_size, image_size)
init_latents_wo = torch.randn(*shape, device=device)
# Calculate absolute differences
abs_diff = torch.abs(reversed_latents_wo - init_latents_wo)

# Calculate relative differences. We add a small constant in the denominator to prevent division by zero.
epsilon = 1e-10
rel_diff = abs_diff / (torch.abs(init_latents_wo) + epsilon)

# Get maximum absolute and relative differences
max_abs_diff = torch.max(abs_diff)
max_rel_diff = torch.max(rel_diff)

print(f"Maximum Absolute Difference: {max_abs_diff.item()}")
print(f"Maximum Relative Difference: {max_rel_diff.item()}")

t=3, alpha_prev=0.9943352937698364, alpha_next=0.9886365532875061, a=-0.5132330656051636, b=-0.6949071884155273, in=-0.5843853950500488, out=-0.5640789866447449
t=2, alpha_prev=0.9981142282485962, alpha_next=0.9943352937698364, a=-0.519475519657135, b=-0.6122012138366699, in=-0.5640789866447449, out=-0.5455706119537354
t=1, alpha_prev=0.9999586939811707, alpha_next=0.9981142282485962, a=-0.525180459022522, b=-0.4809526801109314, in=-0.5455706119537354, out=-0.5282607078552246
t=0, alpha_prev=1.0, alpha_next=0.9999586939811707, a=-0.5274044275283813, b=-0.13496221601963043, in=-0.5282607078552246, out=-0.5274044275283813
t=0, alpha_prev=None, alpha_next=0.9981142282485962, a=-0.5272672772407532, b=-0.023033279925584793, in=-0.5274044275283813, out=-0.5277701020240784
t=1, alpha_prev=None, alpha_next=0.9943352937698364, a=-0.5061648488044739, b=-0.5085208415985107, in=-0.5277701020240784, out=-0.5430026650428772
t=2, alpha_prev=None, alpha_next=0.9886365532875061, a=-0.5016630291938782, 

## Implement Tree-Ring Watermark


In [None]:
# Load guided diffusion models which are class-conditional diffusion models trained on ImageNet
model, diffusion = load_guided_diffusion_model(image_size, device)


# Generate images without watermark
images_wo = guided_diffusion_without_watermark(
    model, diffusion, labels, image_size, diffusion_seed=0
)

# Generate one watermark message (which is the key in tree-ring's paper)
message = generate_message(
    message_seed=0,
    image_size=image_size,
    tree_ring_paras=tree_ring_paras,
    device=device,
)

# Generate one watermark key (which is the mask in tree-ring's paper)
key = generate_key(
    key_seed=0, image_size=image_size, tree_ring_paras=tree_ring_paras, device=device
)

# Generate images with watermark
images_w = guided_diffusion_with_watermark(
    model,
    diffusion,
    labels,
    keys=key,
    messages=message,
    tree_ring_paras=tree_ring_paras,
    image_size=image_size,
    diffusion_seed=0,
)

# Reverse diffusion on images with and without watermark
reversed_latents_wo = reverse_guided_diffusion(
    model,
    diffusion,
    images=images_wo,
    image_size=image_size,
)
reversed_latents_w = reverse_guided_diffusion(
    model,
    diffusion,
    images=images_w,
    image_size=image_size,
)

# Detect and evaluate watermark
auc, acc, low = detect_evaluate_watermark(
    reversed_latents_wo,
    reversed_latents_w,
    keys=key,
    messages=message,
    tree_ring_paras=tree_ring_paras,
    image_size=image_size,
)
print(
    f"Sanity check when there is no attack: AUC={auc}, accuracy={acc}, and TPR@1%FPR={low}."
)

## How Tree-Ring Watermarks are Changing through Forward Diffusion Sampling


In [None]:
# Get iterators for guided diffusion with and without watermark
images_prog_wo = guided_diffusion_without_watermark(
    model,
    diffusion,
    labels,
    image_size=image_size,
    diffusion_seed=0,
    progressive=True,
    return_image=True,
)
images_prog_w = guided_diffusion_with_watermark(
    model,
    diffusion,
    labels,
    keys=key,
    messages=message,
    tree_ring_paras=tree_ring_paras,
    image_size=image_size,
    diffusion_seed=0,
    progressive=True,
    return_image=True,
)

# Guided diffusion with and without watermark together step by step
figs = []
for images_wo, images_w in zip(images_prog_wo, images_prog_w):
    # Unnormalize images
    images_wo_tensor = to_tensor_and_normalize(images_wo)
    images_w_tensor = to_tensor_and_normalize(images_w)
    # Pixel-wise delta between images with and without watermark
    pixel_delta = torch.abs(images_wo_tensor - images_w_tensor)
    # FFT delta between images with and without watermark
    fft_delta = torch.abs(
        torch.fft.fftshift(
            torch.fft.fft2(images_wo_tensor - images_w_tensor),
            dim=(-1, -2),
        )
    )
    figs.append(
        visualize_image_grid(
            [
                images_wo,
                images_w,
                unnormalize_and_to_pil(pixel_delta * 10),
                unnormalize_and_to_pil(fft_delta / fft_delta.max()),
            ],
            col_headers=[
                "w/o watermark",
                "w/ watermark",
                "pixel-delta * 10",
                "fft-delta (max-normd)",
            ],
            row_headers=get_imagenet_class_names(labels),
            fontsize=10,
            column_first=True,
        )
    )
    plt.show(figs[-1])

# Make gif from the figures
make_gif(figs, filepath="./results/E001/forward_diffusion.gif")

assert False

In [None]:
# Visualizing the evolution of tree-ring watermark through diffusion

import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import random


def paired_transforms(img1, img2, img3, type=None, index=None):
    # Ensure img1 and img2 are PIL Images
    assert type in ["Rotation", "RandomResizedCrop", "RandomErasing", "IndexedErasing"]

    if type == "Rotation":
        # Rotation
        angle = random.uniform(-30, 30)  # Random rotation by up to 30 degrees
        img1 = F.rotate(img1, angle)
        img2 = F.rotate(img2, angle)
    elif type == "RandomResizedCrop":
        # Random Resized Crop
        i, j, h, w = transforms.RandomResizedCrop.get_params(
            img1, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3)
        )
        img1 = F.resized_crop(img1, i, j, h, w, (64, 64))
        img2 = F.resized_crop(img2, i, j, h, w, (64, 64))
    elif type == "RandomErasing":
        # Cutout
        x, y = random.randint(0, img1.width), random.randint(0, img1.height)
        h, w = random.randint(
            int(0.02 * img1.width), int(0.33 * img1.width)
        ), random.randint(int(0.02 * img1.height), int(0.33 * img1.height))
        img1, img2 = transforms.ToTensor()(img1), transforms.ToTensor()(img2)
        img1 = F.erase(img1, x, y, h, w, v=0)
        img2 = F.erase(img2, x, y, h, w, v=0)
        img1, img2 = transforms.ToPILImage()(img1), transforms.ToPILImage()(img2)
    elif type == "IndexedErasing":
        assert index >= 0 and index <= 8 * 8
        # Cutout
        x, y = 8 * (index // 8), 8 * (index % 8)
        h, w = 16, 16
        img1, img2, img3 = (
            transforms.ToTensor()(img1),
            transforms.ToTensor()(img2),
            transforms.ToTensor()(img3),
        )
        img1 = F.erase(img1, x, y, h, w, v=0)
        img2 = F.erase(img2, x, y, h, w, v=0)
        img3 = F.erase(img3, x, y, h, w, v=0)
        img1, img2, img3 = (
            transforms.ToPILImage()(img1),
            transforms.ToPILImage()(img2),
            transforms.ToPILImage()(img3),
        )
    return img1, img2, img3


def generate_and_compare_reverse(
    model,
    diffusion,
    prompt,
    key,
    image_size,
    tree_ring_paras,
    init_latents_w,
    watermarking_mask,
    diffusion_seed,
):
    set_random_seed(diffusion_seed)
    # For this class-conditioned diffusion model, prompts are just class ids
    assert isinstance(prompt, int) and 0 <= prompt < 1000
    # For simplicity, fix batch size to one
    shape = (1, 3, image_size, image_size)
    tree_ring_args = namedtuple("Args", tree_ring_paras.keys())(**tree_ring_paras)
    # Unnormalize for
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    unnormalize = transforms.Normalize(
        (-mean[0] / std[0], -mean[1] / std[1], -mean[2] / std[2]),
        (1 / std[0], 1 / std[1], 1 / std[2]),
    )
    # First pic
    # Diffusion w/o watermark
    no_wm_iter = diffusion.ddim_sample_loop_progressive(
        model=model,
        shape=shape,
        noise=init_latents_w,
        clip_denoised=True,
        model_kwargs=dict(y=torch.tensor([prompt], device=device)),
        device=device,
    )
    # Diffusion w watermark
    wm_iter = diffusion.ddim_sample_loop_progressive(
        model=model,
        shape=shape,
        noise=inject_watermark(init_latents_w, watermarking_mask, key, tree_ring_args),
        clip_denoised=True,
        model_kwargs=dict(y=torch.tensor([prompt], device=device)),
        device=device,
    )
    # Main loop
    image_list = []
    for no_wm_sample, wm_sample in zip(no_wm_iter, wm_iter):
        diff_init = (
            torch.abs(
                unnormalize(no_wm_sample["sample"][0])
                - unnormalize(wm_sample["sample"][0])
            )
            * 10
        )
        break

    # Diffusion
    no_wm_output = diffusion.ddim_sample_loop(
        model=model,
        shape=shape,
        noise=init_latents_w,
        clip_denoised=True,
        model_kwargs=dict(y=torch.tensor([prompt], device=device)),
        device=device,
        return_image=True,
    )
    wm_output = diffusion.ddim_sample_loop(
        model=model,
        shape=shape,
        noise=inject_watermark(init_latents_w, watermarking_mask, key, tree_ring_args),
        clip_denoised=True,
        model_kwargs=dict(y=torch.tensor([prompt], device=device)),
        device=device,
        return_image=True,
    )
    no_wm_image, wm_image = no_wm_output[0], wm_output[0]
    # no_wm_image = unnormalize(no_wm_image).permute(1, 2, 0).cpu()*255
    # wm_image = unnormalize(wm_image).permute(1, 2, 0).cpu()*255

    image_list = []
    for index in trange(8 * 8):
        # Augmentation
        no_wm_image_aug, wm_image_aug, diff_init_aug = paired_transforms(
            no_wm_image,
            wm_image,
            transforms.ToPILImage()(diff_init),
            type="IndexedErasing",
            index=index,
        )

        # Reverse Diffusion w/o watermark
        no_wm_iter_reverse = diffusion.ddim_reverse_sample_loop_progressive(
            model=model,
            shape=shape,
            image=no_wm_image_aug,
            clip_denoised=True,
            model_kwargs=dict(y=torch.tensor([prompt], device=device)),
            device=device,
        )
        # Reverse Diffusion w watermark
        wm_iter_reverse = diffusion.ddim_reverse_sample_loop_progressive(
            model=model,
            shape=shape,
            image=wm_image_aug,
            clip_denoised=True,
            model_kwargs=dict(y=torch.tensor([prompt], device=device)),
            device=device,
        )
        # Main loop
        for no_wm_sample_reverse, wm_sample_reverse in zip(
            no_wm_iter_reverse, wm_iter_reverse
        ):
            no_wm_image_reverse, wm_image_reverse = (
                no_wm_sample_reverse["sample"][0],
                wm_sample_reverse["sample"][0],
            )
        # Plot
        fft_diff = torch.abs(
            torch.fft.fftshift(
                torch.fft.fft2(
                    unnormalize(no_wm_image_reverse) - unnormalize(wm_image_reverse)
                )  # , dim=(-1, -2) check if this is making difference
            )
        )
        fft_diff = fft_diff / fft_diff.max()
        fig = visualize_images(
            [
                [
                    transforms.ToTensor()(no_wm_image_aug),
                    transforms.ToTensor()(wm_image_aug),
                    transforms.ToTensor()(diff_init_aug),
                    torch.abs(
                        transforms.ToTensor()(no_wm_image_aug)
                        - transforms.ToTensor()(wm_image_aug)
                    )
                    * 10,
                    unnormalize(no_wm_image_reverse),
                    unnormalize(wm_image_reverse),
                    torch.abs(
                        unnormalize(no_wm_image_reverse) - unnormalize(wm_image_reverse)
                    )
                    * 10,
                    fft_diff,
                ]
            ],
            [
                "w/o watermark",
                "w/ watermark",
                "delta*10 (inited)",
                "delta*10",
                "w/o watermark (reversed)",
                "w/ watermark (reversed)",
                "delta*10 (reversed)",
                "fft-delta (reversed, normd)",
            ],
            [""],
            fontsize=10,
        )

        buf = BytesIO()  # in-memory binary stream
        fig.savefig(
            buf, format="png", dpi=200, bbox_inches="tight"
        )  # save figure to the stream in PNG format
        buf.seek(0)
        image_list.append(imageio.imread(buf))  # read image from the stream
        plt.show(fig)
        plt.close(fig)
    imageio.mimsave("vis_64x64_0_3.gif", image_list, loop=1, duration=0.5)
    return None


set_random_seed(0)
key = create_key(key_seed=0, image_size=image_size, tree_ring_paras=tree_ring_paras)
shape = (1, 3, image_size, image_size)
tree_ring_args = namedtuple("Args", tree_ring_paras.keys())(**tree_ring_paras)
init_latents_w = torch.randn(*shape, device=device)
watermarking_mask = get_watermarking_mask(init_latents_w, tree_ring_args, device=device)
generate_and_compare_reverse(
    model,
    diffusion,
    0,
    key,
    image_size,
    tree_ring_paras,
    init_latents_w,
    watermarking_mask,
    3,
)