In [None]:
import numpy as np

import torch
import torch.nn.functional as F

<br>

```python
import tqdm

# ...

[*map(tqdm.tqdm._decr_instances, list(tqdm.tqdm._instances))]
```

<br>

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
device_ = torch.device("cuda:2")

ARD penalty class

In [None]:
from cplxmodule.nn.relevance.base import BaseARD
from torch.nn.modules.loss import _Loss


class Penalty(_Loss):
    def __init__(self, mean):
        super().__init__()
        self.reduction = torch.mean if mean else torch.sum

    def forward(self, module, target=None):
        raise NotImplementedError


def named_ard_modules(module, prefix=""):
    for name, mod in module.named_modules(prefix=prefix):
        if isinstance(mod, BaseARD):
            yield name, mod


class ARDPenalty(Penalty):
    def __init__(self, coef=1., mean=False):
        super().__init__(mean=mean)
        if isinstance(coef, float):
            self.coef = lambda n: coef

        elif isinstance(coef, dict):
            self.coef = coef.get

        elif callable(coef):
            self.coef = coef

    def forward(self, module, target=None):
        """Reimplements `named_penalties` with non-uniform coefficients."""
        # get names of variational modules and prefetch coefficients
        submods = dict(named_ard_modules(module))
        # names, submods = zip(*named_ard_modules(module))  # can't handle empty iterators
        weights = (weight if weight is not None else 1.
                   for weight in map(self.coef, submods.keys()))

        # lazy-compute the weighted sum
        return sum(weight * self.reduction(mod.penalty)
                   for mod, weight in zip(submods.values(), weights)
                   if weight > 0)

<br>

test this out

In [None]:
class FeedWrapper(object):
    """A wrapper for a dataLoader that puts batches on device on the fly.

    Parameters
    ----------
    feed : torch.utils.data.DataLoader
        The data loader instance to be wrapped.

    **kwargs : keyword arguments
        The keyword arguments to be passed to `torch.Tensor.to()`.
    """
    def __init__(self, feed, **kwargs):
        assert isinstance(feed, torch.utils.data.DataLoader)
        self.feed, self.kwargs = feed, kwargs

    def __len__(self):
        return len(self.feed)

    def __iter__(self):
        if not self.kwargs:
            yield from iter(self.feed)

        else:
            for batch in iter(self.feed):
                yield tuple(b.to(**self.kwargs)
                            for b in batch)

In [None]:
import tqdm
from dpd.tools.delayed import DelayedKeyboardInterrupt
from torch.nn.utils import clip_grad_norm_

def fit(model, feed, optim, criterion, penalty=None, sched=None,
        n_epochs=100, klw=1e-2, grad_clip=0., verbose=True):

    model.train()
    history, losses, abort = [], [], False
    with tqdm.tqdm(range(n_epochs), disable=not verbose) as bar, \
            DelayedKeyboardInterrupt("ignore") as stop:
        for epoch in bar:
            epoch_loss, kl_d, grad_norm = [], 0., float("nan")
            for j, (data, target) in zip(range(1000), feed):
                optim.zero_grad()

                crit = criterion(model(data), target)
                if penalty is not None:
                    kl_d = penalty(model)
                loss = crit + klw * kl_d

                loss.backward()
                if grad_clip > 0:
                    grad_norm = clip_grad_norm_(model.parameters(), grad_clip)

                optim.step()
                if verbose:
                    bar.set_postfix_str(
                        f"{float(crit):.3e} {float(kl_d):.3e} |g| {grad_norm:.1e}"
                    )

                history.append((float(crit), float(kl_d)))
                epoch_loss.append(float(loss))

                # abort on nan -- no need to waste compute
                abort = np.isnan(epoch_loss[-1])
                if abort or stop:
                    break

            if abort or stop:
                break

            if sched is not None:
                # exclusions to `.step` api only apply to ReduceLROnPlateau
                if isinstance(sched, ReduceLROnPlateau):
                    sched.step(np.mean(epoch_loss))
                else:
                    sched.step()
        # end for
    # end with

    return model.eval(), history, abort or stop

In [None]:
def predict(model, feed):
    model.eval()

    device = next(model.parameters()).device
    with torch.no_grad():
        return torch.cat([
            model(X.to(device)).cpu() for X, *rest in feed
        ], dim=0)

<br>

In [None]:
from cplxpaper.musicnet.trabelsi2017 import TwoLayerDense
from cplxpaper.musicnet.trabelsi2017.extensions import TwoLayerDenseARD
from cplxpaper.musicnet.trabelsi2017.extensions import TwoLayerDenseMasked

In [None]:
from cplxpaper.musicnet.trabelsi2017 import DeepConvNet
from cplxpaper.musicnet.trabelsi2017.extensions import DeepConvNetARD
from cplxpaper.musicnet.trabelsi2017.extensions import DeepConvNetMasked

* 1e-3{<10}, 1e-4{<100}, 5e-5{<120}, 1e-5{<150}, 1e-6{>=150}.


In [None]:
def lr_lambda(epoch):
    if epoch <  10: return 1e-0
    if epoch < 100: return 1e-1
    if epoch < 120: return 2e-1
    if epoch < 150: return 1e-2
    return 1e-3

<br>

In [None]:
models = {
    "dense": TwoLayerDense(),
    "bayes": TwoLayerDenseARD(),  # (imbalanced) AP 12% compression 15.2%
    "masked": TwoLayerDenseMasked()
}

In [None]:
models = {
    "dense": DeepConvNet(),
    "bayes": DeepConvNetARD(),
    "masked": DeepConvNetMasked()
}

<br>

# Train

In [None]:
import h5py
from cplxpaper.musicnet import MusicNetHDF5

h5_in = h5py.File("./data/musicnet_11khz_train.h5", "r")
dataset = MusicNetHDF5(h5_in, resident=True)

```python
def get_note_dur(dataset):
    lab = dataset["labels"]
    note_id = lab["note_id"] - 21
    duration = lab["end_time"] - lab["start_time"]

    f_samples = float(len(dataset["data"]))
    return note_id, duration / f_samples

note_durs = (get_note_dur(h5_in[f"{key}"]) for key in tqdm.tqdm(h5_in))
notes, durs = map(np.concatenate, zip(*note_durs))
```

The dataset is imbalanced: some musical notes occur more frequently than others.

In [None]:
def note_proba(dataset):
    """Estimate the probability of a musical note playing at
    a random time within the composition.
    """
    labels, n_samples = dataset["labels"], len(dataset["data"])
    durations = labels["end_time"] - labels["start_time"]

    total_duration = np.bincount(labels["note_id"] - 21, weights=durations, minlength=84)
    return total_duration / float(n_samples)

train_proba = np.stack(list(map(note_proba, tqdm.tqdm(h5_in.values()))), axis=1)

Average probability across all 321 compositions. Clamp to prevent saturation.

In [None]:
proba_hat = train_proba.mean(axis=1).clip(1e-6, 1 - 1e-6)

Clearly we need to rebalance labels

Given $n_-$ negative and $n_+$ positive samples, the desired balance
$\alpha$ and average weight $w$, chose weights $w_-$ and $w_+$ such
that:
$$
w_- n_- + w_+ n_+ = w m
    \,,
    n_+ w_+ = \alpha n_- w_-
    \,. $$

Therefore
$$
w_- = \frac{w m}{n_-} \frac1{1 + \alpha}
    \,,
    w_+ = \frac{w m}{n_+} \frac{\alpha}{1 + \alpha}
    \,. $$

If $w_-$ is fixed to $1$, then
$$
w_+
    = \alpha \frac{n_-}{n_+}
    = \alpha \biggl( \frac{m}{n_+} - 1\biggr)
    = \alpha \frac{1-p_+}{p_+}
    \,. $$

In [None]:
from scipy.special import logit

alpha = 0.5  # 2:1 neg-pos balance

pos_weight = np.exp(-logit(proba_hat)) * alpha

pos_weight_clip = pos_weight.clip(max=1e2)  # clip weights to avoid overflows

tr_pos_weight = torch.from_numpy(pos_weight_clip).float()

Look at them smiles and grins

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 3))

ax[0].semilogy(pos_weight)
ax[0].semilogy(pos_weight_clip)

ax[1].plot(proba_hat * pos_weight_clip)
ax[1].set_ylim(-0.05)

plt.show()

Create train feed with fft features.

In [None]:
from scipy.fftpack import fft
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

def fft_collate_fn(batch):
    data, target = map(np.stack, zip(*batch))

    # produce features real/cplx, fft/stft/raw
    data = fft(data, axis=-1)
    data = np.stack([data.real, data.imag], axis=-2)

    return default_collate([*zip(data, target)])

feed = DataLoader(dataset, batch_size=320, shuffle=True,
                  collate_fn=fft_collate_fn, pin_memory=True)
feed_dev = FeedWrapper(feed, device=device_)

Create losses and penalties and train

In [None]:
from cplxmodule.nn.masked import binarize_masks
from cplxmodule.nn.relevance import compute_ard_masks

# (baseline) 98.7% compression, AP 56%
model, threshold = None, 0.

criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
# criterion = torch.nn.BCEWithLogitsLoss(reduction="mean", pos_weight=tr_pos_weight).to(device_)
# (try, no dense, grad_clip=0, batch 64) not that much slower than 32, AP 39% compression 95%

penalty = ARDPenalty(mean=True)
# penalty = ARDPenalty(lambda n: (1e-3 if "conv" in n else 1.), mean=True)
# (try, no dense, grad_clip=0.05)  : gives 90% compression, valid-precision ~0.7-0.8

histories = []
for name, new_model in models.items():
    print(">>> ", name, flush=True)
    if model is not None:
        state_dict = model.state_dict()
        masks = compute_ard_masks(model, threshold=threshold, hard=True)
        state_dict, masks = binarize_masks(state_dict, masks)
        state_dict.update(masks)
        new_model.load_state_dict(state_dict, strict=False)

        # (optional) crudely estimate sparsity from masks (ignores
        #  non-dropout layers and biases)
        if masks:
            nnz = int(sum(map(torch.sum, masks.values())))
            nel = int(sum(map(torch.numel, masks.values())))
            print(f">>> {1 - nnz / nel:4.1%}", flush=True)
    
    # create stuff
    model = new_model.to(device_)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda)

    model, history, flag = fit(model, feed_dev, optim, criterion,
                               penalty, klw=1e-1, n_epochs=20, grad_clip=0.05)
    # (try, no dense, grad_clip=0.05) klw=1e-3 : gives 90% compression, valid-precision ~0.6-0.7

    model.cpu()
    histories.append((name, history))

<br>

In [None]:
from scipy import stats
from cplxpaper.auto.objective import named_ard_modules
from ipywidgets import widgets

log_alphas = {}
with torch.no_grad():
    for name, submod in named_ard_modules(model):
        log_alpha = submod.log_alpha.detach().cpu()
        log_alphas[name] = log_alpha.numpy()


def darker(color, a=0.5):
    """Adapted from this stackoverflow question_.
    .. _question: https://stackoverflow.com/questions/37765197/
    """
    from matplotlib.colors import to_rgb
    from colorsys import rgb_to_hls, hls_to_rgb

    h, l, s = rgb_to_hls(*to_rgb(color))
    return hls_to_rgb(h, max(0, min(a * l, 1)), s)

In [None]:
if log_alphas:
    w_keys = widgets.Dropdown(options=list(log_alphas), description="Layer")

    @widgets.interact(layer=w_keys)
    def plot_hists(layer):
        fig, ax = plt.subplots(1, 1, figsize=(16, 5))
        support = np.linspace(-12, 40, num=265)
        for name, log_alpha in log_alphas.items():
            if name != layer:
                extra = dict(histtype="step", lw=1, zorder=10)
            else:
                extra = dict(histtype="bar", lw=0, alpha=1., zorder=-10)

            *_, patches = ax.hist(log_alpha.flat, label=name, bins=51, density=True, **extra)
            if name == layer:
                subsample = log_alpha.flat
                if len(subsample) > 50000:
                    subsample = np.random.choice(subsample, replace=False, size=50000)
                density = stats.kde.gaussian_kde(subsample)

                color = darker(patches[0].get_facecolor(), 0.75)
                ax.plot(support, density(support), c=color, lw=3, zorder=10)


        ax.legend(ncol=2)
        plt.show()

In [None]:
_, history = zip(*histories)

from itertools import chain

crit, kl_d = map(np.array, zip(*chain(*history)))
plt.semilogy(crit)
plt.plot(kl_d)

In [None]:
from cplxmodule.utils.stats import sparsity, named_sparsity

for name, model in models.items():
    print(f">>> {name:10} {sparsity(model, threshold=threshold):6.1%}")

# for name, model in models.items():
#     print(f">>> {name}", [*named_sparsity(model, threshold=threshold)])

<br>

# Scoring

In [None]:
h5_in_valid = h5py.File("./data/musicnet_11khz_valid.h5", "r")
dataset_valid = MusicNetHDF5(h5_in_valid, stride=128, resident=True)

In [None]:
feed_valid = DataLoader(dataset_valid, batch_size=256, shuffle=False,
                        collate_fn=fft_collate_fn, pin_memory=True)
feed_valid_dev = FeedWrapper(feed_valid, device=device_)

## Here

In [None]:
def predict(model, feed):
    model.eval()

    device = next(model.parameters()).device
#     assert False
    with torch.no_grad():
        return torch.cat([
            model(X.to(device)).cpu() for X, *rest in feed
        ], dim=0)

logits = predict(models["bayes"], tqdm.tqdm(feed_valid))

In [None]:
target = torch.cat([target.cpu() for *rest, target in tqdm.tqdm(feed_valid)], dim=0)

<br>

In [None]:
model.eval()
with torch.no_grad():
    logits, y_true = [], []
    model = models["bayes"].to(device_)
    for input, target in tqdm.tqdm(feed_valid_dev):
        logits.append(model(input).cpu().numpy())
        y_true.append(target.cpu().numpy().astype(np.int))
    
    model.cpu()

logits, y_true = map(np.concatenate, (logits, y_true))
y_pred = (logits >= 0).astype(np.int)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 3))

ax[0].plot(y_true.mean(axis=0), label="true")
ax[0].plot(proba_hat, label="est.")

ax[1].plot(y_true.mean(axis=0) * pos_weight_clip, label="true")
ax[1].plot(proba_hat * pos_weight_clip, label="est.")

plt.legend()
plt.show()

In [None]:
from sklearn.metrics import confusion_matrix

confusion_cube = np.stack([
    confusion_matrix(y_true[:, j], y_pred[:, j], labels=[0, 1])
    for j in tqdm.tqdm(range(84))
], axis=-1)

In [None]:
(tn, fp), (fn, tp) = confusion_cube

accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / np.maximum(tp + fp, 1)
recall = tp / np.maximum(tp + fn, 1)

Compute per-note aggregate metrics suitable for imbalanced classification.

In [None]:
from sklearn.metrics import average_precision_score, roc_auc_score

average_precision = average_precision_score(
    y_true, logits, average=None, pos_label=1)

# roc_auc = roc_auc_score(y_true, logits, average=None)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 3))

ax[0].plot(average_precision, label=f"AP {np.nanmean(average_precision):.1%}")
ax[0].legend(ncol=2)

ax[1].plot(accuracy, label="acc.")
ax[1].plot(precision, label="P")
ax[1].plot(recall, label="R")
ax[1].legend(ncol=3)

plt.show()

Pooled average precision (treatin different output as the same).

In [None]:
average_precision_score(y_true.flat, logits.flat, average=None, pos_label=1)

Precision-recall curves

In [None]:
from sklearn.metrics import precision_recall_curve
from matplotlib.collections import LineCollection

fig, ax = plt.subplots(1, 1, figsize=(16, 7))

p, r, t = zip(*map(precision_recall_curve, y_true.T, logits.T))
ax.add_collection(
    LineCollection([*map(np.transpose, map(np.stack, zip(r, p)))],
                   colors=plt.cm.PuBuGn(np.linspace(0, 1, num=len(p))),
                   alpha=0.7)
)

p, r, t = precision_recall_curve(y_true.flat, logits.flat)
ax.plot(r, p, c="k", lw=2)
plt.show()

In [None]:
average_precision.round(3)*100

In [None]:
fig = plt.figure(figsize=(16, 5))
ax = fig.add_subplot(111)
ax.imshow(y_pred[3500:4000].T, cmap=plt.cm.hot)
# ax.imshow(y_true[3500:4000].T, cmap=plt.cm.hot)

In [None]:
assert False

Let $\beta_j$ and $c \in \mathbb{R}^{n\times m}$ be parameters, then

$$
F
  \colon \mathbb{R}^m \to \mathbb{R}^n
  \colon x \mapsto \bigl( \beta_j \| x - c_j \|_2 \bigr)_{j=1}^n
    \,, $$

In [None]:
from torch.nn import Module, Parameter, init

class ClusteringLayer(Module):
    def __init__(self, in_features, out_features, p=2):
        super().__init__()
        self.p = p
        self.weight = Parameter(
            torch.empty(out_features, in_features))
        
        self.scale = Parameter(
            torch.empty(out_features))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_normal_(self.weight)
        init.uniform_(self.scale, -0.01, +0.01)
    
    def forward(self, input):
        delta = input.unsqueeze(-2) - self.weight
        return torch.norm(delta, dim=-1, p=self.p) * self.scale

In [None]:
assert False

<br>

In [None]:
from itertools import chain

class CompositeCriterion(object):
    def __init__(self, *elements):
        self._elements = list(elements)

    def __repr__(self):
        return f"{type(self).__name__}({repr(self._elements)})"

    def __iter__(self):
        return iter(self._elements)
    
    def append(self, other, coef):
        if not isinstance(other, _Loss):
            raise NotImplementedError

        self._elements.append((other, coef))

    def __add__(self, other):
        if isinstance(other, CompositeCriterion):
            return type(self)(*self, *other)

        elif isinstance(other, _Loss):
            self.append(other, 1.)
            return self

        return NotImplemented

    def __getitem__(self, index):
        if isinstance(index, slice):
            return type(self)(*self._elements[index])

        return self._elements[index]

In [None]:
a = CompositeCriterion()

In [None]:
a += ARDPenalty()

In [None]:
a.append(torch.nn.BCEWithLogitsLoss(), 1.)

In [None]:
a

In [None]:
assert False, """Bad design and do not do this! A `list`, but not the list... what was i thinking?!"""

from collections.abc import Iterable

def as_iterable(a):
    return a if isinstance(a, Iterable) else (a,)

class CompositeCriterion(list):
    def __new__(cls, loss=()):
        if isinstance(loss, CompositeCriterion):
            return loss
        
        return super().__new__(cls, loss)

    def __repr__(self):
        return f"{type(self).__name__}({super().__repr__()})"

    def __add__(self, other):
        return type(self)(list(self) + list(other))

    def __getitem__(self, index):
        ret = super().__getitem__(index)
        if isinstance(index, slice):
            return type(self)(ret)

        return as_iterable(ret)[0]
    
    def __mul__(self, other):
        return NotImplemented
    
    __imul__ = __mul__

    def __reversed__(self):
        return NotImplemented
    
    reverse = __reversed__
    sort = __reversed__
