In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from IPython.display import display, HTML


QUALITY_METRICS = {
    "legacy_fid": "FID",
    "clip_fid": "CLIP FID",
    "psnr": "PSNR",
    "ssim": "SSIM",
    "nmi": "Normed Mutual-Info",
    "lpips": "LPIPS",
    "watson": "Watson-DFT",
    "aesthetics": "Delta Aesthetics",
    "artifacts": "Delta Artifacts",
    "clip_score": "Delta CLIP-Score",
}

ATTACK_NAMES = {
    "distortion_single_rotation": "Dist-Rotation",
    "distortion_single_resizedcrop": "Dist-RCrop",
    "distortion_single_erasing": "Dist-Erase",
    "distortion_single_brightness": "Dist-Bright",
    "distortion_single_contrast": "Dist-Contrast",
    "distortion_single_blurring": "Dist-Blur",
    "distortion_single_noise": "Dist-Noise",
    "distortion_single_jpeg": "Dist-JPEG",
    "distortion_combo_geometric": "DistCom-Geo",
    "distortion_combo_photometric": "DistCom-Photo",
    "distortion_combo_degradation": "DistCom-Deg",
    "distortion_combo_all": "DistCom-All",
    "regen_diffusion": "Regen-Diff",
    "regen_diffusion_prompt": "Regen-DiffP",
    "regen_vae": "Regen-VAE",
    "kl_vae": "Regen-KLVAE",
    "2x_regen": "Rinse-2xDiff",
    "4x_regen": "Rinse-4xDiff",
    "4x_regen_bmshj": "RinseD-VAE",
    "4x_regen_kl_vae": "RinseD-KLVAE",
    "adv_emb_same_vae_untg": "AdvEmbG-KLVAE8",
    "adv_emb_resnet18_untg": "AdvEmbB-RN18",
    "adv_emb_clip_untg_alphaRatio_0.05_step_200": "AdvEmbB-CLIP",
    "adv_emb_klf16_vae_untg": "AdvEmbB-KLVAE16",
    "adv_emb_sdxl_vae_untg": "AdvEmbB-SdxlVAE",
    "adv_cls_unwm_wm_0.01_50_warm_train3k": "AdvCls-UnWM&WM",
    "adv_cls_real_wm_0.01_50_warm": "AdvCls-Real&WM",
    "adv_cls_wm1_wm2_0.01_50_warm": "AdvCls-WM1&WM2",
}

WATERMARK_METHODS = {
    "tree_ring": "Tree-Ring",
    "stable_sig": "Stable-Signature",
    "stegastamp": "Stega-Stamp",
}

PERFORMANCE_METRICS = {
    "acc_1": "Mean Detection Accuracy",
    "auc_1": "AUC",
    "low100_1": "TPR@1%FPR",
    "low1000_1": "TPR@0.1%FPR",
    "acc_100": "Identification Accuracy (100 Users)",
    "acc_1000": "Identification Accuracy (1K Users)",
    "acc_1000000": "Identification Accuracy (1M Users)",
}


df = pd.read_pickle("all_result.pkl")
quality_threshold_low, quality_threshold_high = 0.1, 0.9

# Change here
#############################################
performance_metric = "acc_1000000"  # "low1000_1"
performance_threshold_low, performance_threshold_high = 0.4, 0.7  # 0.7, 0.95

In [2]:
# Extract qualities metircs of removal setup (and remove abnormal None values) as a dictionary of lists
qualities = {
    k: [
        v[0] if k not in ["psnr", "ssim", "nmi", "artifacts"] else -v[0]
        for v in df[df["eval_setup"] == "removal"][k]
        if v is not None
    ]
    for k in QUALITY_METRICS.keys()
    if k not in ["watson", "clip_score"]
}

# Number of types (subplots)
num_types = len(qualities)

# Create a subplot figure with 1 row and `num_types` columns
fig = make_subplots(rows=1, cols=num_types, shared_yaxes=True)

# Loop through each item in the dictionary and add to subplots
for i, (key, values) in enumerate(qualities.items(), start=1):
    # Calculate the CDF
    values_sorted = np.sort(values)
    cdf = np.arange(1, len(values) + 1) / len(values)

    # Add a trace for each type to its respective subplot
    fig.add_trace(
        go.Scatter(
            x=values_sorted,
            y=cdf,
            mode="lines",
            name=("- " if key in ["psnr", "ssim", "nmi", "artifacts"] else "")
            + QUALITY_METRICS[key],
        ),
        row=1,
        col=i,
    )

    # Update x-axis title for each subplot
    fig.update_xaxes(
        title_text=("- " if key in ["psnr", "ssim", "nmi", "artifacts"] else "")
        + QUALITY_METRICS[key],
        row=1,
        col=i,
    )

    # Find intersection points with horizontal lines at a and b
    intersect_a = np.interp(quality_threshold_low, cdf, values_sorted)
    intersect_b = np.interp(quality_threshold_high, cdf, values_sorted)

    # Add dashed lines for intersection points
    fig.add_trace(
        go.Scatter(
            x=[intersect_a, intersect_a],
            y=[0, quality_threshold_low],
            mode="lines",
            line=dict(color="red", dash="dash"),
            showlegend=False,
        ),
        row=1,
        col=i,
    )
    fig.add_trace(
        go.Scatter(
            x=[intersect_b, intersect_b],
            y=[0, quality_threshold_high],
            mode="lines",
            line=dict(color="blue", dash="dash"),
            showlegend=False,
        ),
        row=1,
        col=i,
    )

# Add horizontal lines at a*100% and b*100%
fig.add_hline(
    y=quality_threshold_low,
    line_dash="dot",
    annotation_text=f"{quality_threshold_low*100}%",
    annotation_position="bottom right",
)
fig.add_hline(
    y=quality_threshold_high,
    line_dash="dot",
    annotation_text=f"{quality_threshold_high*100}%",
    annotation_position="top right",
)


# Update layout of the figure
fig.update_layout(
    title="Cumulative Distribution Function Plot with Independent X-Axes",
    yaxis_title="Cumulative Probability",
    yaxis=dict(tickformat=".0%", range=[0, 1]),  # Setting y-axis as percentage
)

ranges = {}

# Loop through each item in the dictionary and add annotations for intersections
for i, (key, values) in enumerate(qualities.items(), start=1):
    # Calculate the CDF
    values_sorted = np.sort(values)
    cdf = np.arange(1, len(values) + 1) / len(values)

    # Find intersection points with horizontal lines at a and b
    intersect_a = np.interp(quality_threshold_low, cdf, values_sorted)
    intersect_b = np.interp(quality_threshold_high, cdf, values_sorted)

    # Add annotations for intersection points
    fig.add_annotation(
        x=intersect_a,
        y=0.05,
        xref=f"x{i}",
        yref="y",
        text=f"<b>{intersect_a:.2f}</b>",
        showarrow=False,
        font=dict(color="red", size=12),
        xshift=-10 if i > 1 else 10,  # Adjust position to avoid overlap
    )
    fig.add_annotation(
        x=intersect_b,
        y=0,
        xref=f"x{i}",
        yref="y",
        text=f"<b>{intersect_b:.2f}</b>",
        showarrow=False,
        font=dict(color="blue", size=12),
        xshift=-10 if i > 1 else 10,  # Adjust position to avoid overlap
    )
    ranges[key] = (
        (intersect_a, intersect_b)
        if key not in ["psnr", "ssim", "nmi", "artifacts"]
        else (-intersect_a, -intersect_b)
    )

In [3]:
def normalized_quality(qualities):
    normed_qualties = {
        key: [
            (value - ranges[key][0])
            / (ranges[key][1] - ranges[key][0])
            * (quality_threshold_high - quality_threshold_low)
            + quality_threshold_low
            for value in qualities[key]
        ]
        for key in qualities.keys()
    }
    weight = {
        "legacy_fid": 1 / 4 / 2,
        "clip_fid": 1 / 4 / 2,
        "psnr": 1 / 4 / 3,
        "ssim": 1 / 4 / 3,
        "nmi": 1 / 4 / 3,
        "lpips": 1 / 4 / 1,
        "aesthetics": 1 / 4 / 2,
        "artifacts": 1 / 4 / 2,
    }
    return [
        sum([weight[key] * normed_qualties[key][idx] for key in normed_qualties.keys()])
        for idx in range(len(list(normed_qualties.values())[0]))
    ]


def modified_interp(x, xp, fp):
    if x < min(xp):
        return np.inf
    elif x > max(xp):
        return -np.inf
    else:
        return np.interp(x, xp, fp)


def rank_with_ties(column, threshold=0.01):
    # Dictionary to hold the ranks
    rank_dict = {}
    rank = 1

    # Sort the unique values to assign ranks in order
    for value in sorted(np.unique(column)):
        if value not in rank_dict:
            rank_dict[value] = rank
            rank += 1

    # Adjust ranks for values within the threshold
    unique_values = sorted(rank_dict.keys())
    for i in range(1, len(unique_values)):
        if unique_values[i] - unique_values[i - 1] < threshold:
            rank_dict[unique_values[i]] = rank_dict[unique_values[i - 1]]

    # Map the ranks back to the original order of the column
    ranks = [rank_dict[val] for val in column]

    return ranks


def plot_aggregated_2d_plot(watermark_method):
    df_sel = df.copy()
    df_sel.drop(columns=["acc_100"], inplace=True)
    df_sel = df_sel[df_sel["eval_setup"] == "removal"].dropna()
    df_sel = df_sel[df_sel["source_name"] == watermark_method]

    qualities = {
        k: [v[0] for v in df_sel[k] if v is not None]
        for k in QUALITY_METRICS.keys()
        if k not in ["watson", "clip_score"]
    }
    # df_sel["normalized_quality"] = normalized_quality(qualities)
    df_sel["normalized_quality"] = [v[0] for v in df_sel["normalized_quality"]]

    df_sel = df_sel[
        [
            "dataset_name",
            "attack_name",
            "attack_strength",
            performance_metric,
            "normalized_quality",
        ]
    ]

    colors = (
        px.colors.qualitative.Plotly
        + px.colors.qualitative.D3
        + px.colors.qualitative.G10
    )
    show_text = False
    line_width = 3
    marker_size = 9
    tick_size = 10
    legend_fontsize = 15
    plot_height = 800

    fig = go.Figure()
    leaderboard_row = []
    for i, attack_name in enumerate(ATTACK_NAMES.keys()):
        if attack_name in ["4x_regen_bmshj", "4x_regen_kl_vae"]:
            continue
        if attack_name.startswith("dist"):
            marker = "square"
        elif attack_name.startswith("adv"):
            marker = "star"
        else:
            marker = "x"

        df_cat = df_sel[df_sel["attack_name"] == attack_name].drop(
            columns=["attack_name"]
        )
        df_cat = (
            df_cat.groupby("attack_strength")[
                [performance_metric, "normalized_quality"]
            ]
            .mean()
            .reset_index()
        )
        df_cat = df_cat.assign(
            attack_strength_float=df_cat["attack_strength"].astype(float)
        )
        df_cat = df_cat.sort_values("attack_strength_float").drop(
            columns=["attack_strength_float"]
        )
        fig.add_trace(
            go.Scatter(
                x=df_cat["normalized_quality"],
                y=df_cat[performance_metric],
                text=df_cat["attack_strength"],
                mode="lines+markers+text" if show_text else "lines+markers",
                name=ATTACK_NAMES[attack_name]
                if attack_name in ATTACK_NAMES
                else "N/A",
                line=dict(color=colors[i % len(colors)], width=line_width),
                marker=dict(symbol=marker, size=marker_size),
                textposition="bottom right",
            )
        )
        performance_threshold_low_str = f"Q@{performance_threshold_low:.2f}P"
        performance_threshold_high_str = f"Q@{performance_threshold_high:.2f}P"
        leaderboard_row.append(
            {
                "Attack": ATTACK_NAMES[attack_name],
                performance_threshold_high_str: modified_interp(
                    performance_threshold_high,
                    df_cat[performance_metric][::-1],
                    df_cat["normalized_quality"][::-1],
                ),
                performance_threshold_low_str: modified_interp(
                    performance_threshold_low,
                    df_cat[performance_metric][::-1],
                    df_cat["normalized_quality"][::-1],
                ),
                "AvgP": np.mean(df_cat[performance_metric]),
                "AvgQ": np.mean(df_cat["normalized_quality"]),
            }
        )

    leaderboard_df = pd.DataFrame(leaderboard_row)
    leaderboard_df[performance_threshold_high_str + "_rank"] = rank_with_ties(
        leaderboard_df[performance_threshold_high_str]
    )
    leaderboard_df[performance_threshold_low_str + "_rank"] = rank_with_ties(
        leaderboard_df[performance_threshold_low_str]
    )
    leaderboard_df["AvgP_rank"] = rank_with_ties(leaderboard_df["AvgP"])
    leaderboard_df["AvgQ_rank"] = rank_with_ties(leaderboard_df["AvgQ"])
    leaderboard_df_sorted = leaderboard_df.sort_values(
        by=[
            performance_threshold_high_str + "_rank",
            performance_threshold_low_str + "_rank",
            "AvgP_rank",
            "AvgQ_rank",
        ],
        ascending=[True, True, True, True],
    )
    leaderboard_df_sorted["Rank"] = (
        leaderboard_df_sorted[
            [
                performance_threshold_high_str + "_rank",
                performance_threshold_low_str + "_rank",
                "AvgP_rank",
                "AvgQ_rank",
            ]
        ]
        .apply(tuple, axis=1)
        .rank(method="min")
        .astype(int)
    )
    # leaderboard_df_sorted = leaderboard_df_sorted.drop(
    #     columns=[
    #         performance_threshold_high_str + "_rank",
    #         performance_threshold_low_str + "_rank",
    #         "AvgP_rank",
    #         "AvgQ_rank",
    #     ]
    # )
    leaderboard_df_sorted["Watermark"] = WATERMARK_METHODS[watermark_method]
    leaderboard_df_sorted = leaderboard_df_sorted[
        ["Watermark", "Attack", "Rank"]
        + [
            col
            for col in leaderboard_df_sorted.columns
            if col not in ["Watermark", "Attack", "Rank"]
        ]
    ]

    fig.add_trace(
        go.Scatter(
            x=[0, 1.6],
            y=[performance_threshold_high, performance_threshold_high],
            mode="lines",
            line=dict(color="black", dash="dash"),
            showlegend=False,
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=[0, 1.6],
            y=[performance_threshold_low, performance_threshold_low],
            mode="lines",
            line=dict(color="black", dash="dash"),
            showlegend=False,
        ),
    )

    # Adjust the layout for the line plot and add legend
    fig.update_layout(
        title_text=f"Comparison of Attacks on {WATERMARK_METHODS[watermark_method]}",
        showlegend=True,
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.05,
            font=dict(size=legend_fontsize),
        ),
        xaxis=dict(title="Normalized Quality", tickfont=dict(size=tick_size)),
        yaxis=dict(
            title=PERFORMANCE_METRICS[performance_metric], tickfont=dict(size=tick_size)
        ),
        height=plot_height,
    )

    return leaderboard_df_sorted, fig

In [4]:
leaderboards = []
fig_jsons = {}
for watermark_method in WATERMARK_METHODS.keys():
    leaderboard, fig = plot_aggregated_2d_plot(watermark_method)
    # fig.show()
    fig_jsons[watermark_method] = fig.to_json()
    leaderboards.append(leaderboard)
# leaderboard = pd.concat(leaderboards, axis=0)
# display(HTML(leaderboard.to_html(index=False)))
# leaderboard.to_csv("leaderboard.csv", index=False)