Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions keras/src/backend/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,18 @@ def _default_save_signature(self):
inputs = self.input

if inputs is not None:
input_signature = [
input_signature = (
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
inputs,
)
]
lambda x: tf.TensorSpec(x.shape, x.dtype), inputs
),
)
else:
shapes_dict = self._build_shapes_dict
if len(shapes_dict) == 1:
input_shape = tuple(shapes_dict.values())[0]
input_signature = [
tf.TensorSpec(input_shape, self.compute_dtype)
]
else:
input_signature = [
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
shapes_dict,
)
]
input_signature = tuple(
tree.map_shape_structure(
lambda s: tf.TensorSpec(s, self.input_dtype), value
)
for value in self._build_shapes_dict.values()
)

@tf.function(input_signature=input_signature)
def serving_default(inputs):
Expand Down
61 changes: 57 additions & 4 deletions keras/src/backend/tensorflow/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras.src import backend
from keras.src import layers
from keras.src import metrics
from keras.src import models
from keras.src import ops
from keras.src import optimizers
from keras.src import testing
from keras.src.saving import object_registration
from keras.src.testing.test_utils import named_product


@object_registration.register_keras_serializable(package="my_package")
Expand Down Expand Up @@ -49,7 +52,7 @@ def mutate(self, new_v):
backend.backend() != "tensorflow",
reason="The SavedModel test can only run with TF backend.",
)
class SavedModelTest(testing.TestCase):
class SavedModelTest(testing.TestCase, parameterized.TestCase):
def test_sequential(self):
model = models.Sequential([layers.Dense(1)])
model.compile(loss="mse", optimizer="adam")
Expand Down Expand Up @@ -143,6 +146,52 @@ def call(self, inputs):
atol=1e-4,
)

@parameterized.named_parameters(
named_product(struct_type=["tuple", "array", "dict"])
)
def test_model_with_input_structure(self, struct_type):

class TupleModel(models.Model):

def call(self, inputs):
x, y = inputs
return x + ops.mean(y, axis=1)

class ArrayModel(models.Model):

def call(self, inputs):
x = inputs[0]
y = inputs[1]
return x + ops.mean(y, axis=1)

class DictModel(models.Model):

def call(self, inputs):
x = inputs["x"]
y = inputs["y"]
return x + ops.mean(y, axis=1)

input_x = tf.constant([1.0])
input_y = tf.constant([[1.0, 0.0, 2.0]])
if struct_type == "tuple":
model = TupleModel()
inputs = (input_x, input_y)
elif struct_type == "array":
model = ArrayModel()
inputs = [input_x, input_y]
elif struct_type == "dict":
model = DictModel()
inputs = {"x": input_x, "y": input_y}

result = model(inputs)
path = os.path.join(self.get_temp_dir(), "my_keras_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
outputs = restored_model.signatures["serving_default"](
inputs=input_x, inputs_1=input_y
)
self.assertAllClose(result, outputs["output_0"], rtol=1e-4, atol=1e-4)

def test_multi_input_model(self):
input_1 = layers.Input(shape=(3,))
input_2 = layers.Input(shape=(5,))
Expand All @@ -169,15 +218,19 @@ def test_multi_input_model(self):
def test_multi_input_custom_model_and_layer(self):
@object_registration.register_keras_serializable(package="my_package")
class CustomLayer(layers.Layer):
def __call__(self, *input_list):
def build(self, *input_shape):
self.built = True

def call(self, *input_list):
self.add_loss(input_list[-2] * 2)
return sum(input_list)

@object_registration.register_keras_serializable(package="my_package")
class CustomModel(models.Model):
def build(self, input_shape):
super().build(input_shape)
def build(self, *input_shape):
self.layer = CustomLayer()
self.layer.build(*input_shape)
self.built = True

@tf.function
def call(self, *inputs):
Expand Down