# Compression and early stopping

In [None]:
import os
import tqdm

%matplotlib inline
import matplotlib.pyplot as plt

Pick the grids

In [None]:
grids = [
    "./grids/grid-fast/musicnet-fast__00/",
    "./grids/grid-fast/musicnet-fast__01/",
    "./grids/grid-fast/musicnet-fast__02/",
    "./grids/grid-fast/musicnet-fast__03/",
    "./grids/grid-fast/musicnet-fast__04/"
]

Using the available tools in cplxpaper.auto enumerate experiments in each grid,
collecting the valiadtion score history `early-stopping` and settings.

In [None]:
from itertools import chain
from cplxpaper.auto.utils import get_stage_snapshot, load_snapshot
from cplxpaper.auto.parameter_grid import flatten
from cplxpaper.auto.reports.utils import enumerate_experiments, get_model_tag

histories = []
for experiment in tqdm.tqdm(chain(*map(enumerate_experiments, grids))):
    filename = get_stage_snapshot("fine-tune", experiment)

    snapshot = load_snapshot(filename)
    options = flatten(snapshot["options"])
    
    sparsity = snapshot["performance"]["test"]["sparsity"]
    n_zer, n_par = map(sum, zip(*sparsity.values()))

    histories.append({
        "path": experiment,
        "kl_div": options["stages__sparsify__objective__kl_div"],
        "compression": n_par / (n_par - n_zer),
        "early": snapshot["early_history"],
        **get_model_tag(options)
    })

Analyze

In [None]:
import pandas as pd

df_main = pd.DataFrame(histories)

df_main["path"] = df_main.path.str.replace(
    os.path.commonpath(df_main.path.to_list()), "*")
df_main["early"] = df_main.early.map(len)

In [None]:
gr = df_main.groupby(["method", "kl_div"])["early"]

df = pd.concat({
    "lo": gr.min(),
    "value": gr.median(),
    "hi": gr.max(),
}, names=["agg"], axis=1).unstack(0)

df = df.swaplevel(axis=1).sort_index(1)

In [None]:
from matplotlib.ticker import FormatStrFormatter, FuncFormatter


def darker(color, a=0.5):
    """Adapted from this stackoverflow question_.
    .. _question: https://stackoverflow.com/questions/37765197/
    """
    from matplotlib.colors import to_rgb
    from colorsys import rgb_to_hls, hls_to_rgb

    h, l, s = rgb_to_hls(*to_rgb(color))
    return hls_to_rgb(h, max(0, min(a * l, 1)), s)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=300)
ax.set_title("Median epochs until early stopping for MusicNet")
ax.set_ylabel("epochs")
ax.set_xscale("log")
ax.set_xlabel("$C$")

for label in ["ARD", "VD"]:
    sr = df[label]
    l, = ax.plot(sr["value"], label=label)
    ax.fill_between(sr.index, sr["lo"], sr["hi"],
                    color=darker(l.get_color(), 1.4),
                    alpha=0.15, lw=0, zorder=-15)

ax.set_xlim(1e-4, 1e-1)
ax.legend(ncol=2)

plt.tight_layout()

fig.savefig("../../assets/figure__fine-tune_fx__early.pdf", dpi=300)
plt.show()

The value of $C$ in ELBO is a good proxy for the ranking of the final compression
rate (just after `sparsify`). Thus this plot provides evidence for the statement
applicalbe to this MusicNet experiment only: for the undercompressed models the
fine-tune stage acts as a continuation of uncompressed training during `dense`,
and thus replicates the peaking and declining validation score.

In [None]:
from matplotlib.ticker import FormatStrFormatter, FuncFormatter

fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=300)
ax.set_title("Epochs until early stopping for MusicNet")
ax.set_ylabel("epochs")
ax.set_xscale("log")
ax.set_xlabel("compression")
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: f"$\\times${int(x):d}"))

ax.axvspan(50, 500, color="k", alpha=0.05, zorder=-10)

gr = df_main.groupby("method")
for (method, data), col in zip(gr, ["C0", "C1"]):
    ax.scatter(data["compression"], data["early"], c=col, s=5, label=method)

ax.set_xlim(2e0, 2e3)
ax.legend(ncol=2)

plt.tight_layout()

fig.savefig("../../assets/figure__fine-tune_fx__early__compression.pdf", dpi=300)
plt.show()

In [None]:
assert False