# ONNX Safetensors Tutorial

This notebook demonstrates how to use the public API of the `onnx_safetensors` package to load and save ONNX weights using safetensors.

In [1]:
# !pip install --upgrade onnx-safetensors

## Load ONNX model

In [2]:
import onnx

model = onnx.load("model.textproto")
print(onnx.printer.to_text(model))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights =  {1,2,3}>
{
   output = Add (input, weights)
}


## Loading tensors from a safetensors file into an ONNX model

We first create a safetensors file with compatible weights, then load these weights into the ONNX model.

In [3]:
import numpy as np
import safetensors.numpy

import onnx_safetensors

# Create a safetensors file with compatible weights
# Note that the tensor key "weights" matches the name of the tensor in the model
weights_dict = {"weights": np.array([4.0, 5.0, 6.0], dtype=np.float32)}
safetensors.numpy.save_file(weights_dict, "weights.safetensors")

# Now you can replace the weights in the model
replaced_model = onnx_safetensors.load_file(model, "weights.safetensors")

# Notice how the weights have been replaced to [4, 5, 6]
print(onnx.printer.to_text(replaced_model))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights =  {4,5,6}>
{
   output = Add (input, weights)
}


Use `load_file_as_external_data` to load safetensors as external data and replace weights in the model

In [4]:
replaced_model_with_external_data = onnx_safetensors.load_file_as_external_data(model, "weights.safetensors")

print(onnx.printer.to_text(replaced_model_with_external_data))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights = ["location": "weights.safetensors", "offset": "72", "length": "12"]>
{
   output = Add (input, weights)
}


### Using safetensors as external data for ONNX

You can also save the ONNX model to use safetensors as external data.

In [5]:
# First take the onnx model
print(onnx.printer.to_text(model))

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights =  {1,2,3}>
{
   output = Add (input, weights)
}


In [6]:
# Save the model to use safetensors as external data. It should contain 1, 2, 3
model_with_external_data = onnx_safetensors.save_file(model, 'model.safetensors', base_dir='.', replace_data=True)
print("Weights saved:", safetensors.numpy.load_file('model.safetensors'))

# This is a model referencing safetensors as external data
print("\nmodel_with_external_data:")
print(onnx.printer.to_text(model_with_external_data))

Weights saved: {'weights': array([1., 2., 3.], dtype=float32)}

model_with_external_data:
<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "onnx-safetensors-example"
>
SimpleGraph (float[1,3] input) => (float[1,3] output) 
   <float[3] weights = ["location": "model.safetensors", "offset": "72", "length": "12"]>
{
   output = Add (input, weights)
}


# Inference with ONNX Runtime

In [7]:
import onnxruntime as ort

onnx.save(model_with_external_data, "model_with_external_data.onnx")
session = ort.InferenceSession("model_with_external_data.onnx")
output = session.run(None, {"input": np.array([[1.0, 2.0, 3.0]], dtype=np.float32)})
print("Output:", output)

Output: [array([[2., 4., 6.]], dtype=float32)]
