In [None]:
import torch
import numpy as np
from tqdm.auto import trange
from dev import complex_l1, GROUND_TRUTH_MESSAGES


def get_tree_ring_key(size=64, radius=10, channel=3):
    mask = torch.zeros((1, 4, size, size), dtype=torch.bool)
    x0 = y0 = size // 2
    y, x = np.ogrid[:size, :size]
    y = y[::-1]
    mask[:, channel] = torch.tensor(((x - x0) ** 2 + (y - y0) ** 2) <= radius**2)
    return mask


def decode_tree_ring(reversed_latents, key):
    reversed_latents_fft = torch.fft.fftshift(
        torch.fft.fft2(reversed_latents), dim=(-1, -2)
    )[key].flatten()
    return (
        torch.concatenate([reversed_latents_fft.real, reversed_latents_fft.imag])
        .cpu()
        .numpy()
    )


key = get_tree_ring_key()
distances = []
for i in trange(200):
    filename = f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/tree_ring/{i}_reversed.pkl"
    reversed_latents = torch.load(filename)
    decoded_message = decode_tree_ring(reversed_latents, key)
    distances.append(complex_l1(decoded_message, GROUND_TRUTH_MESSAGES["tree_ring"]))
print(f"Average complex L1 of watermarked images: {np.mean(distances)}")

distances = []
for i in trange(200):
    filename = (
        f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/real/{i}_reversed.pkl"
    )
    reversed_latents = torch.load(filename)
    decoded_message = decode_tree_ring(reversed_latents, key)
    distances.append(complex_l1(decoded_message, GROUND_TRUTH_MESSAGES["tree_ring"]))
print(f"Average complex L1 of real images: {np.mean(distances)}")