In [1]:
# required module
# !pip install protobuf numpy

# required for this example
# !pip install onnx chainer

In [2]:
from graphviewer.proto.attr_value_pb2 import AttrValue
from graphviewer.proto.graph_pb2 import GraphDef
from graphviewer.proto.node_def_pb2 import NodeDef
from graphviewer.proto.tensor_shape_pb2 import TensorShapeProto
from graphviewer.proto.versions_pb2 import VersionDef

def get_graphdef_from_file(path):
    import onnx
    model = onnx.load(path)
    return parse(model.graph)


def parse(graph):
    nodes_proto, nodes = [], []
    import itertools
    for node in itertools.chain(graph.input, graph.output):
        nodes_proto.append(node)

    for node in nodes_proto:
        shapeproto = TensorShapeProto(
            dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim])
        nodes.append(NodeDef(
            name=node.name.encode(encoding='utf_8'),
            op='Variable',
            input=[],
            attr={
                'dtype': AttrValue(type=node.type.tensor_type.elem_type),
                'shape': AttrValue(shape=shapeproto),
            }
        ))

    for node in graph.node:
        attr = []
        for s in node.attribute:
            attr.append(' = '.join([str(f[1]) for f in s.ListFields()]))
        attr = ', '.join(attr).encode(encoding='utf_8')
        nodes.append(NodeDef(
            name=node.output[0].encode(encoding='utf_8'),
            op=node.op_type,
            input=node.input,
            attr={'parameters': AttrValue(s=attr)},
        ))
    mapping = {}
    for node in nodes:
        mapping[node.name] = node.op + '_' + node.name

    return GraphDef(node=nodes, versions=VersionDef(producer=22))

In [3]:
from IPython.display import display
from IPython.display import HTML
import numpy

from graphviewer.proto.graph_pb2 import GraphDef


def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add()
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = '<stripped {:d} bytes>'.format(size)
    return strip_def


def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
        function load() {{
            document.getElementById("{id}").pbtxt = {data};
        }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html"
          onload=load()>
        <div style="height:600px">
        <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(numpy.random.rand()))

    iframe = """
        <iframe seamless style="width:960px;height:720px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

In [4]:
gdef = get_graphdef_from_file('model.onnx')
show_graph(gdef)

In [5]:
from collections import Counter

import chainer
from chainer.computational_graph import build_computational_graph

from graphviewer.parser.dtypes import convert_dtype


def parse(outputs):
    cgraph = build_computational_graph([outputs])

    nodes = []
    input_dict = {}
    for head, tail in cgraph.edges:
        input_dict.setdefault(id(tail), []).append(head)

    name_cnt = Counter()
    id_to_name = {}

    def name_resolver(node):
        name = id_to_name.get(id(node), None)
        if name is not None:
            return name
        if isinstance(node, chainer.variable.VariableNode):
            name = 'Variable{:d}'.format(name_cnt['Variable'])
            name_cnt['Variable'] += 1
        else:
            name = '{}{:d}'.format(node.label, name_cnt[node.label])
            name_cnt[node.label] += 1
        id_to_name[id(node)] = name
        return name

    for node in cgraph.nodes:
        assert isinstance(node, (
            chainer.variable.VariableNode, chainer.function_node.FunctionNode))

        if id(node) not in input_dict:
            shpeproto = TensorShapeProto(
                dim=[TensorShapeProto.Dim(size=s) for s in node.shape])
            nodes.append(NodeDef(
                name=name_resolver(node).encode(encoding='utf_8'),
                op='Variable',
                input=[],
                attr={
                    'dtype': AttrValue(type=convert_dtype(node.dtype)),
                    'shpae': AttrValue(shape=shpeproto),
                }
            ))
        else:
            inputs = [name_resolver(n).encode(encoding='utf_8') for n in input_dict[id(node)]]
            attr = node.label.encode(encoding='utf_8')  # TODO
            nodes.append(NodeDef(
                name=name_resolver(node).encode(encoding='utf_8'),
                op=node.__class__.__name__,
                input=inputs,
                attr={'parameters': AttrValue(s=attr)},
            ))
    return GraphDef(node=nodes, versions=VersionDef(producer=22))

In [6]:
import chainer.functions as F
import chainer.links as L

# Network definition
class Net(chainer.Chain):

    def __init__(self):
        super(Net, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.conv1 = L.Convolution2D(None, 10, ksize=5)
            self.conv2 = L.Convolution2D(None, 20, ksize=5)
            self.l1 = L.Linear(None, 50)
            self.l2 = L.Linear(None, 10)

    def forward(self, x):
        x = F.relu(F.max_pooling_2d(self.conv1(x), 2))
        x = F.relu(F.max_pooling_2d(self.conv2(x), 2))
        x = F.relu(self.l1(x))
        x = F.dropout(x)
        return self.l2(x)

model = L.Classifier(Net())

In [7]:
x = chainer.Variable(numpy.random.rand(1, 1, 28, 28).astype(numpy.float32))
t = chainer.Variable(numpy.random.rand(1).astype(numpy.int32))
y = model(x, t)

In [8]:
gdef = parse(y)
show_graph(gdef)