In [1]:

import onnx
import numpy as np
import onnx_graphsurgeon as gs
from onnxsim import simplify

In [2]:
onnx_raw = onnx.load("./pointpillar_raw.onnx")
graph = gs.import_onnx(onnx_raw)
# for node in graph.nodes:
#     print([node][0]) 
    
first_ConvTranspose_node = [node for node in graph.nodes if node.op == "ConvTranspose"][0]
print(first_ConvTranspose_node)

# for i in range(3):
#     next_node = [node for node in graph.nodes if len(node.inputs) != 0 and len(first_ConvTranspose_node.outputs) != 0 and node.inputs[0] == first_ConvTranspose_node.outputs[0]][0]
#     first_ConvTranspose_node = next_node

# print(first_ConvTranspose_node)

ConvTranspose_243 (ConvTranspose)
	Inputs: [
		Variable (427): (shape=None, dtype=None)
		Constant (backbone_2d.deblocks.0.0.weight): (shape=[64, 128, 1, 1], dtype=<class 'numpy.float32'>)
	]
	Outputs: [
		Variable (428): (shape=None, dtype=None)
	]
Attributes: OrderedDict([('dilations', [1, 1]), ('group', 1), ('kernel_shape', [1, 1]), ('pads', [0, 0, 0, 0]), ('strides', [1, 1])])


In [3]:
first_node_after_concat = [node for node in graph.nodes if len(node.inputs) != 0 and len(first_ConvTranspose_node.outputs) != 0 and node.inputs[0] == first_ConvTranspose_node.outputs[0]]
print(first_node_after_concat[0])

BatchNormalization_244 (BatchNormalization)
	Inputs: [
		Variable (428): (shape=None, dtype=None)
		Constant (backbone_2d.deblocks.0.1.weight): (shape=[128], dtype=<class 'numpy.float32'>)
		Constant (backbone_2d.deblocks.0.1.bias): (shape=[128], dtype=<class 'numpy.float32'>)
		Constant (backbone_2d.deblocks.0.1.running_mean): (shape=[128], dtype=<class 'numpy.float32'>)
		Constant (backbone_2d.deblocks.0.1.running_var): (shape=[128], dtype=<class 'numpy.float32'>)
	]
	Outputs: [
		Variable (429): (shape=None, dtype=None)
	]
Attributes: OrderedDict([('epsilon', 0.0010000000474974513), ('momentum', 0.9900000095367432)])


In [4]:
@gs.Graph.register()
def replace_with_clip(self, inputs, outputs):
    for inp in inputs:
        inp.outputs.clear()

    for out in outputs:
        out.inputs.clear()

    op_attrs = dict()
    op_attrs["dense_shape"] = np.array([496,432])

    return self.layer(name="PPScatter_0", op="PPScatterPlugin", inputs=inputs, outputs=outputs, attrs=op_attrs)

def loop_node(graph, current_node, loop_time=0):
  # print(current_node)
  for i in range(loop_time):
    next_node = [node for node in graph.nodes if len(node.inputs) != 0 and len(current_node.outputs) != 0 and node.inputs[0] == current_node.outputs[0]][0]
    current_node = next_node
    
  # print(next_node)  
  return next_node

def simplify_postprocess(onnx_model):
  print("Use onnx_graphsurgeon to adjust postprocessing part in the onnx...")
  graph = gs.import_onnx(onnx_model)
 
  cls_preds = gs.Variable(name="cls_preds", dtype=np.float32, shape=(1, 248, 216, 18))
  box_preds = gs.Variable(name="box_preds", dtype=np.float32, shape=(1, 248, 216, 42))
  dir_cls_preds = gs.Variable(name="dir_cls_preds", dtype=np.float32, shape=(1, 248, 216, 12))

  tmap = graph.tensors()
  new_inputs = [tmap["voxels"], tmap["voxel_idxs"], tmap["voxel_num"]]
  new_outputs = [cls_preds, box_preds, dir_cls_preds]

  for inp in graph.inputs:
    if inp not in new_inputs:
      inp.outputs.clear()
      # print(inp)

  for out in graph.outputs:
    out.inputs.clear()
  
  first_ConvTranspose_node = [node for node in graph.nodes if node.op == "ConvTranspose"][0]
  concat_node = loop_node(graph, first_ConvTranspose_node, 3)
  assert concat_node.op == "Concat"

  first_node_after_concat = [node for node in graph.nodes if len(node.inputs) != 0 and len(concat_node.outputs) != 0 and node.inputs[0] == concat_node.outputs[0]]

  for i in range(3):
    transpose_node = loop_node(graph, first_node_after_concat[i], 1)
    assert transpose_node.op == "Transpose"
    transpose_node.outputs = [new_outputs[i]]

  graph.inputs = new_inputs
  graph.outputs = new_outputs
  graph.cleanup().toposort()
  return gs.export_onnx(graph)



def simplify_preprocess(onnx_model):
  print("Use onnx_graphsurgeon to modify onnx...")
  graph = gs.import_onnx(onnx_model)

  tmap = graph.tensors()
  MAX_VOXELS = tmap["voxels"].shape[0]

  # voxels: [V, P, C']
  # V is the maximum number of voxels per frame
  # P is the maximum number of points per voxel
  # C' is the number of channels(features) per point in voxels.
  input_new = gs.Variable(name="voxels", dtype=np.float32, shape=(MAX_VOXELS, 32, 10))

  # voxel_idxs: [V, 4]
  # V is the maximum number of voxels per frame
  # 4 is just the length of indexs encoded as (frame_id, z, y, x).
  X = gs.Variable(name="voxel_idxs", dtype=np.int32, shape=(MAX_VOXELS, 4))

  # voxel_num: [1]
  # Gives valid voxels number for each frame
  Y = gs.Variable(name="voxel_num", dtype=np.int32, shape=(1,))

  first_node_after_pillarscatter = [node for node in graph.nodes if node.op == "Conv"][0]
  # print(first_node_after_pillarscatter)

  first_node_pillarvfe = [node for node in graph.nodes if node.op == "MatMul"][0]
  # print(first_node_pillarvfe)

  next_node = current_node = first_node_pillarvfe
  
  # print(current_node.outputs)
  # print(current_node.outputs[0])
  # for node in graph.nodes:
  #     print(node)   
      # if node.inputs[0] == current_node.outputs[0]:
      #     print(node)  
  
  for i in range(6):
    # print('start loop')
    # print(next_node)
    next_node = [node for node in graph.nodes if node.inputs[0] == current_node.outputs[0]][0]
    # print(next_node)
    if i == 5:              # ReduceMax
      current_node.attrs['keepdims'] = [0]
      break
    current_node = next_node

  last_node_pillarvfe = current_node

  #merge some layers into one layer between inputs and outputs as below
  # print(graph.inputs)
  # print(graph.outputs)
  
  # graph.inputs.append(Y)
  
  # print(graph.inputs)
  
  inputs = [last_node_pillarvfe.outputs[0], X, Y]
  outputs = [first_node_after_pillarscatter.inputs[0]]
  graph.replace_with_clip(inputs, outputs)
  
  tmp_model_01 = gs.export_onnx(graph)
  onnx.save(tmp_model_01, "tmp_model_01.onnx")
  
  # Remove the now-dangling subgraph.
  graph.cleanup().toposort()

  tmp_model_02 = gs.export_onnx(graph)
  onnx.save(tmp_model_02, "tmp_model_02.onnx")
  
  #just keep some layers between inputs and outputs as below
  graph.inputs = [first_node_pillarvfe.inputs[0] , X, Y]
  graph.outputs = [tmap["cls_preds"], tmap["box_preds"], tmap["dir_cls_preds"]]
  
  tmp_model_03 = gs.export_onnx(graph)
  onnx.save(tmp_model_03, "tmp_model_03.onnx")
  
  graph.cleanup()

  tmp_model_04 = gs.export_onnx(graph)
  onnx.save(tmp_model_04, "tmp_model_04.onnx")
  
  
  #Rename the first tensor for the first layer 
  graph.inputs = [input_new, X, Y]
  # first_add = [node for node in graph.nodes if node.op == "MatMul"][0]
  # first_add.inputs[0] = input_new

  graph.cleanup().toposort()
  print('-------------------------END-------------------------')

  return gs.export_onnx(graph)
  
  
  
onnx_raw = onnx.load("./pointpillar_raw.onnx")  # load onnx model

# onnx_simp, check = simplify(onnx_raw)


onnx_simp = simplify_postprocess(onnx_raw)
onnx.save(onnx_simp, "pointpillar_simple11.onnx")

onnx_simp, check = simplify(onnx_simp)
onnx.save(onnx_simp, "pointpillar_simple111.onnx")

onnx_simp = simplify_preprocess(onnx_simp)
onnx.save(onnx_simp, "pointpillar_simple22.onnx")


Use onnx_graphsurgeon to adjust postprocessing part in the onnx...
Use onnx_graphsurgeon to modify onnx...
[W] colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored
[W] Found distinct tensors that share the same name:
[id: 139986394002288] Variable (voxel_idxs): (shape=[10000, 4], dtype=int32)
[id: 139986393825632] Variable (voxel_idxs): (shape=(10000, 4), dtype=<class 'numpy.int32'>)
Note: Producer node(s) of first tensor:
[]
Producer node(s) of second tensor:
[]
[W] colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored
[W] Found distinct tensors that share the same name:
[id: 139986394002480] Variable (voxel_num): (shape=[1], dtype=int32)
[id: 139986393825344] Variable (voxel_num): (shape=(1,), dtype=<class 'numpy.int32'>)
Note: Producer node(s) of first tensor:
[]
Producer node(s) of second t

In [None]:
onnx_raw = onnx.load("./pp.onnx")  # load onnx model
onnx_simp, check = simplify(onnx_raw)
onnx.save(onnx_simp, "pp_simple.onnx")