# Intro to the ONNX IR

https://github.com/microsoft/onnxscript/tree/main/onnxscript/ir

The ONNX IR is our attempt at providing a robust, efficient and Pythonic in-memory IR for ONNX to power model building, analysis and manipulation. It has

- **Full ONNX spec support**: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them).
- **Low memory footprint**: mmap'ed external tensors; unified interface for ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size limitation. Zero copies.
- **Straightforward access patterns**: Access value information and traverse the graph topology at ease.
- **Robust mutation support**: Create as many iterators as you like on the graph while mutating it.
- **Speed**: Performant graph manipulation, serialization/deserialization to Protobuf.
- **Pythonic and familiar APIs**: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
- **No protobuf dependency**: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.

In [None]:
import onnx
from onnxscript import ir

In [None]:
# Load ONNX model

proto = onnx.load("/home/justinchu/dev/onnx-script/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx")
model = ir.serde.deserialize_model(proto)

In [None]:
model

In [None]:
graph = model.graph

In [None]:
graph.display()

In [None]:
print(graph.initializers.keys())

In [None]:
graph.initializers["conv_stem.weight"].display()

In [None]:
# graph.initializers["model.embed_tokens.weight"].numpy()

In [None]:
len(graph)

In [None]:
node = graph[6]

In [None]:
print(node)

In [None]:
new_node = ir.Node(
    "my_custom_domain",
    "Linear_classifier",
    node.inputs,
    name="new_torch_nn_modules_linear_Linear_classifier_1"
)

In [None]:
new_node.display()

In [None]:
for value, replacement in zip(node.outputs, new_node.outputs):
    for user_node, index in tuple(value.uses()):
        user_node.replace_input_with(index, replacement)

In [None]:
graph.insert_after(node, new_node)

In [None]:
print(graph)

In [None]:
graph.remove(node, safe=True)

In [None]:
graph.outputs[0] = new_node.outputs[0]

In [None]:
print(graph)

In [None]:
print(node.inputs)

In [None]:
graph.remove(node, safe=True)

In [None]:
new_model_proto = ir.serde.serialize_model(model)

In [None]:
graph.outputs[0].shape = ir.Shape([1, 1000])

In [None]:
graph.outputs[0].dtype = ir.DataType.FLOAT

In [None]:
graph.outputs[0]

In [None]:
print(graph)

In [None]:
new_model_proto = ir.serde.serialize_model(model)
onnx.save(new_model_proto, "new_model_proto.onnx")

In [None]:
len(node.inputs)

In [None]:
len(node.outputs)

In [None]:
graph.insert_after