# Plots `fig:hist__and__threshold__tradeoff`

In [None]:
import os
import re

import torch
import numpy as np

In [None]:
import pandas as pd

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import warnings

warnings.simplefilter("ignore")

In [None]:
from cplxpaper.auto import auto

from cplxpaper.auto.utils import file_cache
from cplxpaper.auto.utils import load_stage_snapshot

In [None]:
from cplxpaper.auto.utils import load_manifest, verify_experiment

In [None]:
device_ = torch.device("cuda:3")

In [None]:
report_name = "figure__musicnet__threshold"

report_target = os.path.normpath(os.path.abspath(os.path.join(
    "../../assets", report_name
)))

filename = f"{report_target}.pdf"

<br>

Borrowed from `figure__musicnet__trade-off.ipynb`

In [None]:
def get_model_tag(opt):
    # extract the class name
    cls = opt["stages__sparsify__model__cls"]
    cls = re.sub("^<class '.*?\.models\.(.*?)'>$", r"\1", cls)

    # get the model kind: real/complex
    if not cls.startswith(("real.", "complex.")):
        raise ValueError("Unknown model type.")

    if cls.startswith("real."):
        kind, cls = "R", cls[5:]
    elif cls.startswith("complex."):
        kind, cls = "C", cls[8:]

    # handle real `double` and cplx `half`
    if kind == "R" and opt.get("model__double", False):
        kind = kind + "*2"
    elif kind == "C" and opt.get("model__half", False):
        kind = kind + "/2"

    # get method
    if not cls.endswith(("VD", "ARD")):
        raise ValueError("Unknown Bayesian method.")

    if cls.endswith("VD"):
        method, cls = "VD", cls[:-2]
    elif cls.endswith("ARD"):
        method, cls = "ARD", cls[:-3]
    
    # Legacy model patch: if not specified then True (see `musicnet.models.base`)
    if "DeepConvNet" in cls and opt.get("model__legacy", True):
        cls += " k3"

    return {"model": cls, "kind": kind, "method": method}

The same service function, as in other notebooks, to load a model (with weights)
from the given snapshot.

In [None]:
def load_model(snapshot, errors="ignore"):
    """Recover the model from the snapshot."""
    if errors not in ("ignore", "raise"):
        raise ValueError(f"`errors` must be either 'ignore' or 'raise'.")

    if any(k not in snapshot for k in ["options", "stage", "model"]):
        if errors == "raise":
            raise ValueError("Bad snapshot.")
        return torch.nn.Module()

    options = snapshot["options"]
    _, settings = snapshot["stage"]

    model = auto.get_model(options["model"], **settings["model"])
    model.to(device=torch.device("cpu"))
    model.load_state_dict(snapshot["model"])

    return model

A dirty hack to avoid reloading the same dataset mutliple times.

In [None]:
import copy
import pickle
from functools import lru_cache

@lru_cache(None)
def _get_datasets(key):
    return auto.get_datasets(pickle.loads(key))

def get_datasets(datasets):
    return _get_datasets(pickle.dumps(datasets))

In [None]:
from cplxpaper.auto.parameter_grid import set_params

def get_scorers(options, threshold=None, **kwargs):
    if threshold is None:
        threshold = options["threshold"]

    datasets = copy.deepcopy(options["datasets"])
    datasets.pop('musicnet-train')
    datasets = get_datasets(datasets)

    feeds = copy.deepcopy(options["feeds"])
    feeds.pop('train')
    feeds = auto.get_feeds(datasets, kwargs, options["features"], feeds)

    scorers = copy.deepcopy(options['scorers'])
    scorers = set_params(scorers, **{
        f"{k}__curves": False for k in scorers
    }, **{
        f"{k}__threshold": threshold for k in scorers
    })
    return auto.get_scorers(feeds, scorers)

This notebook uses sparisfied and fine-tuned models from each experiment.

In [None]:
from functools import lru_cache
from cplxpaper.auto.auto import state_dict_with_masks

@lru_cache(None)
def load_experiment(folder):
    options = load_manifest(folder)
    return options, {
        "sparsify": load_model(load_stage_snapshot("sparsify", folder)),
        "masked": load_model(load_stage_snapshot("fine-tune", folder))
    }

@file_cache(f"./cache__{report_name}.pk")
def evaluate_experiment(folder, threshold, name="test"):
    device = torch.device(device_)

    # get the models and the scorer
    options, models = load_experiment(folder)
    scorers = get_scorers(options, threshold=threshold, device=device)

    model = models["masked"].to(device_)
    with torch.no_grad():
        # get the masks
        state_dict, masks = state_dict_with_masks(
            models["sparsify"], threshold=threshold, hard=True)

        # copy weights and deploy masks onto the next-stage model
        model.load_state_dict(state_dict, strict=False)
        result = scorers[name](model.eval())

    model.cpu()
    return options, result

<br>

In [None]:
experiments = [
# # 1.5/200  # the best model
#     './grids/grid-fast/musicnet-fast__00/musicnet[003]-098',  # ARD
#     './grids/grid-fast/musicnet-fast__01/musicnet[001]-046',
#     './grids/grid-fast/musicnet-fast__02/musicnet[004]-124',
#     './grids/grid-fast/musicnet-fast__03/musicnet[002]-072',
#     './grids/grid-fast/musicnet-fast__04/musicnet[000]-020',

#     './grids/grid-fast/musicnet-fast__00/musicnet[003]-085',  # VD
#     './grids/grid-fast/musicnet-fast__01/musicnet[001]-033',
#     './grids/grid-fast/musicnet-fast__02/musicnet[004]-111',
#     './grids/grid-fast/musicnet-fast__03/musicnet[002]-059',
#     './grids/grid-fast/musicnet-fast__04/musicnet[000]-007',

# 1/200
    './grids/grid-fast/musicnet-fast__00/musicnet[003]-097',  # ARD
    './grids/grid-fast/musicnet-fast__01/musicnet[001]-045',
    './grids/grid-fast/musicnet-fast__02/musicnet[004]-123',
    './grids/grid-fast/musicnet-fast__03/musicnet[002]-071',
    './grids/grid-fast/musicnet-fast__04/musicnet[000]-019',

    './grids/grid-fast/musicnet-fast__00/musicnet[003]-084',  # VD
    './grids/grid-fast/musicnet-fast__01/musicnet[001]-032',
    './grids/grid-fast/musicnet-fast__02/musicnet[004]-110',
    './grids/grid-fast/musicnet-fast__03/musicnet[002]-058',
    './grids/grid-fast/musicnet-fast__04/musicnet[000]-006',
    
# # 1/2000
#     './grids/grid-fast/musicnet-fast__00/musicnet[003]-093',  # ARD
#     './grids/grid-fast/musicnet-fast__01/musicnet[001]-041',
#     './grids/grid-fast/musicnet-fast__02/musicnet[004]-119',
#     './grids/grid-fast/musicnet-fast__03/musicnet[002]-067',
#     './grids/grid-fast/musicnet-fast__04/musicnet[000]-015',

#     './grids/grid-fast/musicnet-fast__00/musicnet[003]-080',  # VD
#     './grids/grid-fast/musicnet-fast__01/musicnet[001]-028',
#     './grids/grid-fast/musicnet-fast__02/musicnet[004]-106',
#     './grids/grid-fast/musicnet-fast__03/musicnet[002]-054',
#     './grids/grid-fast/musicnet-fast__04/musicnet[000]-002',

# 1/20
    './grids/grid-fast/musicnet-fast__00/musicnet[003]-101',  # ARD
    './grids/grid-fast/musicnet-fast__01/musicnet[001]-049',
    './grids/grid-fast/musicnet-fast__02/musicnet[004]-127',
    './grids/grid-fast/musicnet-fast__03/musicnet[002]-075',
    './grids/grid-fast/musicnet-fast__04/musicnet[000]-023',

    './grids/grid-fast/musicnet-fast__00/musicnet[003]-088',  # VD
    './grids/grid-fast/musicnet-fast__01/musicnet[001]-036',
    './grids/grid-fast/musicnet-fast__02/musicnet[004]-114',
    './grids/grid-fast/musicnet-fast__03/musicnet[002]-062',
    './grids/grid-fast/musicnet-fast__04/musicnet[000]-010',
]

Collect the results from the specified experiments

In [None]:
import tqdm
from itertools import starmap, product
from cplxpaper.auto.parameter_grid import flatten

thresholds = np.linspace(-3, 3, num=25)

raw_results = {}
for experiment, tau in tqdm.tqdm([*product(experiments, thresholds)]):
    options, score = evaluate_experiment(experiment, tau)  # , "validation")  # too slow and unclear intent
    options = flatten(options)

    key = "{kind} {model} {method} ($C={kld}$)".format(
        **get_model_tag(options),
        kld=options["stages__sparsify__objective__kl_div"]
    )

    n_zer, n_par = map(sum, zip(*score["sparsity"].values()))
    raw_results.setdefault(key, []).append((
        options, {
            "tau": tau,
            "score": score["pooled_average_precision"],
            "compression": n_par / (n_par - n_zer)
        }
    ))

Group raw results in $\tau$ buckets.

In [None]:
results = {}
for key, result in raw_results.items():
    options, data = zip(*result)

    # collect into a dataframe and group by tau
    groups = {k: v.drop(columns="tau") for k, v in pd.DataFrame(data).groupby("tau")}
    results[key] = options[0], pd.concat(groups, axis=0, names=["tau", "replication"])

Color scheme borrowed from `figure__musicnet__trade-off.ipynb`

In [None]:
def darker(color, a=0.5):
    """Adapted from this stackoverflow question_.
    .. _question: https://stackoverflow.com/questions/37765197/
    """
    from matplotlib.colors import to_rgb
    from colorsys import rgb_to_hls, hls_to_rgb

    h, l, s = rgb_to_hls(*to_rgb(color))
    return hls_to_rgb(h, max(0, min(a * l, 1)), s)


def kind_model_method_color(kind, model, method, kld):
    return {  # VD/ARD
        # tab10 colours are paired! use this to keep similar models distinguishable
        ("C"  , "DeepConvNet",   "VD", 1/200): "C0",
        ("C"  , "DeepConvNet",  "ARD", 1/200): "C1",
        ("C"  , "DeepConvNet",   "VD", 1/20): "C2",
        ("C"  , "DeepConvNet",  "ARD", 1/20): "C3",
        ("C"  , "DeepConvNet",   "VD", 1/2000): "C4",
        ("C"  , "DeepConvNet",  "ARD", 1/2000): "C5",
    }[kind, model, method, kld]

Do a crude plot

In [None]:
fig, ax_l = plt.subplots(1, 1, figsize=(12, 5), dpi=300)
fig.patch.set_alpha(1.0)

ax_r = ax_l.twinx()

ax_l.set_ylabel("Average Precision")
ax_r.set_ylabel("$\\times$ compression")
ax_r.set_yscale("log")
ax_l.set_xlabel("Threshold $\\tau$")
ax_l.set_title("The effect of $\\tau$ on performance and compression (MusicNet)")
ax_l.axvline(-0.5, c="k", lw=2, zorder=-10)

for experiment, (options, df) in tqdm.tqdm(results.items(), desc="populating plots"):
    kld = options["stages__sparsify__objective__kl_div"]
    m, min_, max_ = df.mean(level=0), df.min(level=0), df.max(level=0)
    model = get_model_tag(options)
    color = kind_model_method_color(**model, kld=kld)
    
    for ax, field, marker in zip([ax_l, ax_r], ["score", "compression"], ["", "o"]):
        ax.fill_between(m.index, min_[field], max_[field],
                          color=darker(color, 1.4), alpha=0.25, zorder=-10)
        ax.plot(m[field], c=color, alpha=1.0, label=experiment,
                marker=marker, markersize=4)

ax_l.legend(ncol=1, loc=(0.5, .35))  # loc="center right")

fig.savefig(filename, dpi=300)

plt.show()
plt.close()

In [None]:
assert False

<br>