In [1]:
import torch
from tqdm.auto import tqdm
import numpy as np
from dev import (
    GROUND_TRUTH_MESSAGES,
    get_messages_from_json,
    message_distance,
    theoretical_distance_cdf,
)

msg_org = GROUND_TRUTH_MESSAGES["tree_ring"]
print(msg_org.shape, msg_org.dtype)

(634,) float16


In [2]:
patch = torch.load(
    "/fs/nexus-projects/HuangWM/datasets/generated/imagenet/tree_ring_imagenet_512_wm2/gt_patch.pt"
)
real = patch["real"].cpu().numpy()
imag = patch["imag"].cpu().numpy()
mask = (
    torch.load(
        "/fs/nexus-projects/HuangWM/datasets/generated/imagenet/tree_ring_imagenet_512_wm2/watermarking_mask.pt"
    )
    .cpu()
    .numpy()
)

print(real.shape, real.dtype)
print(imag.shape, imag.dtype)
print(mask.shape, mask.dtype)
msg_new = np.concatenate([real[mask], imag[mask]])
print(msg_new.shape, msg_new.dtype)

(1, 4, 64, 64) float16
(1, 4, 64, 64) float16
(1, 4, 64, 64) bool
(634,) float16


In [3]:
acc_dict = {}
mode = "tree_ring"
for strength in tqdm([0, 2, 4, 6, 8]):
    if strength == 0:
        watermarked_path = f"/fs/nexus-projects/HuangWM/datasets/decoded/diffusiondb/tree_ring-decode.json"
    else:
        watermarked_path = f"/fs/nexus-projects/HuangWM/datasets/decoded/diffusiondb/adv_cls_wm1_wm2_0.01_50_warm-{strength}-tree_ring-decode.json"
    watermarked_messages = get_messages_from_json(watermarked_path, mode)
    for num_users in [100, 1000, 1000000]:
        for gt_user in ["user 1", "user 2"]:
            distances = np.array(
                [
                    message_distance(
                        message,
                        msg_org if gt_user == "user 1" else msg_new,
                        "identification",
                    )
                    for message in watermarked_messages
                ]
            )
            acc = np.mean(
                np.exp(
                    (num_users - 1)
                    * np.log(
                        1.0
                        - theoretical_distance_cdf(
                            watermarked_messages, distances, mode
                        ),
                    )
                )
            )
            acc_dict[(strength, num_users, gt_user)] = acc
            print(
                f"strength: {strength}, num_users: {num_users}, gt_user: {gt_user}, acc: {acc}"
            )

  0%|          | 0/5 [00:00<?, ?it/s]

strength: 0, num_users: 100, gt_user: user 1, acc: 0.9981658661009984
strength: 0, num_users: 100, gt_user: user 2, acc: 4.3489422837110336e-76
strength: 0, num_users: 1000, gt_user: user 1, acc: 0.9966199455324674
strength: 0, num_users: 1000, gt_user: user 2, acc: 0.0
strength: 0, num_users: 1000000, gt_user: user 1, acc: 0.9908251614225024
strength: 0, num_users: 1000000, gt_user: user 2, acc: 0.0
strength: 2, num_users: 100, gt_user: user 1, acc: 0.8330119860288396
strength: 2, num_users: 100, gt_user: user 2, acc: 0.0005147524943915642
strength: 2, num_users: 1000, gt_user: user 1, acc: 0.7764117892334849
strength: 2, num_users: 1000, gt_user: user 2, acc: 1.109519250542755e-06
strength: 2, num_users: 1000000, gt_user: user 1, acc: 0.6470593889987359
strength: 2, num_users: 1000000, gt_user: user 2, acc: 0.0
strength: 4, num_users: 100, gt_user: user 1, acc: 0.6398497235683687
strength: 4, num_users: 100, gt_user: user 2, acc: 0.0028846944494287973
strength: 4, num_users: 1000, gt

In [5]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

fig = make_subplots(specs=[[{"secondary_y": True}]])

strengths = [0, 2, 4, 6, 8]

# Define colors from a qualitative color scale
color1, color2 = px.colors.qualitative.T10[:2]

# User A - Primary Axis
fig.add_trace(
    go.Scatter(
        x=strengths,
        y=[acc_dict[(strength, 100, "user 1")] for strength in strengths],
        name="Identified as User 1 (100 Users)",
        marker=dict(size=11, symbol="circle"),
        line=dict(color=color1),
    ),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(
        x=strengths,
        y=[acc_dict[(strength, 100, "user 2")] for strength in strengths],
        name="Identified as User 2 (100 Users)",
        marker=dict(size=11, symbol="circle"),
        line=dict(color=color2),
    ),
    secondary_y=True,
)
fig.add_trace(
    go.Scatter(
        x=strengths,
        y=[acc_dict[(strength, 1000, "user 1")] for strength in strengths],
        name="Identified as User 1 (1K Users)",
        marker=dict(size=11, symbol="cross"),
        line=dict(color=color1),
    ),
    secondary_y=False,
)
fig.add_trace(
    go.Scatter(
        x=strengths,
        y=[acc_dict[(strength, 1000, "user 2")] for strength in strengths],
        name="Identified as User 2 (1K Users)",
        marker=dict(size=11, symbol="cross"),
        line=dict(color=color2),
    ),
    secondary_y=True,
)

# Axis titles
fig.update_xaxes(title_text="Attack Strength")
fig.update_yaxes(
    title_text="Identified as <b>User 1</b>",
    secondary_y=False,
    tickfont=dict(color=color1),
    titlefont=dict(color=color1),
)
fig.update_yaxes(
    title_text="Identified as <b>User 2</b>",
    secondary_y=True,
    tickfont=dict(color=color2),
    titlefont=dict(color=color2),
)


fig.update_layout(
    yaxis=dict(range=[-0.05, 1.05]),
    yaxis2=dict(range=[-0.008 / 21, 0.008]),
    height=480,  # Aspect ratio 4:3, height = 3/4 of width
    width=860,
    legend=dict(
        x=1.1, y=1.0, xanchor="left", orientation="v"
    ),  # Adjusting legend position
)

fig.show()