# Experiment Analysis: MNIST-like

In [None]:
import os
import tqdm
import json
import copy

In [None]:
import numpy as np

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

Fully flatten the dictionary.

In [None]:
from cplxpaper.auto.parameter_grid import get_params

def get_details(self):
    out = dict()
    for key in self:
        value = self[key]
        if isinstance(value, (dict, list, tuple)):
            if isinstance(value, (list, tuple)):
                value = {f"[{i}]": v for i, v in enumerate(value)}
                nested = get_details(value).items()
                out.update((key + k, val) for k, val in nested)

            elif isinstance(value, dict):
                nested = get_details(value).items()
                out.update((key + '__' + k, val) for k, val in nested)

            continue

        out[key] = value

    return out

Load performance results from each snapshot in the experiment.

In [None]:
from cplxpaper.auto.utils import load_snapshot

def from_snapshots(*snapshots):
    results, options = {}, {}
    for snapshot in sorted(snapshots):
        name = os.path.basename(snapshot)
        snapshot = load_snapshot(snapshot)

        options = snapshot['options']
        stage, settings = snapshot['stage']

        results[name] = stage, snapshot['performance']

    return results, options

load experiment from its snapshots or from cache

In [None]:
import re
import pickle


def load_experiment(folder, cache="cache.pk"):
    if isinstance(cache, str):
        cache = os.path.join(folder, cache)

    assert cache is None or isinstance(cache, str)

    snapshots = []
    folder, _, filenames = next(os.walk(folder))
    for filename in sorted(filenames):
        if re.match(r"^\d+.*\.gz$", filename) is not None:
            snapshots.append(filename)

    # load scorer results from the snapshots or from cache
    scores, options = {}, {}
    if cache is not None and os.path.exists(cache):
        with open(cache, "rb") as fin:
            scores, options = pickle.load(fin)

    # reload from originals if anything is missing (use SHA-digest)
    if any(s not in scores for s in snapshots):
        snapshots = [os.path.join(folder, s) for s in snapshots]
        scores, options = from_snapshots(*snapshots)
        if cache is not None:
            with open(cache, "wb") as fout:
                pickle.dump((scores, options), fout)

    return scores, options

## Cache the results

In [None]:
source = """./grids/sum__fashion-mnist/"""

Collect perfomance grid

In [None]:
def performance_summary(scores):
    out = {}
    for stage, results in scores.values():
        # Collect performance metrics..
        score = results["test"]

        # ... aggregate sparsity and accuracy.
        n_zer, n_par = map(sum, zip(*score["sparsity"].values()))
        out[stage] = {"accuracy": score["accuracy"],
                      "n_zer": int(n_zer), "n_par": int(n_par)}

    return pd.DataFrame.from_dict(out, orient='index')

Collect results and reconstruct the grid

In [None]:
import pandas as pd
from collections import defaultdict

grid = defaultdict(set)
ignore = {"__name__", "__timestamp__", "__version__", "device"}

results = []
source, experiments, manifests = next(os.walk(source))
for experiment in tqdm.tqdm(experiments):
    match = re.match(r"^(?!\.).*__\d+$", experiment)
    if not match:
        continue

    head, copy, expno = experiment.rsplit("__", 2)

    # load scorer results from the snapshots
    scores, options = load_experiment(
        os.path.join(source, experiment),
        cache='cache.pk')

    if not options:
        continue

    flat = get_details(options)
    for k, v in flat.items():
        if k not in ignore:
            grid[k].add(v)

    results.append((
        experiment,
        performance_summary(scores),
        flat
    ))

In [None]:
experiments, scores, manifests = zip(*results)

Finalize the grid variables

In [None]:
# pick all keys which have more than one unique value
#  and drop any nested model spec changes
grid = [k for k, v in grid.items()
        if len(v) > 1 and  "__model__cls" not in k]

# upcast is a service variable, which only complex models have
#  and it i usually mirrored in `features` settings.
grid = [g for g in grid if not g.endswith("__upcast")]

In [None]:
params = [{k: opt.get(k, None) for k in grid} for opt in manifests]
params = pd.DataFrame.from_dict(dict(zip(experiments, params)), orient="index")

scores = pd.concat(dict(zip(experiments, scores)),
                   axis=0, names=["expno"])

In [None]:
df = scores.unstack(-1)
df.columns = df.columns.to_flat_index().map('-'.join)
df = params.join(df).reset_index()

In [None]:
df = df.replace({
    "model__cls": {
        "<class 'cplxpaper.mnist.models.real.SimpleConvModel'>": "real.SimpleConvModel",
        "<class 'cplxpaper.mnist.models.complex.SimpleConvModel'>": "cplx.SimpleConvModel",
        "<class 'cplxpaper.mnist.models.real.TwoLayerDenseModel'>": "real.TwoLayerDenseModel",
        "<class 'cplxpaper.mnist.models.complex.TwoLayerDenseModel'>": "cplx.TwoLayerDenseModel",
        "<class 'cplxpaper.mnist.models.real.SimpleDenseModel'>": "real.SimpleDenseModel",
        "<class 'cplxpaper.mnist.models.complex.SimpleDenseModel'>": "cplx.SimpleDenseModel",
    },
    "features__cls": {
        "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>": 'raw',
        "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>": 'fourier'
    },
})

In [None]:
df = df.set_index([*grid, "index"], append=False, drop=True).sort_index(0)
main_grid = [g for g in grid if not g.endswith('__kl_div')]

In [None]:
summary = {}
for k, g in df.groupby(axis=0, level=main_grid):
    g = g.loc[k]

    acc_before = g["accuracy-dense"].mean(), g["accuracy-dense"].std()
    f_acc, n_par, n_zer = g["accuracy-fine-tune"], g["n_par-sparsify"], g["n_zer-sparsify"]

    curve = pd.concat([n_zer / n_par, f_acc], axis=1)
#     curve = curve.mean(level=0).to_numpy()
    curve = curve.to_numpy()
    order = curve[:, 0].argsort()

    summary[k] = acc_before, curve[order]

In [None]:
from matplotlib.ticker import FormatStrFormatter, FuncFormatter

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 5))

for name, (dense, curve) in summary.items():
    m, s = dense
    spr, acc = curve.T
    pts = ax.scatter(1 - spr, acc, label=name, s=15)
    color = pts.get_facecolor()[0]
#     pts, = ax.plot(1 - spr, acc, label=name)
#     color = pts.get_color()
    ax.axhspan(m-1.96*s, m+1.96*s, alpha=0.1, color=color)

ax.legend(ncol=2)
ax.set_title("Fashion MNIST")
ax.set_ylabel("accuracy")

ax.set_xlabel("% nonzero")
ax.set_xscale("log")

ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: f"{x:.1%}"))
plt.show()

In [None]:
assert False

In [None]:
m = (n_zer).mean(level=0)
s = (n_zer).std(level=0)
m.plot()
plt.fill_between(m.index, m-1.96*s, m+1.96*s, alpha=0.25)
plt.gca().set_xscale("log")

plt.twinx()
m = (g["accuracy-dense"]).mean(level=0)
s = (g["accuracy-dense"]).std(level=0)
m.plot(c="C1")
plt.fill_between(m.index, m-1.96*s, m+1.96*s, alpha=0.25, color="C1")
m = (g["accuracy-fine-tune"]).mean(level=0)
s = (g["accuracy-fine-tune"]).std(level=0)
m.plot(c="C2")
plt.fill_between(m.index, m-1.96*s, m+1.96*s, alpha=0.25, color="C2")


In [None]:
(n_par - n_zer) / n_par

In [None]:
comp = ((n_par - n_zer) / n_par).to_numpy()
plt.gca().set_xscale("log")
plt.gca().set_xlim(2e-3, 1.1)
plt.scatter(comp, 1-acc_after.to_numpy())

In [None]:
acc_before.mean(level=0).plot()
acc_after.mean(level=0).plot()
plt.gca().set_xscale("log")

In [None]:
n_par

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 5))

for name, (dense, curve) in summary.items():
    m, s = dense
#     pts = ax.scatter(*curve.T, label=name, s=25)
#     color = pts.get_facecolor()[0]
    spr, acc = curve.T
    pts, = ax.semilogx(1-spr, acc, label=name)
    color = pts.get_color()

    ax.axhspan(m-1.96*s, m+1.96*s, alpha=0.1, color=color)

#     ax.set_yscale("log")
#     ax.set_xscale("log");# ax.set_xlim(0.01, 1.5)
#     ax.set_xlim(-0.05, 1.05)

ax.legend(ncol=2)
ax.set_title("MNIST")
ax.set_ylabel("accuracy")
ax.set_xlabel("compression")

In [None]:
summary

<br>