# Model viewer

In [None]:
import os

import torch
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

Pick a snapshot.

In [None]:
snapshot = """./runs/grid_cplx_fine_kl_div/musicnet[000]-011/1-sparsify 20200112-023551.gz"""
# snapshot = """../mnist/grids/sum__fashion-mnist/cplx__simpleconv__cplx-fft-raw__0__0022/1-sparsify 20200105-183555.gz"""

Load it and get the configuration.

In [None]:
from cplxpaper.auto.utils import load_snapshot

cold = load_snapshot(snapshot)
options = cold["options"]

# sparsity settings: threshold is log(p / (1 - p)) for p=dropout rate
threshold = options['threshold']
# threshold = -0.5

devtype = dict(device=torch.device("cpu"))

Get an instance of the model and load its weights.

In [None]:
from cplxpaper.auto import auto

name, settings = cold["stage"]
state = auto.state_create(options["model"], settings, devtype)

state.model.load_state_dict(cold["model"])
if cold["optim"] is not None:
    state.optim.load_state_dict(cold["optim"]["state"])

<br>

Inspect the sparsity of each layer.

In [None]:
from cplxmodule.utils.stats import named_sparsity

sparsity = dict(named_sparsity(state.model, threshold=threshold, hard=True))
n_zer, n_par = map(sum, zip(*sparsity.values()))
print(f">>> {n_zer / n_par:6.1%}")

Inspect model's performance

In [None]:
score = "average_precision"
# score = "accuracy"

In [None]:
performance = cold["performance"]["test"]

fig, ax = plt.subplots(1, 2, figsize=(15, 3))

mean_ap = np.nanmean(performance[score])
ax[0].plot(performance[score], label=f"AP {mean_ap:.1%}")
ax[0].legend(ncol=2)

# ax[1].plot(performance["accuracy"], label="acc.")
ax[1].plot(performance["precision"], label="P")
ax[1].plot(performance["recall"], label="R")
ax[1].legend(ncol=3)

plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve
from matplotlib.collections import LineCollection

if "ap_curves" in performance:
    fig, ax = plt.subplots(1, 1, figsize=(16, 7))

    p, r, t = zip(*[prt for k, prt in performance["ap_curves"].items() if k != 'pooled'])
    ax.add_collection(
        LineCollection([*map(np.transpose, map(np.stack, zip(r, p)))],
                       colors=plt.cm.PuBuGn(np.linspace(0, 1, num=len(p))),
                       alpha=0.7)
    )

    p, r, t = performance["ap_curves"]["pooled"]
    ax.plot(r, p, c="k", lw=2)
    plt.show()

<br>

Remove technically duplicated parameters (useful for
computing sparisty w.r.t raw floating point numbers,
but redundant for visualization).

In [None]:
cleaned = {}
for name, v in sparsity.items():
    # pick the info from .real components and ignore '.imag'
    if name.endswith(".real") or name.endswith(".imag"):
        name, part = name.rsplit(".", 1)
        if part == "imag":
            continue
        assert f"{name}.imag" in sparsity

    cleaned[name] = v

fields = "{:<32}   {:>5}   {:>10}".format
print(fields('name', 'n_zer', 'sparsity'))
for k, (z, n) in cleaned.items():
    print(fields(k, int(n - z), f"{z/n:>.1%}"))

<br>

Visualize the computed relevance scores ($\log \alpha$ of var dropout layers).

In [None]:
from scipy import stats
from cplxpaper.auto.objective import named_ard_modules
from ipywidgets import widgets

log_alphas = {}
with torch.no_grad():
    for name, submod in named_ard_modules(state.model):
        log_alpha = submod.log_alpha.detach().cpu()
        log_alphas[name] = log_alpha.numpy()


def darker(color, a=0.5):
    """Adapted from this stackoverflow question_.
    .. _question: https://stackoverflow.com/questions/37765197/
    """
    from matplotlib.colors import to_rgb
    from colorsys import rgb_to_hls, hls_to_rgb

    h, l, s = rgb_to_hls(*to_rgb(color))
    return hls_to_rgb(h, max(0, min(a * l, 1)), s)

if log_alphas:
    w_keys = widgets.Dropdown(options=[None, *log_alphas], description="Layer")

    @widgets.interact(layer=w_keys)
    def plot_hists(layer):
        colors = plt.cm.jet(np.linspace(0,1, num=len(log_alphas)))

        fig, ax = plt.subplots(1, 1, figsize=(16, 5))
        support = np.linspace(-15, 40, num=265)
        for (name, log_alpha), col in zip(log_alphas.items(), colors):
            if name != layer:
                extra = dict(histtype="step", lw=1, zorder=10, alpha=0.25)
            else:
                extra = dict(histtype="bar", lw=0, alpha=1., zorder=-10)

            *_, patches = ax.hist(log_alpha.flat, label=name, bins=51,
                                  density=True, **extra, color=col)
            if name == layer:
                subsample = log_alpha.flat
                if len(subsample) > 50000:
                    subsample = np.random.choice(subsample, replace=False, size=50000)
                density = stats.kde.gaussian_kde(subsample)

                color = darker(patches[0].get_facecolor(), 0.75)
                ax.plot(support, density(support), c=color, lw=1, zorder=10)


        ax.axvline(threshold, c="k")
        ax.legend(ncol=2, loc='upper right')
        ax.set_ylim(0, 0.5)
        ax.set_xlim(-15, 40)
        plt.show()

<br>