In [None]:
# magic commands, make python reimport modules when code is changed
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

import sys

sys.path.append(".utilities/")

from utilities.download import download_sweep
from utilities.process import create_error_plots_custom

# set pandas dataframe display options
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)

In [None]:
USE_BIG_DATA = True
USE_FULL_AUGMENT = True

# Setup

In [None]:
import os

# make figure folder
if not os.path.exists("figures"):
    os.makedirs("figures")

SAVE_FOLDER = "csv-files"
SWEEPS = {
    "moons": ["204gpz3v"],
    "mnist": ["rut3a738"],
    "downsampled_mnist": ["npnuge21"],
    "cifar10": ["2uxccm2r", "669xhx85", "sgffog8z"],
}

full_dfs = {}
for dataset, sweep_list in SWEEPS.items():
    print(f"Downloading {dataset}")
    sweep_save_folder = os.path.join(SAVE_FOLDER, dataset)
    if not os.path.exists(sweep_save_folder):
        os.makedirs(sweep_save_folder)

    dataset_dfs = []
    for sweep in sweep_list:
        sweep_id = f"ljroos-msc/knot-solver/{sweep}"
        save_loc = os.path.join(sweep_save_folder, f"{dataset}_{sweep}.csv")
        _ = download_sweep(sweep_id, save_loc, override_existing=False)

        dataset_dfs.append(pd.read_csv(save_loc))

    print()

    full_dfs[dataset] = pd.concat(dataset_dfs)

In [None]:
group_info = {
    "trivial": [1, "Trivial", True],
    # D4
    "flipH": [2, "FlipH", (True,)],
    "flipW": [2, "FlipW", (True)],
    "rot180": [2, "Rot180", False],
    "transpose": [2, "Transpose", False],
    "antidiagonal_transpose": [2, "Anti-Transpose", False],
    "flipH_and_or_flipW": [4, "FlipH and/or W", False],
    "rot180_and_or_transpose": [4, "Main- and/or Anti-Transpose", False],
    "rot90": [4, "Rot90", True],
    "D4": [8, "FlipRot90", True],
    # Z_2^7
    "translateH": [7, "TranslateH"],
    "translateW": [7, "TranslateW"],
    "translateH_and_W": [7, "TranslateDiag"],
    "translateH_and_or_W": [49, "TranslateH and/or W"],
}

group_sizes = {key: value[0] for key, value in group_info.items()}
proper_names = {key: value[1] for key, value in group_info.items()}

In [None]:
for dataset, df in full_dfs.items():
    full_dfs[dataset]["group_order"] = full_dfs[dataset]["group"].map(
        lambda x: group_sizes[x]
    )
    full_dfs[dataset]["group_proper_name"] = full_dfs[dataset]["group"].map(
        lambda x: proper_names[x]
    )
    full_dfs[dataset]["parameter_factor"] = full_dfs[dataset][
        "hidden_group_channels"
    ] / full_dfs[dataset]["group"].map(lambda x: np.sqrt(group_sizes[x]))
    full_dfs[dataset]["sqrt_total_smoothness"] = np.sqrt(
        full_dfs[dataset]["total_smoothness"]
    )
    full_dfs[dataset]["expected_gradient_norm_normalized_smoothness"] = (
        full_dfs[dataset]["sqrt_total_smoothness"]
        / full_dfs[dataset]["expected_gradient_norm"]
    )
    full_dfs[dataset]["generalization_gap"] = (
        full_dfs[dataset]["test_loss"] - full_dfs[dataset]["val_loss"]
    )
    full_dfs[dataset]["generalization_ratio"] = (
        full_dfs[dataset]["test_loss"] - full_dfs[dataset]["val_loss"]
    ) / full_dfs[dataset]["val_loss"]

In [None]:
# downloaded data from https://wandb.ai/ljroos-msc/mosaic/sweeps/w705aehx/table?workspace=user-luro
# not sure if link will work for others.

# all target columns of interest
target_cols = [
    "total_num_knots",
    "expected_knot_uniformity",
    "expected_knot_entropy",
    "expected_gradient_norm",
    "sqrt_total_smoothness",
    "val_loss",
    "test_loss",
]

# main big table
main_target_cols = [
    "total_num_knots",
    "expected_knot_uniformity",
    "expected_gradient_norm",
    "sqrt_total_smoothness",
    "val_loss",
]

# for 'appendix'
alternative_target_cols = [
    "expected_knot_entropy",
    "expected_gradient_norm_normalized_smoothness",
    "val_loss",
    "test_loss",
    "generalization_gap",
]

# for important shot mnist
minimal_key_target_cols = ["total_num_knots", "sqrt_total_smoothness", "val_loss"]

target_cols = ["val_loss", "test_loss"]

pairs = [(0, 2), (2, 3), (3, 5), (5, 8), (8, 1), (1, 7), (7, 9), (9, 4), (4, 6), (6, 0)]


def get_alternative_target_cols(pair):
    return [
        f"num_knots{pair}",
        f"smoothness{pair}",
        f"knot_uniformity{pair}",
        f"expected_gradient_norm{pair}",
    ]


for dataset, df in full_dfs.items():
    print("dataset: ", dataset)
    print(f"len before dropna: {len(df)}")

    # print entries with na
    # print(df[df[target_cols].isna().any(axis=1)])

    # remove NA rows from the dataframe
    df.dropna(subset=target_cols, inplace=True)

    print(f"len after dropna: {len(df)}")

In [None]:
full_dfs["downsampled_mnist"]["augment"].unique()

In [None]:
full_dfs["downsampled_mnist"]["num_train"].unique()

In [None]:
full_dfs["mnist"]["augment"].unique()

In [None]:
big_data_full_augment_dfs = {}
for big_data in [False, True]:
    for full_augment in [False, True]:
        dfs = {}
        for dataset, df in full_dfs.items():
            if dataset == "mnist":
                num_train = 60000 if big_data else 10000
                augment = "rotflip" if full_augment else "trivial"
            elif dataset == "cifar10":
                num_train = 50000 if big_data else 10000
                augment = "rot90flip" if full_augment else "trivial"
            elif dataset == "moons":
                num_train = 10000 if big_data else 500
                augment = "flipH_and_or_flipW" if full_augment else "trivial"
            elif dataset == "downsampled_mnist":
                num_train = 60000 if big_data else 10000
                augment = "translateH_and_or_W" if full_augment else "trivial"
            dfs[dataset] = df[
                (df["augment"] == augment) & (df["num_train"] == num_train)
            ].copy()
        big_data_full_augment_dfs[(big_data, full_augment)] = dfs

for dataset in full_dfs.keys():
    # print dataset name, and number of elements in full and reduced
    print("dataset:\t", dataset)
    print(f"full:   \t {len(full_dfs[dataset])}")
    print(f"reduced:  \t {len(dfs[dataset])}")
    print()

dfs = big_data_full_augment_dfs[(USE_BIG_DATA, USE_FULL_AUGMENT)]

In [None]:
# find unique groups, sort by group size
unique_groups = {}
unique_group_orders = {}
unique_group_channels = {}
unique_parameter_factors = {}

for dataset, df in dfs.items():
    unique_groups[dataset] = list(df["group"].unique())
    unique_group_channels[dataset] = list(df["hidden_group_channels"].unique())
    unique_parameter_factors[dataset] = list(df["parameter_factor"].unique())
    unique_group_orders[dataset] = list(df["group_order"].unique())

    unique_groups[dataset].sort(key=lambda x: group_sizes[x])
    unique_group_channels[dataset].sort()
    unique_parameter_factors[dataset].sort()
    unique_group_orders[dataset].sort()

    print(unique_groups[dataset])
    print(unique_group_channels[dataset])
    print(unique_parameter_factors[dataset])
    print(unique_group_orders[dataset])

In [None]:
metric_proper_names = {
    "total_num_knots": "# Knots (K)",
    "expected_knot_uniformity": "Knot Uniformity $(\omega^2)$",
    "expected_gradient_norm": "Sensitivity (Sn)",
    "sqrt_total_smoothness": "Smoothness $(\sqrt{ S }$)",
    "val_accuracy": "Training Accuracy",
    "val_loss": "Training Loss",
    "test_accuracy": "Test Accuracy",
    "test_loss": "Test Loss",
    "total_smoothness": "Smoothness ($S$)",
    "expected_knot_entropy": "Entropy $(\\mathbb{ H })$",
    "generalization_gap": "Generalization Gap",
    "generalization_ratio": "Standarized Generalization Gap",
    "expected_gradient_norm_normalized_smoothness": "$\\frac{\sqrt{ S }}{ \mathrm{Sn} }$",
}

plt_facecolors = ["lightyellow", "honeydew", "lightcyan", "mistyrose"]

dataset_proper_names = {
    "mnist": r"MNIST $O(2)$",
    "moons": r"Moons $K_4$",
    "downsampled_mnist": r"Downsampled MNIST $\mathbb{Z}^2_7$",
    "cifar10": r"CIFAR10 $D_4$",
}

cifar10_df = dfs.pop("cifar10")
# \mathrm removes the italics from the text

# Without CIFAR

In [None]:
# fig, ax = plt.subplots(len(selected_target_cols), 3, figsize=(14, 12))#, sharex="col", sharey="row")

selected_target_cols = main_target_cols

fig, ax = plt.subplots(
    len(selected_target_cols), len(dfs.items()), figsize=(14, 13)
)  # , sharex="col", sharey="row")
for col, (dataset, df) in enumerate(dfs.items()):
    for row, target_col in enumerate(selected_target_cols):
        create_error_plots_custom(
            df=df,
            target_col=target_col,
            covariate_col="hidden_group_channels",
            unique_covariates=unique_group_channels[dataset],
            class_col="group",
            unique_classes=unique_groups[dataset],
            aggregate_mode="mean",
            std_error=True,
            ax=ax[row, col],
            eps_noise_level=0.35,
        )

for a in ax.flatten():
    # remove all titles from the plots
    a.set_title("")

    # remove all x labels from the plots
    a.set_xlabel("")

    # use larger font for titles, and set all yticks to right
    a.yaxis.tick_right()

    # larger font for yticks and xticks
    a.yaxis.set_tick_params(labelsize=13)
    a.xaxis.set_tick_params(labelsize=13)

    # Force scientific notation
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=3, prune="upper", min_n_ticks=3))
    a.xaxis.set_major_locator(MaxNLocator(nbins=5, prune="both"))

# only keep y axis labels of left most plots
for a in ax[:, 1:].flatten():
    a.set_ylabel("")

# only keep x axis labels of bottom plots
# for a in ax[-1:, :].flatten():
# a.set_xlabel("hidden_group_channels")

# large size
ax[-1, 1].set_xlabel("group-expanded channels", fontsize=16, labelpad=10)

# remove all xticks, except bottom row
for a in ax[:-1, :].flatten():
    a.set_xticks([])
    a.set_xticklabels([])

# remove legends
for a in ax.flatten():
    if a.get_legend() is not None:
        a.get_legend().remove()

# add legends to the top of the figure
# top left
ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=1, fontsize=11)
# top middle
ax[0, 1].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2, fontsize=11)
# top right
ax[0, 2].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=1, fontsize=11)

# add title to each column
for n, (dataset, df) in enumerate(dfs.items()):
    ax[0, n].set_title(dataset_proper_names[dataset], fontsize=18)

# set leftmost ytick labels appropriately
for a in ax.flatten():
    # if y axis label in metric_proper_names, set it to the proper name
    if a.get_ylabel() in metric_proper_names:
        # a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16, rotation=75, labelpad=25)
        a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16, rotation=90)

ax[0, 0].set_facecolor(plt_facecolors[0])
ax[0, 2].set_facecolor(plt_facecolors[0])
ax[2, 1].set_facecolor(plt_facecolors[1])
ax[2, 2].set_facecolor(plt_facecolors[1])
ax[3, 2].set_facecolor(plt_facecolors[2])
ax[4, 2].set_facecolor(plt_facecolors[3])

for a in ax.flatten():
    a.set_ylim(bottom=min(0, a.get_ylim()[0]))
    # bot, top = a.get_ylim()
    # mid = (top + bot) / 2
    # a.set_yticks([0, float(f"{mid:.1e}"), float(f"{top:.1e}")])

# stack figures on top of each other
fig.tight_layout()

# but keep ample space between row elements
plt.subplots_adjust(hspace=0.15)

# save as pdf
plt.savefig("figures/error_plots_wide.ignore.pdf", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# important shot plot:
# focus on mnist plot.
fig, ax = plt.subplots(len(minimal_key_target_cols), 2, figsize=(8, 8))

dataset = "mnist"
df = dfs[dataset]

for row, target_col in enumerate(minimal_key_target_cols):
    create_error_plots_custom(
        df=df,
        target_col=target_col,
        covariate_col="hidden_group_channels",
        unique_covariates=unique_group_channels[dataset],
        class_col="group",
        unique_classes=unique_groups[dataset],
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, 0],
    )

    create_error_plots_custom(
        df=df,
        target_col=target_col,
        covariate_col="parameter_factor",
        unique_covariates=unique_parameter_factors[dataset],
        class_col="group",
        unique_classes=unique_groups[dataset],
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, 1],
    )

for a in ax.flatten():
    # remove all titles from the plots
    a.set_title("")

    # remove all x labels from the plots
    a.set_xlabel("")

    # a.xaxis.get_label().set_fontsize(16)
    # a.yaxis.get_label().set_fontsize(16)

    # Force scientific notation
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=3, prune="upper", min_n_ticks=3))
    a.xaxis.set_major_locator(MaxNLocator(nbins=5, prune="both"))

    a.tick_params(axis="x", labelsize=12)
    a.tick_params(axis="y", labelsize=12)

# set leftmost ytick labels appropriately
for a in ax.flatten():
    # if y axis label in metric_proper_names, set it to the proper name
    if a.get_ylabel() in metric_proper_names:
        a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16)

# only keep y axis labels of left most plots
for a in ax[:, 1:].flatten():
    a.set_ylabel("")

# only keep x axis labels of bottom plots
ax[-1, 0].set_xlabel("group-expanded channels", fontsize=16)
ax[-1, 1].set_xlabel("$\propto$ trainable parameters", fontsize=16)

# remove all xticks, except bottom row
for a in ax[:-1, :].flatten():
    a.set_xticks([])
    a.set_xticklabels([])

for a in ax[:, 1].flatten():
    a.yaxis.tick_right()

# remove yticks and yticklabels except for right most plots. Put yticks on the right side.
for a in ax[:, :1].flatten():
    a.set_yticks([])
    a.set_yticklabels([])

# remove legends
for a in ax.flatten():
    if a.get_legend() is not None:
        a.get_legend().remove()

# add legends to the top of the figure
# top left
# ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2)
# use small legend
ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2, fontsize=12)

# add title to each column
# for n, (dataset, df) in enumerate(dfs.items()):
#     ax[0, n].set_title(dataset)

# remove fig title
fig.suptitle("")

# draw a striped vertical line at the point 20 for ax[2, 1]
# ax[0, 1].axvline(20, color="black", ls="--", alpha=0.75)
# ax[1, 1].axvline(20, color="black", ls="--", alpha=0.75)
ax[2, 1].axvline(20, color="black", ls="--", alpha=0.75)

# make a legend for ax[2, 1], only including the striped vertical line
# ax[2, 1].legend(["literature comparison"], loc="upper right", bbox_to_anchor=(0, 1), ncol=1)
ax[2, 1].legend(["literature comparison"], loc="upper right", fontsize=12)

# stack figures on top of each other
fig.tight_layout()

# save as pdf
plt.savefig(
    "figures/important_plot.ignore.pdf", dpi=300, bbox_inches="tight", pad_inches=0.1
)
plt.show()

# CIFAR

In [None]:
for appendix in [False, True]:
    if appendix:
        selected_target_cols = alternative_target_cols
        fig, ax = plt.subplots(
            len(selected_target_cols),
            len(dfs.items()),
            figsize=(14, 13 / len(main_target_cols) * len(selected_target_cols)),
        )  # , sharex="col", sharey="row")
    else:
        selected_target_cols = main_target_cols
        fig, ax = plt.subplots(
            len(selected_target_cols), len(dfs.items()), figsize=(14, 13)
        )  # , sharex="col", sharey="row")

    # fig, ax = plt.subplots(len(selected_target_cols), 3, figsize=(14, 12))#, sharex="col", sharey="row")

    dataset = "cifar10"
    df = cifar10_df
    group_order_pairings = [(1, 8), (2,), (4,)]

    for col, pairing in enumerate(group_order_pairings):
        for row, target_col in enumerate(selected_target_cols):
            create_error_plots_custom(
                df=df[df["group_order"].isin(pairing)],
                target_col=target_col,
                covariate_col="hidden_group_channels",
                unique_covariates=unique_group_channels[dataset],
                class_col="group",
                unique_classes=unique_groups[dataset],
                aggregate_mode="mean",
                std_error=True,
                ax=ax[row, col],
                eps_noise_level=0.35,
            )

    for a in ax.flatten():
        # remove all titles from the plots
        a.set_title("")

        # remove all x labels from the plots
        a.set_xlabel("")

        # use larger font for titles, and set all yticks to right
        a.yaxis.tick_right()

        # larger font for yticks and xticks
        a.yaxis.set_tick_params(labelsize=13)
        a.xaxis.set_tick_params(labelsize=13)

        # Force scientific notation
        a.ticklabel_format(
            style="scientific", axis="y", scilimits=(0, 0), useMathText=True
        )
        a.yaxis.set_major_locator(MaxNLocator(nbins=5, prune=None, min_n_ticks=5))
        a.xaxis.set_major_locator(MaxNLocator(nbins=7, prune="both"))

    # only keep y axis ticks of right most plots
    for a in ax[:, :-1].flatten():
        a.set_yticklabels([])

    # only keep y axis labels of left most plots
    for a in ax[:, 1:].flatten():
        a.set_ylabel("")

    # only keep x axis labels of bottom plots
    # for a in ax[-1:, :].flatten():
    # a.set_xlabel("hidden_group_channels")

    # large size
    ax[-1, 1].set_xlabel("group-expanded channels", fontsize=16, labelpad=10)

    # add x and y axis grid lines to all plots
    for a in ax.flatten():
        # Manually set grid lines
        a.grid(True)

    # remove all xticks, except bottom row
    for a in ax[:-1, :].flatten():
        a.set_xticklabels([])

    # remove legends
    for a in ax.flatten():
        if a.get_legend() is not None:
            a.get_legend().remove()

    # add legends to the top of the figure
    def _filter_ax_legend(ax, desired_labels, ncol=1):
        # Get the current handles and labels
        handles, labels = ax.get_legend_handles_labels()

        # Filter the handles and labels
        filtered_handles = [
            handle for handle, label in zip(handles, labels) if label in desired_labels
        ]
        filtered_labels = [label for label in labels if label in desired_labels]

        # modify property of filtered_labels's text to display metric proper names
        filtered_labels = [group_info[label][1] for label in filtered_labels]

        # Create a new legend with the filtered handles and labels
        ax.legend(
            filtered_handles,
            filtered_labels,
            # loc="upper left" if not appendix else "lower right",
            loc="lower right" if not appendix else "lower right",
            ncol=ncol,
            fontsize=10,
        )

    # top left
    ax[0, 0].legend(loc="lower left", ncol=1, fontsize=11)
    _filter_ax_legend(ax[0, 0], ["trivial", "D4"])
    # top middle
    ax[0, 1].legend(loc="lower left", ncol=2, fontsize=11)
    _filter_ax_legend(
        ax[0, 1],
        ["flipW", "flipH", "rot180", "transpose", "antidiagonal_transpose"],
        ncol=2,
    )
    # top right
    ax[0, 2].legend(loc="lower left", ncol=1, fontsize=11)
    _filter_ax_legend(
        ax[0, 2],
        ["rot180_and_or_transpose", "rot90", "flipH_and_or_flipW"],
        ncol=1,
    )

    # add title to each column
    for col in range(3):
        ax[0, col].set_title(group_order_pairings[col], fontsize=18)

    # set leftmost ytick labels appropriately
    for a in ax.flatten():
        # if y axis label in metric_proper_names, set it to the proper name
        if a.get_ylabel() in metric_proper_names:
            # a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16, rotation=75, labelpad=25)
            a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16, rotation=90)

    # ensure that every row has the same y limits
    for row in range(len(selected_target_cols)):
        y_lims = [a.get_ylim() for a in ax[row, :]]
        y_lims = [min([lim[0] for lim in y_lims]), max([lim[1] for lim in y_lims])]
        for a in ax[row, :]:
            a.set_ylim(y_lims)

    if appendix:
        # Ensure y-axis scaling and ticks are shared for ax[2, :] and ax[3, :]
        min_y = min(
            ax[2, 0].get_ylim()[0],
            ax[2, 1].get_ylim()[0],
            ax[3, 0].get_ylim()[0],
            ax[3, 1].get_ylim()[0],
        )
        max_y = max(
            ax[2, 0].get_ylim()[1],
            ax[2, 1].get_ylim()[1],
            ax[3, 0].get_ylim()[1],
            ax[3, 1].get_ylim()[1],
        )
        for a in ax[2, :]:
            a.set_ylim(min_y, max_y)
        for a in ax[3, :]:
            a.set_ylim(min_y, max_y)

    # set title of every column
    ax[0, 0].set_title("{1, 8}", fontsize=18)
    ax[0, 1].set_title("{2}", fontsize=18)
    ax[0, 2].set_title("{4}", fontsize=18)

    # stack figures on top of each other
    fig.tight_layout()

    # but keep ample space between row elements
    plt.subplots_adjust(hspace=0.15)

    # save as pdf
    plt.savefig(
        f"figures/cifar10_per_group_appendix={appendix}.ignore.pdf",
        dpi=300,
        bbox_inches="tight",
    )
    plt.show()

In [None]:
plot_groups = [
    ("knots", ["total_num_knots", "expected_knot_uniformity", "expected_knot_entropy"]),
    (
        "smoothnesses",
        [
            "expected_gradient_norm",
            "sqrt_total_smoothness",
            "expected_gradient_norm_normalized_smoothness",
        ],
    ),
    ("losses", ["val_loss", "test_loss", "generalization_gap"]),
]

for plot_group_name, plot_group in plot_groups:
    fig, ax = plt.subplots(
        len(plot_group),
        len(dfs.items()),
        figsize=(14, 13 / len(main_target_cols) * len(plot_group)),
    )

    dataset = "cifar10"
    df = cifar10_df
    group_order_pairings = [(1, 8), (2,), (4,)]

    for col, pairing in enumerate(group_order_pairings):
        for row, target_col in enumerate(plot_group):
            create_error_plots_custom(
                df=df[df["group_order"].isin(pairing)],
                target_col=target_col,
                covariate_col="hidden_group_channels",
                unique_covariates=unique_group_channels[dataset],
                class_col="group",
                unique_classes=unique_groups[dataset],
                aggregate_mode="mean",
                std_error=True,
                ax=ax[row, col],
                eps_noise_level=0.35,
            )

    for a in ax.flatten():
        a.set_title("")
        a.set_xlabel("")
        a.yaxis.tick_right()
        a.yaxis.set_tick_params(labelsize=13)
        a.xaxis.set_tick_params(labelsize=13)
        a.ticklabel_format(
            style="scientific", axis="y", scilimits=(0, 0), useMathText=True
        )
        a.yaxis.set_major_locator(MaxNLocator(nbins=5, prune=None, min_n_ticks=5))
        a.xaxis.set_major_locator(MaxNLocator(nbins=7, prune="both"))

    for a in ax[:, :-1].flatten():
        a.set_yticklabels([])

    for a in ax[:, 1:].flatten():
        a.set_ylabel("")

    ax[-1, 1].set_xlabel("group-expanded channels", fontsize=16, labelpad=10)

    for a in ax.flatten():
        a.grid(True)

    for a in ax[:-1, :].flatten():
        a.set_xticklabels([])

    for a in ax.flatten():
        if a.get_legend() is not None:
            a.get_legend().remove()

    def _filter_ax_legend(ax, desired_labels, ncol=1, loc="lower right"):
        handles, labels = ax.get_legend_handles_labels()
        filtered_handles = [
            handle for handle, label in zip(handles, labels) if label in desired_labels
        ]
        filtered_labels = [label for label in labels if label in desired_labels]
        filtered_labels = [group_info[label][1] for label in filtered_labels]
        ax.legend(
            filtered_handles,
            filtered_labels,
            loc=loc,
            ncol=ncol,
            fontsize=10,
        )

    if plot_group_name == "knots":
        loc = "lower right"
    elif plot_group_name == "smoothnesses":
        loc = "upper left"
    elif plot_group_name == "losses":
        loc = "lower left"
    else:
        raise ValueError("Unknown plot group")

    ax[0, 0].legend(loc="lower left", ncol=1, fontsize=11)
    _filter_ax_legend(ax[0, 0], ["trivial", "D4"], loc=loc)
    ax[0, 1].legend(loc="lower left", ncol=2, fontsize=11)
    _filter_ax_legend(
        ax[0, 1],
        ["flipW", "flipH", "rot180", "transpose", "antidiagonal_transpose"],
        ncol=2,
        loc=loc,
    )
    ax[0, 2].legend(loc="lower left", ncol=1, fontsize=11)
    _filter_ax_legend(
        ax[0, 2],
        ["rot180_and_or_transpose", "rot90", "flipH_and_or_flipW"],
        ncol=1,
        loc=loc,
    )

    for col in range(3):
        ax[0, col].set_title(
            f"{{{', '.join(map(str, group_order_pairings[col]))}}}", fontsize=18
        )

    for a in ax.flatten():
        if a.get_ylabel() in metric_proper_names:
            a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16, rotation=90)

    for row in range(len(plot_group)):
        y_lims = [a.get_ylim() for a in ax[row, :]]
        y_lims = [min([lim[0] for lim in y_lims]), max([lim[1] for lim in y_lims])]
        for a in ax[row, :]:
            a.set_ylim(y_lims)

    fig.tight_layout()
    plt.subplots_adjust(hspace=0.15)
    plt.savefig(
        f"figures/cifar10_per_group_{plot_group_name}.ignore.pdf",
        dpi=300,
        bbox_inches="tight",
    )
    plt.show()

In [None]:
# selected_target_cols = ["total_smoothness", "val_loss", "expected_knot_uniformity"]
# selected_target_cols = main_target_cols
selected_target_cols = [
    "total_num_knots",
    "expected_gradient_norm",
    "val_loss",
    "test_loss",
]
# selected_target_cols = ["total_num_knots", "expected_gradient_norm", "val_loss", "expected_knot_uniformity"]
fig, ax = plt.subplots(
    # 1, len(selected_target_cols), figsize=(14, 13 / len(main_target_cols) * len(selected_target_cols))
    2,
    2,
    figsize=(9, 7),
)  # , sharex="col", sharey="row")

dataset = "cifar10"
df = cifar10_df
# group_order_pairings = [(1, 8), (2,), (4,)]
unique_groups_order_2 = df[df["group_order"] == 2]["group"].unique()

# only use hidden_group_channels > 30 and less than 40
df = df[(df["hidden_group_channels"] > 22) & (df["hidden_group_channels"] < 60)]

loc = [(0, 0), (0, 1), (1, 0), (1, 1)]

for (row, col), target_col in zip(loc, selected_target_cols):
    create_error_plots_custom(
        df=df[df["group_order"] == 2],
        target_col=target_col,
        covariate_col="hidden_group_channels",
        unique_covariates=unique_group_channels[dataset],
        class_col="group",
        unique_classes=unique_groups_order_2,
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, col],
        eps_noise_level=0.45,
    )
    # ax[row].set_title(metric_proper_names[selected_target_cols[row]], fontsize=18)
    ax[row, col].set_ylabel(metric_proper_names[target_col], fontsize=18)

for a in ax.flatten():
    # remove all x labels from the plots
    a.set_xlabel("")

    # use larger font for titles, and set all yticks to right
    a.yaxis.tick_right()

    # larger font for yticks and xticks
    a.yaxis.set_tick_params(labelsize=13)
    a.xaxis.set_tick_params(labelsize=13)

    # Force scientific notation
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=5, prune=None, min_n_ticks=5))
    a.xaxis.set_major_locator(MaxNLocator(nbins=7, prune="both"))

    # remove title
    a.set_title("")

# # only keep y axis ticks of right most plots
# for a in ax[:, :-1].flatten():
#     a.set_yticklabels([])

# # only keep y axis labels of left most plots
# for a in ax.flatten():
#     a.set_ylabel("")

# only keep x axis labels of bottom plots
for a in ax[-1:, :].flatten():
    a.set_xlabel("hidden_group_channels")

# remove x ticks of top plots
for a in ax[:-1, :].flatten():
    a.set_xticks([])
    a.set_xticklabels([])

# # large size
# ax[-1, 1].set_xlabel("group-expanded channels", fontsize=16, labelpad=10)

# # add x and y axis grid lines to all plots
# for a in ax.flatten():
#     # Manually set grid lines
#     a.grid(True)

# # remove all xticks, except bottom row
# for a in ax[:-1, :].flatten():
#     a.set_xticklabels([])

# remove legends
for a in ax.flatten():
    if a.get_legend() is not None:
        a.get_legend().remove()


# # top left
ax[0, 0].legend(loc="upper left", ncol=1, fontsize=11)

# use proper names for legend in ax[0, 0]
# use all groups of order 2
_filter_ax_legend(ax[0, 0], unique_groups_order_2, ncol=1)

# change xlabel to group-expanded channels
ax[1, 0].set_xlabel("group-expanded channels", fontsize=16)
ax[1, 1].set_xlabel("group-expanded channels", fontsize=16)

# # add title to each column
# for col in range(len(selected_target_cols)):
#     ax[col].set_title(selected_target_cols[3], fontsize=18)

# # ensure that every row has the same y limits
# for row in range(len(selected_target_cols)):
#     y_lims = [a.get_ylim() for a in ax[row, :]]
#     y_lims = [min([lim[0] for lim in y_lims]), max([lim[1] for lim in y_lims])]
#     for a in ax[row, :]:
#         a.set_ylim(y_lims)

# if appendix:
#     # Ensure y-axis scaling and ticks are shared for ax[2, :] and ax[3, :]
#     min_y = min(ax[2, 0].get_ylim()[0], ax[2, 1].get_ylim()[0], ax[3, 0].get_ylim()[0], ax[3, 1].get_ylim()[0])
#     max_y = max(ax[2, 0].get_ylim()[1], ax[2, 1].get_ylim()[1], ax[3, 0].get_ylim()[1], ax[3, 1].get_ylim()[1])
#     for a in ax[2, :]:
#         a.set_ylim(min_y, max_y)
#     for a in ax[3, :]:
#         a.set_ylim(min_y, max_y)

# stack figures on top of each other
fig.tight_layout()

# but keep ample space between row elements
plt.subplots_adjust(hspace=0.1)

# save as pdf
plt.savefig(
    f"figures/cifar10_only_group_order_2.ignore.pdf", dpi=300, bbox_inches="tight"
)
plt.show()

In [None]:
# important shot plot:
for appendix in [False, True]:
    if appendix:
        selected_target_cols = alternative_target_cols
        fig, ax = plt.subplots(
            len(selected_target_cols), 2, figsize=(7, 7 * len(selected_target_cols) / 3)
        )
    else:
        selected_target_cols = main_target_cols
        fig, ax = plt.subplots(
            len(selected_target_cols),
            2,
            figsize=(7, 7 * len(selected_target_cols) / len(minimal_key_target_cols)),
        )

    dataset = "cifar10"
    df = cifar10_df[~cifar10_df["group"].isin(["transpose", "flipW"])]
    for row, target_col in enumerate(selected_target_cols):
        create_error_plots_custom(
            df,
            target_col=target_col,
            covariate_col="hidden_group_channels",
            unique_covariates=unique_group_channels[dataset],
            class_col="group_order",
            unique_classes=unique_group_orders[dataset],
            aggregate_mode="mean",
            std_error=True,
            ax=ax[row, 0],
        )

        create_error_plots_custom(
            df,
            target_col=target_col,
            covariate_col="parameter_factor",
            unique_covariates=unique_parameter_factors[dataset],
            class_col="group_order",
            unique_classes=unique_group_orders[dataset],
            aggregate_mode="mean",
            std_error=True,
            ax=ax[row, 1],
        )

    for a in ax.flatten():
        # remove all titles from the plots
        a.set_title("")

        # remove all x labels from the plots
        a.set_xlabel("")
        # larger font for yticks and xticks
        a.yaxis.set_tick_params(labelsize=13)
        a.xaxis.set_tick_params(labelsize=13)

        # Force scientific notation
        a.ticklabel_format(
            style="scientific", axis="y", scilimits=(0, 0), useMathText=True
        )
        a.yaxis.set_major_locator(MaxNLocator(nbins=5, prune=None, min_n_ticks=5))
        a.xaxis.set_major_locator(MaxNLocator(nbins=7, prune="both"))
        a.ticklabel_format(
            style="scientific", axis="y", scilimits=(0, 0), useMathText=True
        )
        a.yaxis.set_major_locator(MaxNLocator(nbins=3, prune="upper", min_n_ticks=3))
        a.xaxis.set_major_locator(MaxNLocator(nbins=5, prune="both"))

        a.tick_params(axis="x", labelsize=12)
        a.tick_params(axis="y", labelsize=12)

    # set leftmost ytick labels appropriately
    for a in ax.flatten():
        # if y axis label in metric_proper_names, set it to the proper name
        if a.get_ylabel() in metric_proper_names:
            a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16)

    # only keep y axis labels of left most plots
    for a in ax[:, 1:].flatten():
        a.set_ylabel("")

    # only keep x axis labels of bottom plots
    ax[-1, 0].set_xlabel("group-expanded channels", fontsize=16)
    ax[-1, 1].set_xlabel("$\propto$ trainable parameters", fontsize=16)

    # remove all xticks, except bottom row
    for a in ax[:-1, :].flatten():
        a.set_xticks([])
        a.set_xticklabels([])

    for a in ax[:, 1].flatten():
        a.yaxis.tick_right()

    # remove yticks and yticklabels except for right most plots. Put yticks on the right side.
    for a in ax[:, :1].flatten():
        a.set_yticks([])
        a.set_yticklabels([])

    # remove legends
    for a in ax.flatten():
        if a.get_legend() is not None:
            a.get_legend().remove()

    # add legends to the top of the figure
    # top left
    # ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2)
    # use small legend
    ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2, fontsize=12)

    # add title to each column
    # for n, (dataset, df) in enumerate(dfs.items()):
    #     ax[0, n].set_title(dataset)

    # remove fig title
    fig.suptitle("")

    if not appendix:
        # draw a striped vertical line at the point 20 for ax[2, 1]
        ax[-1, 1].axvline(20, color="black", ls="--", alpha=0.75)

        # make a legend for ax[2, 1], only including the striped vertical line
        ax[-1, 1].legend(["literature comparison"], loc="upper right", fontsize=12)

    if appendix:
        # Ensure y-axis scaling and ticks are shared for ax[2, :] and ax[3, :]
        min_y = min(
            ax[2, 0].get_ylim()[0],
            ax[2, 1].get_ylim()[0],
            ax[3, 0].get_ylim()[0],
            ax[3, 1].get_ylim()[0],
        )
        max_y = max(
            ax[2, 0].get_ylim()[1],
            ax[2, 1].get_ylim()[1],
            ax[3, 0].get_ylim()[1],
            ax[3, 1].get_ylim()[1],
        )
        for a in ax[2, :]:
            a.set_ylim(min_y, max_y)
        for a in ax[3, :]:
            a.set_ylim(min_y, max_y)

    # stack figures on top of each other
    fig.tight_layout()

    # save as pdf
    plt.savefig(
        f"figures/important_plot_cifar10_appendix={appendix}_extensive.ignore.pdf",
        dpi=300,
        bbox_inches="tight",
        pad_inches=0.1,
    )
    plt.show()

In [None]:
# important shot plot:
# focus on mnist plot.
fig, ax = plt.subplots(len(minimal_key_target_cols), 2, figsize=(8, 8))

dataset = "cifar10"
df = cifar10_df[~cifar10_df["group"].isin(["transpose", "flipW"])]

for row, target_col in enumerate(minimal_key_target_cols):
    create_error_plots_custom(
        df,
        target_col=target_col,
        covariate_col="hidden_group_channels",
        unique_covariates=unique_group_channels[dataset],
        class_col="group_order",
        unique_classes=unique_group_orders[dataset],
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, 0],
    )

    create_error_plots_custom(
        df,
        target_col=target_col,
        covariate_col="parameter_factor",
        unique_covariates=unique_parameter_factors[dataset],
        class_col="group_order",
        unique_classes=unique_group_orders[dataset],
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, 1],
    )

for a in ax.flatten():
    # remove all titles from the plots
    a.set_title("")

    # remove all x labels from the plots
    a.set_xlabel("")

    # a.xaxis.get_label().set_fontsize(16)
    # a.yaxis.get_label().set_fontsize(16)

    # Force scientific notation
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=3, prune="upper", min_n_ticks=3))
    a.xaxis.set_major_locator(MaxNLocator(nbins=5, prune="both"))

    a.tick_params(axis="x", labelsize=12)
    a.tick_params(axis="y", labelsize=12)

# set leftmost ytick labels appropriately
for a in ax.flatten():
    # if y axis label in metric_proper_names, set it to the proper name
    if a.get_ylabel() in metric_proper_names:
        a.set_ylabel(metric_proper_names[a.get_ylabel()], fontsize=16)

# only keep y axis labels of left most plots
for a in ax[:, 1:].flatten():
    a.set_ylabel("")

# only keep x axis labels of bottom plots
ax[-1, 0].set_xlabel("group-expanded channels", fontsize=16)
ax[-1, 1].set_xlabel("$\propto$ trainable parameters", fontsize=16)

# remove all xticks, except bottom row
for a in ax[:-1, :].flatten():
    a.set_xticks([])
    a.set_xticklabels([])

for a in ax[:, 1].flatten():
    a.yaxis.tick_right()

# remove yticks and yticklabels except for right most plots. Put yticks on the right side.
for a in ax[:, :1].flatten():
    a.set_yticks([])
    a.set_yticklabels([])

# remove legends
for a in ax.flatten():
    if a.get_legend() is not None:
        a.get_legend().remove()

# add legends to the top of the figure
# top left
# ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2)
# use small legend
ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=2, fontsize=12)

# add title to each column
# for n, (dataset, df) in enumerate(dfs.items()):
#     ax[0, n].set_title(dataset)

# remove fig title
fig.suptitle("")

# draw a striped vertical line at the point 20 for ax[2, 1]
# ax[0, 1].axvline(20, color="black", ls="--", alpha=0.75)
# ax[1, 1].axvline(20, color="black", ls="--", alpha=0.75)
ax[2, 1].axvline(20, color="black", ls="--", alpha=0.75)

# make a legend for ax[2, 1], only including the striped vertical line
# ax[2, 1].legend(["literature comparison"], loc="upper right", bbox_to_anchor=(0, 1), ncol=1)
ax[2, 1].legend(["literature comparison"], loc="upper right", fontsize=12)

# stack figures on top of each other
fig.tight_layout()

# save as pdf
plt.savefig(
    "figures/important_plot_cifar10.ignore.pdf",
    dpi=300,
    bbox_inches="tight",
    pad_inches=0.1,
)
plt.show()

In [None]:
# Create a new set of plots on a 3x3 axes
fig, ax = plt.subplots(3, 3, figsize=(12, 10))

# Define the metrics to plot
metrics = [
    "total_num_knots",
    "expected_knot_uniformity",
    "expected_knot_entropy",
    "expected_gradient_norm",
    "sqrt_total_smoothness",
    "expected_gradient_norm_normalized_smoothness",
    "val_loss",
    "test_loss",
    "generalization_gap",
]

# use df that is cifar but not flipW or transpose
dataset = "cifar10"
df = cifar10_df[~cifar10_df["group"].isin(["transpose", "flipW"])]

# Plot each metric
for i, metric in enumerate(metrics):
    row, col = divmod(i, 3)
    create_error_plots_custom(
        df,
        target_col=metric,
        covariate_col="hidden_group_channels",
        unique_covariates=unique_group_channels[dataset],
        class_col="group_order",
        unique_classes=unique_group_orders[dataset],
        aggregate_mode="mean",
        std_error=True,
        ax=ax[row, col],
    )
    ax[row, col].set_title(metric_proper_names[metric], fontsize=16)

for a in ax.flatten():
    # remove all titles from the plots
    a.set_xlabel("")
    # larger font for yticks and xticks
    a.yaxis.set_tick_params(labelsize=13)
    a.xaxis.set_tick_params(labelsize=13)
    a.yaxis.tick_right()

    # Force scientific notation
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=5, prune=None, min_n_ticks=5))
    a.xaxis.set_major_locator(MaxNLocator(nbins=7, prune="both"))
    a.ticklabel_format(style="scientific", axis="y", scilimits=(0, 0), useMathText=True)
    a.yaxis.set_major_locator(MaxNLocator(nbins=3, prune="upper", min_n_ticks=3))
    a.xaxis.set_major_locator(MaxNLocator(nbins=5, prune="both"))

    a.tick_params(axis="x", labelsize=12)
    a.tick_params(axis="y", labelsize=12)

# Remove x-axes for all but the lowest plots
for a in ax[:-1, :].flatten():
    a.set_xticks([])
    a.set_xticklabels([])
    a.set_xlabel("")

# remove all other legends
for a in ax.flatten():
    if a.get_legend() is not None:
        a.get_legend().remove()

    # remove ylabel
    a.set_ylabel("")

# only keep legend for the top left plot
ax[0, 0].legend(loc="upper left", bbox_to_anchor=(0, 1), ncol=1, fontsize=13)

# Set x-axis label for the lowest plots
ax[-1, 1].set_xlabel("group-expanded channels", fontsize=16, labelpad=15)

# Ensure y-axis scaling and ticks are shared for ax[2, 0] and ax[2, 1]
min_y = min(ax[2, 0].get_ylim()[0], ax[2, 1].get_ylim()[0])
max_y = max(ax[2, 0].get_ylim()[1], ax[2, 1].get_ylim()[1])
ax[2, 0].set_ylim(min_y, max_y)
ax[2, 1].set_ylim(min_y, max_y)

# Adjust layout
fig.tight_layout()

# Save the figure
plt.savefig(
    f"figures/cifar10_all_metrics_important.pdf",
    dpi=300,
    bbox_inches="tight",
    pad_inches=0.1,
)
plt.show()

In [None]:
raise Exception("stop notebook autorun")