In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import pprint

import numpy as np
import torch
import librosa
from torch.nn import functional as F
import matplotlib.pyplot as plt
import IPython.display as ipd

from wavenet import model, train, sample, audio, datasets, utils, viz, debug

In [None]:
pp = pprint.PrettyPrinter(indent=2)

In [None]:
# this cell contains papermill tagged parameters
# they can be overriden by the cli when training:  
# papermill in.ipynb out.ipynb -p batch_norm True

batch_norm = True
learning_rate = 0.0026
finder = False
batch_size = 8
max_epochs = 50

# Train on a single Track

Run this on a GPU. Try to overfit a single track.

In [None]:
p = model.HParams(
    embed_inputs=True, 
    n_audio_chans=1, 
    squash_to_mono=True,
    batch_norm=batch_norm
    
).with_all_chans(256)

pp.pprint(dict(p))

In [None]:
tp = train.HParams(
    max_epochs=max_epochs, 
    batch_size=batch_size, 
    num_workers=8, 
    learning_rate=learning_rate
)

pp.pprint(dict(tp))

In [None]:
utils.seed(p)
ds = datasets.Track('fixtures/goldberg/aria.wav', p)
ds, len(ds), tp.n_steps(len(ds))

In [None]:
utils.seed(p)
m = model.Wavenet(p)
debug.summarize(m)

In [None]:
t = train.Trainer(m, ds, None, tp, None)
t.metrics

In [None]:
track_i = viz.plot_random_track(ds)
track, *_ = ds[track_i]
ipd.Audio(audio.mu_expand(track.numpy(), p), rate=p.sampling_rate)

In [None]:
%%capture
utils.seed(p)
t.train()

In [None]:
utils.seed(p)
tracks, logits, g = sample.fast(m, ds.transforms, utils.decode_nucleus(), n_samples=32000, batch_size=10)

In [None]:
for track in tracks:
    track = ds.transforms.normalise(track.numpy())
    track = audio.mu_expand(track, p)
    ipd.display(ipd.Audio(track, rate=p.sampling_rate))

In [None]:
print('closing wandb')
t.metrics.finish()