In [18]:
import onnx
from onnx import helper
from onnx import TensorProto

a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])

In [19]:
mul = helper.make_node('Mul', ['a', 'x'], ['ax'])
add = helper.make_node('Add', ['ax', 'b'], ['output'])

graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])

In [20]:
model = helper.make_model(graph)

In [21]:
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')

ir_version: 10
opset_import {
  version: 21
}
graph {
  node {
    input: "a"
    input: "x"
    output: "ax"
    op_type: "Mul"
  }
  node {
    input: "ax"
    input: "b"
    output: "output"
    op_type: "Add"
  }
  name: "linear_func"
  input {
    name: "a"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  input {
    name: "x"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  input {
    name: "b"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 10
          }
          dim {
            dim_value: 10
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_typ

In [22]:
import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession('linear_func.onnx')
a = np.random.randn(10, 10).astype(np.float32)
x = np.random.randn(10, 10).astype(np.float32)
b = np.random.randn(10, 10).astype(np.float32)

output = sess.run(None, {'a': a, 'x': x, 'b': b})[0]
assert np.allclose(output, a * x + b)

### Modify model

In [23]:
model = onnx.load('linear_func.onnx')
onnx.checker.check_model(model)

In [24]:
nodes = model.graph.node
print(nodes)

[input: "a"
input: "x"
output: "ax"
op_type: "Mul"
, input: "ax"
input: "b"
output: "output"
op_type: "Add"
]


In [25]:
nodes[1].op_type = 'Sub'
onnx.checker.check_model(model)

onnx.save(model, 'linear_func_sub.onnx')

In [26]:
sess = onnxruntime.InferenceSession('linear_func_sub.onnx')
output = sess.run(None, {'a': a, 'x': x, 'b': b})[0]
assert np.allclose(output, a * x - b)

### Debug model

In [27]:
import torch 
 
class Model(torch.nn.Module): 
 
    def __init__(self): 
        super().__init__() 
        self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
        self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3), 
                                          torch.nn.Conv2d(3, 3, 3)) 
    def forward(self, x): 
        x = self.convs1(x) 
        x1 = self.convs2(x) 
        x2 = self.convs3(x) 
        x = x1 + x2 
        x = self.convs4(x) 
        return x 
 
model = Model() 
input = torch.randn(1, 3, 20, 20) 
 
torch.onnx.export(model, input, 'whole_model.onnx') 

In [30]:
model = onnx.load('whole_model.onnx')
for node in model.graph.node:
    print(node.input)
    print(node.output)
    print()

['input.1', 'convs1.0.weight', 'convs1.0.bias']
['/convs1/convs1.0/Conv_output_0']

['/convs1/convs1.0/Conv_output_0', 'convs1.1.weight', 'convs1.1.bias']
['/convs1/convs1.1/Conv_output_0']

['/convs1/convs1.1/Conv_output_0', 'convs1.2.weight', 'convs1.2.bias']
['/convs1/convs1.2/Conv_output_0']

['/convs1/convs1.2/Conv_output_0', 'convs2.0.weight', 'convs2.0.bias']
['/convs2/convs2.0/Conv_output_0']

['/convs2/convs2.0/Conv_output_0', 'convs2.1.weight', 'convs2.1.bias']
['/convs2/convs2.1/Conv_output_0']

['/convs1/convs1.2/Conv_output_0', 'convs3.0.weight', 'convs3.0.bias']
['/convs3/convs3.0/Conv_output_0']

['/convs3/convs3.0/Conv_output_0', 'convs3.1.weight', 'convs3.1.bias']
['/convs3/convs3.1/Conv_output_0']

['/convs2/convs2.1/Conv_output_0', '/convs3/convs3.1/Conv_output_0']
['/Add_output_0']

['/Add_output_0', 'convs4.0.weight', 'convs4.0.bias']
['/convs4/convs4.0/Conv_output_0']

['/convs4/convs4.0/Conv_output_0', 'convs4.1.weight', 'convs4.1.bias']
['/convs4/convs4.1/Conv_o

In [31]:
onnx.utils.extract_model('whole_model.onnx', 'partial_model.onnx', ['input.1'], ['/Add_output_0'])

partial_model = onnx.load('partial_model.onnx')
print(partial_model)

ir_version: 8
opset_import {
  version: 17
}
producer_name: "onnx.utils.extract_model"
graph {
  node {
    input: "input.1"
    input: "convs1.0.weight"
    input: "convs1.0.bias"
    output: "/convs1/convs1.0/Conv_output_0"
    name: "/convs1/convs1.0/Conv"
    op_type: "Conv"
    attribute {
      name: "dilations"
      type: INTS
      ints: 1
      ints: 1
    }
    attribute {
      name: "group"
      type: INT
      i: 1
    }
    attribute {
      name: "kernel_shape"
      type: INTS
      ints: 3
      ints: 3
    }
    attribute {
      name: "pads"
      type: INTS
      ints: 0
      ints: 0
      ints: 0
      ints: 0
    }
    attribute {
      name: "strides"
      type: INTS
      ints: 1
      ints: 1
    }
  }
  node {
    input: "/convs1/convs1.0/Conv_output_0"
    input: "convs1.1.weight"
    input: "convs1.1.bias"
    output: "/convs1/convs1.1/Conv_output_0"
    name: "/convs1/convs1.1/Conv"
    op_type: "Conv"
    attribute {
      name: "dilations"
      type: