Skip to content

Commit

Permalink
Improve visualize_util to support container layers with optional recu…
Browse files Browse the repository at this point in the history
…rsion when plotting.

Also add an option to show layer shapes
  • Loading branch information
Julien Rebetez committed Nov 23, 2015
1 parent b059945 commit 98c6c8b
Showing 1 changed file with 139 additions and 33 deletions.
172 changes: 139 additions & 33 deletions keras/utils/visualize_util.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,146 @@
import pydot
# old pydot will not work with python3, must use one
# that works with python3 such as pydot2 or pydot
from keras.models import Sequential, Graph

def to_graph(model):
graph = pydot.Dot(graph_type='digraph')
if type(model) == Sequential:
previous_node = None
written_nodes = []
n = 1
for node in model.get_config()['layers']:
# append number in case layers have same name to differentiate
if (node['name'] + str(n)) in written_nodes:
n += 1
current_node = pydot.Node(node['name'] + str(n))
written_nodes.append(node['name'] + str(n))
graph.add_node(current_node)
if previous_node:
graph.add_edge(pydot.Edge(previous_node, current_node))
previous_node = current_node
elif type(model) == Graph:
# don't need to append number for names since all nodes labeled
for input_node in model.input_config:
graph.add_node(pydot.Node(input_node['name']))

# intermediate and output nodes have input defined
for layer_config in [model.node_config, model.output_config]:
for node in layer_config:
graph.add_node(pydot.Node(node['name']))
# possible to have multiple 'inputs' vs 1 'input'
if node['inputs']:
for e in node['inputs']:
graph.add_edge(pydot.Edge(e, node['name']))
import itertools
from keras.layers.containers import Graph, Sequential
from keras.layers.core import Merge


def layer_typename(layer):
return type(layer).__module__ + "." + type(layer).__name__


def get_layer_to_name(model):
"""Returns a dict mapping layer to their name in the model"""
if not isinstance(model, Graph):
return {}
else:
node_to_name = itertools.chain(
model.nodes.items(), model.inputs.items(), model.outputs.items()
)
return {v: k for k, v in node_to_name}


class ModelToDot(object):
"""
This is a helper class which visits a keras model (Sequential or Graph) and
returns a pydot.Graph representation.
This is implemented as a class because we need to maintain various states.
Use it as ```ModelToDot()(model)```
Keras models can have an arbitrary number of inputs and outputs. A given
layer can have multiple inputs but has a single output. We therefore
explore the model by starting at its output and crawling "up" the tree.
"""
def _pydot_node_for_layer(self, layer, label):
"""
Returns the pydot.Node corresponding to the given layer.
`label` specify the name of the layer (only used if the layer isn't yet
associated with a pydot.Node)
"""
# Check if this already exists (will be the case for nodes that
# serve as input to more than one layer)
if layer in self.layer_to_pydotnode:
node = self.layer_to_pydotnode[layer]
else:
layer_id = 'layer%d' % self.idgen
self.idgen += 1

label = label + " (" + layer_typename(layer) + ")"

if self.show_shape:
# Build the label that will actually contain a table with the
# input/output
outputlabels = str(layer.output_shape)
if hasattr(layer, 'input_shape'):
inputlabels = str(layer.input_shape)
elif hasattr(layer, 'input_shapes'):
inputlabels = ', '.join(
[str(ishape) for ishape in layer.input_shapes])
else:
graph.add_edge(pydot.Edge(node['input'], node['name']))
return graph
inputlabels = ''
label = "%s\n|{input:|output:}|{{%s}|{%s}}" % (
label, inputlabels, outputlabels)

node = pydot.Node(layer_id, label=label)
self.g.add_node(node)
self.layer_to_pydotnode[layer] = node
return node

def _process_layer(self, layer, layer_to_name=None, connect_to=None):
"""
Process a layer, adding its node to the graph and creating edges to its
outputs.
`connect_to` specify where the output of the current layer will be
connected
`layer_to_name` is a dict mapping layer to their name in the Graph
model. Should be {} when processing a Sequential model
"""
# The layer can be a container layer, in which case we can recurse
is_graph = isinstance(layer, Graph)
is_seq = isinstance(layer, Sequential)
if self.recursive and (is_graph or is_seq):
# We got a container layer, recursively transform it
if is_graph:
child_layers = layer.outputs.values()
else:
child_layers = [layer.layers[-1]]
for l in child_layers:
self._process_layer(l, layer_to_name=get_layer_to_name(layer),
connect_to=connect_to)
else:
# This is a simple layer.
label = layer_to_name.get(layer, '')
layer_node = self._pydot_node_for_layer(layer, label=label)

if connect_to is not None:
self.g.add_edge(pydot.Edge(layer_node, connect_to))

# Proceed upwards to the parent(s). Only Merge layers have more
# than one parent
if isinstance(layer, Merge): # Merge layer
for l in layer.layers:
self._process_layer(l, layer_to_name,
connect_to=layer_node)
elif hasattr(layer, 'previous') and layer.previous is not None:
self._process_layer(layer.previous, layer_to_name,
connect_to=layer_node)

def __call__(self, model, recursive=True, show_shape=False,
connect_to=None):
self.idgen = 0
# Maps keras layer to the pydot.Node representing them
self.layer_to_pydotnode = {}
self.recursive = recursive
self.show_shape = show_shape

self.g = pydot.Dot()
self.g.set('rankdir', 'TB')
self.g.set('concentrate', True)
self.g.set_node_defaults(shape='record', fontname="Fira Mono")

if hasattr(model, 'outputs'):
# Graph
for name, l in model.outputs.items():
self._process_layer(l, get_layer_to_name(model),
connect_to=connect_to)
else:
# Sequential container
self._process_layer(model.layers[-1], {}, connect_to=connect_to)

return self.g


def to_graph(model, **kwargs):
"""
`recursive` controls wether we recursively explore container layers
`show_shape` controls wether the shape is shown in the graph
"""
return ModelToDot()(model, **kwargs)


def plot(model, to_file='model.png'):
graph = to_graph(model)
Expand Down

0 comments on commit 98c6c8b

Please sign in to comment.