In [1]:
import json
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

jsonl_paths = [
    "../data/training_runs/00004-stylegan2-00000-all-frames-512-gpus2-batch32-gamma1.6384/stats.jsonl",
    "../data/training_runs/00005-stylegan2-00000-all-frames-512-gpus2-batch32-gamma1.6384/stats.jsonl"
]

lines = []
for jsonl_path in jsonl_paths:
    if not jsonl_path.endswith(".jsonl"):
        raise ValueError(f"Expected a .jsonl file, got: {jsonl_path}")

    with open(jsonl_path, "r") as f:
        for line in f:
            try:
                entry = json.loads(line.strip())
                if isinstance(entry, dict):
                    lines.append((entry, jsonl_path))
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON from line: {line.strip()}\nError: {e}")

                

data_records = []
for entry, json_path in lines:
    tick = entry.get("Progress/tick", {}).get("mean", None)
    kimg = entry.get("Progress/kimg", {}).get("mean", None)
    if tick is None:
        continue
    for metric, stats in entry.items():
        if isinstance(stats, dict):
            data_records.append({
                "tick": tick,
                "kimg": kimg,
                "metric": metric,
                "mean": stats.get("mean", None),
                "std": stats.get("std", None),
                "num": stats.get("num", None),
                "json_path": json_path
            })

df = pd.DataFrame(data_records)


In [2]:
df['kimg']= df['kimg'].astype(int)
df['training_order'] = df['json_path'].apply(lambda x: jsonl_paths.index(x))
df.loc[df.training_order == 1, 'kimg'] += df[df.training_order == 0].kimg.max()

In [3]:
df_pivot = df.pivot_table(index="kimg", columns="metric", values="mean")

In [4]:
# Plot G and D loss together
fig1 = go.Figure()
for metric in ["Loss/G/loss", "Loss/D/loss"]:
    if metric == "Loss/G/loss":
        name = "Generator Loss"
    else:
        name = "Discriminator Loss"
    
    if metric in df_pivot.columns:
        fig1.add_trace(go.Scatter(x=df_pivot.index, y=df_pivot[metric],
                                  mode="lines", name=name))
fig1.update_layout(
    title="Generator and Discriminator Loss",
    xaxis_title="kimg",
    yaxis_title="Loss",
    template="plotly_white",
    legend=dict(x=0, y=1)
)
fig1.show()
fig1.write_image("../data/figs/Generator and Discriminator Loss.png")

# Plot scores fakes and real together
fig2 = go.Figure()
for metric in ["Loss/scores/fake", "Loss/scores/real"]:
    if metric in df_pivot.columns:
        fig2.add_trace(go.Scatter(x=df_pivot.index, y=df_pivot[metric],
                                  mode="lines", name=metric))
fig2.update_layout(
    title="Discriminator Real/Fake Scores",
    xaxis_title="kimg",
    yaxis_title="Score",
    template="plotly_white",
    legend=dict(x=0, y=1)
)
fig2.show()
fig2.write_image("../data/figs/Discriminator Real-Fake Scores.png")

# Plot pl_penalty alone
fig3 = go.Figure()
if "Loss/pl_penalty" in df_pivot.columns:
    fig3.add_trace(go.Scatter(x=df_pivot.index, y=df_pivot["Loss/pl_penalty"],
                              mode="lines", name="Loss/pl_penalty"))
fig3.update_layout(
    title="Loss pl_penalty",
    xaxis_title="kimg",
    yaxis_title="Penalty",
    template="plotly_white",
    legend=dict(x=0, y=1)
)
fig3.show()
fig3.write_image("../data/figs/Loss pl_penalty.png")

# Plot progress augment alone
fig4 = go.Figure()
if "Progress/augment" in df_pivot.columns:
    fig4.add_trace(go.Scatter(x=df_pivot.index, y=df_pivot["Progress/augment"],
                              mode="lines", name="Progress/augment"))
fig4.update_layout(
    title="Augment Probability during Training",
    xaxis_title="kimg",
    yaxis_title="Augment Probability",
    yaxis_tickformat='p',
    template="plotly_white",
    legend=dict(x=0, y=1)
)
fig4.show()
fig4.write_image("../data/figs/Augment Probability during Training.png")
