In [6]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import os
from pathlib import Path
from dotenv import load_dotenv
import matplotlib.pyplot as plt

load_dotenv()

PROJECT_ROOT = Path(os.getenv("PROJECT_ROOT")).resolve()  # type: ignore

def ema(values, alpha=0.98):
    """Exponential Moving Average smoothing"""
    smoothed = []
    m = None
    for v in values:
        m = v if m is None else alpha * m + (1 - alpha) * v
        smoothed.append(m)
    return smoothed

In [7]:
def draw_one_scalar(event_accs, tag, graph_dir, names=None, ema_alpha=0.98, sample=3, for_paper=False):
    # 스타일 설정
    if for_paper:
        plt.rcParams.update({
            "font.family": "Times New Roman",
            "font.size": 9,
            "axes.labelsize": 9,
            "xtick.labelsize": 8,
            "ytick.labelsize": 8,
            "legend.fontsize": 8,
            "axes.unicode_minus": False
        })
        figsize = (3.5, 2.5)   # 논문용 작은 사이즈 (inch)
        alpha_raw = 0.15
        lw = 1.0
        add_grid = False
        add_title = False
    else:
        figsize = (10, 5)
        alpha_raw = 0.2
        lw = 1.8
        add_grid = True
        add_title = True

    plt.figure(figsize=figsize)
    labels = names if names is not None else [f'Log {i+1}' for i in range(len(event_accs))]
    labels = [label.replace('_', ' ') for label in labels]

    # 원시 데이터 (옅게)
    for i, event_acc in enumerate(event_accs):
        scalar_events = event_acc.Scalars(tag)
        steps = [e.step for e in scalar_events]
        values = [e.value for e in scalar_events]
        plt.plot(steps[::sample], values[::sample], color=f"C{i}", alpha=alpha_raw)

    # EMA 데이터 (선명하게)
    for i, event_acc in enumerate(event_accs):
        scalar_events = event_acc.Scalars(tag)
        steps = [e.step for e in scalar_events]
        values = [e.value for e in scalar_events]
        smoothed_values = ema(values, alpha=ema_alpha)
        plt.plot(steps, smoothed_values, label=labels[i], color=f"C{i}", linewidth=lw)

    # 축 라벨
    plt.xlabel('Step')
    plt.ylabel(tag[6:].replace('/', ' '))

    # 제목 / grid / legend
    if add_title:
        plt.title(f'{tag}'.replace('_', ' '))
    if add_grid:
        plt.grid()
    plt.legend(frameon=False)

    plt.tight_layout()
    # 논문용은 pdf 저장도 가능
    plt.savefig(graph_dir / f"{tag.replace('/', '_')}.png", dpi=300)
    plt.close()

In [8]:
log_root = PROJECT_ROOT / "logs" / "experiment3"
log_dirs = []
event_accs = []
names = []

name_dict = {
    "cyclic": "Baseline (All edges)",
    "acyclic_reason_kahn": "FASO Reason-Kahn",
    "acyclic_no_reason_kahn": "FASO-Kahn",
    "acyclic_reason_deterministic": "FASO-Deterministic",
    "acyclic_no_reason_deterministic": "FASO Reason-Deterministic",
}

for sub in sorted(log_root.iterdir()):
    if sub.is_dir():
        log_dirs.append(sub)
        event_acc = EventAccumulator(str(sub))
        event_acc.Reload()
        event_accs.append(event_acc)
        names.append(name_dict[sub.name])

In [9]:
graph_dir = PROJECT_ROOT / "graphs" / "experiment3" / "comparisons"
os.makedirs(graph_dir, exist_ok=True)

In [10]:
# only common tags
scalars = set(event_accs[0].Tags()['scalars'])
for event_acc in event_accs[1:]:
    scalars.intersection_update(set(event_acc.Tags()['scalars']))

for tag in scalars:
    draw_one_scalar(event_accs, tag, graph_dir, ema_alpha=0.9, names=names, sample=1, for_paper=False)
    print(f"Saved graph for {tag}")

Saved graph for train/loss
Saved graph for train/rewards/margins
Saved graph for train/logits/rejected
Saved graph for train/train_loss
Saved graph for train/rewards/rejected
Saved graph for train/train_samples_per_second
Saved graph for train/train_steps_per_second
Saved graph for train/logps/chosen
Saved graph for train/rewards/accuracies
Saved graph for train/train_runtime
Saved graph for train/grad_norm
Saved graph for train/logps/rejected
Saved graph for train/logits/chosen
Saved graph for train/rewards/chosen
Saved graph for train/total_flos
Saved graph for train/epoch
Saved graph for train/learning_rate
