In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np

In [None]:
print(f"{plt.rcParams['font.sans-serif'] = }")
print(f"{plt.rcParams['font.monospace'] = }")
# prepend desired fonts
plt.rcParams['font.sans-serif'] = ['Helvetica'] + plt.rcParams['font.sans-serif']
plt.rcParams['font.monospace'] = ['Berkeley Mono'] + plt.rcParams['font.monospace']

In [None]:
# plt.style.use(style='ggplot')

In [None]:
# Load the CSV file
whitelist = ["delenda", "modern-800", "oculist-800", "insurmountable-700"]
# whitelist = ["alembic"]
PATH = "../../bullet/checkpoints"

dfs: list[tuple[str, pd.DataFrame]] = []
for file in os.listdir(PATH):
    if not any(name in file for name in whitelist):
        continue
    file_path = f"{PATH}/{file}/log.txt"
    print(f"{file_path = }")

    df = pd.read_csv(file_path, header=None, names=["superbatch", "batch", "loss"], dtype=str) # type: ignore

    assert isinstance(df, pd.DataFrame)

    print(f"lines = {len(df)}")

    # print(df.head())
    df["superbatch"] = df["superbatch"].str.removeprefix("superbatch:").astype(int)
    df["batch"] = df["batch"].str.removeprefix("batch:").astype(int)
    df["loss"] = df["loss"].str.removeprefix("loss:").astype(float)
    # print(df.head())

    # Group by epoch and calculate the average loss for each epoch
    df["total_batch"] = (df["superbatch"] - 1) * 6104 + df["batch"]
    df.drop(["superbatch", "batch"], axis=1, inplace=True)

    dfs.append((file, df))

# Plotting
plt.figure(figsize=(15, 12)) # type: ignore

CONV = 551
HALF = (CONV - 1) // 2

colors = ["coral", "cyan", "orange", "red", "white", "green", "blue", "pink"]
colors = ["red", "blue", "orange", "green"]
colors = (c for c in colors)

for file, df in dfs:
    plt.plot(  # type: ignore
        df["total_batch"][500:][HALF:-HALF] / 6104,
        np.convolve(df["loss"][500:], np.ones(CONV) / CONV, "valid"),
        label=f"{file} epoch mean loss",
        alpha=0.6,
        color=next(colors),
    )

# Adding labels and title
plt.xlabel("Superbatch") # type: ignore
plt.ylabel('Loss') # type: ignore
plt.title('Experiment loss over time') # type: ignore
plt.legend() # type: ignore
plt.grid(True) # type: ignore

# Show the plot
plt.show() # type: ignore