In [None]:
import os
import torch
import numpy as np
from tqdm.auto import trange
from PIL import Image
from utils import to_tensor
from space import bit_error_rate, STABLE_SIGNATURE_MESSAGE


def init_stable_signature_model():
    assert torch.cuda.is_available()
    model = torch.jit.optimize_for_inference(
        torch.jit.load(
            os.path.join(os.environ.get("MODEL_DIR"), "stable_signature.pt"),
            map_location="cuda",
        )
    )
    return model


def decode_stable_signature(image, model):
    return (
        (model(to_tensor([image], norm_type="imagenet").to("cuda")) > 0)
        .squeeze()
        .cpu()
        .numpy()
        .astype(bool)
    )


model = init_stable_signature_model()
distances = []
for i in trange(200):
    filename = (
        f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/stable_sig/{i}.png"
    )
    image = Image.open(filename)
    decoded_message = decode_stable_signature(image, model)
    distances.append(bit_error_rate(decoded_message, STABLE_SIGNATURE_MESSAGE))
print(f"Average bit error rate of watermarked images: {np.mean(distances)}")

distances = []
for i in trange(200):
    filename = f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/real/{i}.png"
    image = Image.open(filename)
    decoded_message = decode_stable_signature(image, model)
    distances.append(bit_error_rate(decoded_message, STABLE_SIGNATURE_MESSAGE))
print(f"Average bit error rate of real images: {np.mean(distances)}")