In [1]:
import pandas as pd
import plotly.express as px
import wandb
import plotly.graph_objects as go

In [2]:
def export_project(project_name: str):
    """Load all runs from wandb project"""
    api = wandb.Api()
    entity, project = "smtb2023", project_name
    runs = api.runs(entity + "/" + project)

    summary_list, config_list, name_list = [], [], []
    for run in runs:
        summary_list.append(run.summary._json_dict)
        config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
        name_list.append(run.name)

    runs_df = pd.DataFrame({"summary": summary_list, "config": config_list, "name": name_list})
    rows = []
    for i, row in runs_df.iterrows():
        new_row = {}
        new_row.update(row.summary)
        new_row.update(row.config)
        new_row.update({"name": row.name})
        rows.append(new_row)
    return pd.DataFrame(rows)

In [3]:
df = export_project("stability")
df["layer_prog"] = df.apply(lambda x: x["layer_num"] / int(x["model_name"].split("_")[1][1:]), axis=1)
df.sort_values("model_name", inplace=True)
df.sort_values("layer_num", inplace=True)

In [18]:
t = df[df["model_name"] == "esm2_t33_650M_UR50D"]
t = t[t["val/pearson"].notna()][["model_name", "layer_prog", "val/pearson"]]

In [19]:
# Calculate mean and standard deviation for each layer_prog
grouped = t.groupby("layer_prog")["val/pearson"].agg(["mean", "min", "max"])

# Plot using plotly
fig = go.Figure()

# Line for mean value
fig.add_trace(
    go.Scatter(
        x=grouped.index,
        y=grouped["mean"],
        mode="lines",
        name="Mean",
        line=dict(color="blue"),
    )
)

# Filled area for standard deviation
fig.add_trace(
    go.Scatter(
        x=grouped.index.tolist() + grouped.index.tolist()[::-1],
        y=(grouped["min"]).tolist() + (grouped["max"]).tolist()[::-1],
        fill="toself",
        fillcolor="rgba(0,100,240,0.2)",
        line=dict(color="rgba(255,255,255,0)"),
        name="Standard Deviation",
    )
)

fig.update_layout(
    title="Layer Progress vs. Pearson Value",
    xaxis_title="Layer Progress",
    yaxis_title="Pearson Value",
    showlegend=True,
)

fig.show()

In [20]:
t

Unnamed: 0,model_name,layer_prog,val/pearson
134,esm2_t33_650M_UR50D,0.0,0.385732
124,esm2_t33_650M_UR50D,0.0,0.37914
183,esm2_t33_650M_UR50D,0.0,0.371838
115,esm2_t33_650M_UR50D,0.0,0.382675
117,esm2_t33_650M_UR50D,0.0,0.382913
121,esm2_t33_650M_UR50D,0.0,0.380921
96,esm2_t33_650M_UR50D,0.030303,0.476102
180,esm2_t33_650M_UR50D,0.030303,0.409198
103,esm2_t33_650M_UR50D,0.030303,0.511602
98,esm2_t33_650M_UR50D,0.030303,0.508983
