diff --git a/keras/engine/network.py b/keras/engine/network.py index 1d6e752729e..3aef5964c79 100644 --- a/keras/engine/network.py +++ b/keras/engine/network.py @@ -19,6 +19,7 @@ from .. import backend as K from ..utils.io_utils import ask_to_proceed_with_overwrite from ..utils.layer_utils import print_summary as print_layer_summary +from ..utils.layer_utils import get_source_inputs from ..utils.generic_utils import has_arg from ..utils.generic_utils import to_list from ..utils.generic_utils import object_list_uid @@ -1270,49 +1271,6 @@ def summary(self, line_length=None, positions=None, print_fn=None): print_fn=print_fn) -def get_source_inputs(tensor, layer=None, node_index=None): - """Returns the list of input tensors necessary to compute `tensor`. - - Output will always be a list of tensors - (potentially with 1 element). - - # Arguments - tensor: The tensor to start from. - layer: Origin layer of the tensor. Will be - determined via tensor._keras_history if not provided. - node_index: Origin node index of the tensor. - - # Returns - List of input tensors. - """ - if not hasattr(tensor, '_keras_history'): - return tensor - - if layer is None or node_index: - layer, node_index, _ = tensor._keras_history - if not layer._inbound_nodes: - return [tensor] - else: - node = layer._inbound_nodes[node_index] - if not node.inbound_layers: - # Reached an Input layer, stop recursion. - return node.input_tensors - else: - source_tensors = [] - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - previous_sources = get_source_inputs(x, - layer, - node_index) - # Avoid input redundancy. - for x in previous_sources: - if x not in source_tensors: - source_tensors.append(x) - return source_tensors - - def _make_node_key(layer_name, node_index): return layer_name + '_ib-' + str(node_index) diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py index 8f7422a7b23..47438b7d2cf 100644 --- a/keras/utils/__init__.py +++ b/keras/utils/__init__.py @@ -18,6 +18,7 @@ from .generic_utils import deserialize_keras_object from .generic_utils import Progbar from .layer_utils import convert_all_kernels_in_model +from .layer_utils import get_source_inputs from .layer_utils import print_summary from .vis_utils import plot_model from .np_utils import to_categorical diff --git a/keras/utils/layer_utils.py b/keras/utils/layer_utils.py index 7afe01886a1..f7f4b5b7445 100644 --- a/keras/utils/layer_utils.py +++ b/keras/utils/layer_utils.py @@ -246,3 +246,46 @@ def convert_dense_weights_data_format(dense, ki = np.transpose(ki, (1, 2, 0)) # first -> last kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) dense.set_weights([kernel, bias]) + + +def get_source_inputs(tensor, layer=None, node_index=None): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + # Arguments + tensor: The tensor to start from. + layer: Origin layer of the tensor. Will be + determined via tensor._keras_history if not provided. + node_index: Origin node index of the tensor. + + # Returns + List of input tensors. + """ + if not hasattr(tensor, '_keras_history'): + return tensor + + if layer is None or node_index: + layer, node_index, _ = tensor._keras_history + if not layer._inbound_nodes: + return [tensor] + else: + node = layer._inbound_nodes[node_index] + if not node.inbound_layers: + # Reached an Input layer, stop recursion. + return node.input_tensors + else: + source_tensors = [] + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + previous_sources = get_source_inputs(x, + layer, + node_index) + # Avoid input redundancy. + for x in previous_sources: + if x not in source_tensors: + source_tensors.append(x) + return source_tensors