In [86]:
import os
from pathlib import Path
from dotenv import load_dotenv
import json
import matplotlib.pyplot as plt
load_dotenv()

PROJECT_ROOT = Path(os.getenv("PROJECT_ROOT")).resolve() # type: ignore
MODEL_ROOT = Path(os.getenv("MODEL_ROOT")).resolve() # type: ignore
IMAGE_ROOT = Path(os.getenv("IMAGE_ROOT")).resolve() # type: ignore
CONFIG_ROOT = Path(os.getenv("CONFIG_ROOT")).resolve() # type: ignore
LOG_ROOT = Path(os.getenv("LOG_ROOT")).resolve() # type: ignore

In [87]:
log_root = LOG_ROOT
log_name = "log.json"
names = []
steps = []
psnrs = []

for sub in sorted(log_root.iterdir()):
    if sub.is_dir():
        if (not (sub / log_name).exists()):
            continue
        names.append(sub.name)
        step = []
        psnr = []
        with open(sub / log_name, "r") as f:
            log_data = json.load(f)
            for k in log_data:
                step.append(k["step"])
                psnr.append(k["psnr"])
        steps.append(step)
        psnrs.append(psnr)

In [88]:
graph_dir = PROJECT_ROOT / "graphs"
os.makedirs(graph_dir, exist_ok=True)

In [89]:
for_paper = True

if for_paper:
    plt.rcParams.update({
        "font.family": "Times New Roman",
        "font.size": 9,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.fontsize": 8,
        "axes.unicode_minus": False
    })
    figsize = (3.5, 2.5)   # 논문용 작은 사이즈 (inch)
    alpha_raw = 0.15
    lw = 1.0
    add_grid = False
    add_title = False
else:
    figsize = (10, 5)
    alpha_raw = 0.2
    lw = 1.8
    add_grid = True
    add_title = True

plt.figure(figsize=figsize)
cnt = 0
for i in range(len(names)):
    label = names[i].replace('_', ' ')
    if label in ("meerkat meta 260", "meerkat no meta 260", "meerkat siren 450"):
        if label == "meerkat meta 260":
            label = "GKAN With Meta Learning"
        if label == "meerkat no meta 260":
            label = "GKAN Without Meta Learning"
        if label == "meerkat siren 450":
            label = "SIREN"
        plt.plot(steps[i], psnrs[i], label=label, color=f"C{cnt}", linewidth=lw)
        cnt += 1

plt.savefig(graph_dir / "psnr_to_step.png", dpi=300, bbox_inches='tight')

plt.xlabel('Step')
plt.ylabel('PSNR (dB)')

if add_title:
    plt.title("PSNR vs Step")
if add_grid:
    plt.grid()
plt.legend(frameon=False)

plt.tight_layout()
plt.savefig(graph_dir / "psnr_to_step.png", dpi=300)
plt.close()

In [90]:
import json
import pandas as pd

records = []
for exp_dir in sorted(LOG_ROOT.iterdir()):
    log_path = exp_dir / "log.json"
    if not log_path.exists():
        continue

    with log_path.open() as f:
        steps = json.load(f)
    if not steps:
        continue

    # 폴더명: meerkat_meta_100 → 모델 계열, 파라미터 태그 분리
    parts = exp_dir.name.split("_")
    model_family = "_".join(parts[:2])          # meerkat_meta, meerkat_no, ...
    param_label = parts[-1]                    # 100, 150, ...
    params = steps[0]["params"]                # log.json에 직접 기록된 실제 파라미터 수
    max_psnr = max(step["psnr"] for step in steps)

    records.append(
        {
            "run": exp_dir.name,
            "model_family": model_family,
            "param_label": param_label,
            "params": params,
            "max_psnr": max_psnr,
        }
    )

df = pd.DataFrame(records).sort_values(["model_family", "params"])
display(df)

grouped = df.groupby("model_family")
display(*grouped)

Unnamed: 0,run,model_family,param_label,params,max_psnr
0,meerkat_meta_100,meerkat_meta,100,60901,23.926114
1,meerkat_meta_150,meerkat_meta,150,136351,26.759318
2,meerkat_meta_200,meerkat_meta,200,241801,29.218082
3,meerkat_meta_260,meerkat_meta,260,407941,32.045871
4,meerkat_no_meta_100,meerkat_no,100,60901,23.078521
5,meerkat_no_meta_150,meerkat_no,150,136351,26.057012
6,meerkat_no_meta_200,meerkat_no,200,241801,27.827926
7,meerkat_no_meta_260,meerkat_no,260,407941,29.585128
8,meerkat_siren_170,meerkat_siren,170,58991,16.721331
9,meerkat_siren_260,meerkat_siren,260,137021,19.226382


('meerkat_meta',
                 run  model_family param_label  params   max_psnr
 0  meerkat_meta_100  meerkat_meta         100   60901  23.926114
 1  meerkat_meta_150  meerkat_meta         150  136351  26.759318
 2  meerkat_meta_200  meerkat_meta         200  241801  29.218082
 3  meerkat_meta_260  meerkat_meta         260  407941  32.045871)

('meerkat_no',
                    run model_family param_label  params   max_psnr
 4  meerkat_no_meta_100   meerkat_no         100   60901  23.078521
 5  meerkat_no_meta_150   meerkat_no         150  136351  26.057012
 6  meerkat_no_meta_200   meerkat_no         200  241801  27.827926
 7  meerkat_no_meta_260   meerkat_no         260  407941  29.585128)

('meerkat_siren',
                   run   model_family param_label  params   max_psnr
 8   meerkat_siren_170  meerkat_siren         170   58991  16.721331
 9   meerkat_siren_260  meerkat_siren         260  137021  19.226382
 10  meerkat_siren_346  meerkat_siren         346  241855  21.178240
 11  meerkat_siren_450  meerkat_siren         450  408151  22.752765)

In [99]:
for_paper = True

if for_paper:
    plt.rcParams.update({
        "font.family": "Times New Roman",
        "font.size": 9,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "legend.fontsize": 8,
        "axes.unicode_minus": False
    })
    figsize = (3.5, 2.5)   # 논문용 작은 사이즈 (inch)
    alpha_raw = 0.15
    lw = 1.0
    add_grid = False
    add_title = False
else:
    figsize = (10, 5)
    alpha_raw = 0.2
    lw = 1.8
    add_grid = True
    add_title = True

plt.figure(figsize=figsize)
cnt = 0
for group in grouped:
    label = group[0]
    if label == "meerkat_meta":
        label = "GKAN With Meta Learning"
    if label == "meerkat_no":
        label = "GKAN Without Meta Learning"
    if label == "meerkat_siren":
        label = "SIREN"
    plt.plot(group[1]["params"] / 10000, group[1]["max_psnr"], label=label, color=f"C{cnt}", linewidth=lw, marker='o', markersize=2.5)
    cnt += 1

plt.savefig(graph_dir / "psnr_to_size.png", dpi=300, bbox_inches='tight')

plt.xlabel(r"Params ($\times 10^4$)")
plt.ylabel('PSNR (dB)')

if add_title:
    plt.title("PSNR vs Step")
if add_grid:
    plt.grid()
plt.legend(frameon=False)

plt.tight_layout()
plt.savefig(graph_dir / "psnr_to_size.png", dpi=300)
plt.close()