Skip to content

Commit

Permalink
Merge pull request #15286 from ddrakard:plot_model_show_activations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 399077719
  • Loading branch information
tensorflower-gardener committed Sep 26, 2021
2 parents f4dd5a7 + 7658bfc commit 0f3e1ac
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
6 changes: 2 additions & 4 deletions keras/utils/vis_utils.py
Expand Up @@ -290,10 +290,8 @@ def format_shape(shape):
[format_shape(ishape) for ishape in layer.input_shapes])
else:
inputlabels = '?'
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
inputlabels,
outputlabels)

label = '{%s}|{input:|output:}|{{%s}}|{{%s}}' % (label, inputlabels,
outputlabels)
if not expand_nested or not isinstance(
layer, functional.Functional):
node = pydot.Node(layer_id, label=label)
Expand Down
19 changes: 16 additions & 3 deletions keras/utils/vis_utils_test.py
Expand Up @@ -105,7 +105,20 @@ def test_plot_model_with_add_loss(self):
except ImportError:
pass

def test_plot_model_cnn_with_activations(self):
@parameterized.parameters({
'show_shapes': False,
'show_dtype': False
}, {
'show_shapes': False,
'show_dtype': True
}, {
'show_shapes': True,
'show_dtype': False
}, {
'show_shapes': True,
'show_dtype': True
})
def test_plot_model_cnn_with_activations(self, show_shapes, show_dtype):
model = keras.Sequential()
model.add(
keras.layers.Conv2D(
Expand All @@ -120,8 +133,8 @@ def test_plot_model_cnn_with_activations(self):
vis_utils.plot_model(
model,
to_file=dot_img_file,
show_shapes=True,
show_dtype=True,
show_shapes=show_shapes,
show_dtype=show_dtype,
show_layer_activations=True)
self.assertTrue(tf.io.gfile.exists(dot_img_file))
tf.io.gfile.remove(dot_img_file)
Expand Down

0 comments on commit 0f3e1ac

Please sign in to comment.