# Demos for Talk Working with Large Models in ONNX IR

In [None]:
# Prepare environment

%pip install --upgrade onnxscript onnx-ir onnx-safetensors model-explorer-onnx onnxruntime

## Demo 1: Safetensors in ONNX

**Q1:** Is there a way to use the safetensors format as an external data format for ONNX?

**A1:** Yes. The data is contiguous, row-major, and little-endian (same as ONNX). Data offset can be found by parsing the json header.

<img src="resources/safetensors-format.svg" width="500"/>

Image source: https://huggingface.co/docs/safetensors/en/index

**Q2:** How do we do it efficiently?

In [None]:
import onnx_ir as ir

model = ir.load("resources/model.textproto")
print(model)

<
    ir_version=10,
    opset_imports={'': 21},
    producer_name='onnx-safetensors-example',
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=SimpleGraph,
    inputs=(
        %"input"<FLOAT,[1,3]>
    ),
    outputs=(
        %"output"<FLOAT,[1,3]>
    ),
    initializers=(
        %"weights"<FLOAT,[3]>{TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='weights')}
    ),
) {
    0 |  # :anonymous_node:130470423510864
         %"output"<FLOAT,[1,3]> ⬅️ ::Add(%"input", %"weights"{[1.0, 2.0, 3.0]})
    return %"output"<FLOAT,[1,3]>
}




### Loading tensors from a safetensors file into an ONNX model

We first create a safetensors file with compatible weights, then load these weights into the ONNX model.

In [20]:
import numpy as np
import safetensors.numpy

import onnx_safetensors

# Create a safetensors file with compatible weights
# Note that the tensor key "weights" matches the name of the tensor in the model
weights_dict = {"weights": np.array([4.0, 5.0, 6.0], dtype=np.float32)}
safetensors.numpy.save_file(weights_dict, "resources/weights.safetensors")

# Now you can replace the weights in the model
replaced_model = onnx_safetensors.load_file(model, "resources/weights.safetensors")

# Notice how the weights have been replaced to [4, 5, 6]
print(replaced_model)

<
    ir_version=10,
    opset_imports={'': 21},
    producer_name='onnx-safetensors-example',
    producer_version=None,
    domain=None,
    model_version=None,
>
graph(
    name=SimpleGraph,
    inputs=(
        %"input"<FLOAT,[1,3]>
    ),
    outputs=(
        %"output"<FLOAT,[1,3]>
    ),
    initializers=(
        %"weights"<FLOAT,[3]>{Tensor<FLOAT,[3]>(array([4., 5., 6.], dtype=float32), name='weights')}
    ),
) {
    0 |  # :anonymous_node:130470423510864
         %"output"<FLOAT,[1,3]> ⬅️ ::Add(%"input", %"weights"{[4.0, 5.0, 6.0]})
    return %"output"<FLOAT,[1,3]>
}




Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model

In [22]:
model_with_external_data = onnx_safetensors.load_file_as_external_data(
    model, "resources/weights.safetensors"
)

print(onnx.printer.to_text(ir.to_proto(model_with_external_data)))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights = ["location": "resources/weights.safetensors", "offset": "72", "length": "12"], float[3] weights>
{
   output = Add (input, weights)
}


### Using safetensors as external data for ONNX

We can similarly save external data file from an ONNX model to safetensors. By storing the tensor dtype in ONNX file, we can even use types safetensors doesn't yet support, like INT4.

You can read more at https://github.com/justinchuby/onnx-safetensors/blob/main/examples/tutorial.ipynb

### Inference with ONNX Runtime

In [23]:
import onnxruntime as ort

ir.save(model_with_external_data, "model_with_external_data.onnx")
session = ort.InferenceSession("model_with_external_data.onnx")
output = session.run(None, {"input": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})
print("[[1.0, 2.0, 3.0]] + [4, 5, 6] =", output)

[[1.0, 2.0, 3.0]] + [4, 5, 6] = [array([[5., 7., 9.]], dtype=float32)]


## Demo 2: 