In [None]:
import plotly.graph_objects as go
import numpy as np
from dev import *


aggregate_attack_name_dict = {
    "Distortions (single)": [
        "distortion_single_rotation",
        "distortion_single_resizedcrop",
        "distortion_single_erasing",
        "distortion_single_brightness",
        "distortion_single_contrast",
        "distortion_single_blurring",
        "distortion_single_noise",
        "distortion_single_jpeg",
    ],
    "Distortions (combo)": [
        "distortion_combo_geometric",
        "distortion_combo_photometric",
        "distortion_combo_degradation",
        "distortion_combo_all",
    ],
    "Embedding Attack (grey-box)": ["adv_emb_same_vae_untg"],
    "Embedding Attack (black-box)": [
        "adv_emb_clip_untg_alphaRatio_0.05_step_200",
        "adv_emb_sdxl_vae_untg",
        "adv_emb_klf16_vae_untg",
    ],
    "Surrogate Detector Attack": ["adv_cls_wm1_wm2_0.01_50_warm"],
    "Regeneration (black-box)": ["regen_diffusion", "kl_vae"],
    "Rinsing regeneration (black-box)": ["2x_regen", "4x_regen"],
    "Mixed regeneration (black-box)": ["4x_regen_kl_bmshj"],
}

categories = [
    "processing cost",
    "mechanical properties",
    "chemical stability",
    "thermal stability",
    "device integration",
]


result_dict = {k: {} for k in aggregate_attack_name_dict.keys()}
for source_name in WATERMARK_METHODS.keys():
    for aggregate_attack_name, attack_names in aggregate_attack_name_dict.items():
        performances = []
        for attack_name in attack_names:
            json_dict = get_all_json_paths(
                lambda _dataset_name, _attack_name, _attack_strength, _source_name, _result_type: (
                    _source_name == source_name and _attack_name == attack_name
                )
            )
            record_keys = set(
                [(key[0], key[3], key[1], key[2]) for key in json_dict.keys()]
            )
            for dataset_name, source_name, attack_name, attack_strength in record_keys:
                performances.append(
                    get_performance(
                        dataset_name,
                        source_name,
                        attack_name,
                        attack_strength,
                        mode="removal",
                    )["low1000_1"]
                )
        result_dict[aggregate_attack_name][source_name] = np.array(
            [p for p in performances if p is not None]
        ).mean()


fig = go.Figure()

for aggregate_attack_name in aggregate_attack_name_dict.keys():
    fig.add_trace(
        go.Scatterpolar(
            r=list(result_dict[aggregate_attack_name].values()),
            theta=list(WATERMARK_METHODS.keys()),
            fill="toself",
            name=aggregate_attack_name,
        )
    )

fig.update_layout(
    polar=dict(radialaxis=dict(visible=True, range=[0, 5])), showlegend=False
)

fig.show()