-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Closed
Description
For the given model, with custom call arguments,
class CustomModel(keras.Model):
def __init__(self):
super(CustomModel, self).__init__()
self.dense = keras.layers.Dense(10)
def call(self, x, return_aux=False):
y = self.dense(x)
if return_aux:
return y, y * 2
return y
model = CustomModel()
model.compile(optimizer='adam', loss='mse')
model.fit(tf.random.normal((32, 5)), tf.random.normal((32, 10)), epochs=1)
model.save('custom_model')
loaded_model = keras.models.load_model('custom_model')
y_pred, aux = loaded_model(tf.random.normal((1, 5)), return_aux=True)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[3], line 1
----> 1 y_pred, aux = loaded_model(tf.random.normal((1, 5)), return_aux=True)
File /opt/conda/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File /opt/conda/lib/python3.10/site-packages/tensorflow/python/saved_model/function_deserialization.py:295, in recreate_function.<locals>.restored_function_body(*args, **kwargs)
291 positional, keyword = concrete_function.structured_input_signature
292 signature_descriptions.append(
293 "Option {}:\n {}\n Keyword arguments: {}".format(
294 index + 1, _pretty_format_positional(positional), keyword))
--> 295 raise ValueError(
296 "Could not find matching concrete function to call loaded from the "
297 f"SavedModel. Got:\n {_pretty_format_positional(args)}\n Keyword "
298 f"arguments: {kwargs}\n\n Expected these arguments to match one of the "
299 f"following {len(saved_function.concrete_functions)} option(s):\n\n"
300 f"{(chr(10)+chr(10)).join(signature_descriptions)}")
ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
Positional arguments (2 total):
* <tf.Tensor 'x:0' shape=(1, 5) dtype=float32>
* True
Keyword arguments: {}
Expected these arguments to match one of the following 2 option(s):
Option 1:
Positional arguments (2 total):
* TensorSpec(shape=(None, 5), dtype=tf.float32, name='input_1')
* False
Keyword arguments: {}
Option 2:
Positional arguments (2 total):
* TensorSpec(shape=(None, 5), dtype=tf.float32, name='x')
* False
Keyword arguments: {}