In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

from utils.cluster import ClusterManager
from utils import notebooks as nb

plt.rcParams.update({"figure.dpi": 150})

cluster = ClusterManager()


In [None]:
# create plot for each dataset
def errorbar(x: np.ndarray):
    return (x.min(), x.max())

def mark_max(data: pd.DataFrame, color):
    # get (h_lr, accuracy) with max accuracy for each h_dim
    max_acc = data.groupby("h_dim").apply(lambda x: x.loc[x["accuracy"].idxmax()]).reset_index(drop=True).sort_values("h_dim")
    num_h_dims = len(max_acc["h_dim"].unique())
    palette = sns.color_palette("viridis", n_colors=num_h_dims)
    for i, row in max_acc.iterrows():
        plt.scatter(row["h_lr"], row["accuracy"], s=500, alpha=0.2, color=palette[i])


def plot_lineplots(results: pd.DataFrame, experiment_folder: str, dataset: str, suffix: str = ""):
    # Plot the dataset
    results = results.copy()
    # keep all "BP" in condition. Only keep "PC" with h_lr < 1.0
    mask_bp = results["condition"].str.contains("BP")
    mask_pc = results["condition"].str.contains("PC") & (results["config.h_lr"] < 0.3)
    results = results[mask_bp | mask_pc]

    results = results.loc[results['experiment.data.dataset'] == dataset]

    h_lrs = sorted(results["config.h_lr"].unique())
    h_dims = sorted(results["config.hidden_dims"].unique())

    plt.clf()
    # Create the main figure
    sns.set_theme("paper", style="whitegrid")
    plt.rcParams.update(nb.NEURIPS_FORMAT_FULL)
    # reduce height
    plt.rcParams["figure.figsize"] = (plt.rcParams["figure.figsize"][0], 2.8)
    fig = plt.figure()
    # Create gridspec to manage subplots sizes
    gs = fig.add_gridspec(1, 3, width_ratios=[5, 5, 1])
    axs = [fig.add_subplot(gs[0, i]) for i in range(3)]
    palette = sns.color_palette("viridis", n_colors=len(h_dims))

    titles = ["PC - SGD", "PC - Adam", "BP"]

    for i, condition in enumerate(["PC-sgd-0.9", "PC-adamw-0.9", "BP-sgd-0.9"]):
        results_condition = results.loc[results["condition"] == condition]

        if "BP" in condition:
            results_condition_2 = results_condition.copy()
            results_condition_2["config.h_lr"] = results_condition_2["config.h_lr"] + 0.1
            results_condition = pd.concat([results_condition, results_condition_2])
            sns.lineplot(data=results_condition, x="config.h_lr", y="results.accuracy", hue="config.hidden_dims", palette=palette, ax=axs[i])
            axs[i].set_xlabel("")
            # remove x axis ticks
            axs[i].set_xticks([])
        else:
            sns.lineplot(data=results_condition, x="config.h_lr", y="results.accuracy", hue="config.hidden_dims", marker="o", palette=palette, ax=axs[i])
            axs[i].set_xlabel(r"$\gamma$ (State Learning Rate)")
            axs[i].set_xscale("log")
        # first plot gets y label
        if i == 0:
            # set y label with font size 12
            axs[i].set_ylabel("Accuracy")
            # set legend. bottom right
            axs[i].legend(ncol=2, title="Width", title_fontsize=10, loc="lower right")
            # axs[i].get_legend().set_title("Width", prop={"size": 10})
        else:
            axs[i].set_ylabel("")
            axs[i].set_yticklabels([])
            axs[i].legend().set_visible(False)

        axs[i].set_title(titles[i])
        axs[i].set_ylim(0, 1)
        # axs[i].axhline(y=linear_performances[record_mappings[dataset]], color='orange', linestyle='--')

    # plt.show()
    suffix = f"-{suffix}" if len(suffix) > 0 else ""
    plt.savefig(os.path.join(experiment_folder, f"lr-scaling-lineplot-{dataset}{suffix}.pdf"))
    plt.close()


# Main Experiments

This analysis the main experiment without data scaling.

In [None]:
experiment_ids = []  # fill in the experiment ids from training e.g. ab12-cd34

In [None]:
experiment_folder, results, results_with_metrics, df = nb.load_data("training_stability", experiment_ids)


In [None]:

plot_lineplots(df, cluster.artifact_dir, "fashion_mnist")


In [None]:
plot_lineplots(df, cluster.artifact_dir, "two_moons")


# Ablation

Analyse ablation results

In [None]:
experiment_ids = []

In [None]:
experiment_folder, results, results_with_metrics, df = nb.load_data("training_stability", experiment_ids)

In [None]:
plot_lineplots(df, cluster.artifact_dir, "fashion_mnist", suffix="constant-layer-size")
