### ONNX-Open Neural Network Exchange

In [2]:
from torch import nn
import torch

### Define a demo Model with Pytorch

In [3]:
class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.features2 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.features3 = nn.Sequential(
            nn.Linear(10, 5),
            nn.ReLU(),
        )

        self.classifier = nn.Linear(128*128*3 + 32*32*3 + 5, 4)
        
    def forward(self, x1, x2, x3):
        x1 = self.features1(x1)
        x2 = self.features2(x2)
        x3 = self.features3(x3)

        x1 = x1.view(x1.data.size(0), -1)
        x2 = x2.view(x2.data.size(0), -1)
        x3 = x3.view(x3.data.size(0), -1)
        
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.classifier(x)
        return x

### Construct the Inputs

In [4]:
x1 = torch.randn(1,1,256,256)
x2 = torch.randn(1,1,64,64)
x3 = torch.randn(1,10)

In [5]:
model = DemoModel()
output = model(x1,x2,x3)

### Torch Export ONNX Model

In [6]:
torch.onnx.export(model,
                 (x1,x2,x3),
                  "demo.onnx",
                 export_params = True,
                 input_names = ["x1","x2","x3"],
                  output_names = ["output"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%x1 : Float(1, 1, 256, 256),
      %x2 : Float(1, 1, 64, 64),
      %x3 : Float(1, 10),
      %features1.0.weight : Float(3, 1, 3, 3),
      %features1.0.bias : Float(3),
      %features2.0.weight : Float(3, 1, 3, 3),
      %features2.0.bias : Float(3),
      %features3.0.weight : Float(5, 10),
      %features3.0.bias : Float(5),
      %classifier.weight : Float(4, 52229),
      %classifier.bias : Float(4)):
  %11 : Float(1, 3, 256, 256) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%x1, %features1.0.weight, %features1.0.bias) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/modules/conv.py:341:0
  %12 : Float(1, 3, 128, 128) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%11) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/functional.py:487:0
  %13 : Float(1, 3, 128, 128) = onnx::Relu(%12) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/functional.py:914:0
  %14

### onnx and rutime

In [7]:
import onnx
import onnxruntime as rt

In [8]:
onnx_model = onnx.load("./demo.onnx")

### GraphProto

In [9]:
graph = onnx_model.graph

In [None]:
graph

In [10]:
graph.initializer[0]

dims: 4
data_type: 1
name: "classifier.bias"
raw_data: "@\252\306\270\320h\2179\231\367a\2736\004h;"

### Helper Function

In [11]:
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto

In [12]:
# 输入
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [2,3,7,7])

In [13]:
# 卷积权重

weight = helper.make_tensor_value_info('weight', TensorProto.FLOAT,[1,3,3,3])

In [14]:
# 卷积bias

bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT,[1])

### 构建一个简单模型

In [17]:
# ValueInfoProto
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])
pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4])

value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1])


# ValueInfoProto
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4])

# NodeProto - This is based on Pad-11
node_def = helper.make_node(
    'Pad',                  # name
    ['X', 'pads', 'value'], # inputs
    ['Y'],                  # outputs
    mode='constant',        # attributes
)

# GraphProto
graph_proto = helper.make_graph(
    [node_def],        # nodes
    'test-model',      # name
    [X, pads, value],  # inputs
    [Y],               # outputs
)

# ModelProto
model_proto = helper.make_model(graph_def, producer_name='onnx-example')


onnx.checker.check_model(model_def)
print('The model is checked!')

The model is checked!


In [18]:
onnx.save(model_proto,"./construct_model_from_scratch.onnx")

### Simplify model

In [25]:
class CustomReshape(torch.nn.Module):
    def __init__(self):
        super(CustomReshape, self).__init__()
        
    def forward(self, x):
        return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))

In [27]:
custom_shape = CustomReshape()

In [30]:
dummy_input = torch.randn(2,3,4,5)
torch.onnx.export(custom_shape,
                 (dummy_input),
                  "simplify_model.onnx",
                 export_params = True,
                 input_names = ["x"],
                  output_names = ["output"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%x : Float(2, 3, 4, 5)):
  %1 : Tensor = onnx::Shape(%x)
  %2 : Tensor = onnx::Constant[value={0}]()
  %3 : Long() = onnx::Gather[axis=0](%1, %2) # <ipython-input-25-811817e7beb6>:6:0
  %4 : Tensor = onnx::Shape(%x)
  %5 : Tensor = onnx::Constant[value={1}]()
  %6 : Long() = onnx::Gather[axis=0](%4, %5) # <ipython-input-25-811817e7beb6>:6:0
  %7 : Tensor = onnx::Shape(%x)
  %8 : Tensor = onnx::Constant[value={3}]()
  %9 : Long() = onnx::Gather[axis=0](%7, %8) # <ipython-input-25-811817e7beb6>:6:0
  %10 : Tensor = onnx::Shape(%x)
  %11 : Tensor = onnx::Constant[value={2}]()
  %12 : Long() = onnx::Gather[axis=0](%10, %11) # <ipython-input-25-811817e7beb6>:6:0
  %13 : Tensor = onnx::Unsqueeze[axes=[0]](%3)
  %14 : Tensor = onnx::Unsqueeze[axes=[0]](%6)
  %15 : Tensor = onnx::Unsqueeze[axes=[0]](%9)
  %16 : Tensor = onnx::Unsqueeze[axes=[0]](%12)
  %17 : Tensor = onnx::Concat[axis=0](%13, %14, %15, %16)
  %output : Float(2, 3, 5, 4) = onnx::Reshape(%x, %17) # <ipython-input-25-811817