Skip to content

Commit

Permalink
Merge pull request #10 from chris-chris:feature/2002-conv-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 299358931
Change-Id: Ie803176d648cf3c0d9275e70d476f13ee4a87744
  • Loading branch information
Copybara-Service committed Mar 6, 2020
2 parents 25576a5 + 3e509db commit 8a81e5d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ hk_py_test(
":base",
":conv",
":initializers",
":test_utils",
# pip: absl/testing:absltest
# pip: absl/testing:parameterized
# pip: jax
Expand Down
90 changes: 90 additions & 0 deletions haiku/_src/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from haiku._src import base
from haiku._src import conv
from haiku._src import initializers
from haiku._src import test_utils
from jax import random
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -163,6 +164,42 @@ def f():
self.assertEqual(out.shape, expected_output_shape)
self.assertEqual(reached[0], n*2)

@test_utils.transform_and_run
def test_invalid_input_shape(self):
n = 1
input_shape = [2, 4] + [16]*n

with self.assertRaisesRegex(ValueError, "Input to ConvND needs to have "
"rank 3, but input has shape"):
data = jnp.zeros(input_shape * 2)
net = conv.ConvND(n, output_channels=3, kernel_shape=3,
data_format="channels_first")
net(data)

@test_utils.transform_and_run
def test_invalid_mask_shape(self):
n = 1
input_shape = [2, 4] + [16]*n

with self.assertRaisesRegex(ValueError, "Mask needs to have the same "
"shape as weights. Shapes are:"):
data = jnp.zeros(input_shape)
net = conv.ConvND(n, output_channels=3, kernel_shape=3,
data_format="channels_first", mask=jnp.ones([1, 5, 1]))
net(data)

@test_utils.transform_and_run
def test_valid_mask_shape(self):
n = 2
input_shape = [2, 4] + [16]*n
data = jnp.zeros(input_shape)
net = conv.ConvND(n, output_channels=3, kernel_shape=3,
data_format="channels_first",
mask=jnp.ones([3, 3, 4, 3]))
out = net(data)
expected_output_shape = (2, 3) + (16,)*n
self.assertEqual(out.shape, expected_output_shape)


class Conv1DTest(parameterized.TestCase):

Expand Down Expand Up @@ -340,6 +377,22 @@ def f():
out = np.squeeze(out, axis=(0, 4))
np.testing.assert_equal(out, expected_out)

@test_utils.transform_and_run
def test_invalid_input_shape(self):
with_bias = True

with self.assertRaisesRegex(ValueError, "Input to ConvND needs to have "
"rank 5, but input has shape"):
data = jnp.ones([1, 5, 5, 5, 1, 9, 9])
net = conv.Conv3D(
output_channels=1,
kernel_shape=3,
stride=1,
padding="VALID",
with_bias=with_bias,
**create_constant_initializers(1.0, 1.0, with_bias))
net(data)


class ConvTransposeTest(parameterized.TestCase):

Expand Down Expand Up @@ -408,6 +461,43 @@ def f():
expected_output_shape = (2, 3) + (16,)*n
self.assertEqual(out.shape, expected_output_shape)

@test_utils.transform_and_run
def test_invalid_input_shape(self):
n = 1
with self.assertRaisesRegex(ValueError,
"Input to ConvND needs to have rank"):
input_shape = [2, 4] + [16]*n
data = jnp.zeros(input_shape*2)
net = conv.ConvNDTranspose(
n, output_channels=3, kernel_shape=3, data_format="channels_first")
return net(data)

@test_utils.transform_and_run
def test_invalid_input_mask(self):
n = 2
with self.assertRaisesRegex(ValueError, "Mask needs to have the same "
"shape as weights. Shapes are:"):
input_shape = [2, 4] + [16]*n
data = jnp.zeros(input_shape)
net = conv.ConvNDTranspose(
n, output_channels=3, kernel_shape=3,
data_format="channels_first",
mask=jnp.zeros([1, 2, 3]))
net(data)

@test_utils.transform_and_run
def test_valid_input_mask(self):
n = 2
input_shape = [2, 4] + [16]*n
data = jnp.zeros(input_shape)
net = conv.ConvNDTranspose(
n, output_channels=3, kernel_shape=3,
data_format="channels_first",
mask=jnp.zeros([3, 3, 4, 3]))
out = net(data)
expected_output_shape = (2, 3, 16, 16)
self.assertEqual(out.shape, expected_output_shape)


class Conv1DTransposeTest(parameterized.TestCase):

Expand Down

0 comments on commit 8a81e5d

Please sign in to comment.