In [1]:
import wandb
import warnings
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go

api = wandb.Api()
runs = api.runs("lucyfarnik/jacobian_saes2_jac_sweep")

if runs.more:
    warnings.warn("You're not fetching all of the runs.\n\n")

metrics = [
    # reconstruction of SAE1
    "reconstruction_quality/mse", "reconstruction_quality/cossim", 
    "reconstruction_quality/explained_variance",
    "model_performance_preservation/ce_loss_score",
    "sparsity/dead_features", "shrinkage/l2_ratio", 

    # reconstruction of SAE2
    "reconstruction_quality/mse2", "reconstruction_quality/cossim2", 
    "reconstruction_quality/explained_variance2",
    "model_performance_preservation/ce_loss_score_double",
    "sparsity/dead_features2", "shrinkage/l2_ratio2", 

    # jacobian sparsity
    "jacobian_sparsity/jac_l0", "jacobian_sparsity/jac_l1", "losses/jacobian_loss",     
]
data = []
for run in runs:
    if run.state != "finished":
        continue
    run_data = {
        "jacobian_coefficient": run.config["jacobian_coefficient"],
        "expansion_factor": run.config["expansion_factor"]
    }
    for m in metrics:
        run_data[m.split("/")[1]] = run.summary[m]
    data.append(run_data)

df = pd.DataFrame(data)





In [2]:
jac_range = [0, 10_000]
for expansion_factor in ["any", *df["expansion_factor"].unique()]:
    df_filtered = df[df["jacobian_coefficient"] >= jac_range[0]]
    df_filtered = df_filtered[df_filtered["jacobian_coefficient"] <= jac_range[1]]
    if expansion_factor != "any":
        df_filtered = df_filtered[df_filtered["expansion_factor"] == expansion_factor]

    fig = make_subplots(rows=5, cols=3, subplot_titles=metrics)
    for m in metrics:
        metric = m.split("/")[1]
        # px.scatter(df_filtered, x="jacobian_coefficient", y=metric, log_x=True, title=metric).show()
        fig.add_trace(go.Scatter(x=df_filtered["jacobian_coefficient"], y=df_filtered[metric], mode="markers", name=metric), row=metrics.index(m) // 3 + 1, col=metrics.index(m) % 3 + 1)
    fig.update_layout(title=f"Jacobian sparsity sweep (Expansion factor: {expansion_factor})", showlegend=False, height=1200)
    # Update y-axes to log scale
    for i in range(1, 16):
        fig.update_xaxes(type="log", row=(i-1)//3 + 1, col=(i-1)%3 + 1)
    fig.show()