In [1]:
from onnx.onnx_pb2 import *
from onnx import checker, helper

In [2]:
# Arg
arg = helper.make_attribute("this_is_an_int", 1701)
print("\nInt attribute:\n")
print(str(arg))

arg = helper.make_attribute("this_is_a_float", 3.14)
print("\nFloat attribute:\n")
print(str(arg))

arg = helper.make_attribute("this_is_a_string", "string_content")
print("\nString attribute:\n")
print(str(arg))

arg = helper.make_attribute("this_is_a_repeated_int", [1, 2, 3, 4])
print("\nRepeated int attribute:\n")
print(str(arg))


Int attribute:

name: "this_is_an_int"
i: 1701


Float attribute:

name: "this_is_a_float"
f: 3.14


String attribute:

name: "this_is_a_string"
s: "string_content"


Repeated int attribute:

name: "this_is_a_repeated_int"
ints: 1
ints: 2
ints: 3
ints: 4



In [3]:
# node
node_proto = helper.make_node("Relu", ["X"], ["Y"])

print("\nNodeProto:\n")
print(str(node_proto))


NodeProto:

input: "X"
output: "Y"
op_type: "Relu"



In [4]:
# node with args
node_proto = helper.make_node(
    "Conv", ["X", "W", "B"], ["Y"],
    kernel=3, stride=1, pad=1)

print("\nNodeProto:\n")
print(str(node_proto))

print("\nMore Readable NodeProto (no args yet):\n")
print(helper.printable_node(node_proto))


NodeProto:

input: "X"
input: "W"
input: "B"
output: "Y"
op_type: "Conv"
attribute {
  name: "kernel"
  i: 3
}
attribute {
  name: "pad"
  i: 1
}
attribute {
  name: "stride"
  i: 1
}


More Readable NodeProto (no args yet):

%Y = Conv %X, %W, %B


In [5]:
# graph
graph_proto = helper.make_graph(
    [
        helper.make_node("FC", ["X", "W1", "B1"], ["H1"]),
        helper.make_node("Relu", ["H1"], ["R1"]),
        helper.make_node("FC", ["R1", "W2", "B2"], ["Y"]),
    ],
    "MLP",
    [
        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W1', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('B1', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W2', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('B2', TensorProto.FLOAT, [1]),
    ],
    [
        helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),
    ]
)

print("\ngraph proto:\n")
print(str(graph_proto))

print("\nMore Readable GraphProto:\n")
print(helper.printable_graph(graph_proto))


graph proto:

version: 1
node {
  input: "X"
  input: "W1"
  input: "B1"
  output: "H1"
  op_type: "FC"
}
node {
  input: "H1"
  output: "R1"
  op_type: "Relu"
}
node {
  input: "R1"
  input: "W2"
  input: "B2"
  output: "Y"
  op_type: "FC"
}
name: "MLP"
input: "X"
input: "W1"
input: "B1"
input: "W2"
input: "B2"
output: "Y"


More Readable GraphProto:

graph MLP %X, %W1, %B1, %W2, %B2 {
  %H1 = FC %X, %W1, %B1
  %R1 = Relu %H1
  %Y = FC %R1, %W2, %B2
  return %Y
}


In [6]:
# An node that is also a graph
graph_proto = helper.make_graph(
    [
        helper.make_node("FC", ["X", "W1", "B1"], ["H1"]),
        helper.make_node("Relu", ["H1"], ["R1"]),
        helper.make_node("FC", ["R1", "W2", "B2"], ["Y"]),
    ],
    "MLP",
    [
        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W1', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('B1', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W2', TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('B2', TensorProto.FLOAT, [1]),
    ],
    [
        helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),
    ]
)

# output = ThisSpecificgraph([input, w1, b1, w2, b2])
node_proto = helper.make_node(
    "graph",
    ["Input", "W1", "B1", "W2", "B2"],
    ["Output"],
    graph=[graph_proto],
)

print("\nNodeProto that contains a graph:\n")
print(str(node_proto))


NodeProto that contains a graph:

input: "Input"
input: "W1"
input: "B1"
input: "W2"
input: "B2"
output: "Output"
op_type: "graph"
attribute {
  name: "graph"
  graphs {
    version: 1
    node {
      input: "X"
      input: "W1"
      input: "B1"
      output: "H1"
      op_type: "FC"
    }
    node {
      input: "H1"
      output: "R1"
      op_type: "Relu"
    }
    node {
      input: "R1"
      input: "W2"
      input: "B2"
      output: "Y"
      op_type: "FC"
    }
    name: "MLP"
    input: "X"
    input: "W1"
    input: "B1"
    input: "W2"
    input: "B2"
    output: "Y"
  }
}

