Skip to content

Commit

Permalink
Fix failing channels first tests for efficientnet and mobilenet_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Inquisitive-ME committed Oct 20, 2023
1 parent c52e562 commit faa997d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
12 changes: 9 additions & 3 deletions tf_keras/applications/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,15 @@ def round_repeats(repeats):
# original implementation.
# See https://github.com/tensorflow/tensorflow/issues/49930 for more
# details
x = layers.Rescaling(
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB]
)(x)
if backend.image_data_format() == 'channels_first':
shape_for_multiply = [1, 3, 1, 1]
else:
shape_for_multiply = [1, 1, 1, 3]
x = tf.math.multiply(x,
tf.reshape(
[1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB],
shape_for_multiply
))

x = layers.ZeroPadding2D(
padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad"
Expand Down
5 changes: 4 additions & 1 deletion tf_keras/applications/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def MobileNetV3(
input_shape = (cols, rows, 3)
# If input_shape is None and input_tensor is None using standard shape
if input_shape is None and input_tensor is None:
input_shape = (None, None, 3)
if backend.image_data_format() == "channels_last":
input_shape = (None, None, 3)
else:
input_shape = (3, None, None)

if backend.image_data_format() == "channels_last":
row_axis, col_axis = (0, 1)
Expand Down

0 comments on commit faa997d

Please sign in to comment.