In [None]:
import json

def load(filename):
    with open(filename, "r") as f:  
        data = json.load(f)
        
    data = [
        {
            "x": x,
            "seconds": d["seconds"],
            **d["params"]
        } for x, ds in data.items() for d in ds
    ]
    
    return data

data_main = load("main_timings.json")
data_pmain = load("pmain_timings.json")

In [None]:
import pandas as pd

pdf = pd.DataFrame(pd.concat([pd.DataFrame(data_main), pd.DataFrame(data_pmain)]))
pdf = pdf.explode("seconds")
pdf = pdf.sort_values(by="x")
pdf.head()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes, fontsizes
import numpy as np

NEURIPS_FORMAT = bundles.neurips2024(rel_width=0.5)
plt.clf()
# Create the main figure
sns.set_theme("paper", style="whitegrid")
plt.rcParams.update(NEURIPS_FORMAT)
fig = plt.figure(figsize=(2.75, 2.4))

palette = sns.color_palette("viridis", as_cmap=True)

def discrete_palette(p, n, margin=0.01):
    return [p(x) for x in np.linspace(0 + margin, 1 - margin, n)]

ax = sns.lineplot(pdf[~pdf["x"].isin(["vL", "L"])], x="mul", hue="x", y="seconds", marker="o", palette=discrete_palette(palette, 3))
ax = sns.lineplot(pdf[pdf["x"] == "L"], x="mul", hue="x", y="seconds", marker="o", palette="rocket")
ax = sns.lineplot(pdf[pdf["x"] == "vL"], x="mul", hue="x", y="seconds", marker="o", dashes=(2,2), palette="rocket")
ax.set_title("Seconds per epoch")
ax.set_ylabel("")
ax.set_xlabel("Multiplicative Factor")
ax.set_ylim(0.5, 4.0)
ax.set_xlim(0.8, 6.2)
yticks = [1.0, 2.0, 3.0, 4.0]
ax.set_yticks(yticks, labels=map(lambda x: f"${x}$", yticks), rotation=90, va="center", ha="left")
xticks = [1,2,3,4,5,6]
ax.set_xticks(xticks, labels=map(lambda x: f"$\\times{x}$", xticks))
ax.legend(
    loc="upper left",
    handles=ax.get_legend().legend_handles,
    labels=[r"Batch size", r"Network width", r"T", r"\# of layers", r"\# of layers (\textit{vmap})"])
plt.savefig("parallel.pdf")