Skip to content

Commit

Permalink
add some test
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Feb 28, 2019
1 parent 250410d commit fae0fa0
Showing 1 changed file with 85 additions and 5 deletions.
90 changes: 85 additions & 5 deletions tests/layers/convolutional/test_pixelcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_pixelcnn_2d_input_and_output(self):
'`PixelCNN2DOutput`: got 123'):
_ = pixelcnn_2d_output(123)

def test_pixelcnn_conv2d_resnet_deps(self):
def test_pixelcnn_conv2d_resnet_one_layer_dep(self):
tf.set_random_seed(1234)

H, W = 11, 12
Expand All @@ -85,8 +85,8 @@ def check_influence(vertical, horizontal):
idx = i * W + j

# vertical stack influences
for h in range(H):
for w in range(W):
for h in range(vertical.shape[1]):
for w in range(vertical.shape[2]):
has_value = (vertical[idx] != 0.)[h, w]
expect_value = ((h - 3 <= i <= h - 1) and
(w - 2 <= j <= w + 2))
Expand All @@ -98,8 +98,8 @@ def check_influence(vertical, horizontal):
)

# horizontal stack influences
for h in range(H):
for w in range(W):
for h in range(horizontal.shape[1]):
for w in range(horizontal.shape[2]):
has_value = (horizontal[idx] != 0.)[h, w]
expect_value = (
# value from vertical stack
Expand Down Expand Up @@ -155,3 +155,83 @@ def check_influence(vertical, horizontal):
with pytest.raises(TypeError, match='`input` is not an instance of '
'`PixelCNN2DOutput`: got 123'):
_ = pixelcnn_conv2d_resnet(123, out_channels=1)

def test_pixelcnn_conv2d_resnet_multi_layer_dep(self):
tf.set_random_seed(1234)

H, W = 6, 7
x = np.eye(H * W, dtype=np.float32).reshape([H * W, H, W, 1])

def check_influence(vertical, horizontal):
for i in range(H):
for j in range(W):
idx = i * W + j

# vertical stack influences
for h in range(vertical.shape[1]):
for w in range(vertical.shape[2]):
has_value = (vertical[idx] != 0.)[h, w]
expect_value = i <= h - 1
self.assertEqual(
has_value,
expect_value,
msg='Vertical stack not match at '
'({},{},{},{})'.format(i, j, h, w)
)

# horizontal stack influences
for h in range(horizontal.shape[1]):
for w in range(horizontal.shape[2]):
has_value = (horizontal[idx] != 0.)[h, w]
expect_value = (
# value from vertical stack
i <= h - 1 or
# value from horizontal stack
(j <= w - 1 and i <= h)
)
self.assertEqual(
has_value,
expect_value,
msg='Horizontal stack not match at '
'({},{},{},{})'.format(i, j, h, w)
)

with mock.patch('tensorflow.nn.conv2d', patched_conv2d), \
self.test_session() as sess:
n_layers = 3

# NHWC
y = o = pixelcnn_2d_input(
x, channels_last=True, auxiliary_channel=False)
for i in range(n_layers):
o = pixelcnn_conv2d_resnet(
o, out_channels=1, vertical_kernel_size=(2, 3),
horizontal_kernel_size=(2, 2), strides=(1, 1),
channels_last=True,
)
ensure_variables_initialized()
vertical, horizontal = sess.run([o.vertical, o.horizontal])
check_influence(
vertical.reshape(vertical.shape[:-1]),
horizontal.reshape(horizontal.shape[:-1])
)

# NCHW
y = o = pixelcnn_2d_input(
np.transpose(x, [0, 3, 1, 2]), channels_last=False,
auxiliary_channel=False
)
for i in range(n_layers):
o = pixelcnn_conv2d_resnet(
o, out_channels=1, vertical_kernel_size=(2, 3),
horizontal_kernel_size=(2, 2), strides=(1, 1),
channels_last=False, activation_fn=tf.nn.leaky_relu
)
ensure_variables_initialized()
vertical, horizontal = sess.run([o.vertical, o.horizontal])
vertical = np.transpose(vertical, [0, 2, 3, 1])
horizontal = np.transpose(horizontal, [0, 2, 3, 1])
check_influence(
vertical.reshape(vertical.shape[:-1]),
horizontal.reshape(horizontal.shape[:-1])
)

0 comments on commit fae0fa0

Please sign in to comment.