In [None]:
import os
import onnxruntime as ort
import numpy as np
from tqdm.auto import trange
from PIL import Image, ImageOps
from space import bit_error_rate, GROUND_TRUTH_MESSAGES


def init_stega_stamp_model():
    assert ort.get_device() == "GPU"
    gpu_index = 0
    session_options = ort.SessionOptions()
    session_options.intra_op_num_threads = 1
    session_options.inter_op_num_threads = 1
    session_options.log_severity_level = 3
    return ort.InferenceSession(
        os.path.join(os.environ.get("MODEL_DIR"), "stega_stamp.onnx"),
        providers=["CUDAExecutionProvider"],
        provider_options=[{"device_id": str(gpu_index)}],
        sess_options=session_options,
    )


def decode_stega_stamp(image, model):
    image = np.array(ImageOps.fit(image, (400, 400)), dtype=np.float32)
    image /= 255.0
    image = np.expand_dims(image, axis=0)
    secret = np.zeros((1, 100), dtype=np.float32)
    inputs = {
        "image": image,
        "secret": secret,
    }
    outputs = model.run(None, inputs)
    return outputs[2].flatten().astype(bool)


model = init_stega_stamp_model()
distances = []
for i in trange(200):
    filename = (
        f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/stegastamp/{i}.png"
    )
    image = Image.open(filename)
    decoded_message = decode_stega_stamp(image, model)
    distances.append(
        bit_error_rate(decoded_message, GROUND_TRUTH_MESSAGES["stegastamp"])
    )
print(f"Average bit error rate of watermarked images: {np.mean(distances)}")

distances = []
for i in trange(200):
    filename = f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/real/{i}.png"
    image = Image.open(filename)
    decoded_message = decode_stega_stamp(image, model)
    distances.append(
        bit_error_rate(decoded_message, GROUND_TRUTH_MESSAGES["stegastamp"])
    )
print(f"Average bit error rate of real images: {np.mean(distances)}")