In [None]:
import os
import copy
import json
import tqdm

import numpy as np
import torch

In [None]:
dry_run = False

Get a common tempalte for musicnet

In [None]:
import pkg_resources

with pkg_resources.resource_stream("cplxpaper.musicnet", "template.json") as fin:
    template = json.load(fin)

Functions to edit nested dictionaries

In [None]:
def get_params(self, deep=True, keepcontainers=True):
    """Depth first redundantly flatten a nested dictionary.

    Arguments
    ---------
    self : dict
        The dictionary to traverser and linearize.

    deep : bool, default=True
        Whether to perform depth first travseral of nested dictionaries
        or not.

    keepcontainers : bool, default=True
        Whether to keep return the nested containers (dicts) or not.
        Effective only if `deep` is `True`.
    """
    out = dict()
    for key in self:
        value = self[key]
        if deep and isinstance(value, dict):
            nested = get_params(value, deep=True, keepcontainers=keepcontainers)
            out.update((key + '__' + k, val) for k, val in nested.items())
            if not keepcontainers:
                continue

        out[key] = value

    return out

In [None]:
from collections import defaultdict

def set_params(self, **params):
    """Inplace update of a nested dictionary.

    Details
    -------
    Adapted from scikit's BaseEstimator. Does not handle
    recurusive dictionaries.
    """
    if not params:
        return self

    nested_params = defaultdict(dict)
    for key, value in params.items():
        key, delim, sub_key = key.partition('__')
        if delim:
            nested_params[key][sub_key] = value

        else:
            self[key] = value

    for key, sub_params in nested_params.items():
        set_params(self[key], **sub_params)

    return self

In [None]:
def special_params(**params):
    """Returns a pair (params, special).

    Details
    -------
    Special parameters are those key that begin with '__'.
    """
    special = set(k for k in params if k.startswith("__"))
    return {k: v for k, v in params.items() if k not in special}, \
           {k[2:]: v for k, v in params.items() if k in special}

<br>

In [None]:
grid_default_settings = {
    "dataset_sources": [
        {
            "train-1": {
                "filename": os.path.abspath("./data/musicnet_11khz_train.h5"),
                "window": 4096,
                "stride": 1
            },
            "train-512": {
                "filename": os.path.abspath("./data/musicnet_11khz_train.h5"),
                "window": 4096,
                "stride": 512
            },
            "valid-128": {
                "filename": os.path.abspath("./data/musicnet_11khz_valid.h5"),
                "window": 4096,
                "stride": 128
            },
            "test-128": {
                "filename": os.path.abspath("./data/musicnet_11khz_test.h5"),
                "window": 4096,
                "stride": 128
            }
        },
    ],
    "feeds": [{
        "train_trabelsi": {
            "cls": "<class 'cplxpaper.musicnet.dataset.MusicNetDataLoader'>",
            "dataset": "train-1",
            "pin_memory": True,
            "n_batches": 1000
        },
        "valid_768": {
            "cls": "<class 'torch.utils.data.dataloader.DataLoader'>",
            "dataset": "valid-128",
            "batch_size": 768,
            "pin_memory": True,
            "shuffle": False,
            "n_batches": -1
        },
        "test_256": {
            "cls": "<class 'torch.utils.data.dataloader.DataLoader'>",
            "dataset": "test-128",
            "batch_size": 256,
            "pin_memory": True,
            "shuffle": False,
            "n_batches": -1
        }
    }],
    # dense
    "stages__dense__n_epochs": [200],
    "stages__dense__early": [{
        "feed": "valid_768",
        "patience": 10,
        "cooldown": 0,
        "rtol": 0,
        "atol": 2e-2,
        "raises": "<class 'StopIteration'>"
    }],
    "stages__dense__restart": [True],
    "stages__dense__reset": [True],

    # sparsify: continue from dense
    "stages__sparsify__n_epochs": [50],
    "stages__sparsify__restart": [False],
    "stages__dense__reset": [False],
    "stages__sparsify__early": [None],

    # sparsify: use masks from dense
    "stages__fine-tune__n_epochs": [200],
    "stages__fine-tune__restart": [True],
    "stages__fine-tune__reset": [True],
    "stages__fine-tune__early": [{
        "feed": "valid_768",
        "patience": 10,
        "cooldown": 0,
        "rtol": 0,
        "atol": 2e-2,
        "raises": "<class 'StopIteration'>"
    }]
}

<br>

## Experiment 1: Deep convolutional complex-valued network

In [None]:
grid_trabelsi = {
    **grid_default_settings,
    "features__kind": [
        "fft",
        "fft-shifted"
    ],
    "stages__sparsify__objective__kl_div": [1e-5, 5e-5, 1e-4, 5e-4],
    "n_replication": [*range(5)]
}

<br>

## Generate manifest JSONs

In [None]:
from sklearn.model_selection import ParameterGrid

grid = ParameterGrid([
    grid_trabelsi,
])

PATH = os.path.abspath(os.path.join(".", "runs", "grid_trabelsi"))

os.makedirs(PATH, exist_ok=False)

In [None]:
for i, par in enumerate(tqdm.tqdm(grid)):
    param, recipe = special_params(**par)
    param_ravel = get_params(param)

    # (patch 20191224) in case frequency spec changes (very unlikely)
    frequency = recipe.pop("frequency", None)

    # update the manifest
    local = set_params(copy.deepcopy(template), **param)
    if frequency is not None:
        local["frequency"] = frequency
    
    # overrides
    local["device"] = None
    n_copy = local.pop("n_replication", None)

    # save
    if n_copy is not None:
        manifest = f"musicnet[{n_copy:03d}]-{i:03d}.json"
    else:
        manifest = f"musicnet-{i:03d}.json"

    experiment = os.path.join(PATH, manifest)
    if not dry_run:
        json.dump(local, open(experiment, "w"), indent=2)

In [None]:
assert False

<br>