# MusicNet : compression experiments

Generate experiment manifest for the compression-average precision curve for 
the complex-valued model from Trabelsi et al. (2017).

In [None]:
import os
import copy
import json
import tqdm

import numpy as np
import torch

In [None]:
dry_run = False

Functions to edit nested dictionaries

In [None]:
from cplxpaper.auto.parameter_grid import special_params

from cplxpaper.auto.parameter_grid import get_params, set_params

Get a common tempalte for musicnet

In [None]:
import pkg_resources

with pkg_resources.resource_stream("cplxpaper.musicnet", "template.json") as fin:
    template = json.load(fin)

del template["datasets"]["musicnet-train-512"]

In [None]:
early = {
    'cls': "<class 'cplxpaper.musicnet.performance.PooledAveragePrecisionEarlyStopper'>",
    'scorer': 'validation',
    'patience': 20,
    'cooldown': 0,
    'rtol': 0,
    'atol': 0.02,
    'raises': "<class 'StopIteration'>"
}

template = set_params(template, **{
    "feeds__train__pin_memory": False,

    # dense: 50-75-50
    "stages__dense__n_epochs": 50,
    "stages__dense__early": early,

    # sparsify: continue from dense
    "stages__sparsify__n_epochs": 75,
    "stages__sparsify__restart": False,
    "stages__dense__reset": False,
    "stages__sparsify__early": None,

    # fine-tune: use vi means and masks from sparsify
    "stages__fine-tune__n_epochs": 50,
    "stages__fine-tune__restart": True,
    "stages__fine-tune__reset": False,
    "stages__fine-tune__early": early,

    # Use bayes-consitent weights
    "objective_terms__kl_div__reduction": "sum",
    "objective_terms__kl_div__coef": 1e-5,  # corresponds to the number of samples (approximate)
                                            #  1k batches x 321 windows

    # shifted complex fft
    "features__signal_ndim": 1,
    "features__cplx": True,
    "features__shift": True,
    
    # paths
    "datasets__musicnet-train__filename":
        os.path.abspath("./data/musicnet_11khz_train.h5"),
    "datasets__musicnet-valid-128__filename":
        os.path.abspath("./data/musicnet_11khz_valid.h5"),
    "datasets__musicnet-test-128__filename":
        os.path.abspath("./data/musicnet_11khz_test.h5"),
    
    # scorer
    "scorers__validation__threshold": -0.5,
    "scorers__test__threshold": -0.5,
})

<br>

## Experiment 1: Deep convolutional complex-valued network

In [None]:
kl_div = np.unique(np.r_[
    0.25 * np.logspace(-4, 0, num=5),
    0.50 * np.logspace(-4, 0, num=5),
    0.75 * np.logspace(-4, 0, num=5),
    1.00 * np.logspace(-5, 0, num=6),
])

len(kl_div), kl_div

In [None]:
grid_cplx_fine_kl_div = {
    "stages__sparsify__objective__kl_div": kl_div,
    "n_replication": [*range(5)]
}

<br>

## Experiment  2: Deep convolutional complex-valued network

Finer kl-div coef grid around the peak of validation scores

In [None]:
kl_div_v2 = set(np.r_[
    np.linspace(2.5e-3, 1e-2, num=7),
    np.linspace(1e-2, 2.5e-2, num=7),
]) - set(kl_div)

kl_div_v2 = np.unique([*kl_div_v2])

len(kl_div_v2), kl_div_v2

In [None]:
grid_cplx_fine_kl_div_v2 = {
    "stages__sparsify__objective__kl_div": kl_div_v2,
    "n_replication": [*range(5)]
}

<br>

## Experiment  3: fast experiment for VD and ARD

Less epochs in `dense`, shorter `sparisfy` phase.

In [None]:
kl_div_v3 = np.unique(np.r_[
    0.25 * np.logspace(-3, -1, num=3),
    0.50 * np.logspace(-3, -1, num=3),
    0.75 * np.logspace(-3, -1, num=3),
    1.00 * np.logspace(-4, -1, num=4),
])

len(kl_div_v3), kl_div_v3

In [None]:
grid_cplx_fine_kl_div_v3_fast = {
    # schedule 12-30-50 epochs
    "stages__dense__n_epochs": [12],
    "stages__sparsify__n_epochs": [32],
    "stages__fine-tune__n_epochs": [50],

    # keep early stopper at fine-tune stage only
    "stages__dense__early": [None],
    "stages__sparsify__early": [None],
    # "stages__fine-tune__early": [None],  # keep the stopper
    
    # use faster lr annealing schedule
    "stages__dense__lr_scheduler": [{
        "cls": "<class 'cplxpaper.musicnet.lr_scheduler.FastStepScheduler'>"
    }],
    "stages__sparsify__lr_scheduler": [{
        "cls": "<class 'cplxpaper.musicnet.lr_scheduler.FastStepScheduler'>"
    }],
    "stages__fine-tune__lr_scheduler": [{
        "cls": "<class 'cplxpaper.musicnet.lr_scheduler.FastStepScheduler'>"
    }],

    # select the model
    "model__legacy": [
        False,  # Set to false to correct the sad kernel size mistake
                # detected on that fateful day of the 28th of January, 2020.
    ],
    "stages__dense__model__cls": [
        "<class 'cplxpaper.musicnet.models.complex.DeepConvNet'>",
    ],
    "stages__sparsify__model__cls": [
        "<class 'cplxpaper.musicnet.models.complex.DeepConvNetVD'>",
        "<class 'cplxpaper.musicnet.models.complex.DeepConvNetARD'>",
    ],
    "stages__fine-tune__model__cls": [
        "<class 'cplxpaper.musicnet.models.complex.DeepConvNetMasked'>",
    ],

    "stages__sparsify__objective__kl_div": kl_div_v3,
    "n_replication": [*range(5)]
}

<br>

## Generate manifest JSONs

In [None]:
from sklearn.model_selection import ParameterGrid

grid = ParameterGrid([
#     grid_cplx_fine_kl_div,
#     grid_cplx_fine_kl_div_v2,
    grid_cplx_fine_kl_div_v3_fast,
])

PATH = os.path.abspath(os.path.join(
    ".", "runs",
#     "grid_cplx_fine_kl_div"
#     "grid_cplx_fine_kl_div_v2"
    "grid_cplx_fine_kl_div_v3_fast"
))

os.makedirs(PATH, exist_ok=False)

In [None]:
for i, par in enumerate(tqdm.tqdm(grid)):
    param, recipe = special_params(**par)
    param_ravel = get_params(param)

    # (patch 20191224) in case frequency spec changes (very unlikely)
    frequency = recipe.pop("frequency", None)

    # update the manifest
    local = set_params(copy.deepcopy(template), **param)
    if frequency is not None:
        local["frequency"] = frequency
    
    # overrides
    local["device"] = None
    n_copy = local.pop("n_replication", None)

    # save
    if n_copy is not None:
        manifest = f"musicnet[{n_copy:03d}]-{i:03d}.json"
    else:
        manifest = f"musicnet-{i:03d}.json"

    experiment = os.path.join(PATH, manifest)
    if not dry_run:
        json.dump(local, open(experiment, "w"), indent=2)

In [None]:
assert False

<br>