In [27]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from copy import deepcopy
from tqdm.notebook import tqdm

In [28]:
api = wandb.Api()

In [59]:
sweep_ids = ("6snt47r7",)
sweeps = [api.sweep(f"project-avengers/goon-test/sweeps/{sweep_id}") for sweep_id in sweep_ids]

In [None]:
results = []
for s in sweeps:
    runs = s.runs
    for r in tqdm(runs):
        if r.state == "finished":
            config = deepcopy(r.config)
            last_loss =[x for x in r.scan_history(min_step=config["num_steps"], keys=["loss"])]
            assert len(last_loss) == 1
            config["loss"] = last_loss[0]["loss"]
            results.append(config)

  0%|          | 0/132 [00:00<?, ?it/s]

In [None]:
results_df = pd.DataFrame(results)
results_df

In [None]:
non_nan_results_df = results_df[results_df.loss != "NaN"]
print(f"Removed {len(results_df) - len(non_nan_results_df)} NaNs")

In [None]:
fig, axs = plt.subplots(                                                  
 ncols=2, nrows=1, sharey=True, sharex=True, figsize=(6 * 2, 4 ) 
)                                                                         
mup_plot =sns.lineplot(data=non_nan_results_df[non_nan_results_df.mup], x="learning_rate", y="loss", hue="d_model", ax=axs[0])
axs[0].set_title("mup")
no_mup_plot = sns.lineplot(data=non_nan_results_df[~non_nan_results_df.mup], x="learning_rate", y="loss", hue="d_model", ax=axs[1])
axs[1].set_title("sp")
mup_plot.set(xscale="log")
mup_plot.set(yscale="log")
no_mup_plot.set(xscale="log")
no_mup_plot.set(yscale="log")

suptitle = f'n_layer-{config["n_layer"]}_seq_length-{config["seq_length"]}_head_dim-{config["head_dim"]}_num_steps-{config["num_steps"]}_batch_size-{config["batch_size"]}_acc_steps-{config["acc_steps"]}_optim-{config["optim"]}'

fig.suptitle(suptitle)
fig.subplots_adjust(top=0.8)
fig.savefig(suptitle + ".png", dpi=256, bbox_inches="tight")