Skip to content

Commit

Permalink
Move get_source_inputs (keras-team#10415)
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee authored and fchollet committed Jun 13, 2018
1 parent a40f335 commit 8e5b853
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 43 deletions.
44 changes: 1 addition & 43 deletions keras/engine/network.py
Expand Up @@ -19,6 +19,7 @@
from .. import backend as K
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.layer_utils import print_summary as print_layer_summary
from ..utils.layer_utils import get_source_inputs
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import to_list
from ..utils.generic_utils import object_list_uid
Expand Down Expand Up @@ -1270,49 +1271,6 @@ def summary(self, line_length=None, positions=None, print_fn=None):
print_fn=print_fn)


def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
Output will always be a list of tensors
(potentially with 1 element).
# Arguments
tensor: The tensor to start from.
layer: Origin layer of the tensor. Will be
determined via tensor._keras_history if not provided.
node_index: Origin node index of the tensor.
# Returns
List of input tensors.
"""
if not hasattr(tensor, '_keras_history'):
return tensor

if layer is None or node_index:
layer, node_index, _ = tensor._keras_history
if not layer._inbound_nodes:
return [tensor]
else:
node = layer._inbound_nodes[node_index]
if not node.inbound_layers:
# Reached an Input layer, stop recursion.
return node.input_tensors
else:
source_tensors = []
for i in range(len(node.inbound_layers)):
x = node.input_tensors[i]
layer = node.inbound_layers[i]
node_index = node.node_indices[i]
previous_sources = get_source_inputs(x,
layer,
node_index)
# Avoid input redundancy.
for x in previous_sources:
if x not in source_tensors:
source_tensors.append(x)
return source_tensors


def _make_node_key(layer_name, node_index):
return layer_name + '_ib-' + str(node_index)

Expand Down
1 change: 1 addition & 0 deletions keras/utils/__init__.py
Expand Up @@ -18,6 +18,7 @@
from .generic_utils import deserialize_keras_object
from .generic_utils import Progbar
from .layer_utils import convert_all_kernels_in_model
from .layer_utils import get_source_inputs
from .layer_utils import print_summary
from .vis_utils import plot_model
from .np_utils import to_categorical
Expand Down
43 changes: 43 additions & 0 deletions keras/utils/layer_utils.py
Expand Up @@ -246,3 +246,46 @@ def convert_dense_weights_data_format(dense,
ki = np.transpose(ki, (1, 2, 0)) # first -> last
kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
dense.set_weights([kernel, bias])


def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
Output will always be a list of tensors
(potentially with 1 element).
# Arguments
tensor: The tensor to start from.
layer: Origin layer of the tensor. Will be
determined via tensor._keras_history if not provided.
node_index: Origin node index of the tensor.
# Returns
List of input tensors.
"""
if not hasattr(tensor, '_keras_history'):
return tensor

if layer is None or node_index:
layer, node_index, _ = tensor._keras_history
if not layer._inbound_nodes:
return [tensor]
else:
node = layer._inbound_nodes[node_index]
if not node.inbound_layers:
# Reached an Input layer, stop recursion.
return node.input_tensors
else:
source_tensors = []
for i in range(len(node.inbound_layers)):
x = node.input_tensors[i]
layer = node.inbound_layers[i]
node_index = node.node_indices[i]
previous_sources = get_source_inputs(x,
layer,
node_index)
# Avoid input redundancy.
for x in previous_sources:
if x not in source_tensors:
source_tensors.append(x)
return source_tensors

0 comments on commit 8e5b853

Please sign in to comment.