In [None]:
"""traverse onnx model graph for visualization, requires `onnx` package"""

In [None]:
%%capture
from functools import wraps, cache
from collections.abc import Iterable, Callable, Generator
import onnx
import nographs as nog
from graphviz import Digraph

In [None]:
onnx_model = onnx.load(f"inception_v4_no_w.onnx")

In [None]:
node_style ={"style": 'filled',            "shape": 'box',
             "align": 'left',              "fontsize": '12',
             "ranksep":'0.01',             "height":'0.02',   "width":'0.04',}
gr_style = {"linelength": '16', "rankdir": 'LR'}

def adapter_graphviz( wrapped_frw ):
    """translation to GraphViz visitor"""
    dot = Digraph( node_attr = node_style, graph_attr = gr_style, )
    @wraps( wrapped_frw )
    def wrapper_bg_frw( vert, _traversal_context ):
        """visitor wrapper for folloing edge"""
        dot.node( vert, vert )
        for result in wrapped_frw( vert, _traversal_context ):
            dot.edge( vert, result[0], result[1], )
            yield result
    return wrapper_bg_frw, dot

In [None]:
def get_onnx_shape( onnx_tensor_type ):
    return [x.dim_value for x in onnx_tensor_type.shape.dim]

In [None]:
%%time

class onnx_walker:
    def __init__( self, graph ):
        self.edge_2_out = {}
        self.in_2_edge  = {}

        def push_to_list_in_dict( edge, value ):
            self.edge_2_out.setdefault( edge, [] )
            self.edge_2_out[ edge ].append( value )

        for node in graph.node:
            self.in_2_edge[ node.name ] = node.output
            for inedge in node.input:
                push_to_list_in_dict( inedge, node.name )
        
        for output in graph.output:
            out_name = str( get_onnx_shape( output.type.tensor_type ) )
            push_to_list_in_dict( output.name, out_name )
            self.in_2_edge[ out_name ] = ()
            
        self.inputs = map( lambda o: o.name, (graph.input) )
    
    def __call__( self, nname, _ ):
        def locate_child( edge_names ):
            for edge in edge_names:
                yield from ( (out, edge) for out in self.edge_2_out[edge] )
        if ( nname == '_GRAPH_INPUTS' ):
            yield from locate_child( self.inputs )
        else:
            yield from locate_child( self.in_2_edge[ nname ] )
        
onnx_walk = onnx_walker( onnx_model.graph )

In [None]:
%%time
(f_1, dot) = adapter_graphviz( onnx_walk )

trav_b = nog.TraversalBreadthFirst( next_labeled_edges= f_1 )
trav_b.start_from( '_GRAPH_INPUTS' )
for _ in trav_b:
    pass

In [None]:
dot              # pylint: disable=pointless-statement

In [None]:
#help(onnx.save)