In [1]:
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

In [2]:
# api = wandb.Api(api_key="<API_KEY>",timeout=50)
api = wandb.Api()

In [12]:
def extract_runs(api: wandb.Api, workspace: str = "ludekcizinsky/seizure-prediction"):
    runs = api.runs("ludekcizinsky/seizure-prediction")
    runs_list = []
    for run in runs:
        tags = list(run.tags)
        has_cross_val = "fold_0" in tags or "fold_1" in tags or "fold_2" in tags
        if not has_cross_val:
            continue

        filtered_tags = [tag for tag in tags if not tag.startswith("fold_") and not tag.startswith("part")]
        key = "_".join(sorted(filtered_tags))
        fold_id = run.config["fold_id"]
        

        for step in run.scan_history():
            if "val/f1_macro" not in step or not step["val/f1_macro"]:
                continue
            run_dict = {
                "key": key,
                "name": run.name,
                "fold_id": fold_id,
                "epoch": step.get("epoch", None),
                "val_accuracy": step.get("val/acc", None),
                "val_f1": step.get("val/f1_macro", None),
                "val_f1_plus": step.get("val/f1_class_1",None),
                "val_f1_neg": step.get("val/f1_class_0",None),
                "val_loss": step.get("val/loss", None)
            }
            runs_list.append(run_dict)
    runs_df = pd.DataFrame(runs_list).sort_values(by=["key", "fold_id"])
    runs_df = runs_df.groupby(["key", "epoch"]).agg({
        "val_f1": "mean",
        "val_f1_plus": "mean",
        "val_f1_neg": "mean",
        "val_accuracy": "mean",
        "val_loss": "mean",
        "name": "first"
    }).reset_index()
    return runs_df

In [13]:
runs_df = extract_runs(api)

In [16]:
runs_df.iloc[runs_df.groupby("key").idxmax()["val_f1"].reset_index()["val_f1"].to_list()].sort_values(by="val_f1",ascending=False)

Unnamed: 0,key,epoch,val_f1,val_f1_plus,val_f1_neg,val_accuracy,val_loss,name
98,baseline_conv1d_fft,474,0.840546,0.738908,0.942185,0.905334,0.711557,tough-cherry-335
272,baseline_fft_tencoder,824,0.835179,0.727594,0.942764,0.905411,1.091691,denim-sun-329
230,baseline_fft_lstm,774,0.835041,0.733113,0.936968,0.898022,0.379498,dark-dew-328
523,dist_1_5_lstm learned_pool_window_gcn,47,0.795578,0.6596,0.931555,0.886054,0.330168,apricot-fog-389
547,dist_1_5_lstm_att learned_pool_window_gcn,45,0.790375,0.652393,0.928356,0.881205,0.320345,fresh-blaze-396
474,dist_1_5_kaggle_lstm learned_pool_window_gcn,49,0.786747,0.645777,0.927718,0.879935,0.341065,gentle-firefly-387
498,dist_1_5_kaggle_lstm_att learned_pool_window_gcn,47,0.784797,0.648649,0.920945,0.87093,0.323083,generous-terrain-394
431,dist_1_5_fft_gcn_learned_pool,549,0.7825,0.641074,0.923926,0.874471,0.403612,clean-haze-367
730,fft_gcn_learned_pool,899,0.781911,0.641377,0.922445,0.87247,1.078126,desert-shape-343
359,corr_fft_gcn_learned_pool,749,0.776973,0.633556,0.92039,0.869237,0.744131,lucky-hill-362


In [10]:
runs_df[runs_df["key"] == "dist_0_5_fft_gcn_learned_pool"]["val_f1"].max()

np.float64(0.7605843941370646)

In [11]:
runs_df[runs_df["key"] == "dist_1_5_fft_gcn_learned_pool"]["val_f1"].max()

np.float64(0.7825000087420145)

In [12]:
runs_df[runs_df["key"] == "fft_gcn_learnable_adj_learned_pool"]["val_f1"].max()

np.float64(0.7648681402206421)

In [13]:
runs_df[runs_df["key"] == "corr_fft_gcn_learned_pool"]["val_f1"].max()

np.float64(0.7769731879234314)

In [14]:
runs_df[runs_df["key"] == "fft_gcn_learned_pool"]["val_f1"].max()

np.float64(0.7819106976191202)