In [28]:
import pandas as pd
import plotly.express as px
import wandb

In [29]:
def export_project(project_name: str):
    """Load all runs from wandb project"""
    api = wandb.Api()
    entity, project = "smtb2023", project_name
    runs = api.runs(entity + "/" + project)

    summary_list, config_list, name_list = [], [], []
    for run in runs:
        summary_list.append(run.summary._json_dict)
        config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
        name_list.append(run.name)

    runs_df = pd.DataFrame({"summary": summary_list, "config": config_list, "name": name_list})
    rows = []
    for i, row in runs_df.iterrows():
        new_row = {}
        new_row.update(row.summary)
        new_row.update(row.config)
        new_row.update({"name": row.name})
        rows.append(new_row)
    return pd.DataFrame(rows)

In [33]:
df = export_project("stability")

In [34]:
# get layer progress (layer / total_layers)
df['layer_prog'] = df.apply(lambda x: x['layer_num'] / int(x['model_name'].split('_')[1][1:]), axis=1)
df.sort_values("model_name", inplace=True)
df.sort_values("layer_num", inplace=True)

In [35]:
px.line(df, x="layer_prog", y="test/pearson", color="model_name")