# Demos for Talk Working with Large Models in ONNX IR

In [79]:
# Prepare environment

# %pip install --upgrade onnxscript onnx-ir onnx-safetensors model-explorer-onnx onnxruntime

## Demo 1: Safetensors in ONNX

**Q1: Is there a way to use the safetensors format as an external data format for ONNX?**

**A1:** Yes. The data is contiguous, row-major, and little-endian (same as ONNX). Data offset can be found by parsing the json header.

<img src="resources/safetensors-format.svg" width="500"/>

Image source: https://huggingface.co/docs/safetensors/en/index

**Q2: How do we do it efficiently?**

**A2:** Use onnx_ir to replace the tensors.

In [80]:
import onnx_ir as ir

model = ir.load("resources/model.textproto")
print(model)

<
    ir_version=10,
    opset_imports={'': 21},
    producer_name='onnx-safetensors-example',
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=SimpleGraph,
    inputs=(
        %"input"<FLOAT,[1,3]>
    ),
    outputs=(
        %"output"<FLOAT,[1,3]>
    ),
    initializers=(
        %"weights"<FLOAT,[3]>{TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='weights')}
    ),
) {
    0 |  # :anonymous_node:130470374283680
         %"output"<FLOAT,[1,3]> ⬅️ ::Add(%"input", %"weights"{[1.0, 2.0, 3.0]})
    return %"output"<FLOAT,[1,3]>
}




### Loading tensors from a safetensors file into an ONNX model

Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model

In [81]:
model_with_external_data = onnx_safetensors.load_file_as_external_data(
    model,
    "resources/weights.safetensors",  # weights containing [4, 5, 6]
)

print(onnx.printer.to_text(ir.to_proto(model_with_external_data)))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights = ["location": "resources/weights.safetensors", "offset": "72", "length": "12"], float[3] weights>
{
   output = Add (input, weights)
}


### Using safetensors as external data for ONNX

We can similarly save external data file from an ONNX model to safetensors. By storing the tensor dtype in ONNX file, we can even use types safetensors doesn't yet support, like INT4.

You can read more at https://github.com/justinchuby/onnx-safetensors/blob/main/examples/tutorial.ipynb

### Inference with ONNX Runtime

In [82]:
import onnxruntime as ort

ir.save(model_with_external_data, "model_with_external_data.onnx")
session = ort.InferenceSession("model_with_external_data.onnx")
output = session.run(None, {"input": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})
print("[[1.0, 2.0, 3.0]] + [4, 5, 6] =", output)

[[1.0, 2.0, 3.0]] + [4, 5, 6] = [array([[5., 7., 9.]], dtype=float32)]


## Demo 2: Fusion in onnxscript for ONNX Runtime

## Demo 3: Putting it together

1. Build a model with the tape module.
2. Replace some initializers
3. Build a pass to modify the model (merge QKV weights)
4. Use rewriter to do the same thing
5. Show it on model explorer


### 1. Build a model with the tape module.

In [90]:
import numpy as np


def build_model() -> ir.Model:
    R"""
      Input Embeddings  [batch_size, seq_len, hidden_size]
          |
          V
      [Linear Layer: W_QKV]  <-- Single merged weight matrix  [hidden_size, 3 * hidden_size]
          |
          V
      [QKV Tensor]  [batch_size, seq_len, 3 * hidden_size]
       /    |    \
      /     |     \
     V      V      V
    Query  Key   Value    <-- Q, K, V are now derived by splitting the QKV Tensor
    # Each of these has shape [batch_size, seq_len, hidden_size]
        \   |   /
         \  |  /
     [Scaled Dot-Product Attention]
                |
                V
        [Attention Output]  [batch_size, seq_len, hidden_size]
    """
    batch_size = 2
    seq_len = 3
    hidden_size = 16

    # Initializer the Tape. It is simply a recorder of operations and initializers.
    tape = ir.tape.Tape()

    # Create initializers
    q_weight = tape.initializer(
        ir.tensor(np.random.rand(hidden_size, hidden_size).astype(np.float32)),
        name="q_weight",
    )
    k_weight = tape.initializer(
        ir.tensor(np.random.rand(hidden_size, hidden_size).astype(np.float32)),
        name="k_weight",
    )
    v_weight = tape.initializer(
        ir.tensor(np.random.rand(hidden_size, hidden_size).astype(np.float32)),
        name="v_weight",
    )

    # Create graph inputs
    input = ir.Value(
        name="input",
        type=ir.TensorType(ir.DataType.FLOAT),
        shape=ir.Shape([batch_size, seq_len, hidden_size]),
    )

    query = tape.op("MatMul", inputs=[input, q_weight])
    key = tape.op("MatMul", inputs=[input, k_weight])
    value = tape.op("MatMul", inputs=[input, v_weight])
    attention_output = tape.op("Attention", inputs=[query, key, value], attributes={"q_num_heads": 1, "kv_num_heads": 1})
    attention_output.shape = ir.Shape([batch_size, seq_len, hidden_size])
    attention_output.type = ir.TensorType(ir.DataType.FLOAT)

    model = ir.Model(
        graph=ir.Graph(
            inputs=[input],
            outputs=[attention_output],
            nodes=tape.nodes,
            initializers=tape.initializers,
            opset_imports={"": 23},
            name="main_graph",
        ),
        ir_version=10,
    )

    return model

### 2. Replace some initializers

In [92]:
model = build_model()

model.graph.initializers["q_weight"].const_value = (
    ir.tensor(np.zeros((hidden_size, hidden_size)).astype(np.float32), name="q_weight")
)
model.graph.initializers["q_weight"].const_value.display()

Tensor<FLOAT,[16,16]>(array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

### 3. Build a pass to modify the model (merge QKV weights)

In [93]:
# Now we want to combine the weights into a single QKV weight matrix

for node in model.graph:
    if node.op_type != "Attention":
        continue

    # Find the weights for Q, K, V
    input_val = node.inputs[0].producer().inputs[0]
    q_weight_val = node.inputs[0].producer().inputs[1]
    k_weight_val = node.inputs[1].producer().inputs[1]
    v_weight_val = node.inputs[2].producer().inputs[1]
    assert q_weight_val.const_value is not None
    assert k_weight_val.const_value is not None
    assert v_weight_val.const_value is not None
    # Show the values of the weights
    print("Q weight value:")
    q_weight_val.const_value.display()

    qkv_weight = ir.Value(
        name="qkv_weight",
        type=ir.TensorType(ir.DataType.FLOAT),
        shape=ir.Shape([hidden_size, 3 * hidden_size]),
        const_value=ir.tensor(
            np.concatenate(
                [
                    q_weight_val.const_value,
                    k_weight_val.const_value,
                    v_weight_val.const_value,
                ],
                axis=1,
            )
        ),
    )
    # Create a new MatMul node uses the combined Q, K, V weights
    combined_matmul_node = ir.node("MatMul", inputs=[input_val, qkv_weight])
    new_qkv = ir.node(
        "Split",
        inputs=combined_matmul_node.outputs,
        attributes={
            "axis": 2,
            "num_outputs": 3,
        },
        num_outputs=3,
    )
    # Add the new node to the graph and register the new initializer
    node.prepend((combined_matmul_node, new_qkv))
    # Reconnect the Attention node to use the new QKV outputs
    ir.convenience.replace_all_uses_with(
        node.inputs,
        new_qkv.outputs,
    )
    # Add the new initializer to the graph
    model.graph.register_initializer(qkv_weight)


Q weight value:
Tensor<FLOAT,[16,16]>(array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.

In [94]:
print("Combined QKV model:")
print(model)

Combined QKV model:
<
    ir_version=10,
    opset_imports={'': 23},
    producer_name=None,
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input"<FLOAT,[2,3,16]>
    ),
    outputs=(
        %"val_3"<FLOAT,[2,3,16]>
    ),
    initializers=(
        %"q_weight"<FLOAT,[16,16]>{Tensor(...)},
        %"k_weight"<FLOAT,[16,16]>{Tensor(...)},
        %"v_weight"<FLOAT,[16,16]>{Tensor(...)},
        %"qkv_weight"<FLOAT,[16,48]>{Tensor(...)}
    ),
) {
    0 |  # node_MatMul_0
         %"val_0"<?,?> ⬅️ ::MatMul(%"input", %"q_weight"{...})
    1 |  # node_MatMul_1
         %"val_1"<?,?> ⬅️ ::MatMul(%"input", %"k_weight"{...})
    2 |  # node_MatMul_2
         %"val_2"<?,?> ⬅️ ::MatMul(%"input", %"v_weight"{...})
    3 |  # node_MatMul_4
         %"val_4"<?,?> ⬅️ ::MatMul(%"input", %"qkv_weight"{...})
    4 |  # node_Split_5
         %"val_5"<?,?>, %"val_6"<?,?>, %"val_7"<?,?> ⬅️ ::Split(%"val_4") {axis=2, num_outputs=3}

### 3.5. Run passes

In [95]:
import onnx_ir.passes.common as common_passes

common_passes.RemoveUnusedNodesPass()(model)

print("Model after removing unused nodes:")
print(model)

Model after removing unused nodes:
<
    ir_version=10,
    opset_imports={'': 23},
    producer_name=None,
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input"<FLOAT,[2,3,16]>
    ),
    outputs=(
        %"val_3"<FLOAT,[2,3,16]>
    ),
    initializers=(
        %"qkv_weight"<FLOAT,[16,48]>{Tensor(...)}
    ),
) {
    0 |  # node_MatMul_4
         %"val_4"<?,?> ⬅️ ::MatMul(%"input", %"qkv_weight"{...})
    1 |  # node_Split_5
         %"val_5"<?,?>, %"val_6"<?,?>, %"val_7"<?,?> ⬅️ ::Split(%"val_4") {axis=2, num_outputs=3}
    2 |  # node_Attention_3
         %"val_3"<FLOAT,[2,3,16]> ⬅️ ::Attention(%"val_5", %"val_6", %"val_7") {q_num_heads=1, kv_num_heads=1}
    return %"val_3"<FLOAT,[2,3,16]>
}




In [96]:
common_passes.ShapeInferencePass()(model)

print("Model after shape inference:")
print(model)

Model after shape inference:
<
    ir_version=10,
    opset_imports={'': 23},
    producer_name=None,
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input"<FLOAT,[2,3,16]>
    ),
    outputs=(
        %"val_3"<FLOAT,[2,3,16]>
    ),
    initializers=(
        %"qkv_weight"<FLOAT,[16,48]>{Tensor(...)}
    ),
) {
    0 |  # node_MatMul_4
         %"val_4"<FLOAT,[2,3,48]> ⬅️ ::MatMul(%"input", %"qkv_weight"{...})
    1 |  # node_Split_5
         %"val_5"<FLOAT,[2,3,16]>, %"val_6"<FLOAT,[2,3,16]>, %"val_7"<FLOAT,[2,3,16]> ⬅️ ::Split(%"val_4") {axis=2, num_outputs=3}
    2 |  # node_Attention_3
         %"val_3"<FLOAT,[2,3,16]> ⬅️ ::Attention(%"val_5", %"val_6", %"val_7") {q_num_heads=1, kv_num_heads=1}
    return %"val_3"<FLOAT,[2,3,16]>
}




### 4. Use rewriter to do the same thing

In [88]:
# Now use the rewriter

import onnxscript.rewriter as rewriter

class CombineQKVWeights(rewriter.pattern.RewriteRuleClassBase):
    def pattern(cls, op, input, q_weight, k_weight, v_weight):
        q = op.MatMul(input, q_weight)
        k = op.MatMul(input, k_weight)
        v = op.MatMul(input, v_weight)
        return op.Attention(q, k, v, q_num_heads=1, kv_num_heads=1)

    def rewrite(cls, op, input, q_weight, k_weight, v_weight):
        qkv_weight = op.initializer(
            ir.tensor(
                np.concatenate(
                    [
                        q_weight.const_value.numpy(),
                        k_weight.const_value.numpy(),
                        v_weight.const_value.numpy(),
                    ],
                    axis=1,
                )
            ),
            name="qkv_weight",
        )
        combined_matmul = op.MatMul(input, qkv_weight)
        new_q, new_k, new_v = op.Split(combined_matmul, axis=2, num_outputs=3, _outputs=3)
        return op.Attention(new_q, new_k, new_v, q_num_heads=1, kv_num_heads=1)

model = build_model()
# Create the rewrite rule
rule = CombineQKVWeights.rule()
# Apply the rewrite rule to the model
rule.apply_to_model(model)
# Clean up and run shape inference. Note that you can use the Sequential pass to chain multiple passes together.
ir.passes.Sequential(
    common_passes.RemoveUnusedNodesPass(),
    common_passes.ShapeInferencePass(),
)(model)

print(model)

<
    ir_version=10,
    opset_imports={'': 23},
    producer_name=None,
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input"<FLOAT,[2,3,16]>
    ),
    outputs=(
        %"val_3"<FLOAT,[2,3,16]>
    ),
    initializers=(
        %"qkv_weight"<FLOAT,[16,48]>{Tensor(...)}
    ),
) {
    0 |  # node_MatMul_4
         %"val_4"<FLOAT,[2,3,48]> ⬅️ ::MatMul(%"input", %"qkv_weight"{...})
    1 |  # node_Split_5
         %"val_5"<FLOAT,[2,3,16]>, %"val_6"<FLOAT,[2,3,16]>, %"val_7"<FLOAT,[2,3,16]> ⬅️ ::Split(%"val_4") {axis=2, num_outputs=3}
    2 |  # node_Attention_6
         %"val_3"<FLOAT,[2,3,16]> ⬅️ ::Attention(%"val_5", %"val_6", %"val_7") {q_num_heads=1, kv_num_heads=1}
    return %"val_3"<FLOAT,[2,3,16]>
}




### 5. Show it on model explorer

In [97]:
ir.save(model, "merged_qkv.onnx")

In [98]:
!onnxvis merged_qkv.onnx

Loading extensions...
2025-06-07 18:56:49.125906: I external/org_tensorflow/tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Loaded 9 extensions:
 - TFLite adapter (Flatbuffer)
 - TFLite adapter (MLIR)
 - TF adapter (MLIR)
 - TF adapter (direct)
 - GraphDef adapter
 - Pytorch adapter (exported program)
 - MLIR adapter
 - ONNX adapter
 - JSON adapter

Starting Model Explorer server at:
http://localhost:8083/?data=%7B%22models%22%3A%20%5B%7B%22url%22%3A%20%22/home/justinchu/dev/onnx-meetup-2025/merged_qkv.onnx%22%7D%5D%7D

Press Ctrl+C to stop.
gio: http://localhost:8083/?data=%7B%22models%22%3A%20%5B%7B%22url%22%3A%20%22/home/justinchu/dev/onnx-meetup-2025/merged_qkv.onnx%22%7D%5D%7D: Operation not supported
Stopping server...
^C
