In [14]:
import pandas as pd
import wandb

import seaborn as sns
import numpy as np
import os

import matplotlib.pyplot as plt
sns.set_style("white", {"grid.color": ".6", "grid.linestyle": ":"})

In [15]:
def get_runs_df_stable_sig(project, entity="jurujin", runtime_limit=6*3600):
    '''
    Returns df with data from wandb project for stable-sig
    '''
    api = wandb.Api()
    runs = api.runs(entity + "/" + project)

    summary_list, config_list, name_list = [], [], []
    for run in runs:
        summary_list.append(run.summary._json_dict)
        config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
        name_list.append(run.name)


    summary_df = pd.DataFrame(summary_list)
    config_df = pd.DataFrame(config_list)

    df = pd.concat([summary_df.reset_index(drop=True), config_df.reset_index(drop=True)], axis=1)
    df["name"] = name_list
    df = df[df["_runtime"] > runtime_limit]

    return df

In [16]:
# Чтобы выбрать долгие FID запуски надо: df["_runtime"] > 7200

def get_runs_df(project, entity="jurujin", runtime_limit=6*3600, resolution=False):
    '''
    Returns df with data from wandb project
    '''
    df = get_runs_df_stable_sig(project, entity, runtime_limit)

    if resolution:
        df["det_resol"] = df["no_w_det_dist_mean"] - df["w_det_dist_mean"]

    return df.sort_values(by=["w_radius", "msg_scaler"])

In [17]:
detection_projects = [
    "detect_msg_all_att_vae",
    "detect_msg_all_att_no_vae", 

    # "clip_different_msg"    # Testing CLIP quality for different message
]

stable_signature_detection_projects = [
    "eval_stable_tree_all_attacks"
]

fid_projects = [
    # "fid_gt_msg_all_att_vae",
    # "fid_gt_msg_all_att_no_vae",

    # "fid_gen_msg_all_att_vae",
    # "fid_gen_msg_all_att_no_vae",

    "fid_gen_message_dependency",
    "fid_gt_message_dependency",
    
]

detection_cols = [
    "name",
    "TPR@1%FPR", "auc", "acc",
    "Bit_acc", "Word_acc",
    "det_resol",

    "w_clip_score_mean",

    "w_det_dist_mean",
    "no_w_det_dist_mean",

    "w_det_dist_std",
    "no_w_det_dist_std",
    
    "msg", "w_radius", "msg_scaler",

    "jpeg_ratio", "crop_scale", "crop_ratio", "gaussian_blur_r", "gaussian_std", "brightness_factor", "r_degree"
]

stable_signature_detection_cols = [
    "name", "Bit_acc", "Word_acc"
]

fid_cols = [
    "name",
    "psnr_w", "ssim_w",
    "psnr_no_w", "ssim_no_w",
    "fid_w", "fid_no_w",
    "msg", "w_radius", "msg_scaler",

]

fid_att_cols = [
    "name",
    "psnr_w", "ssim_w",
    "psnr_no_w", "ssim_no_w",
    "fid_w", "fid_no_w",
    "msg", "w_radius", "msg_scaler",

    "jpeg_ratio", "crop_scale", "crop_ratio", "gaussian_blur_r", "gaussian_std", "brightness_factor", "r_degree"
]

In [18]:
for project in detection_projects:
    os.makedirs("./detection", exist_ok=True)
    get_runs_df(project, resolution=True, runtime_limit=4 * 3600).to_csv(f"./detection/{project}.csv", index=False, columns=detection_cols)

In [19]:
for project in fid_projects:
    os.makedirs("./fid", exist_ok=True)
    get_runs_df(project).sort_values(by="name", ascending=False).to_csv(f"./fid/{project}.csv", index=False, columns=fid_cols)

In [20]:
for project in stable_signature_detection_projects:
    os.makedirs("./detection", exist_ok=True)
    get_runs_df_stable_sig(project, runtime_limit=0).to_csv(f"./detection/{project}.csv", index=False, columns=stable_signature_detection_cols)

In [21]:
clip_different_message = "clip_different_msg"
os.makedirs("./detection", exist_ok=True)
get_runs_df(clip_different_message, runtime_limit=0, resolution=True).sort_values(by="name", ascending=False).to_csv(f"./detection/{clip_different_message}.csv", index=False, columns=detection_cols)