In [14]:
import plotly
import plotly.graph_objects as go
import numpy as np
import warnings
from dev import *

warnings.filterwarnings("ignore")


QUALITY_METRICS = {
    "legacy_fid": "Legacy FID",
    "clip_fid": "CLIP FID",
    "psnr": "PSNR",
    "ssim": "SSIM",
    "nmi": "Normed Mutual-Info",
    "lpips": "LPIPS",
    "aesthetics": "Delta Aesthetics",
    "artifacts": "Delta Artifacts",
}

aggregate_attack_name_dict = {
    "Distortions (single)": [
        "distortion_single_rotation",
        "distortion_single_resizedcrop",
        "distortion_single_erasing",
        "distortion_single_brightness",
        "distortion_single_contrast",
        "distortion_single_blurring",
        "distortion_single_noise",
        "distortion_single_jpeg",
    ],
    "Distortions (combo)": [
        "distortion_combo_geometric",
        "distortion_combo_photometric",
        "distortion_combo_degradation",
        "distortion_combo_all",
    ],
    "Embedding Attack (grey-box)": ["adv_emb_same_vae_untg"],
    "Embedding Attack (black-box)": [
        "adv_emb_clip_untg_alphaRatio_0.05_step_200",
        "adv_emb_sdxl_vae_untg",
        "adv_emb_klf16_vae_untg",
    ],
    "Surrogate Detector Attack": ["adv_cls_wm1_wm2_0.01_50_warm"],
    "Regeneration (black-box)": ["regen_diffusion", "kl_vae"],
    "Rinsing regeneration (black-box)": ["2x_regen", "4x_regen"],
    "Mixed regeneration (black-box)": ["4x_regen_bmshj", "4x_regen_kl_vae"],
}


result_dict = {k: {} for k in aggregate_attack_name_dict.keys()}
for aggregate_attack_name, attack_names in aggregate_attack_name_dict.items():
    qualities = {k: [] for k in QUALITY_METRICS.keys()}
    for attack_name in attack_names:
        json_dict = get_all_json_paths(
            lambda _dataset_name, _attack_name, _attack_strength, _source_name, _result_type: (
                (not _source_name.startswith("real")) and (_attack_name == attack_name)
            )
        )
        record_keys = set(
            [(key[0], key[3], key[1], key[2]) for key in json_dict.keys()]
        )
        for dataset_name, source_name, attack_name, attack_strength in record_keys:
            quality = get_quality(
                dataset_name,
                source_name,
                attack_name,
                attack_strength,
                mode="removal",
            )
            for mode in QUALITY_METRICS.keys():
                if quality[mode] is not None:
                    qualities[mode].append(quality[mode][0])
    for mode in QUALITY_METRICS.keys():
        result_dict[aggregate_attack_name][mode] = (
            np.array(qualities[mode]).mean()
            if mode != "artifacts"
            else -np.array(qualities[mode]).mean()
        )


fig = go.Figure()


# Retrieve the default Plotly colors
default_colors = plotly.colors.DEFAULT_PLOTLY_COLORS

alpha = 0.0  # Set your desired alpha value here, e.g., 0.5 for 50% transparency
modified_colors = [
    color.replace("rgb", "rgba").replace(")", f", {alpha})") for color in default_colors
]

left_dict = {}
right_dict = {}
for mode in QUALITY_METRICS.keys():
    qualities = []
    for aggregate_attack_name in aggregate_attack_name_dict.keys():
        qualities.append(result_dict[aggregate_attack_name][mode])
    if mode not in ["psnr", "ssim", "nmi"]:
        left_dict[mode] = max(qualities)
        right_dict[mode] = min(qualities)
    else:
        left_dict[mode] = min(qualities)
        right_dict[mode] = max(qualities)

for index, aggregate_attack_name in enumerate(aggregate_attack_name_dict.keys()):
    r = np.array(list(result_dict[aggregate_attack_name].values()))
    r = (r - np.array(list(left_dict.values()))) / (
        np.array(list(right_dict.values())) - np.array(list(left_dict.values()))
    )
    r = r.tolist()
    r.append(r[0])
    theta = [
        f" {v} [{l:.2f}, {r:.2f}]"
        for v, l, r in zip(
            QUALITY_METRICS.values(), left_dict.values(), right_dict.values()
        )
    ]
    theta.append(theta[0])
    fig.add_trace(
        go.Scatterpolar(
            r=r,
            theta=theta,
            fill="toself",
            name=aggregate_attack_name,
            fillcolor=modified_colors[index],
        )
    )

fig.update_layout(
    polar=dict(radialaxis=dict(visible=True, range=[0, 1])), showlegend=True
)

fig.show()