# Info Extraction

it's much more easier to extract information of model from pytorch module than onnx...onnx doesn't have output shape

In [1]:
import onnx

# Load the ONNX model
model = onnx.load("onnx/vgg19.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

graph torch-jit-export (
  %1[FLOAT, 1x3x224x224]
) initializers (
  %2[FLOAT, 64x3x3x3]
  %3[FLOAT, 64]
  %4[FLOAT, 64x64x3x3]
  %5[FLOAT, 64]
  %6[FLOAT, 128x64x3x3]
  %7[FLOAT, 128]
  %8[FLOAT, 128x128x3x3]
  %9[FLOAT, 128]
  %10[FLOAT, 256x128x3x3]
  %11[FLOAT, 256]
  %12[FLOAT, 256x256x3x3]
  %13[FLOAT, 256]
  %14[FLOAT, 256x256x3x3]
  %15[FLOAT, 256]
  %16[FLOAT, 256x256x3x3]
  %17[FLOAT, 256]
  %18[FLOAT, 512x256x3x3]
  %19[FLOAT, 512]
  %20[FLOAT, 512x512x3x3]
  %21[FLOAT, 512]
  %22[FLOAT, 512x512x3x3]
  %23[FLOAT, 512]
  %24[FLOAT, 512x512x3x3]
  %25[FLOAT, 512]
  %26[FLOAT, 512x512x3x3]
  %27[FLOAT, 512]
  %28[FLOAT, 512x512x3x3]
  %29[FLOAT, 512]
  %30[FLOAT, 512x512x3x3]
  %31[FLOAT, 512]
  %32[FLOAT, 512x512x3x3]
  %33[FLOAT, 512]
  %34[FLOAT, 4096x25088]
  %35[FLOAT, 4096]
  %36[FLOAT, 4096x4096]
  %37[FLOAT, 4096]
  %38[FLOAT, 1000x4096]
  %39[FLOAT, 1000]
) {
  %41 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%1, %

In [2]:
#import onnx_caffe2.backend as backend
import onnx_tf.backend as backend
import numpy as np
import time

  from ._conv import register_converters as _register_converters


## Find Graph Edge (each link)

Node is operation, start from 0 ; Entity is object, start from u'1' (means %1)

基本上把每個node跑過一次後，所有的Entity都會摸到

In [3]:
def get_graph_order():
    Node2nextEntity = {}
    Entity2nextNode = {} 
    for Node_idx, node in enumerate(model.graph.node):
        # node input
        for Entity_idx in node.input:
            if not Entity_idx in Entity2nextNode.keys():
                Entity2nextNode.update({Entity_idx:Node_idx})
        # node output
        for Entity_idx in node.output:
            if not Node_idx in Node2nextEntity.keys():
                Node2nextEntity.update({Node_idx:Entity_idx})                
    return Node2nextEntity, Entity2nextNode
      
Node2nextEntity, Entity2nextNode = get_graph_order()

In [4]:
len(Node2nextEntity), len(Entity2nextNode)

(61, 99)

In [5]:
import pickle
pickle.dump(Node2nextEntity,open('onnx/vgg19_Node2nextEntity_dict.pkl','wb'))
pickle.dump(Entity2nextNode,open('onnx/vgg19_Entity2nextNode_dict.pkl','wb'))

## Get Subgroup

In [6]:
import pickle
Node2nextEntity = pickle.load(open('onnx/vgg19_Node2nextEntity_dict.pkl','rb'))
Entity2nextNode = pickle.load(open('onnx/vgg19_Entity2nextNode_dict.pkl','rb'))

In [7]:
def find_sequencial_nodes(search_target=['Conv', 'Add', 'Relu', 'MaxPool'], if_print = False): 
    found_nodes = []
    for i, node in enumerate(model.graph.node): 
        if if_print: print("\nnode[{}] ...".format(i))
        n_idx = i #init
        is_fit = True
        for tar in search_target:
            try:
                assert model.graph.node[n_idx].op_type == tar #check this node
                if if_print: print("node[{}] fit op_type [{}]".format(n_idx, tar))
                e_idx = Node2nextEntity[n_idx] #find next Entity
                n_idx = Entity2nextNode[e_idx] #find next Node
                #if if_print: print(e_idx,n_idx)
            except: 
                is_fit = False
                if if_print: print("node[{}] doesn't fit op_type [{}]".format(n_idx, tar))
                break

        if is_fit:
            if if_print: print("node[{}] ...fit!".format(i))
            found_nodes.append(i)
        else:
            if if_print: print("node[{}] ...NOT fit!".format(i))
    if if_print: print("\nNode{} fit the matching pattern".format(found_nodes))
    return found_nodes
find_sequencial_nodes(search_target=['Conv', 'Add', 'Relu'], if_print = True)
find_sequencial_nodes(search_target=['Conv', 'Add', 'Relu', 'MaxPool'], if_print = False)


node[0] ...
node[0] fit op_type [Conv]
node[1] fit op_type [Add]
node[2] fit op_type [Relu]
node[0] ...fit!

node[1] ...
node[1] doesn't fit op_type [Conv]
node[1] ...NOT fit!

node[2] ...
node[2] doesn't fit op_type [Conv]
node[2] ...NOT fit!

node[3] ...
node[3] fit op_type [Conv]
node[4] fit op_type [Add]
node[5] fit op_type [Relu]
node[3] ...fit!

node[4] ...
node[4] doesn't fit op_type [Conv]
node[4] ...NOT fit!

node[5] ...
node[5] doesn't fit op_type [Conv]
node[5] ...NOT fit!

node[6] ...
node[6] doesn't fit op_type [Conv]
node[6] ...NOT fit!

node[7] ...
node[7] fit op_type [Conv]
node[8] fit op_type [Add]
node[9] fit op_type [Relu]
node[7] ...fit!

node[8] ...
node[8] doesn't fit op_type [Conv]
node[8] ...NOT fit!

node[9] ...
node[9] doesn't fit op_type [Conv]
node[9] ...NOT fit!

node[10] ...
node[10] fit op_type [Conv]
node[11] fit op_type [Add]
node[12] fit op_type [Relu]
node[10] ...fit!

node[11] ...
node[11] doesn't fit op_type [Conv]
node[11] ...NOT fit!

node[12] ..

[3, 10, 23, 36, 49]

In [8]:
import itertools
def get_permutations(a):
    p = []
    for r in range(len(a)+1):
        c = list(itertools.combinations(a,r))
        
        for cc in c:
            p += list(itertools.permutations(cc))
    return p 
#a = [4,5,6]
#get_permutations(a)    

In [9]:
search_head = ['Conv']
followings = ['Add', 'Relu', 'MaxPool']
search_targets = [ search_head+list(foll) for foll in get_permutations(followings)] 
search_targets

[['Conv'],
 ['Conv', 'Add'],
 ['Conv', 'Relu'],
 ['Conv', 'MaxPool'],
 ['Conv', 'Add', 'Relu'],
 ['Conv', 'Relu', 'Add'],
 ['Conv', 'Add', 'MaxPool'],
 ['Conv', 'MaxPool', 'Add'],
 ['Conv', 'Relu', 'MaxPool'],
 ['Conv', 'MaxPool', 'Relu'],
 ['Conv', 'Add', 'Relu', 'MaxPool'],
 ['Conv', 'Add', 'MaxPool', 'Relu'],
 ['Conv', 'Relu', 'Add', 'MaxPool'],
 ['Conv', 'Relu', 'MaxPool', 'Add'],
 ['Conv', 'MaxPool', 'Add', 'Relu'],
 ['Conv', 'MaxPool', 'Relu', 'Add']]

In [10]:
matchings = [find_sequencial_nodes(search_target) for search_target in search_targets]

for i,matching in enumerate(matchings):
    if matching!=[]:
        print("\nsearch:{}, \nget matching node:{}".format(search_targets[i],matching))


search:['Conv'], 
get matching node:[0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49]

search:['Conv', 'Add'], 
get matching node:[0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49]

search:['Conv', 'Add', 'Relu'], 
get matching node:[0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49]

search:['Conv', 'Add', 'Relu', 'MaxPool'], 
get matching node:[3, 10, 23, 36, 49]
