Skip to content

Commit

Permalink
Use static tensor shapes if possible for one-padding (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Mar 24, 2020
1 parent 7bf5c7c commit da595ef
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
4 changes: 3 additions & 1 deletion larq/layers_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _get_spatial_shape(self, input_shape):
)

def _get_padding_same(self, inputs):
input_shape = tf.shape(inputs)
input_shape = inputs.shape
if not input_shape[1:].is_fully_defined():
input_shape = tf.shape(inputs)
padding = self._get_spatial_padding_same(self._get_spatial_shape(input_shape))
return (
[[0, 0], *padding, [0, 0]]
Expand Down
27 changes: 18 additions & 9 deletions larq/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,31 @@ def test_non_zero_padding_layers(
spy.assert_called_once_with(mocker.ANY, mocker.ANY, constant_values=1.0)

@pytest.mark.parametrize(
"layer_cls, input_shape",
"layer_cls",
[
(lq.layers.QuantConv1D, (None, 3)),
(lq.layers.QuantConv2D, (None, None, 3)),
(lq.layers.QuantConv3D, (None, None, None, 3)),
(lq.layers.QuantSeparableConv1D, (None, 3)),
(lq.layers.QuantSeparableConv2D, (None, None, 3)),
(lq.layers.QuantDepthwiseConv2D, (None, None, 3)),
lq.layers.QuantConv1D,
lq.layers.QuantConv2D,
lq.layers.QuantConv3D,
lq.layers.QuantSeparableConv1D,
lq.layers.QuantSeparableConv2D,
lq.layers.QuantDepthwiseConv2D,
],
)
@pytest.mark.parametrize("data_format", ["channels_last", "channels_first"])
def test_non_zero_padding_unknown_inputs(self, layer_cls, input_shape, data_format):
@pytest.mark.parametrize("static", [True, False])
def test_non_zero_padding_shapes(self, layer_cls, data_format, static):
layer = layer_cls(
16, 3, padding="same", pad_values=1.0, data_format=data_format
)
input_shape = [32 if static else None] * layer.rank + [3]
if data_format == "channels_first":
input_shape = reversed(input_shape)
input = tf.keras.layers.Input(shape=input_shape)
layer_cls(16, 3, padding="same", pad_values=1.0, data_format=data_format)(input)

layer(input)
if static:
for dim in layer.output_shape[1:]:
assert dim is not None


class TestLayerWarns:
Expand Down
1 change: 1 addition & 0 deletions larq/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_profile_model():
input_quantizer=lq.quantizers.SteTern(),
depthwise_quantizer=lq.quantizers.SteTern(),
padding="same",
pad_values=1.0,
use_bias=False,
),
tf.keras.layers.BatchNormalization(scale=False),
Expand Down

0 comments on commit da595ef

Please sign in to comment.