Skip to content

Commit

Permalink
Improve performance by using a custom causal convolution
Browse files Browse the repository at this point in the history
tf.atrous_conv2d pads the height dimension of the input to be a multiple
of the dilation rate, which explains why the model was slow and used a
lot of memory.

After switching to the custom causal convolution function in this
commit, performance has increased massively, and it's now possible to
train large models with dilation rates that wrap around multiple times
from 1 to 512.
  • Loading branch information
ibab committed Sep 16, 2016
1 parent a29de11 commit 8add545
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
59 changes: 38 additions & 21 deletions wavenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import tensorflow as tf

def Print(op):
return tf.Print(op, [op, tf.shape(op)], summarize=10)

class WaveNet(object):
'''Implements the WaveNet network for generative audio.
Expand All @@ -10,7 +13,13 @@ class WaveNet(object):
loss = net.loss(input_batch)
'''

def __init__(self, batch_size, channels, dilations, filter_width, residual_channels, dilation_channels):
def __init__(self,
batch_size,
channels,
dilations,
filter_width,
residual_channels,
dilation_channels):
self.batch_size = batch_size
self.channels = channels
self.dilations = dilations
Expand All @@ -19,18 +28,37 @@ def __init__(self, batch_size, channels, dilations, filter_width, residual_chann
self.dilation_channels = dilation_channels


# We add our own dilated convolution here, because atrous_conv2d
# pads the height so that is matches `dilation`, which leads
# to terrible performance if dilation is large.
def _causal_dilated_conv(self, value, filter, dilation):
if dilation == 1:
out = tf.nn.conv2d(value, filter, strides=4 * [1], padding='VALID')
else:
shape = tf.shape(value)
# How many elements we are missing to be divisible by dilation.
pad_elements = dilation - 1 - (shape[2] + dilation - 1) % dilation
padded = tf.pad(value, [[0, 0], [0, 0], [0, pad_elements], [0, 0]])
# Use the batch dimension to skip (dilation - 1) elements.
reshaped = tf.reshape(padded, [shape[0] * dilation, 1, -1, shape[3]])
# Perform a simple convolution.
conv = tf.nn.conv2d(reshaped, filter, strides=[1, 1, 1, 1], padding='VALID')
restored = tf.reshape(conv, [shape[0], 1, -1, tf.shape(filter)[3]])
# Remove padding elements from the end
out = tf.slice(restored, 4 * [0], [-1, -1, tf.shape(restored)[2] - pad_elements, -1])

padding = (tf.shape(filter)[1] - 1) * dilation
return tf.pad(out, [[0, 0], [0, 0], [padding, 0], [0, 0]])


# A single causal convolution layer that reduces the number of channels.
def _create_causal_layer(self, input_batch, in_channels, out_channels):
with tf.name_scope('causal_layer'):
weights_filter = tf.Variable(tf.truncated_normal(
[1, self.filter_width, in_channels, out_channels],
stddev=0.2,
name="filter"))
conv = tf.nn.conv2d(
input_batch,
weights_filter,
strides=4 * [1],
padding='VALID')
return tf.pad(conv, [[0, 0], [0, 0], [self.filter_width - 1, 0], [0, 0]])
return self._causal_dilated_conv(input_batch, weights_filter, 1)


def _create_dilation_layer(self, input_batch, layer_index, dilation, in_channels, dilation_channels):
Expand All @@ -44,34 +72,23 @@ def _create_dilation_layer(self, input_batch, layer_index, dilation, in_channels
[1, self.filter_width, in_channels, dilation_channels],
stddev=0.2, name="gate"))

# TensorFlow has an operator for convolution with holes.
conv_filter = tf.nn.atrous_conv2d(input_batch, weights_filter,
rate=dilation,
padding="VALID",
name="conv_filter")
conv_gate = tf.nn.atrous_conv2d(input_batch, weights_gate,
rate=dilation,
padding="VALID",
name="conv_gate")
conv_filter = self._causal_dilated_conv(input_batch, weights_filter, dilation)
conv_gate = self._causal_dilated_conv(input_batch, weights_gate, dilation)

out = tf.tanh(conv_filter) * tf.sigmoid(conv_gate)

# Pad output with zeros from the left to ensure that a given pixel only
# uses current/past values
out = tf.pad(out, [[0, 0], [0, 0], [dilation, 0], [0, 0]])

weights_dense = tf.Variable(tf.truncated_normal(
[1, 1, dilation_channels, in_channels], stddev=0.2, name="dense"))
transformed = tf.nn.conv2d(out, weights_dense, strides=[1] * 4,
padding="SAME", name="dense")

layer = 'layer{}'.format(layer_index)
tf.histogram_summary(layer + '_filter', weights_filter)
tf.histogram_summary(layer + '_gate', weights_gate)
tf.histogram_summary(layer + '_dense', weights_dense)

return transformed, input_batch + transformed


def _preprocess(self, audio):
'''Quantizes waveform amplitudes.'''
with tf.name_scope('preprocessing'):
Expand Down
6 changes: 3 additions & 3 deletions wavenet_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"filter_width": 2,
"quantization_steps": 256,
"sample_rate": 16000,
"dilations": [1, 2, 4, 8, 16, 32],
"residual_channels": 64,
"dilation_channels": 32
"dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
"residual_channels": 32,
"dilation_channels":16
}

0 comments on commit 8add545

Please sign in to comment.