Skip to content

Commit

Permalink
Add Functional serialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 24, 2023
1 parent b4df8af commit 2f4f72c
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 95 deletions.
2 changes: 1 addition & 1 deletion keras_core/backend/keras_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __tf_tensor__(self, dtype=None, name=None):
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
"name={self.name}>"
f"name={self.name}>"
)

def __iter__(self):
Expand Down
2 changes: 2 additions & 0 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def build_from_config(self, config):
if config:
if "input_shape" in config:
self.build(config["input_shape"])
self._build_shapes_dict = config
elif "shapes_dict" in config:
self.build(**config["shapes_dict"])
self._build_shapes_dict = config["shapes_dict"]

def add_variable(
self,
Expand Down
307 changes: 277 additions & 30 deletions keras_core/models/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import warnings

Expand Down Expand Up @@ -33,11 +34,8 @@ def __new__(cls, *args, **kwargs):
return Function.__new__(cls)

@tracking.no_automatic_dependency_tracking
@python_utils.default
def __init__(self, inputs, outputs, name=None, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
# twice. Using the skip_init to skip one of the invocation of __init__
# to avoid any side effects
if isinstance(inputs, dict):
for k, v in inputs.items():
if not isinstance(v, backend.KerasTensor):
Expand Down Expand Up @@ -87,7 +85,7 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
f"Unrecognized type for `outputs`: {outputs} (of type {type(outputs)})"
)

super().__init__(inputs, outputs, name=name, **kwargs)
Function.__init__(self, inputs, outputs, name=name, **kwargs)
self._layers = self.layers
self.built = True

Expand All @@ -108,9 +106,10 @@ def call(self, inputs, training=None, mask=None):
masks = self._flatten_to_reference_inputs(mask)
for x, mask in zip(inputs, masks):
x._keras_mask = mask
return self._run_through_graph(
outputs = self._run_through_graph(
inputs, operation_fn=lambda op: operation_fn(op, training=training)
)
return unpack_singleton(outputs)

def compute_output_spec(self, inputs, training=None, mask=None):
# From Function
Expand Down Expand Up @@ -188,35 +187,202 @@ def add_loss(self, loss):
# Symbolic only. TODO
raise NotImplementedError

@python_utils.default
def get_config(self):
# Prepare base arguments
if not functional_like_constructor(self.__class__):
# Subclassed networks are not serializable
# (unless serialization is implemented by
# the author of the subclassed network).
return Model.get_config()

config = {
"name": self.name,
"trainable": self.trainable,
}
# Check whether the class has a constructor compatible with a Functional
# model or if it has a custom constructor.
if functional_like_constructor(self.__class__):
# Only return a Functional config if the constructor is the same
# as that of a Functional model. This excludes subclassed Functional
# models with a custom __init__.
config = {**config, **get_functional_config(self)}
else:
# Try to autogenerate config
xtra_args = set(config.keys())
if getattr(self, "_auto_get_config", False):
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
# Build a map from a layer unique name (make_node_key)
# to the index of the nodes that are saved in the config.
# Only nodes in network_nodes are saved.
node_reindexing_map = {}
for operation in self.operations:
if issubclass(operation.__class__, Functional):
# Functional models start with a pre-existing node
# linking their input to output.
kept_nodes = 1
else:
kept_nodes = 0
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# i.e. we mark it to be saved
node_reindexing_map[node_key] = kept_nodes
kept_nodes += 1

# serialize and save the layers in layer_configs
layer_configs = []
for operation in self.operations: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = serialize_node(node, node_reindexing_map)
if node_data is not None:
filtered_inbound_nodes.append(node_data)

layer_config = serialization_lib.serialize_keras_object(operation)
layer_config["name"] = operation.name
layer_config["inbound_nodes"] = filtered_inbound_nodes
layer_configs.append(layer_config)
config["layers"] = layer_configs

# Gather info about inputs and outputs.
model_inputs = []
for tensor in self._inputs:
operation = tensor._keras_history[0]
node_index = tensor._keras_history[1]
tensor_index = tensor._keras_history[2]
node_key = make_node_key(operation, node_index)
if node_key not in self._nodes:
continue
new_node_index = node_reindexing_map[node_key]
model_inputs.append([operation.name, new_node_index, tensor_index])
config["input_layers"] = model_inputs
model_outputs = []
for tensor in self._outputs:
operation = tensor._keras_history[0]
node_index = tensor._keras_history[1]
tensor_index = tensor._keras_history[2]
node_key = make_node_key(operation, node_index)
if node_key not in self._nodes:
continue
new_node_index = node_reindexing_map[node_key]
model_outputs.append([operation.name, new_node_index, tensor_index])
config["output_layers"] = model_outputs
return copy.deepcopy(config)

@classmethod
def from_config(self):
raise NotImplementedError
def from_config(cls, config, custom_objects=None):
"""Instantiates a Model from its config (output of `get_config()`)."""
# Layer instances created during
# the graph reconstruction process
created_layers = {}

# Dictionary mapping layer instances to
# node data that specifies a layer call.
# It acts as a queue that maintains any unprocessed
# layer call until it becomes possible to process it
# (i.e. until the input tensors to the call all exist).
unprocessed_nodes = {}

def add_unprocessed_node(layer, node_data):
"""Add node to layer list
Arg:
layer: layer object
node_data: Node data specifying layer call
"""
if layer not in unprocessed_nodes:
unprocessed_nodes[layer] = [node_data]
else:
unprocessed_nodes[layer].append(node_data)

def process_node(layer, node_data):
"""Reconstruct node by linking to inbound layers
Args:
layer: Layer to process
node_data: List of layer configs
"""
args, kwargs = deserialize_node(node_data, created_layers)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
layer(*args, **kwargs)

def process_layer(layer_data):
"""Deserializes a layer, then call it on appropriate inputs.
Args:
layer_data: layer config dict.
"""
layer_name = layer_data["name"]

# Instantiate layer.
layer = serialization_lib.deserialize_keras_object(
layer_data, custom_objects=custom_objects
)
created_layers[layer_name] = layer

# Gather layer inputs.
inbound_nodes_data = layer_data["inbound_nodes"]
for node_data in inbound_nodes_data:
# We don't process nodes (i.e. make layer calls)
# on the fly because the inbound node may not yet exist,
# in case of layer shared at different topological depths
# (e.g. a model such as A(B(A(B(x)))))
add_unprocessed_node(layer, node_data)

# First, we create all layers and enqueue nodes to be processed
for layer_data in config["layers"]:
process_layer(layer_data)

# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in config["layers"]:
layer = created_layers[layer_data["name"]]

# Process all nodes in layer, if not yet processed
if layer in unprocessed_nodes:
node_data_list = unprocessed_nodes[layer]

# Process nodes in order
node_index = 0
while node_index < len(node_data_list):
node_data = node_data_list[node_index]
try:
process_node(layer, node_data)

# If the node does not have all inbound layers
# available, stop processing and continue later
except IndexError:
break

node_index += 1

# If not all nodes processed then store unprocessed nodes
if node_index < len(node_data_list):
unprocessed_nodes[layer] = node_data_list[node_index:]
# If all nodes processed remove the layer
else:
del unprocessed_nodes[layer]

# Create lits of input and output tensors and return new class
name = config.get("name")
input_tensors = []
output_tensors = []
for layer_data in config["input_layers"]:
layer_name, node_index, tensor_index = layer_data
assert layer_name in created_layers
layer = created_layers[layer_name]
layer_output_tensors = layer._inbound_nodes[
node_index
].output_tensors
input_tensors.append(layer_output_tensors[tensor_index])
for layer_data in config["output_layers"]:
layer_name, node_index, tensor_index = layer_data
assert layer_name in created_layers
layer = created_layers[layer_name]
layer_output_tensors = layer._inbound_nodes[
node_index
].output_tensors
output_tensors.append(layer_output_tensors[tensor_index])
return cls(inputs=input_tensors, outputs=output_tensors, name=name)


def operation_fn(operation, training):
Expand All @@ -239,5 +405,86 @@ def functional_like_constructor(cls):
return False


def get_functional_config(network):
raise NotImplementedError
def unpack_singleton(x):
if len(x) == 1:
return x[0]
return x


def serialize_node(node, node_reindexing_map):
if not node.input_tensors:
# Does not need to be serialized.
return

args = node.arguments.args
kwargs = node.arguments.kwargs
return {
"args": serialization_lib.serialize_keras_object(args),
"kwargs": serialization_lib.serialize_keras_object(kwargs),
}


def deserialize_node(node_data, created_layers):
"""Return (args, kwargs) for calling the node layer."""
if not node_data:
return [], {}

if isinstance(node_data, list):
# Legacy case.
input_tensors = []
for input_data in node_data:
inbound_layer_name = input_data[0]
inbound_node_index = input_data[1]
inbound_tensor_index = input_data[2]
if len(input_data) == 3:
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
else:
raise ValueError(
"Cannot deserialize the model (invalid config data?)"
)
inbound_layer = created_layers[inbound_layer_name]

# Raise an error if the corresponding layer node
# has not yet been created
if len(inbound_layer._inbound_nodes) <= inbound_node_index:
raise IndexError(
"Layer node index out of bounds.\n"
f"inbound_layer = {inbound_layer}\n"
f"inbound_layer._inbound_nodes = {inbound_layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append(
inbound_node.output_tensors[inbound_tensor_index]
)
return [unpack_singleton(input_tensors)], kwargs

args = serialization_lib.deserialize_keras_object(node_data["args"])
kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])

def convert_revived_tensor(x):
if isinstance(x, backend.KerasTensor):
history = x._pre_serialization_keras_history
if history is None:
return x
layer = created_layers.get(history[0], None)
if layer is None:
raise ValueError(f"Unknown layer: {history[0]}")
inbound_node_index = history[1]
inbound_tensor_index = history[1]
if len(layer._inbound_nodes) <= inbound_node_index:
raise ValueError(
"Layer node index out of bounds.\n"
f"inbound_layer = {layer}\n"
f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = layer._inbound_nodes[inbound_node_index]
return inbound_node.output_tensors[inbound_tensor_index]
return x

args = nest.map_structure(convert_revived_tensor, args)
kwargs = nest.map_structure(convert_revived_tensor, kwargs)
return args, kwargs

0 comments on commit 2f4f72c

Please sign in to comment.