In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%cd ../..

In [None]:
import os
import pprint

import numpy as np
import torch
import librosa
from torch.nn import functional as F
import matplotlib.pyplot as plt
import librosa.display
import IPython.display as ipd
from tqdm import tqdm
import wandb

from wavenet import model, train, sample, audio, datasets, utils, viz, debug

In [None]:
pp = pprint.PrettyPrinter(indent=2)

# Train on Tiny

A full training run with the big model, against tiny. Should crush it. 

In [None]:
# this cell contains papermill tagged parameters
# they can be overriden by the cli when training:  
# papermill in.ipynb out.ipynb -p batch_norm True

learning_rate = 0.04
batch_size = 256
max_epochs = 20
n_samples=30
batch_size=40

In [None]:
n_samples, n_examples = 30, 5_000
p = model.HParams(
    mixed_precision=False,
    n_audio_chans=1, 
    n_classes=2*n_samples, 
    dilation_stacks=1,
    n_layers=6,
    compress=False, 
    sample_length=n_samples,
    seed=133,
    embed_inputs=True,
    verbose=False,
    batch_norm=False
    
).with_all_chans(32)

pp.pprint(dict(p))

In [None]:
utils.seed(p)
ds, ds_test = datasets.Tiny(n_samples, n_examples), datasets.Tiny(n_samples, n_examples)

In [None]:
m = model.Wavenet(p)
debug.summarize(m)

In [None]:
tp =  train.HParams(
    max_epochs=max_epochs, 
    batch_size=batch_size, 
    num_workers=1, 
    learning_rate=learning_rate
)

pp.pprint(dict(tp))

In [None]:
t = train.Trainer(m, ds, ds_test, tp)
t.metrics

In [None]:
t.train()

## Sample

In [None]:
def plotit(generated, name):
    color = np.random.rand(3,)
    plt.subplots(figsize=(30, 12))
    for i in range(batch_size):
        plt.subplot(5, 8, i+1)
        plt.ylim(0, p.n_classes)
        plt.grid(color='lightgray')
        plt.title(name)
        plt.plot(torch.arange(n_samples), generated[i, 0].cpu(), '.', color=color)


for f in [sample.fast, sample.simple]:
    y, logits, *_ = f(m, ds.transforms, utils.decode_random, n_samples=n_samples, batch_size=batch_size)
    plotit(y, str(f))

In [None]:
t.metrics.finish()