In [None]:
import os
import onnxruntime as ort
import numpy as np
from tqdm.auto import trange
from PIL import Image
from dev import bit_error_rate, GROUND_TRUTH_MESSAGES


def init_stable_signature_model():
    assert ort.get_device() == "GPU"
    gpu_index = 0
    session_options = ort.SessionOptions()
    session_options.intra_op_num_threads = 1
    session_options.inter_op_num_threads = 1
    session_options.log_severity_level = 3
    return ort.InferenceSession(
        os.path.join(os.environ.get("MODEL_DIR"), "stable_signature.onnx"),
        providers=["CUDAExecutionProvider"],
        provider_options=[{"device_id": str(gpu_index)}],
        sess_options=session_options,
    )


def decode_stable_signature(image, model):
    image = np.array(image, dtype=np.float32) / 255.0
    image -= [0.485, 0.456, 0.406]
    image /= [0.229, 0.224, 0.225]
    image = np.expand_dims(image.transpose((2, 0, 1)), axis=0)
    inputs = {
        "image": image,
    }
    outputs = model.run(None, inputs)
    return (outputs[0] > 0).flatten().astype(bool)


model = init_stable_signature_model()
distances = []
for i in trange(200):
    filename = (
        f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/stable_sig/{i}.png"
    )
    image = Image.open(filename)
    decoded_message = decode_stable_signature(image, model)
    distances.append(
        bit_error_rate(decoded_message, GROUND_TRUTH_MESSAGES["stable_sig"])
    )
print(f"Average bit error rate of watermarked images: {np.mean(distances)}")

distances = []
for i in trange(200):
    filename = f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/real/{i}.png"
    image = Image.open(filename)
    decoded_message = decode_stable_signature(image, model)
    distances.append(
        bit_error_rate(decoded_message, GROUND_TRUTH_MESSAGES["stable_sig"])
    )
print(f"Average bit error rate of real images: {np.mean(distances)}")