Skip to content

Commit

Permalink
Add channels first testing for applications_test and cleanup testing …
Browse files Browse the repository at this point in the history
…script
  • Loading branch information
Inquisitive-ME committed Oct 20, 2023
1 parent 214fc39 commit c52e562
Showing 1 changed file with 77 additions and 42 deletions.
119 changes: 77 additions & 42 deletions tf_keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@

MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST

MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "NASNet", "RegNetX", "RegNetY"]
# Add each data format for each model
test_parameters_with_image_data_format = [
('{}_{}'.format(model[0].__name__, image_data_format), *model, image_data_format)
for image_data_format in ["channels_first", "channels_last"]
for model in MODEL_LIST
]

# Parameters for loading weights for MobileNetV3.
# (class, alpha, minimalistic, include_top)
MOBILENET_V3_FOR_WEIGHTS = [
Expand All @@ -138,7 +146,16 @@


class ApplicationsTest(tf.test.TestCase, parameterized.TestCase):
def assertShapeEqual(self, shape1, shape2):
@classmethod
def setUpClass(cls):
cls.original_image_data_format = backend.image_data_format()

@classmethod
def tearDownClass(cls):
backend.set_image_data_format(cls.original_image_data_format)

@classmethod
def assertShapeEqual(cls, shape1, shape2):
if len(shape1) != len(shape2):
raise AssertionError(
f"Shapes are different rank: {shape1} vs {shape2}"
Expand All @@ -147,8 +164,19 @@ def assertShapeEqual(self, shape1, shape2):
if v1 != v2:
raise AssertionError(f"Shapes differ: {shape1} vs {shape2}")

@parameterized.parameters(*MODEL_LIST)
def test_application_base(self, app, _):
def skip_if_invalid_image_data_format_for_model(self, app, image_data_format):
does_not_support_channels_first = any(
[unsupported_name.lower() in app.__name__.lower() for unsupported_name in
MODELS_UNSUPPORTED_CHANNELS_FIRST])
if image_data_format == "channels_first" and does_not_support_channels_first:
self.skipTest(
"{} does not support channels first".format(app.__name__)
)

@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_base(self, app, _, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
# Can be instantiated with default arguments
model = app(weights=None)
# Can be serialized and deserialized
Expand All @@ -162,36 +190,47 @@ def test_application_base(self, app, _):
self.assertEqual(len(model.weights), len(reconstructed_model.weights))
backend.clear_session()

@parameterized.parameters(*MODEL_LIST)
def test_application_notop(self, app, last_dim):
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_notop(self, app, last_dim, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if image_data_format == "channels_first":
input_shape = (3, None, None)
correct_output_shape = (None, last_dim, None, None)
channels_axis = 1
else:
input_shape = (None, None, 3)
correct_output_shape = (None, None, None, last_dim)
channels_axis = -1

if "NASNet" in app.__name__:
only_check_last_dim = True
else:
only_check_last_dim = False
output_shape = _get_output_shape(
lambda: app(weights=None, include_top=False)
)
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape
if only_check_last_dim:
self.assertEqual(output_shape[-1], last_dim)
self.assertEqual(output_shape[channels_axis], last_dim)
else:
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

@parameterized.parameters(*MODEL_LIST)
def test_application_notop_custom_input_shape(self, app, last_dim):
output_shape = _get_output_shape(
lambda: app(
weights="imagenet", include_top=False, input_shape=(224, 224, 3)
)
)
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_notop_custom_input_shape(self, app, last_dim, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if image_data_format == "channels_first":
input_shape = (3, 224, 224)
channels_axis = 1
else:
input_shape = (224, 224, 3)
channels_axis = -1
output_shape = app(weights="imagenet", include_top=False, input_shape=input_shape).output_shape

self.assertEqual(output_shape[-1], last_dim)
self.assertEqual(output_shape[channels_axis], last_dim)

@parameterized.parameters(MODEL_LIST)
def test_application_pooling(self, app, last_dim):
output_shape = _get_output_shape(
lambda: app(weights=None, include_top=False, pooling="avg")
)
output_shape = app(weights=None, include_top=False, pooling="avg").output_shape
self.assertShapeEqual(output_shape, (None, last_dim))

@parameterized.parameters(MODEL_LIST)
Expand All @@ -204,30 +243,28 @@ def test_application_classifier_activation(self, app, _):
last_layer_act = model.layers[-1].activation.__name__
self.assertEqual(last_layer_act, "softmax")

@parameterized.parameters(*MODEL_LIST_NO_NASNET)
def test_application_variable_input_channels(self, app, last_dim):
@parameterized.named_parameters(test_parameters_with_image_data_format)
def test_application_variable_input_channels(self, app, last_dim, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)
if backend.image_data_format() == "channels_first":
input_shape = (1, None, None)
correct_output_shape = (None, last_dim, None, None)
else:
input_shape = (None, None, 1)
output_shape = _get_output_shape(
lambda: app(
weights=None, include_top=False, input_shape=input_shape
)
)
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
correct_output_shape = (None, None, None, last_dim)
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape

self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

if backend.image_data_format() == "channels_first":
input_shape = (4, None, None)
else:
input_shape = (None, None, 4)
output_shape = _get_output_shape(
lambda: app(
weights=None, include_top=False, input_shape=input_shape
)
)
self.assertShapeEqual(output_shape, (None, None, None, last_dim))
output_shape = app(weights=None, include_top=False, input_shape=input_shape).output_shape

self.assertShapeEqual(output_shape, correct_output_shape)
backend.clear_session()

@parameterized.parameters(*MOBILENET_V3_FOR_WEIGHTS)
Expand All @@ -242,9 +279,12 @@ def test_mobilenet_v3_load_weights(
include_top=include_top,
)

@parameterized.parameters(MODEL_LIST)
@parameterized.named_parameters(test_parameters_with_image_data_format)
@test_utils.run_v2_only
def test_model_checkpoint(self, app, _):
def test_model_checkpoint(self, app, _, image_data_format):
self.skip_if_invalid_image_data_format_for_model(app, image_data_format)
backend.set_image_data_format(image_data_format)

model = app(weights=None)

checkpoint = tf.train.Checkpoint(model=model)
Expand All @@ -256,10 +296,5 @@ def test_model_checkpoint(self, app, _):
checkpoint_manager.save(checkpoint_number=1)


def _get_output_shape(model_fn):
model = model_fn()
return model.output_shape


if __name__ == "__main__":
tf.test.main()

0 comments on commit c52e562

Please sign in to comment.