In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
from pathlib import Path
import os
import pprint

import numpy as np
import torch
import librosa
from torch.nn import functional as F
import matplotlib.pyplot as plt
import IPython.display as ipd

from wavenet import model, train, sample, audio, datasets, utils, viz, debug, distributed

In [None]:
pp = pprint.PrettyPrinter(indent=2)

# Train on Maestro

In [None]:
# this cell contains papermill tagged parameters
# they can be overriden by the cli when training:  
# papermill in.ipynb out.ipynb -p batch_norm True

batch_norm = False
learning_rate = 0.0044
finder = False
batch_size = 12
max_epochs = 2
with_all_chans = None
sample_overlap_receptive_field = True
progress_bar = False

In [None]:
p = model.HParams(
    embed_inputs=True, 
    n_audio_chans=1, 
    squash_to_mono=True,
    sample_overlap_receptive_field=sample_overlap_receptive_field,
    batch_norm=batch_norm
)

if with_all_chans:
    p = p.with_all_chans(with_all_chans)

pp.pprint(dict(p))

In [None]:
tp = train.HParams(
    max_epochs=max_epochs, 
    batch_size=batch_size, 
    num_workers=8, 
    finder=finder, 
    progress_bar=progress_bar,
    learning_rate=learning_rate
)

pp.pprint(dict(tp))

In [None]:
utils.seed(p)
nas_path = Path('/srv/datasets/maestro/maestro-v2.0.0')
ssd_path = Path('/srv/datasets-ssd/maestro/maestro-v2.0.0')
ds_train, ds_test = datasets.maestro(nas_path, p, ssd_path, year=2017)

In [None]:
utils.seed(p)
m = model.Wavenet(p)
debug.summarize(m)

In [None]:
t = distributed.DDP(m, ds_train, ds_test, tp)

In [None]:
track_i = viz.plot_random_track(ds_train)
track, *_ = ds_train[track_i]
ipd.Audio(audio.mu_expand(track.squeeze().numpy(), p), rate=p.sampling_rate)

In [None]:
%%capture
utils.seed(p)
t.train()

In [None]:
utils.seed(p)
tracks, logits, g = sample.fast(m, ds_train.transforms, utils.decode_nucleus(), n_samples=32000, batch_size=10)

In [None]:
for track in tracks:
    track = ds_train.transforms.normalise(track.numpy())
    track = audio.mu_expand(track, p)
    ipd.display(ipd.Audio(track, rate=p.sampling_rate))