# Get APE results wandb

In [None]:
import wandb
import pandas as pd
import numpy as np

model_label = {
    "ViT": "ViT",
    "plstm": "pLSTM",
    "plstmPonly": "pLSTM / (P-mode)",
    "plstmDonly": "pLSTM / (D-mode)",
    "plstmNoPosEmb": "pLSTM / (no pos emb.)",
    "plstmConstSTM": "pLSTM / (STM bias only)",
    "ViL": "ViL",
    "2DMamba": "2DMamba",
    "Mamba2D": "Mamba2D",
    "EfficientNet": "EfficientNet",
}
lrs = [3e-5, 1e-4, 3e-4, 1e-3]
eps = 1e-7
results = {}
# metric_key = "test_ext/acc1"
# full_curve = False
metric_keys = ["val/acc1", "val_ext/acc1", "train/acc1"]
check_finished_metrics = "test/acc1"
full_curve = True

for metric_key in metric_keys:
    results[metric_key]= {}
    for model in model_label:
        for lr_value in lrs:
            name_pattern = f"ape_(spl|trc)_{model}_s4[3-9].*"
            entity = "poeppel"
            project = "plstm"

            api = wandb.Api()

            # Construct filters: name regex and specific learning rate
            filters = {
                "$and": [
                    {"display_name": {"$regex": name_pattern}},
                    {"config.optimizer.learning_rate.peak_value": lr_value}, # {"$and": [{"$lt": lr_value+eps}, {"$gt": lr_value-eps}]}},
                ]
            }

            # Fetch matching runs
            runs = api.runs(f"{entity}/{project}", filters=filters)

            res = []
            for run in runs:
                if check_finished_metrics not in run.summary:
                    continue
                # Replace 'test_accuracy' with your actual test metric key
                if full_curve:
                    df = run.history(keys=[metric_key], pandas=True)
                    # ensure index is step or epoch; use 'step' if available
                    # index_col = 'step' if 'step' in df.columns else df.index.name
                    # df = df.set_index(index_col)[metric_key]
                    if "_step" in df.columns:
                        df = run.history(keys=["_step", metric_key], pandas=True)
                        df = df.rename(columns={metric_key: "seed"+str(run.config['aux']['seed'])})
                        df = df.set_index("_step")
                        
                        res.append(df)
                else:
                    test_metric = run.summary.get(metric_key)
                    if test_metric is not None:
                        res.append(test_metric)
            
            if full_curve:
                if res:
                    results[metric_key][(model, lr_value)] = pd.concat(res, axis=1)
            else:
                if res:
                    results[metric_key][(model, lr_value)] = res

In [None]:
import pickle as pkl

# with open("ape_results_v3.pkl", "wb") as fp:
#     pkl.dump(results, fp)

with open("ape_results_v3.pkl", "rb") as fp:
    results = pkl.load(fp)

In [None]:
results['train/acc1'][('plstmConstSTM', 1e-4)]

In [None]:
[r for r in results['train/acc1'] if r[0] == "ViL"]

In [None]:
import matplotlib.pyplot as plt
metric_keys = ["val_ext/acc1", "train/acc1", "val/acc1"]
random_init = 0.5

metric_axes_label = {
    "val_ext/acc1": "Val. Acc. (Ext.)",
    "val/acc1": "Val. Acc.",
    "train/acc1": "Train. Acc.",
}

optimal_model_lrs = [
     ("plstm", 1e-4), ("plstmNoPosEmb", 1e-4),
    ("plstmPonly", 1e-4), ("plstmDonly", 1e-4), ("plstmConstSTM", 1e-4), ("ViT", 3e-4),
    ("ViL", 3e-5),
    ("2DMamba", 1e-4), ("Mamba2D", 1e-4), ("EfficientNet", 1e-3),

]

for metric_key in metric_keys:
    fig, ax = plt.subplots(figsize=(2.5,2))

    for item in optimal_model_lrs:
        idx = np.concat([np.array([0.0]), np.array(results[metric_key][item].index)])
        mean = np.concat([np.array([random_init]), results[metric_key][item].to_numpy().mean(axis=1)])
        std = 2.015 / np.sqrt(4) * np.concat([np.array([0.0]), results[metric_key][item].to_numpy().std(axis=1)])
        ax.plot(idx, mean)
    fig.legend([model_label[model] for model, _ in optimal_model_lrs], loc=(0, 0.7), ncols=len(optimal_model_lrs))

    for item in optimal_model_lrs:
        idx = np.concat([np.array([0.0]), np.array(results[metric_key][item].index)])
        mean = np.concat([np.array([random_init]), results[metric_key][item].to_numpy().mean(axis=1)])
        # use 2.015 / np.sqrt(4) ?
        std = 2.015 / np.sqrt(4) * np.concat([np.array([0.0]), results[metric_key][item].to_numpy().std(axis=1)])
        ax.fill_between(idx, mean-std, mean+std, alpha=0.3)

    ax.set_xlabel("Training Steps")
    ax.set_ylabel(metric_axes_label[metric_key])
    ax.set_xticks((0, 10000, 20000, 30000, 40000))
    ax.grid(alpha=0.1)
    ax.spines[['right', 'top']].set_visible(False)

    fig.savefig(metric_key.split("/")[0] + ".svg")
    fig.show()



In [None]:
# df = run.history(keys=["_step", "val/acc1", "val_ext/acc1"], pandas=True)

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
metric_keys = ["train/acc1", "val/acc1", "val_ext/acc1"]
random_init = 0.5

colors = ["tab:blue", "tab:orange", "tab:red", "tab:green", "tab:purple", "tab:brown"]

metric_axes_label = {
    "train/acc1": "Train. Acc.",
    "val/acc1": "Val. Acc.",
    "val_ext/acc1": "Val. Acc. (Ext.)",
    
}

optimal_model_lrs = [
    ("plstm", 1e-4), 
    # ("plstmNoPosEmb", 1e-4),
    # ("plstmPonly", 1e-4), ("plstmDonly", 1e-4),
    # ("plstmConstSTM", 1e-4), 
    ("ViL", 3e-5),
    ("2DMamba", 1e-4), ("Mamba2D", 1e-4), ("EfficientNet", 1e-3),
    ("ViT", 3e-4),
]

num_models = len(optimal_model_lrs)
# model_label = {
#     "plstm": "pLSTM",
#     "plstmNoPosEmb": "pLSTM / (no posemb.)",
#     "plstmPonly": "pLSTM / (P-mode)",
#     "plstmDonly": "pLSTM / (D-mode)",
#     "plstmConstSTM": "pLSTM / (STM bias-only)",
#     "ViT": "ViT",
# }

fontsize=14
fontsize_small=12
fontsize_ticks=12

with mpl.rc_context(
        rc={
            "text.usetex": False,
            "font.size": fontsize,
            "axes.labelsize": fontsize,
            "figure.labelsize": fontsize,
            "legend.fontsize": fontsize_small,
            "xtick.labelsize": fontsize_ticks,
            "ytick.labelsize": fontsize_ticks,
            "axes.titlesize": fontsize,
            "lines.markersize": 4.0,  # * default: 6.0
        }
    ):

    fig, axs = plt.subplots(1, 3, figsize=(9,2))
    for midx, metric_key in enumerate(metric_keys):
        for modidx, item in enumerate(optimal_model_lrs):
            idx = np.concat([np.array([0.0]), np.array(results[metric_key][item].index)])
            mean = np.concat([np.array([random_init]), results[metric_key][item].to_numpy().mean(axis=1)])
            std = 2.015 / np.sqrt(4) * np.concat([np.array([0.0]), results[metric_key][item].to_numpy().std(axis=1)])
            axs[midx].plot(idx, mean, color=colors[modidx])
        
        for modidx, item in enumerate(reversed(optimal_model_lrs)):
            idx = np.concat([np.array([0.0]), np.array(results[metric_key][item].index)])
            mean = np.concat([np.array([random_init]), results[metric_key][item].to_numpy().mean(axis=1)])
            # use 2.015 / np.sqrt(4) ?
            std = 2.015 / np.sqrt(4) * np.concat([np.array([0.0]), results[metric_key][item].to_numpy().std(axis=1)])
            axs[midx].fill_between(idx, mean-std, mean+std, alpha=0.3, color=colors[num_models-1-modidx])

        for modidx, item in enumerate(reversed(optimal_model_lrs)):
            idx = np.concat([np.array([0.0]), np.array(results[metric_key][item].index)])
            mean = np.concat([np.array([random_init]), results[metric_key][item].to_numpy().mean(axis=1)])
            std = 2.015 / np.sqrt(4) * np.concat([np.array([0.0]), results[metric_key][item].to_numpy().std(axis=1)])
            axs[midx].plot(idx, mean, color=colors[num_models-1-modidx])
        

        axs[midx].set_ylim(0.5, 1.)
        axs[midx].set_xlabel("Training Steps (k)")
        axs[midx].set_ylabel(metric_axes_label[metric_key])
        axs[midx].set_xticks(range(0, 50000, 10000))
        axs[midx].set_xticklabels([i//1000 for i in range(0, 50000, 10000)])
        axs[midx].grid(alpha=0.2)
        axs[midx].spines[['right', 'top']].set_visible(False)

    fig.legend([model_label[model] for model, _ in optimal_model_lrs], loc=(0.1, 0.7), ncols=(len(optimal_model_lrs)-1) //2+1)

    fig.savefig("ape_training_comp.svg")
    fig.show()



In [None]:
run.name

In [None]:
run.summary

In [None]:
results

In [None]:
# runs = api.runs(f"{entity}/{project}")

# Best Final Test Results

In [None]:
import wandb
import pandas as pd
import numpy as np

In [None]:
models = ["ViT", "plstm", "plstmNoPosEmb", "plstmPonly", "plstmDonly", "plstmConstSTM", "Mamba2D", "2DMamba", "ViL", "EfficientNet"]
model_label = {
    "ViT": "ViT",
    "plstm": "pLSTM",
    "plstmPonly": "pLSTM / (P-mode)",
    "plstmDonly": "pLSTM / (D-mode)",
    "plstmNoPosEmb": "pLSTM / (no pos emb.)",
    "plstmConstSTM": "pLSTM / (STM bias only)",
    "ViL": "ViL",
    "2DMamba": "2DMamba",
    "Mamba2D": "Mamba2D",
    "EfficientNet": "EfficientNet",
}

In [None]:
lrs = [1e-3, 3e-4, 1e-4, 3e-5]

results_test = {}
full_curve = False

for metric_key in ["test_ext/acc1", "test/acc1"]:
    results_test[metric_key] = {}
    for model in model_label:
        for lr_value in lrs:
            name_pattern = f"ape_(spl|trc)_{model}_s4[3-9].*"
            entity = "poeppel"
            project = "plstm"

            api = wandb.Api()

            # Construct filters: name regex and specific learning rate
            filters = {
                "$and": [
                    {"display_name": {"$regex": name_pattern}},
                    {"config.optimizer.learning_rate.peak_value": lr_value}, # {"$and": [{"$lt": lr_value+eps}, {"$gt": lr_value-eps}]}},
                ]
            }

            # Fetch matching runs
            runs = api.runs(f"{entity}/{project}", filters=filters)

            res = []
            for run in runs:
                # Replace 'test_accuracy' with your actual test metric key
                if full_curve:
                    df = run.history(keys=[metric_key], pandas=True)
                    # ensure index is step or epoch; use 'step' if available
                    # index_col = 'step' if 'step' in df.columns else df.index.name
                    # df = df.set_index(index_col)[metric_key]
                    if "_step" in df.columns:
                        df = run.history(keys=["_step", metric_key], pandas=True)
                        df = df.rename(columns={metric_key: "seed"+str(run.config['aux']['seed'])})
                        df = df.set_index("_step")
                        
                        res.append(df)
                else:
                    test_metric = run.summary.get(metric_key)
                    if test_metric is not None:
                        res.append(test_metric)
            
            if full_curve:
                results_test[metric_key][(model, lr_value)] = pd.concat(res, axis=1)
            else:
                results_test[metric_key][(model, lr_value)] = res

In [None]:
results_test

In [None]:
# pre-extracted

results_test = {'test_ext/acc1': {('ViT', 0.001): [0.623046875,
   0.62548828125,
   0.6396484375,
   0.68291015625,
   0.6900390625],
  ('ViT', 0.0003): [0.70673828125,
   0.73076171875,
   0.7080078125,
   0.6876953125,
   0.7],
  ('ViT', 0.0001): [0.6732421875,
   0.6904296875,
   0.680078125,
   0.68251953125,
   0.68369140625],
  ('ViT', 3e-05): [],
  ('plstm', 0.001): [0.5021484375,
   0.75498046875,
   0.76494140625,
   0.69423828125,
   0.680859375],
  ('plstm', 0.0003): [0.69814453125,
   0.75302734375,
   0.64892578125,
   0.69560546875,
   0.7185546875],
  ('plstm', 0.0001): [0.790234375,
   0.74326171875,
   0.76533203125,
   0.76201171875,
   0.83134765625],
  ('plstm', 3e-05): [],
  ('plstmPonly', 0.001): [0.60341796875,
   0.52216796875,
   0.5265625,
   0.4978515625,
   0.56328125],
  ('plstmPonly', 0.0003): [0.80478515625,
   0.748828125,
   0.74140625,
   0.7095703125,
   0.6853515625],
  ('plstmPonly', 0.0001): [0.7220703125,
   0.72744140625,
   0.7451171875,
   0.79775390625,
   0.73720703125],
  ('plstmPonly', 3e-05): [],
  ('plstmDonly', 0.001): [0.6970703125,
   0.5005859375,
   0.75576171875,
   0.6537109375,
   0.62568359375],
  ('plstmDonly', 0.0003): [0.751953125,
   0.648828125,
   0.71630859375,
   0.694140625,
   0.56328125],
  ('plstmDonly', 0.0001): [0.7787109375,
   0.82294921875,
   0.83623046875,
   0.86220703125,
   0.837890625],
  ('plstmDonly', 3e-05): [],
  ('plstmNoPosEmb', 0.001): [0.593359375,
   0.64189453125,
   0.546484375,
   0.50048828125,
   0.496875],
  ('plstmNoPosEmb', 0.0003): [0.73115234375,
   0.75419921875,
   0.85380859375,
   0.670703125,
   0.72744140625],
  ('plstmNoPosEmb', 0.0001): [0.79970703125,
   0.75478515625,
   0.76181640625,
   0.75009765625,
   0.7798828125],
  ('plstmNoPosEmb', 3e-05): [],
  ('plstmConstSTM', 0.001): [],
  ('plstmConstSTM', 0.0003): [],
  ('plstmConstSTM', 0.0001): [0.78671875,
   0.76064453125,
   0.78623046875,
   0.770703125,
   0.81806640625],
  ('plstmConstSTM', 3e-05): [],
  ('ViL', 0.001): [],
  ('ViL', 0.0003): [0.4996093809604645,
   0.498046875,
   0.5005859136581421,
   0.49931639432907104,
   0.49833983182907104],
  ('ViL', 0.0001): [0.500683605670929,
   0.49882811307907104,
   0.5013672113418579,
   0.501269519329071,
   0.5009765625],
  ('ViL', 3e-05): [0.50244140625,
   0.5013672113418579,
   0.501171886920929,
   0.51171875,
   0.5003906488418579],
  ('2DMamba', 0.001): [],
  ('2DMamba', 0.0003): [],
  ('2DMamba', 0.0001): [0.500683605670929,
   0.8662109375,
   0.5511718988418579,
   0.500683605670929,
   0.5008789300918579],
  ('2DMamba', 3e-05): [],
  ('Mamba2D', 0.001): [],
  ('Mamba2D', 0.0003): [0.49931639432907104,
   0.500683605670929,
   0.500683605670929,
   0.4994140565395355,
   0.49931639432907104],
  ('Mamba2D', 0.0001): [0.43046873807907104,
   0.4969726502895355,
   0.500781238079071,
   0.4989257752895355,
   0.49931639432907104],
  ('Mamba2D', 3e-05): [],
  ('EfficientNet', 0.001): [0.676953136920929,
   0.619824230670929,
   0.6299804449081421,
   0.6728515625,
   0.6478515863418579],
  ('EfficientNet', 0.0003): [0.6151367425918579,
   0.628125011920929,
   0.6415039300918579,
   0.639843761920929,
   0.6309570074081421],
  ('EfficientNet', 0.0001): [0.5904296636581421,
   0.651171863079071,
   0.598925769329071,
   0.5478515625,
   0.56591796875],
  ('EfficientNet', 3e-05): [0.59033203125,
   0.5826171636581421,
   0.5386718511581421,
   0.547656238079071,
   0.5810546875]},
 'test/acc1': {('ViT', 0.001): [0.7412109375,
   0.7609375,
   0.76123046875,
   0.86884765625,
   0.865234375],
  ('ViT', 0.0003): [0.906640625,
   0.94326171875,
   0.921875,
   0.88447265625,
   0.917578125],
  ('ViT', 0.0001): [0.83994140625,
   0.86767578125,
   0.8501953125,
   0.8486328125,
   0.85283203125],
  ('ViT', 3e-05): [],
  ('plstm', 0.001): [0.669140625,
   0.9734375,
   0.95380859375,
   0.9611328125,
   0.96982421875],
  ('plstm', 0.0003): [0.97041015625,
   0.97080078125,
   0.94970703125,
   0.9431640625,
   0.98447265625],
  ('plstm', 0.0001): [0.97314453125,
   0.9681640625,
   0.97421875,
   0.9697265625,
   0.9751953125],
  ('plstm', 3e-05): [],
  ('plstmPonly', 0.001): [0.94306640625,
   0.56357421875,
   0.6646484375,
   0.50576171875,
   0.8033203125],
  ('plstmPonly', 0.0003): [0.97939453125,
   0.97578125,
   0.98408203125,
   0.97958984375,
   0.97060546875],
  ('plstmPonly', 0.0001): [0.9759765625,
   0.97900390625,
   0.9787109375,
   0.97529296875,
   0.9734375],
  ('plstmPonly', 3e-05): [],
  ('plstmDonly', 0.001): [0.88232421875,
   0.5060546875,
   0.90791015625,
   0.9072265625,
   0.83896484375],
  ('plstmDonly', 0.0003): [0.95712890625,
   0.8732421875,
   0.9796875,
   0.887890625,
   0.85859375],
  ('plstmDonly', 0.0001): [0.95556640625,
   0.955078125,
   0.959765625,
   0.95908203125,
   0.95595703125],
  ('plstmDonly', 3e-05): [],
  ('plstmNoPosEmb', 0.001): [0.76611328125,
   0.83984375,
   0.75791015625,
   0.521484375,
   0.54052734375],
  ('plstmNoPosEmb', 0.0003): [0.98447265625,
   0.9888671875,
   0.990625,
   0.9318359375,
   0.969921875],
  ('plstmNoPosEmb', 0.0001): [0.98017578125,
   0.9712890625,
   0.97744140625,
   0.9728515625,
   0.972265625],
  ('plstmNoPosEmb', 3e-05): [],
  ('plstmConstSTM', 0.001): [],
  ('plstmConstSTM', 0.0003): [],
  ('plstmConstSTM', 0.0001): [0.974609375,
   0.975,
   0.97451171875,
   0.9705078125,
   0.9796875],
  ('plstmConstSTM', 3e-05): [],
  ('ViL', 0.001): [],
  ('ViL', 0.0003): [0.49394530057907104,
   0.4964843690395355,
   0.4969726502895355,
   0.4935546815395355,
   0.5025390386581421],
  ('ViL', 0.0001): [0.4999023377895355,
   0.49833983182907104,
   0.49609375,
   0.5077148675918579,
   0.523242175579071],
  ('ViL', 3e-05): [0.85546875,
   0.890429675579071,
   0.520800769329071,
   0.972558617591858,
   0.876757800579071],
  ('2DMamba', 0.001): [],
  ('2DMamba', 0.0003): [],
  ('2DMamba', 0.0001): [0.997167944908142,
   0.994921863079071,
   0.945117175579071,
   0.8843749761581421,
   0.996972680091858],
  ('2DMamba', 3e-05): [],
  ('Mamba2D', 0.001): [],
  ('Mamba2D', 0.0003): [0.49345701932907104,
   0.506542980670929,
   0.498046875,
   0.49335938692092896,
   0.49345701932907104],
  ('Mamba2D', 0.0001): [0.7938476800918579,
   0.4927734434604645,
   0.49433594942092896,
   0.49296873807907104,
   0.49345701932907104],
  ('Mamba2D', 3e-05): [],
  ('EfficientNet', 0.001): [0.997363269329071,
   0.9970703125,
   0.997753918170929,
   0.996972680091858,
   0.998437523841858],
  ('EfficientNet', 0.0003): [0.994238257408142,
   0.994726538658142,
   0.994042992591858,
   0.99462890625,
   0.99560546875],
  ('EfficientNet', 0.0001): [0.96484375,
   0.96875,
   0.965722680091858,
   0.971093773841858,
   0.9716796875],
  ('EfficientNet', 3e-05): [0.74462890625,
   0.741015613079071,
   0.7607421875,
   0.753125011920929,
   0.7533203363418579]}}

In [None]:
# get best lr per model

In [None]:
from collections import defaultdict
import numpy as np
from operator import itemgetter

In [None]:
model_res = {}
for metric_key in ["test_ext/acc1", "test/acc1"]:
    model_means = defaultdict(list)

    for model in model_label:
        for (res_model, lr), vals in results_test[metric_key].items():
            if model == res_model and vals:
                model_means[model].append((lr, np.mean(np.array(vals)), 2.015/np.sqrt(4.)* np.std(np.array(vals))))

    model_best_lr = {}
    for model in model_label:
        model_best_lr[model] = max(model_means[model], key=itemgetter(1))

    model_res[metric_key] = model_best_lr

In [None]:
model_means

In [None]:
import pandas as pd


df = pd.DataFrame([[model_label[mod], model_res["test/acc1"][mod][0], model_res["test/acc1"][mod][1], model_res["test/acc1"][mod][2], model_res["test_ext/acc1"][mod][1], model_res["test_ext/acc1"][mod][2]] for mod in model_label], columns=["Model", "Best LR", "Test Acc.", "STD", "Test Acc. (Ext.)", "STD2"]).set_index("Model")

In [None]:
df

In [None]:
print(df.to_latex(float_format="%.3f", formatters={"Best LR": str}))