# Prototype of auto-experiment

In [None]:
import os
import re
import copy
import time
import tqdm
import json
import h5py
import importlib

import torch
import numpy as np

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from auto import auto

<br>

In [None]:
folder = "./cplx_09"
!mkdir "{folder}"

In [None]:
import logging

logger = logging.getLogger()  # "auto"
logger.setLevel(logging.DEBUG)

fh = logging.FileHandler(os.path.join(folder, "main.log"), mode="w")
fh.setLevel(logging.DEBUG)

fh.setFormatter(logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
logger.addHandler(fh)

In [None]:
with open("sparsify_masked_weighted_lrsched_cplx.json", "r") as fin:
    options = json.load(fin)

options['device'] = "cuda:3"

In [None]:
options["stages"].update({
    'dense': {
        'snapshot': None,
        'feed': 'train_trabelsi',
        'restart': True,
        'n_epochs': 200,
        'grad_clip': 0.,
        'model': {
            'cls': "<class 'musicnet.complex.base.CplxDeepConvNet'>"
        },
        'lr_scheduler': {
            'cls': "<class 'musicnet.trabelsi2017.base.Trabelsi2017LRSchedule'>"
        },
        'optimizer': {
            'cls': "<class 'torch.optim.adam.Adam'>",
            'lr': 0.001,
            'betas': [0.9, 0.999],
            'eps': 1e-08,
            'weight_decay': 0,
            'amsgrad': False
        },
        'objective': {
            'loss': 1.0,
            'kl_div': 0.0
        },
        'early': {
            'feed': 'valid_256',
            'patience': 200,
            'cooldown': 0,
            'rtol': 0,
            'atol': 0.01,
            'raises': "<class 'StopIteration'>"
        }
    }
})

In [None]:
options["stage-order"] = [
    "dense",
    'sparsify',
    'fine-tune',
]

In [None]:
# del options["objective_terms"]["loss"]["pos_weight"]
options["stages"]["sparsify"]["early"]["atol"] = 2e-2
options["stages"]["fine-tune"]["early"]["atol"] = 2e-2

In [None]:
options

In [None]:
auto.run(options, folder, time.strftime("%Y%m%d-%H%M%S"), True)

In [None]:
auto.defaults(options)

<br>

# Reload

In [None]:
from auto.utils import load_snapshot

cold = load_snapshot('./cplx_08/1-sparsify 20191220-204057.gz')

options = cold["options"]

In [None]:
# placement (dtype / device)
options["device"] = "cuda:3"
devtype = dict(device=torch.device(options["device"]), dtype=torch.float32)

# sparsity settings: threshold is log(p / (1 - p)) for p=dropout rate
sparsity = dict(hard=True, threshold=options["threshold"])

datasets = auto.get_datasets(options["dataset"], options["dataset_sources"])
collate_fn = auto.get_collate_fn(options["features"])
feeds = auto.get_feeds(datasets, collate_fn, devtype, options["feeds"])

In [None]:
name, settings = cold["stage"]
state = auto.state_create(options["model"], settings, devtype)

state.model.load_state_dict(cold["model"])
state.optim.load_state_dict(cold["optim"]["state"])

In [None]:
objective_terms = auto.get_objective_terms(datasets, options["objective_terms"])

formula = settings["objective"]
objective = auto.get_objective(objective_terms, formula).to(**devtype)

<br>

In [None]:
from cplxmodule.utils.stats import sparsity, named_sparsity

print(f">>> {sparsity(state.model, threshold=options['threshold']):6.1%}")

In [None]:
{name: v for name, v in named_sparsity(state.model, threshold=-0.5)}

In [None]:
from scipy import stats
from auto.objective import named_ard_modules
from ipywidgets import widgets

log_alphas = {}
with torch.no_grad():
    for name, submod in named_ard_modules(state.model):
        log_alpha = submod.log_alpha.detach().cpu()
        log_alphas[name] = log_alpha.numpy()


def darker(color, a=0.5):
    """Adapted from this stackoverflow question_.
    .. _question: https://stackoverflow.com/questions/37765197/
    """
    from matplotlib.colors import to_rgb
    from colorsys import rgb_to_hls, hls_to_rgb

    h, l, s = rgb_to_hls(*to_rgb(color))
    return hls_to_rgb(h, max(0, min(a * l, 1)), s)

In [None]:
if log_alphas:
    w_keys = widgets.Dropdown(options=[None, *log_alphas], description="Layer")

    @widgets.interact(layer=w_keys)
    def plot_hists(layer):
        colors = plt.cm.jet(np.linspace(0,1, num=len(log_alphas)))

        fig, ax = plt.subplots(1, 1, figsize=(16, 5))
        support = np.linspace(-15, 40, num=265)
        for (name, log_alpha), col in zip(log_alphas.items(), colors):
            if name != layer:
                extra = dict(histtype="step", lw=1, zorder=10, alpha=0.25)
            else:
                extra = dict(histtype="bar", lw=0, alpha=1., zorder=-10)

            *_, patches = ax.hist(log_alpha.flat, label=name, bins=51,
                                  density=True, **extra, color=col)
            if name == layer:
                subsample = log_alpha.flat
                if len(subsample) > 50000:
                    subsample = np.random.choice(subsample, replace=False, size=50000)
                density = stats.kde.gaussian_kde(subsample)

                color = darker(patches[0].get_facecolor(), 0.75)
                ax.plot(support, density(support), c=color, lw=1, zorder=10)


        ax.axvline(threshold, c="k")
        ax.legend(ncol=2, loc='upper right')
        ax.set_ylim(0, 0.5)
        ax.set_xlim(-15, 40)
        plt.show()

<br>

## Conclusions from slowdown investigation
* devtype: add 40 Gb
* the length is the cultrip that causes slowdown!
    * torch.utils.data.sampler.RandomSample allocates HUGE ram for randint or randperm of 1G sample indices
        * plus deallocation!
* pinned memory is moderately faster than unpinned
* moving to device is very slow

* Create a custom random sampler, that preallocates the sample schedule. The dataset is very large, soe we can treat is as almost infinite stream.

In [None]:
datasets.update(auto.get_datasets(options["dataset"], {
#     "train-1": {"filename": "./data/musicnet/musicnet_11khz_train.h5", "stride": 1}
    "test-1": {"filename": "./data/musicnet/musicnet_11khz_test.h5", "stride": 1}
#     "test-32": {"filename": "./data/musicnet/musicnet_11khz_test.h5", "stride": 32}
}))

In [None]:
from musicnet.dataset import MusicNetDataLoader

feeds.update(auto.get_feeds(datasets, collate_fn, devtype, {
#     "train-512": {'dataset': 'train-512', 'batch_size': 128, 'shuffle': False},
    "test_trabelsi": {'cls': str(MusicNetDataLoader), 'dataset': 'test-1', "n_batches": 1000},
#     "test-32": {'dataset': 'test-32', 'batch_size': 512, 'shuffle': False, "pin_memory": True}
}))

In [None]:
feeds

In [None]:
import warnings

with warnings.catch_warnings(record=True):  # no need to filter
#     feed = auto.wrap_feed(feeds["train-512"], max_iter=-1, **devtype)
#     feed = auto.wrap_feed(feeds["test_256"], max_iter=-1, **devtype)
    feed = feeds["test_trabelsi"]
#     feed = auto.wrap_feed(feeds["test-32"], max_iter=10000, **devtype)
    out = auto.evaluate(state.model, tqdm.tqdm(feed), curves=True)

In [None]:
out["pooled_average_precision"]

In [None]:
cold["performance"]['test_256']['pooled_average_precision']

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 3))

ax[0].plot(out["average_precision"], label=f"AP {np.nanmean(out['average_precision']):.1%}")
ax[0].legend(ncol=2)

ax[1].plot(out["accuracy"], label="acc.")
ax[1].plot(out["precision"], label="P")
ax[1].plot(out["recall"], label="R")
ax[1].legend(ncol=3)

plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve
from matplotlib.collections import LineCollection

fig, ax = plt.subplots(1, 1, figsize=(16, 7))

p, r, t = zip(*[prt for k, prt in out["ap_curves"].items() if k != 'pooled'])
ax.add_collection(
    LineCollection([*map(np.transpose, map(np.stack, zip(r, p)))],
                   colors=plt.cm.PuBuGn(np.linspace(0, 1, num=len(p))),
                   alpha=0.7)
)

p, r, t = out["ap_curves"]["pooled"]
ax.plot(r, p, c="k", lw=2)
plt.show()

<br>