Skip to content

Commit

Permalink
Merge pull request #2 from monthly-hack/1zk-patch-1
Browse files Browse the repository at this point in the history
1zk patch 1
  • Loading branch information
610mto committed Nov 6, 2016
2 parents fd97ec8 + 8b473b3 commit 3c35157
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 6 deletions.
6 changes: 0 additions & 6 deletions README.md
Expand Up @@ -7,9 +7,3 @@ A conditional model is not implemented yet.
Python 3

Chainer

Monoral μ-law AIFF file (as a dataset)

## Usage

under construction
133 changes: 133 additions & 0 deletions model.py
@@ -0,0 +1,133 @@
import chainer.functions as F
import chainer.links as L
from chainer import Chain
from chainer import reporter


class WaveNet(Chain):

''' Implements the WaveNet network for generative audio.
Usage (with the architecture as in the DeepMind paper):
dilations = [2**i for i in range(10)] * 3
residual_channels = 16 # Not specified in the paper.
dilation_channels = 32 # Not specified in the paper.
skip_channels = 16 # Not specified in the paper.
model = WaveNet(dilations, residual_channels, dilation_channels, skip_channels,
quantization_channels)
'''

def __init__(self, dilations,
residual_channels=16,
dilation_channels=32,
skip_channels=128,
quantization_channels=256):
'''
Args:
dilations (list of int):
A list with the dilation factor for each layer.
residual_channels (int):
How many filters to learn for the residual.
dilation_channels (int):
How many filters to learn for the dilated convolution.
skip_channels (int):
How many filters to learn that contribute to the quantized softmax output.
quantization_channels (int):
How many amplitude values to use for audio quantization and the corresponding
one-hot encoding.
Default: 256 (8-bit quantization).
'''

super(WaveNet, self).__init__(
# a "one-hot" causal conv
causal_embedID=L.EmbedID(
quantization_channels, 2 * residual_channels),

# last 3 layers (include convolution on skip-connections)
conv1x1_0=L.Convolution2D(None, skip_channels, 1),
conv1x1_1=L.Convolution2D(None, skip_channels, 1),
conv1x1_2=L.Convolution2D(None, quantization_channels, 1),
)
# dilated stack
for i, dilation in enumerate(dilations):
self.add_link('conv_filter{}'.format(i),
L.DilatedConvolution2D(None, dilation_channels, (1, 2), dilate=dilation))
self.add_link('conv_gate{}'.format(i),
L.DilatedConvolution2D(None, dilation_channels, (1, 2), dilate=dilation, bias=1))
self.add_link('conv_res{}'.format(i),
L.Convolution2D(None, residual_channels, 1, nobias=True))

self.residual_channels = residual_channels
self.dilations = dilations

def __call__(self, x):
''' Computes the unnormalized log probability.
It uses L.EmbedID in first causal conv because it is efficient for one-hot input.
Args:
x (Variable): Variable holding 3 dimensional int32 array whose element indicates
quantized amplitude.
The shape must be (B, 1, wavelength).
Returns:
Variable: A variable holding 4 dimensional float32 array whose element indicates
unnormalized log probability.
The shape is (B, quantization_channels, 1, wavelength - ar_order + 1).
'''

# a "one-hot" causal conv
x = self.causal_embedID(x)
x = x[..., :-1, :self.residual_channels] + \
x[..., 1:, self.residual_channels:]

# shape (B, residual_channels, 1, wavelength-1)
x = F.transpose(x, (0, 3, 1, 2))

# dilated stack and skip connections
skip = []
for i in range(len(self.dilations)):
out = F.tanh(self['conv_filter{}'.format(i)](x)) * \
F.sigmoid(self['conv_gate{}'.format(i)](x))
skip.append(out)
len_out = out.data.shape[3]
x = self['conv_res{}'.format(i)](out) + x[..., -len_out:]

skip = [out[:, :, :, -len_out:] for out in skip]
y = F.concat(skip)

# last 3 layers
y = F.relu(self.conv1x1_0(y))
y = F.relu(self.conv1x1_1(y))
y = self.conv1x1_2(y)

return y


class ARClassifier(Chain):

compute_accuracy = True

def __init__(self, predictor, ar_order,
lossfun=F.softmax_cross_entropy,
accfun=F.accuracy):
super(ARClassifier, self).__init__(predictor=predictor)
self.lossfun = lossfun
self.accfun = accfun
self.y = None
self.loss = None
self.accuracy = None

self.ar_order = ar_order

def __call__(self, arg):
x = arg[..., :-1]
t = arg[..., self.ar_order:]
self.y = None
self.loss = None
self.accuracy = None
self.y = self.predictor(x)
self.loss = self.lossfun(self.y, t)
reporter.report({'loss': self.loss}, self)
if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)
reporter.report({'accuracy': self.accuracy}, self)
return self.loss
147 changes: 147 additions & 0 deletions train.py
@@ -0,0 +1,147 @@
# coding: utf-8

import matplotlib.pyplot as plt
import numpy as np

import chainer
from chainer import cuda, Variable
from chainer import datasets, iterators, optimizers, serializers, training
import chainer.functions as F
from chainer.training import extensions
from chainer.dataset import iterator

import scipy.io.wavfile as wavfile
import os, librosa, fnmatch

from model import *


directory = 'dataset/dateset/'

sample_rate = 8000

output_file_dir = 'results/'
output_len = 100000
gpu = 0
resume = False
epoch = 100
train_length = 10000

residual_channels = 16
dilation_channels = 32
skip_channels = 16
dilations = [2**i for i in range(10)] * 3

quantization_channels = 255


def find_files(directory, pattern='*.wav'):
'''Recursively finds all files matching the pattern.'''
files = []
for root, dirnames, filenames in os.walk(directory):
for filename in fnmatch.filter(filenames, pattern):
files.append(os.path.join(root, filename))
return files


def load_generic_audio(directory, sample_rate):
'''Generator that yields audio waveforms from the directory.'''
files = find_files(directory)
for filename in files:
audio, _ = librosa.load(filename, sr=sample_rate, mono=True)
audio = audio.reshape(-1, 1)
yield audio, filename


def mu_law_encode(audio, quantization_channels):
'''Quantizes waveform amplitudes.'''
mu = quantization_channels - 1
# Perform mu-law companding transformation (ITU-T, 1988).
magnitude = np.log(1 + mu * np.abs(audio)) / np.log(1. + mu)
signal = np.sign(audio) * magnitude
# Quantize signal to the specified number of levels.
return ((signal + 1) / 2 * mu + 0.5).astype(np.int32)


def mu_law_decode(output, quantization_channels):
'''Recovers waveform from quantized values.'''
mu = quantization_channels - 1
# Map values back to [-1, 1].
casted = output.astype(np.float32)
signal = 2. * (casted / mu) - 1
# Perform inverse of mu-law transformation.
magnitude = (1 / mu) * ((1 + mu)**abs(signal) - 1)
return np.sign(signal) * magnitude


def chop_dataset(data, train_length, stride, ar_order):
k = train_length + ar_order
dataset = np.stack([data[stride * i : stride * i + k]
for i in range((len(data) - k) // stride + 1)])
return dataset[:, np.newaxis, :, 0]


def generate_and_write_one_sample(ar_order, x, loc):
y = model.predictor(x[..., loc - ar_order : loc])

prob = F.softmax(y).data.flatten()
prob = cuda.to_cpu(prob)
x.data[..., loc] = np.random.choice(range(quantization_channels), p=prob)


def save_x(x, ar_order, quanttization_channels, filename, fs):
output = mu_law_decode(cuda.to_cpu(x.data[0, 0, ar_order:]), quantization_channels)
output = np.round(output * 2 ** 15).astype(np.int16).reshape((-1,))
wavfile.write(filename, fs, output)



ar_order = sum(dilations) + 2

wave_arrays = []
for audio, _ in load_generic_audio(directory, sample_rate):
x = mu_law_encode(audio, quantization_channels)
x = chop_dataset(x, train_length, train_length, ar_order)
wave_arrays.append(x)

dataset = np.concatenate(wave_arrays).astype(np.int32)

if gpu >= 0:
cuda.get_device(gpu).use()



model = ARClassifier(WaveNet(dilations,
residual_channels,
dilation_channels,
skip_channels,
quantization_channels),
ar_order)
model.to_gpu()

optimizer = optimizers.Adam()
optimizer.setup(model)


train, test = chainer.datasets.split_dataset_random(dataset, len(dataset) // 10 * 9)

train_iter = chainer.iterators.SerialIterator(train, 6)
test_iter = chainer.iterators.SerialIterator(test, 8, repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=gpu)


trainer = training.Trainer(updater, (epoch, 'epoch'))
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu))

trainer.extend(extensions.dump_graph('main/loss'))

trainer.extend(extensions.snapshot(), trigger=(epoch, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy']))

trainer.extend(extensions.ProgressBar())

trainer.run()
File renamed without changes.

0 comments on commit 3c35157

Please sign in to comment.