In [None]:
import pandas as pd
import wandb

api = wandb.Api()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager

# Set global figure background color
plt.style.use('seaborn-whitegrid')
plt.rcParams["figure.facecolor"] = "white"
sns.set(style="white")

# 한글 폰트 설정
font_name = font_manager.FontProperties(
    fname="/usr/share/fonts/NanumFont/NanumGothicBold.ttf"
).get_name()
plt.rc("font", family=font_name)

In [None]:
sorted_columns = [
    ("Basic", "a/accuracy"),
    ("Basic", "ac/accuracy"),
    ("Basic", "f/accuracy"),
    ("Basic", "fc/accuracy"),
    ("Basic", "fn/accuracy"),
    ("Basic", "g/accuracy"),
    ("Basic", "gc/accuracy"),
    ("Family", "a/accuracy"),
    ("Family", "ca/accuracy"),
    ("Family", "g/accuracy"),
    ("Family", "cg/accuracy"),
    ("Family", "ag/accuracy"),
    ("Family", "cag/accuracy"),
    ("Personal", "a/accuracy"),
    ("Personal", "ac/accuracy"),
]

sorted_index = [
    "single-fr-ver-1-230529_0140-230602_1942",
    "dual-frkr-ver-1-230602_2021-230606_2104",
    "triple-fraekr-ver-1-230601_1838-230602_2112",
]

In [None]:
# Project is specified by <entity/project-name>
runs = api.runs("jongphago/test_validation")

In [None]:
summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

In [None]:
runs_df = pd.DataFrame(
    {"summary": summary_list, "config": config_list, "name": name_list}
)

runs_df.to_csv("project.csv")
runs_df.drop(2, inplace=True)

In [None]:
_summary_df = pd.DataFrame(
    {name: summary for summary, name in zip(runs_df.summary, runs_df.name)}
).T
_summary_df = _summary_df[
    sorted(_summary_df.columns[~_summary_df.columns.str.startswith("_")])
]

# 컬럼 이름을 분리하여 멀티 인덱스로 설정
_summary_df.columns = pd.MultiIndex.from_tuples(
    [tuple(c.split("-")) for c in _summary_df.columns]
)

summary_df = _summary_df.loc[sorted_index, sorted_columns].T.copy()
summary_df = summary_df.astype(float, copy=True)

In [None]:
# 컬럼명에서 첫 번째 원소를 추출하여 X축 라벨로 사용합니다.
edited_x_label = summary_df.columns.str.rsplit("-", 4).map(lambda x: x[0].upper())

# 그림 크기를 설정하고, 히트맵을 생성합니다.
fig, ax = plt.subplots(figsize=(10, 8))

# 히트맵 생성
sns.heatmap(
    summary_df,
    annot=True,
    cmap="coolwarm",
    fmt=".4g",
    cbar_kws={"label": "Accuracy"},
    xticklabels=edited_x_label,
    ax=ax,
)
print("Figure | 멀티 태스크 모델별 얼굴 검증 데이터셋 정확도")

# 타이틀과 라벨을 설정합니다.
ax.set_title("멀티 태스크 모델별 얼굴 검증 데이터셋 정확도", fontsize=15)
ax.set_xlabel("얼굴 검증 데이터셋 종류", fontsize=12)
ax.set_ylabel("멀티 태스크 모델 종류", fontsize=12)

# x, y 축 라벨의 회전 각도를 설정합니다.
plt.xticks(rotation=45)
# plt.yticks(rotation=45)

# 그래프를 출력합니다.
plt.show()
