In [None]:
from functools import reduce
import sys
import onnx
from collections import defaultdict

def inspect_onnx(model_path): 
    onnx_model = onnx.load(model_path)
    print(f"====== load model {model_path} done.")

    g = onnx_model.graph

    def process_node():
        op_counter = defaultdict(int)
        for node in g.node:
            name = node.name
            optype = name.split("/")[-1].split("_")[0]
            optype = node.op_type
            op_counter[optype] += 1

        print("*" * 50)
        for k, v in sorted(op_counter.items()):
            print(f"{k:40}:  {v}")
        print("*" * 50)
        print("Total OP num: ", reduce(lambda i, j: i + j, op_counter.values(), 0))

    def process_inout():
        print("\n", "*" * 50, "inputs:")
        for node in g.input:
            name = node.name
            dtype = node.type.tensor_type.elem_type
            shape = [dim.dim_value for dim in node.type.tensor_type.shape.dim]
            print(f"{name:40}{dtype:<4}{shape}")
            
        print("\n", "*" * 50, "outputs:")
        for node in g.output:
            name = node.name
            dtype = node.type.tensor_type.elem_type
            shape = [dim.dim_value for dim in node.type.tensor_type.shape.dim]
            print(f"{name:40}{dtype:<4}{shape}")

    def nodename_info():

        node_name_set = set(n.name for n in g.node) 
        valinfo_name_set = set(n.name for n in g.value_info)


        print(f"node  name len: {len(node_name_set)}") 
        print(f"vainf name len: {len(valinfo_name_set)}")

        a = node_name_set.update(valinfo_name_set)
        print(f"merge name len: {len(node_name_set)}")
    
    process_node()
    process_inout()
    nodename_info()
