Skip to content

Commit

Permalink
[P] [RELNOTES] Add ability to visualize wrapped models with plot_mode…
Browse files Browse the repository at this point in the history
…l function (#11431)

* Added ability to expand nested models for visualisation and ability to specify plot dpi

* Documented vis_utils.py new parameters

* Documentation and test for vis_utils.py

* Pep8 fixes

* Documentation adjustments, cluster labeling to vis_utils.py

* Returned missing dots

* PEP8 Fix for comment line
  • Loading branch information
yoks authored and fchollet committed Oct 25, 2018
1 parent dbf8062 commit 882302d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
4 changes: 3 additions & 1 deletion docs/templates/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ from keras.utils import plot_model
plot_model(model, to_file='model.png')
```

`plot_model` takes two optional arguments:
`plot_model` takes four optional arguments:

- `show_shapes` (defaults to False) controls whether output shapes are shown in the graph.
- `show_layer_names` (defaults to True) controls whether layer names are shown in the graph.
- `expand_nested` (defaults to False) controls whether to expand nested models into clusters in the graph.
- `dpi` (defaults to 96) controls image dpi.

You can also directly obtain the `pydot.Graph` object and render it yourself,
for example to show it in an ipython notebook :
Expand Down
73 changes: 55 additions & 18 deletions keras/utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def _check_pydot():
def model_to_dot(model,
show_shapes=False,
show_layer_names=True,
rankdir='TB'):
rankdir='TB',
expand_nested=False,
dpi=96,
subgraph=False):
"""Convert a Keras model to dot format.
# Arguments
Expand All @@ -45,35 +48,60 @@ def model_to_dot(model,
a string specifying the format of the plot:
'TB' creates a vertical plot;
'LR' creates a horizontal plot.
expand_nested: whether to expand nested models into clusters.
dpi: dot DPI.
subgraph: whether to return a pydot.Cluster instance.
# Returns
A `pydot.Dot` instance representing the Keras model.
A `pydot.Dot` instance representing the Keras model or
a `pydot.Cluster` instance representing nested model if
`subgraph=True`.
"""
from ..layers.wrappers import Wrapper
from ..models import Model
from ..models import Sequential

_check_pydot()
dot = pydot.Dot()
dot.set('rankdir', rankdir)
dot.set('concentrate', True)
dot.set_node_defaults(shape='record')
if subgraph:
dot = pydot.Cluster(style='dashed')
dot.set('label', model.name)
dot.set('labeljust', 'l')
else:
dot = pydot.Dot()
dot.set('rankdir', rankdir)
dot.set('concentrate', True)
dot.set('dpi', dpi)
dot.set_node_defaults(shape='record')

if isinstance(model, Sequential):
if not model.built:
model.build()
layers = model._layers

# Create graph nodes.
for layer in layers:
for i, layer in enumerate(layers):
layer_id = str(id(layer))

# Append a wrapped layer's label to node's label, if it exists.
layer_name = layer.name
class_name = layer.__class__.__name__
if isinstance(layer, Wrapper):
layer_name = '{}({})'.format(layer_name, layer.layer.name)
child_class_name = layer.layer.__class__.__name__
class_name = '{}({})'.format(class_name, child_class_name)
if expand_nested and isinstance(layer.layer, Model):
submodel = model_to_dot(layer.layer, show_shapes,
show_layer_names, rankdir, expand_nested,
subgraph=True)
model_nodes = submodel.get_nodes()
dot.add_edge(pydot.Edge(layer_id, model_nodes[0].get_name()))
if len(layers) > i + 1:
next_layer_id = str(id(layers[i + 1]))
dot.add_edge(pydot.Edge(
model_nodes[len(model_nodes) - 1].get_name(),
next_layer_id))
dot.add_subgraph(submodel)
else:
layer_name = '{}({})'.format(layer_name, layer.layer.name)
child_class_name = layer.layer.__class__.__name__
class_name = '{}({})'.format(class_name, child_class_name)

# Create node's label.
if show_layer_names:
Expand Down Expand Up @@ -107,20 +135,26 @@ def model_to_dot(model,
node_key = layer.name + '_ib-' + str(i)
if node_key in model._network_nodes:
for inbound_layer in node.inbound_layers:
inbound_layer_id = str(id(inbound_layer))
# Make sure that both nodes exist before connecting them with
# an edge, as add_edge would otherwise create any missing node.
assert dot.get_node(inbound_layer_id)
assert dot.get_node(layer_id)
dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
if not expand_nested or not (
isinstance(inbound_layer, Wrapper) and
isinstance(inbound_layer.layer, Model)):
inbound_layer_id = str(id(inbound_layer))
# Make sure that both nodes exist before connecting them with
# an edge, as add_edge would otherwise
# create any missing node.
assert dot.get_node(inbound_layer_id)
assert dot.get_node(layer_id)
dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
return dot


def plot_model(model,
to_file='model.png',
show_shapes=False,
show_layer_names=True,
rankdir='TB'):
rankdir='TB',
expand_nested=False,
dpi=96):
"""Converts a Keras model to dot format and save to a file.
# Arguments
Expand All @@ -132,8 +166,11 @@ def plot_model(model,
a string specifying the format of the plot:
'TB' creates a vertical plot;
'LR' creates a horizontal plot.
expand_nested: whether to expand nested models into clusters.
dpi: dot DPI.
"""
dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)
dot = model_to_dot(model, show_shapes, show_layer_names, rankdir,
expand_nested, dpi)
_, extension = os.path.splitext(to_file)
if not extension:
extension = 'png'
Expand Down
16 changes: 15 additions & 1 deletion tests/keras/utils/vis_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os
import sys
import numpy as np
from keras.layers import Conv2D
from keras import Input, Model

from keras.layers import Conv2D, Bidirectional
from keras.layers import Dense
from keras.layers import Embedding
from keras.layers import Flatten
Expand All @@ -26,6 +28,18 @@ def test_plot_model():
vis_utils.plot_model(model, to_file='model2.png', show_shapes=True)
os.remove('model2.png')

inner_input = Input(shape=(2, 3), dtype='float32', name='inner_input')
inner_lstm = Bidirectional(LSTM(16, name='inner_lstm'), name='bd')(inner_input)
encoder = Model(inner_input, inner_lstm, name='Encoder_Model')
outer_input = Input(shape=(5, 2, 3), dtype='float32', name='input')
inner_encoder = TimeDistributed(encoder, name='td_encoder')(outer_input)
lstm = LSTM(16, name='outer_lstm')(inner_encoder)
preds = Dense(5, activation='softmax', name='predictions')(lstm)
model = Model(outer_input, preds)
vis_utils.plot_model(model, to_file='model3.png', show_shapes=True,
expand_nested=True, dpi=300)
os.remove('model3.png')


def test_plot_sequential_embedding():
"""Fixes #11376"""
Expand Down

0 comments on commit 882302d

Please sign in to comment.