In [None]:
import os
import torch
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt

from wavenet.wn.audiodata import AudioData, AudioLoader
from wavenet.wn.models import Model, Generator
from wavenet.wn.utils import list_files

%matplotlib inline

x_len = 2**10
num_classes = 256
num_layers = 8
num_blocks = 2
num_hidden = 32
kernel_size = 2
learn_rate = 0.001
step_size = 50
gamma = 0.8
batch_size = 8
num_workers = 1
num_epochs = 5
model_file = 'model.pt'
use_visdom = True
n_new_samples = 1000
disp_interval = 1
device = torch.device('cpu')

## Create dataset and dataloader

In [None]:
filelist = list_files('./audio')

def get_ylen(x_len, num_layers, num_blocks, kernel_size):
    rec_field = 1 + (kernel_size - 1) * \
                num_blocks * sum([2**k for k in range(num_layers)])
    return x_len - rec_field

y_len = get_ylen(x_len, num_layers, num_blocks, kernel_size)
print('y_len: {}'.format(y_len))

dataset = AudioData(filelist, x_len, y_len=y_len, 
                    num_classes=num_classes,store_tracks=True)
dataloader = AudioLoader(dataset, batch_size=batch_size, 
                         num_workers=num_workers)

## Define and train model

In [None]:
wave_model = Model(x_len, num_channels=1, num_classes=num_classes, 
                   num_blocks=num_blocks, num_layers=num_layers,
                   num_hidden=num_hidden, kernel_size=kernel_size)

In [None]:
wave_model.set_device(device)
if os.path.isfile(model_file):
    print('Loading model data from file: {}'.format(model_file))
    wave_model.load_state_dict(torch.load(model_file))
else:
    print('Model data not found: {}'.format(model_file))
    print('Training new model.')
    wave_model.criterion = nn.CrossEntropyLoss()
    wave_model.optimizer = optim.Adam(wave_model.parameters(), 
                                      lr=learn_rate)
    wave_model.scheduler = optim.lr_scheduler.StepLR(
        wave_model.optimizer, step_size=step_size, gamma=gamma)
    
    wave_model.train(dataloader, num_epochs=num_epochs, 
                     disp_interval=disp_interval, 
                     use_visdom=use_visdom)

    print('Saving model data to file: {}'.format(model_file))
    torch.save(wave_model.state_dict(), model_file)

## Predict sequence

In [None]:
wave_generator = Generator(wave_model, dataset)

n_total_samples = x_len + n_new_samples
audio = dataset.tracks[0]['audio'][:n_total_samples]
sample_rate = dataset.tracks[0]['sample_rate']
x = audio[:x_len]
n_predictions = n_total_samples - x_len

In [None]:
print('Predicting {} samples'.format(n_predictions))
y = wave_generator.run(x, n_predictions, disp_interval=100)

In [None]:
idxs = np.linspace(0, 
                   (n_total_samples - 1) * sample_rate, 
                   n_total_samples)
enc = dataset.encoder
reencoded_audio = enc.expand(enc.normalize(audio, span='minmax'))
plt.plot(idxs, reencoded_audio, 'b')
plt.plot(idxs[x_len:], y, 'r')