In [1]:
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

sns.set_theme(style="dark")


def extract_epoch_loss(event_files: list[Path], tag_to_plot: str) -> dict[int, float]:
    """Extract loss values aggregated per epoch from TensorBoard log files."""
    epoch_loss = defaultdict(list)

    for event_file_path in event_files:
        event_acc = EventAccumulator(str(event_file_path))
        event_acc.Reload()

        tags = event_acc.Tags()["scalars"]
        if "epoch" not in tags or tag_to_plot not in tags:
            print(f"Warning: Required tags not found in {event_file_path}")
            continue

        epoch_events = event_acc.Scalars("epoch")
        loss_events = event_acc.Scalars(tag_to_plot)

        step_to_epoch = {e.step: int(e.value) for e in epoch_events}

        for e in loss_events:
            if e.step in step_to_epoch:
                epoch = step_to_epoch[e.step]
                epoch_loss[epoch].append(e.value)

    return {epoch: sum(losses) / len(losses) for epoch, losses in epoch_loss.items()}


def plot_loss(
    log_dirs: dict[str, Path],
    tag_to_plot: str,
    filename: str,
    title: str,
    ylim_top: float,
) -> None:
    """Plot loss from multiple TensorBoard log directories using epoch-based values."""
    plt.figure(figsize=(4, 4))

    for label, log_dir in log_dirs.items():
        event_files = sorted(log_dir.glob("events.out.tfevents.*"))
        if not event_files:
            print(f"Warning: No event files found in {log_dir}")
            continue

        epoch_loss = extract_epoch_loss(event_files, tag_to_plot)
        if not epoch_loss:
            print(f"Warning: No valid data for {tag_to_plot} in {log_dir}")
            continue

        dataframe = pd.DataFrame(
            {
                "epoch": sorted(epoch_loss.keys()),
                "value": [epoch_loss[e] for e in sorted(epoch_loss.keys())],
            },
        )

        sns.lineplot(x="epoch", y="value", data=dataframe, label=label)

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.ylim(0, ylim_top)
    plt.savefig(filename, format="pdf", dpi=300, bbox_inches="tight")
    plt.show()


In [None]:
log_dirs = {
    "Original TrueType": Path("../lightning_logs/version_0"),
    "Decomposed TrueType": Path("../lightning_logs/version_1"),
    "Segmented TrueType": Path("../lightning_logs/version_2"),
    "PostScript": Path("../lightning_logs/version_3"),
}

plot_loss(log_dirs, "train_loss_epoch", "train_loss.pdf", "Training Loss", 0.6)

plot_loss(log_dirs, "val_loss", "validation_loss.pdf", "Validation Loss", 0.6)


In [None]:
log_dirs = {
    "Original TrueType": Path("../lightning_logs/version_4"),
    "Decomposed TrueType": Path("../lightning_logs/version_5"),
    "Segmented TrueType": Path("../lightning_logs/version_6"),
    "PostScript": Path("../lightning_logs/version_7"),
}

plot_loss(log_dirs, "train_loss_epoch", "train_loss.pdf", "Training Loss", 2.5)

plot_loss(log_dirs, "val_loss", "validation_loss.pdf", "Validation Loss", 2.5)
