Note: this notebook example uses tensorflow 2.0

In [1]:
import tensorflow as tf
import numpy as np

Output shape formula for `padding='VALID'`:
$$y_{height} = (x_{height} - 1) \times strides_{height} + kernel_{height}$$
$$y_{width} = (x_{width} - 1) \times strides_{width} + kernel_{width}$$
Output shape formula for `padding='SAME'`:
$$y_{height} = x_{height} \times strides_{height}$$
$$y_{width} = x_{width} \times strides_{width}$$

In [2]:
def print_transpose_conv_example(input_shape, kernel_shape, strides, padding):
    '''
    input_shape - (height, width)
    kernel_shape - (height, width)
    strides - int
    padding - 'VALID' or 'SAME'
    '''
    x_height, x_width = input_shape
    kernel_height, kernel_width = kernel_shape
    x = np.ones((1, x_height, x_width, 1)) # (batch, x_height, x_width, in_channels)
    kernel = np.ones((kernel_height, kernel_width, 1, 1)) # (kernel_height, kernel_width, output_channels, input_channels)
    
    if padding == 'VALID':
        output_shape = tuple((np.array(input_shape) - 1) * strides + np.array(kernel_shape))
    elif padding == 'SAME':
        output_shape = tuple(np.array(input_shape) * strides)
    y_height, y_width = output_shape
    y = tf.nn.conv2d_transpose(x, kernel, output_shape=(1, y_height, y_width, 1), # (batch, y_height, y_width, output_channels)
                               strides=strides, padding=padding)
    
    print(' Input:')
    print(np.reshape(x, (x_height, x_width)))
    print(' Kernel: ')
    print(np.reshape(kernel, (kernel_height, kernel_width)))
    print(f' Output (strides = {strides}, padding = "{padding}"):')
    print(np.reshape(y, (y_height, y_width)))

Valid padding, unit strides

In [3]:
print_transpose_conv_example((2,2), (3,3), 1, 'VALID')

 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
 Output (strides = 1, padding = "VALID"):
[[1. 2. 2. 1.]
 [2. 4. 4. 2.]
 [2. 4. 4. 2.]
 [1. 2. 2. 1.]]


Same padding, unit strides, center crop

In [4]:
print('Center crop')
print_transpose_conv_example((2,2), (3,3), 1, 'SAME')

Center crop
 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
 Output (strides = 1, padding = "SAME"):
[[4. 4.]
 [4. 4.]]


Same padding, unit strides, top left crop

In [5]:
print_transpose_conv_example((2,2), (2,2), 1, 'VALID')
print('Top left crop')
print_transpose_conv_example((2,2), (2,2), 1, 'SAME')

 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1.]
 [1. 1.]]
 Output (strides = 1, padding = "VALID"):
[[1. 2. 1.]
 [2. 4. 2.]
 [1. 2. 1.]]
Top left crop
 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1.]
 [1. 1.]]
 Output (strides = 1, padding = "SAME"):
[[1. 2.]
 [2. 4.]]


Valid padding, strides > 1

In [6]:
print_transpose_conv_example((2,2), (3,3), 2, 'VALID')

 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
 Output (strides = 2, padding = "VALID"):
[[1. 1. 2. 1. 1.]
 [1. 1. 2. 1. 1.]
 [2. 2. 4. 2. 2.]
 [1. 1. 2. 1. 1.]
 [1. 1. 2. 1. 1.]]


Same padding, strides > 1

In [7]:
print_transpose_conv_example((2,2), (3,3), 2, 'SAME') # top left crop

 Input:
[[1. 1.]
 [1. 1.]]
 Kernel: 
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
 Output (strides = 2, padding = "SAME"):
[[1. 1. 2. 1.]
 [1. 1. 2. 1.]
 [2. 2. 4. 2.]
 [1. 1. 2. 1.]]
