In [None]:
"""traverse onnx model graph for visualization, requires `onnx` package, simplest version"""

In [None]:
from functools import wraps, cache
from itertools import product
from decorator import decorator
from typing import Generator
from __future__ import annotations
import onnx
import nographs as nog
from graphviz import Digraph

In [None]:
onnx_model = onnx.load("inception_v4_no_w.onnx")

In [None]:
node_style ={"style": 'filled',            "shape": 'box',
             "align": 'left',              "fontsize": '12',
             "ranksep":'0.01',             "height":'0.02',   "width":'0.04',}
gr_style = {"linelength": '16', "rankdir": 'TB'}

def adapter_graphviz( wrapped_frw ):
    """translation to GraphViz visitor"""
    graph_dot = Digraph( node_attr = node_style, graph_attr = gr_style, )
    @wraps( wrapped_frw )
    def wrapper_bg_frw( vert, _traversal_context ):
        """visitor wrapper for edge following"""
        graph_dot.node( vert, vert )
        for result, temps in wrapped_frw( vert, _traversal_context ):
            graph_dot.edge( vert, result[0], result[1], )
            yield result
    return wrapper_bg_frw, graph_dot

In [None]:
@cache
def name2node( name:str ):
    for node in onnx_model.graph.node:          # pylint: disable=no-member
        if name == node.name:
            return node

def locate_children( incoming_name ):
    "find target in onnx grapn by edge name "
    for child in onnx_model.graph.node:          # pylint: disable=no-member
        if incoming_name in set(child.input):
            yield (child.name, incoming_name)
    for output in onnx_model.graph.output:
        if output.name == incoming_name:
            yield ('_GRAPH_OUTPUTS', incoming_name )

def onnx_walk( nname, _ ):
    "returns successors of given vertex"
    if nname == '_GRAPH_INPUTS':
        for graph_in in onnx_model.graph.input:  # pylint: disable=no-member
            yield from locate_children( graph_in.name )
    else:
        for isnt_it_me in onnx_model.graph.node: # pylint: disable=no-member
            if nname == isnt_it_me.name:
                for node_output in isnt_it_me.output:
                    yield from locate_children( node_output )
                break

In [None]:
@decorator
def glue_conv_relu( wrapped_frw, vert, _traversal_context ):
    """combining Convolution nodes with ReLU ones"""
    for result in wrapped_frw( vert, _traversal_context ):
        temps = 0
        yield result, temps

# @decorator
# def clear_temps( wrapped_frw, vert, _traversal_context ):
#     """combining Convolution nodes with ReLU ones"""
#     for result in wrapped_frw( vert, _traversal_context ):
#         yield result

In [None]:
#(f_1, dot) = adapter_graphviz( glue_conv_relu(onnx_walk) )
(f_1, dot) = adapter_graphviz( (onnx_walk) )
#f_1 = clear_temps(f_1)

In [None]:
# trav_b = nog.TraversalBreadthFirst( next_labeled_edges= f_1 )
# trav_b.start_from( '_GRAPH_INPUTS' )
# for _ in trav_b:
#     pass

In [None]:
#onnx_model.graph.node[0].__hash__
dot              # pylint: disable=pointless-statement

In [None]:
class SubgraphTemplate():
    """Whats what we looking for"""
    def __init__(self, report_match):
        self.report_match = report_match
    def __call__(self, vert):
        if getattr( name2node(vert), 'op_type', '' ) == 'Relu':
            yield self.report_match(vert)

class LazyGraph():
    """Whats where we looking"""
    def __init__(self, start_vertices, next_edges):
        self.start_vertices = start_vertices
        self.next_edges     = next_edges
    def find_subgraphs(
        self,
        sub_template: SubgraphTemplate, 
    ) -> Generator[LazyGraph, None, None]:
        """All business take place here"""
        trav_b = nog.TraversalBreadthFirst( next_labeled_edges= self.next_edges )
        trav_b.start_from( start_vertices= self.start_vertices )
        for vert in trav_b:
            yield from self.find_rooted( sub_template, vert )
    def find_rooted( self, sub_template, vert ):
        yield list(sub_template(vert))

In [None]:
lg = LazyGraph( ('_GRAPH_INPUTS',), f_1 )

subs = lg.find_subgraphs( SubgraphTemplate( lambda x:x ) )

In [None]:
#list(subs)
dot              # pylint: disable=pointless-statement

In [None]:
#(SubgraphTemplate( lambda x:x )(onnx_model.graph.node[2]))