In [1]:
from keras.models import Model
from keras.layers import Input, Dense
from keras.layers.wrappers import Wrapper
import pydot

Using TensorFlow backend.


In [2]:
input_ = Input(shape=(3,))
dense0 = Dense(3, name = 'dense0')
dense1 = Dense(3, name = 'dense1')

In [3]:
h = dense0(input_)
h = dense0(h) 
output_ = dense1(h)

model = Model(input_, output_)

Instructions for updating:
Colocations handled automatically by placer.


In [4]:
model._network_nodes

{'dense0_ib-0', 'dense0_ib-1', 'dense1_ib-0', 'input_1_ib-0'}

In [5]:
def get_nodes_edges(model, show_layer_names=True):
    
    dot = pydot.Dot()
    dot.set('rankdir', 'TB')
    dot.set('concentrate', True)
    dot.set('dpi', 96)
    dot.set_node_defaults(shape='record')
    sub_n_first_node = {}
    sub_n_last_node = {}
    sub_w_first_node = {}
    sub_w_last_node = {}

    layers = model._layers
    # Create graph nodes.
    for i, layer in enumerate(layers):
        layer_id = str(id(layer))
        # Append a wrapped layer's label to node's label, if it exists.
        # ignore model in model.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)
        
        # Create node's label.
        if show_layer_names:
            label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name
        try:
            outputlabels = str(layer.output_shape)
        except AttributeError:
            outputlabels = 'multiple'
        if hasattr(layer, 'input_shape'):
            inputlabels = str(layer.input_shape)
        elif hasattr(layer, 'input_shapes'):
            inputlabels = ', '.join(
                (str(ishape) for ishape in layer.input_shapes))
        else:
            inputlabels = 'multiple'
        label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                        inputlabels,
                                                        outputlabels)
        node = pydot.Node(layer_id, label=label)
        dot.add_node(node)
    for layer in layers:
        # objectの持つ番号を取得
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            # layerが持つnodeに対しても以下のformatで名前がついている
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._network_nodes:
                for inbound_layer in node.inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    # graphにinbound_layerが登録されているか確認
                    assert dot.get_node(inbound_layer_id)
                    assert dot.get_node(layer_id)
                    # graphにedgeを登録
                    dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
    return dot

In [26]:
dot = get_nodes_edges(model)
dot.write('./model.png', format='png')

True

In [52]:
dot.obj_dict

{'attributes': {'rankdir': 'TB', 'concentrate': True, 'dpi': 96},
 'name': 'G',
 'type': 'digraph',
 'strict': False,
 'suppress_disconnected': False,
 'simplify': False,
 'current_child_sequence': 8,
 'nodes': {'node': [{'attributes': {'shape': 'record'},
    'type': 'node',
    'parent_graph': <pydot.Dot at 0x11df97048>,
    'parent_node_list': None,
    'sequence': 1,
    'name': 'node',
    'port': None}],
  '4795137944': [{'attributes': {'label': 'input_1: InputLayer\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}'},
    'type': 'node',
    'parent_graph': <pydot.Dot at 0x11df97048>,
    'parent_node_list': None,
    'sequence': 2,
    'name': '4795137944',
    'port': None}],
  '4795137888': [{'attributes': {'label': 'dense0: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}'},
    'type': 'node',
    'parent_graph': <pydot.Dot at 0x11df97048>,
    'parent_node_list': None,
    'sequence': 3,
    'name': '4795137888',
    'port': None}],
  '4795196248': [{'attributes': {'label': 'd

In [None]:
print(dir(dot))

In [7]:
edge_list = dot.get_edge_list()
layout = dot.get_layout()
label = dot.get_label()
node_list = dot.get_node_list()

In [8]:
for edge in edge_list:
    print('edge', edge)
print('layout',layout)
print('label', layout)
for node in node_list:
    print('node', node)
    print('pos', node.get_pos())
    

edge 4795137944 -> 4795137888;
edge 4795137888 -> 4795137888;
edge 4795137888 -> 4795196248;
layout None
label None
node node [shape=record];
pos None
node 4795137944 [label="input_1: InputLayer\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
pos None
node 4795137888 [label="dense0: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
pos None
node 4795196248 [label="dense1: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
pos None


In [24]:
import io
s = dot.to_string()
print(s)
# raw 形式
with io.open('raw_image.png', mode='wt', encoding=None) as f:
    f.write(s)
# ここがグラフを描画する直前
img = dot.create('dot', 'png', encoding=None)
print(type(img)) # png形式
"""
with io.open('image_test.png', 'wb') as f:
    f.write(img)
"""
s_after = dot.to_string()
print(s_after)



digraph G {
concentrate=True;
dpi=96;
rankdir=TB;
node [shape=record];
4795137944 [label="input_1: InputLayer\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795137888 [label="dense0: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795196248 [label="dense1: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795137944 -> 4795137888;
4795137888 -> 4795137888;
4795137888 -> 4795196248;
}

<class 'bytes'>
digraph G {
concentrate=True;
dpi=96;
rankdir=TB;
node [shape=record];
4795137944 [label="input_1: InputLayer\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795137888 [label="dense0: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795196248 [label="dense1: Dense\n|{input:|output:}|{{(None, 3)}|{(None, 3)}}"];
4795137944 -> 4795137888;
4795137888 -> 4795137888;
4795137888 -> 4795196248;
}

dot


str

In [43]:
#print(dir(dot.get_node('node')[0]))
print(dot.get_node_defaults())
print(dot.get_edge_defaults())

[{'shape': 'record'}]
[]


AttributeError: 'list' object has no attribute 'get_name'

In [48]:
print(dot.get_nodes())
n1 = dot.get_nodes()[0]
print(dot.get_edges())

[<pydot.Node object at 0x11ddd2fd0>, <pydot.Node object at 0x11dfcc710>, <pydot.Node object at 0x11ddd2080>, <pydot.Node object at 0x11de37da0>]
[<pydot.Edge object at 0x11e007780>, <pydot.Edge object at 0x11de6ceb8>, <pydot.Edge object at 0x11e0077b8>]


In [51]:
n1.get_name()

'node'

In [58]:
import networkx as nx
def setup_europe():
    G = nx.Graph()

    G.add_edge("Portugal", "Spain")
    G.add_edge("Spain","France")
    G.add_edge("France","Belgium")
    G.add_edge("France","Germany")
    G.add_edge("France","Italy")
    G.add_edge("Belgium","Netherlands")
    G.add_edge("Germany","Belgium")
    G.add_edge("Germany","Netherlands")
    G.add_edge("England","Wales")
    G.add_edge("England","Scotland")
    G.add_edge("Scotland","Wales")
    G.add_edge("Switzerland","Austria")
    G.add_edge("Switzerland","Germany")
    G.add_edge("Switzerland","France")
    G.add_edge("Switzerland","Italy")
    G.add_edge("Austria","Germany")
    G.add_edge("Austria","Italy")
    G.add_edge("Austria","Czech Republic")
    G.add_edge("Austria","Slovakia")
    G.add_edge("Austria","Hungary")
    G.add_edge("Denmark","Germany")
    G.add_edge("Poland","Czech Republic")
    G.add_edge("Poland","Slovakia")
    G.add_edge("Poland","Germany")
    G.add_edge("Czech Republic","Slovakia")
    G.add_edge("Czech Republic","Germany")
    G.add_edge("Slovakia","Hungary")
    return G

G = setup_europe()
pos = nx.nx_agraph.graphviz_layout(G, prog = 'dot')
print(pos)

agraph = nx.nx_agraph.to_agraph(G)
agraph.draw("europe.png", format = 'png', prog = 'dot')

{'Portugal': (322.83, 666.0), 'Spain': (322.83, 594.0), 'France': (322.83, 522.0), 'Belgium': (234.83, 450.0), 'Germany': (239.83, 378.0), 'Italy': (396.83, 378.0), 'Netherlands': (52.834, 306.0), 'England': (440.83, 666.0), 'Wales': (408.83, 594.0), 'Scotland': (441.83, 522.0), 'Switzerland': (393.83, 306.0), 'Austria': (340.83, 234.0), 'Czech Republic': (227.83, 162.0), 'Slovakia': (290.83, 90.0), 'Hungary': (337.83, 18.0), 'Denmark': (242.83, 306.0), 'Poland': (221.83, 18.0)}
