# Demos for Talk Working with Large Models in ONNX IR

In [5]:
# Prepare environment

%pip install --upgrade onnxscript onnx-ir onnx-safetensors model-explorer-onnx onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting coloredlogs (from onnxruntime)
  Using cached coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting flatbuffers (from onnxruntime)
  Using cached flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting sympy (from onnxruntime)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Using cached humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy->onnxruntime)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading onnxruntime-1.22.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hUsing cached coloredlogs-15.0.1-py2.py

## Demo 1: Safetensors in ONNX

Q1: Is there a way to use the safetensors format as an external data format for ONNX?
A1: Yes. The data is contiguous, row-major, and little-endian (same as ONNX). Data offset can be found by parsing the json header.

<img src="resources/safetensors-format.svg" width="500"/>

Image source: https://huggingface.co/docs/safetensors/en/index

In [3]:
import onnx

model = onnx.load("resources/model.textproto")
print(onnx.printer.to_text(model))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights =  {1,2,3}>
{
   output = Add (input, weights)
}


### Loading tensors from a safetensors file into an ONNX model

We first create a safetensors file with compatible weights, then load these weights into the ONNX model.

In [None]:
import numpy as np
import safetensors.numpy

import onnx_safetensors

# Create a safetensors file with compatible weights
# Note that the tensor key "weights" matches the name of the tensor in the model
weights_dict = {"weights": np.array([4.0, 5.0, 6.0], dtype=np.float32)}
safetensors.numpy.save_file(weights_dict, "resources/weights.safetensors")

# Now you can replace the weights in the model
replaced_model = onnx_safetensors.load_file(model, "resources/weights.safetensors")

# Notice how the weights have been replaced to [4, 5, 6]
print(onnx.printer.to_text(replaced_model))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights =  {4,5,6}, float[3] weights>
{
   output = Add (input, weights)
}


Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model

In [12]:
model_with_external_data = onnx_safetensors.load_file_as_external_data(
    model, "resources/weights.safetensors"
)

print(onnx.printer.to_text(model_with_external_data))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights = ["location": "resources/weights.safetensors", "offset": "72", "length": "12"], float[3] weights>
{
   output = Add (input, weights)
}


### Using safetensors as external data for ONNX

We can similarly save external data file from an ONNX model to safetensors. By storing the tensor dtype in ONNX file, we can even use types safetensors doesn't yet support, like INT4.

You can read more at https://github.com/justinchuby/onnx-safetensors/blob/main/examples/tutorial.ipynb

### Inference with ONNX Runtime

In [18]:
import onnxruntime as ort

onnx.save(model_with_external_data, "model_with_external_data.onnx")
session = ort.InferenceSession("model_with_external_data.onnx")
output = session.run(None, {"input": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})
print("[[1.0, 2.0, 3.0]] + [4, 5, 6] =", output)

[[1.0, 2.0, 3.0]] + [4, 5, 6] = [array([[5., 7., 9.]], dtype=float32)]
