# Dilated and Causal Convolution Guide

This notebook is meant to demonstrate bare minimum examples of dilated and causal convolutions, the main ingredients in the Wavenet architecture. First we start off by creating our signal.

In [45]:
import tensorflow as tf

input_len = 8
filter_width = 3

X = tf.range(input_len, dtype=tf.float32)
X = tf.reshape(X, [1, input_len, 1])

with tf.Session() as sess:
    x = sess.run(X)
    print('X is', x.flatten())
    print('X has shape', X.shape.as_list())

X is [0. 1. 2. 3. 4. 5. 6. 7.]
X has shape [1, 8, 1]


Next we will create our filter (an accumulator) and perform a normal 1D convolution for later comparison.

In [46]:
# Basic 1D convolution

H = tf.ones([filter_width, 1, 1])
Y = tf.nn.conv1d(X, H, 1, 'VALID')

with tf.Session() as sess:
    x, h, y = sess.run([X, H, Y])
    print('{} * {} = {}'.format(x.flatten(), h.flatten(), y.flatten()))

[0. 1. 2. 3. 4. 5. 6. 7.] * [1. 1. 1.] = [ 3.  6.  9. 12. 15. 18.]


Now we will implement a basic causal convolution. We do this by padding the input signal with zeros, causing it to shift in time. This shift allows the convolution to never use "future" values of x[t] when computing y[t]. Note that the result has the same dimensionality as the input signal.

In [47]:
# Non-dilated causal convolution

X_pad = tf.pad(X, [[0, 0], [filter_width - 1, 0], [0, 0]], 'CONSTANT')
Y = tf.nn.conv1d(X_pad, H, 1, 'VALID')

with tf.Session() as sess:
    x, h, y = sess.run([X, H, Y])
    print('{} * {} = {}'.format(x.flatten(), h.flatten(), y.flatten()))

[0. 1. 2. 3. 4. 5. 6. 7.] * [1. 1. 1.] = [ 0.  1.  3.  6.  9. 12. 15. 18.]


Now we will add a dilation to our convolution. In order to do this efficiently, we will use a trick: rather than actually creating a filter with many 0s, we will turn our input signal into multiple signals sampled at the dilation rate. This requires us to first pad our signal with zeros so its length is a multiple of the dilation rate. Afterwards, we combine the result of our convolutions of each of the subsampled signals.

In [48]:
# Dilated causal convolution

dilation = 2

# Turn our input signal into multiple subsampled signals
padded = tf.pad(X, [[0, 0], [0, (dilation - input_len % dilation) % dilation], [0, 0]])
reshaped = tf.reshape(padded, [-1, dilation, 1])
transposed = tf.transpose(reshaped, perm=[1, 0, 2])
X_pad = tf.reshape(transposed, [dilation, -1, 1])

# Perform causal convolution as normal
X_pad = tf.pad(X_pad, [[0, 0], [filter_width - 1, 0], [0, 0]], 'CONSTANT')
Y = tf.nn.conv1d(X_pad, H, 1, 'VALID')

# Undo our subsampling
prepared = tf.reshape(Y, [dilation, -1, 1])
transposed = tf.transpose(prepared, perm=[1, 0, 2])
Y = tf.reshape(transposed, [1, -1, 1])

with tf.Session() as sess:
    x, h, y = sess.run([X, H, Y])
    print('{} * {} = {}'.format(x.flatten(), h.flatten(), y.flatten()))

[0. 1. 2. 3. 4. 5. 6. 7.] * [1. 1. 1.] = [ 0.  1.  2.  4.  6.  9. 12. 15.]
