# Exporting LLM from transformers to ONNX for running in ONNX Runtime GenAI

In this tutorial, you will learn how to export a model from transformers to ONNX, optimize it for ONNX Runtime, and run it with ONNX Runtime GenAI. The model exported will be in float32. For model quantization and better performance, refer to Olive and many other quantization tools available for ONNX.

In [None]:
# First install required packages. To use torch.onnx export, you need the latest version of onnxscript
%pip install transformers==4.55 torch==2.9.0
%pip install --upgrade onnxscript

In [1]:
import os

import torch
from transformers import AutoConfig, AutoModelForCausalLM
import transformers

from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# For this example, we will use Gemma-3 model
# MODEL_ID = "google/gemma-3-270m-it"
MODEL_ID = "google/gemma-3-1b-it"
# MODEL_ID = "google/gemma-3-4b-it"
# MODEL_ID = "google/gemma-3-27b-it"

MODEL_NAME = MODEL_ID.split("/")[-1]

### Make transformers model exportable

There are some tweaks we need to do to the model so that it is torch.export friendly. We do that by registering a simplified version of the attention implementation to transformers.

In [3]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def sdpa_attention_forward_with_check(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    dropout: float = 0.0,
    scaling: float | None = None,
    is_causal: bool | None = None,
    **kwargs,
) -> tuple[torch.Tensor, None]:
    if hasattr(module, "num_key_value_groups"):
        key = repeat_kv(key, module.num_key_value_groups)
        value = repeat_kv(value, module.num_key_value_groups)
    if attention_mask is not None and attention_mask.ndim == 4:
        attention_mask = attention_mask[:, :, :, : key.shape[-2]]
        torch._check(
            attention_mask.shape[-1] == key.shape[-2],
            lambda: "attention_mask.shape[-1] == key.shape[-2] should be true",
        )

    if is_causal is None:
        is_causal = (
            query.shape[2] > 1
            and attention_mask is None
            and getattr(module, "is_causal", True)
        )

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None


# Patch the attention functions to use the custom SDPA without vmap
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register(
    "sdpa_without_vmap",
    transformers.integrations.executorch.sdpa_mask_without_vmap,
)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_attention_forward_with_check)

### Prepare the model and example inputs for tracing

Load the model from transformers

In [4]:
def get_hf_model(model_id: str):
    """Load a Hugging Face model and its config."""
    # We use our custom attention implementation here
    config = AutoConfig.from_pretrained(
        model_id, attn_implementation="sdpa_without_vmap"
    )
    config.use_cache = True
    # Use the correct AutoModel class for your model architecture
    model = AutoModelForCausalLM.from_pretrained(model_id, config=config)

    return model, config


model, config = get_hf_model(MODEL_ID)

Then, create example inputs and specify all the input and output names. Be sure to always provide a example inputs with dimension size>=2 when the dimension is specified to be dynamic. This is due to the [0/1 specialization problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk) in PyTorch.

Refer to the torch.export documentation for examples on how you can provide dynamic shapes.

In [5]:
def make_dynamic_cache(
    past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    """Create a DynamicCache from past_key_values."""
    cache = transformers.cache_utils.DynamicCache()
    for layer_idx in range(len(past_key_values)):
        key_states, value_states = past_key_values[layer_idx]
        cache.update(key_states, value_states, layer_idx)
    return cache


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.get_text_config()
    num_hidden_layers = config.num_hidden_layers
    # batch = "batch"
    # sequence_len = "sequence_len"
    # past_sequence_len = "past_sequence_len"
    batch = torch.export.Dim("batch")
    sequence_len = torch.export.Dim("sequence_len")
    # past_sequence_len = torch.export.Dim("past_sequence_len")

    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "attention_mask": {
            0: batch,
            1: "past_sequence_len+sequence_len",
        },
        "position_ids": {
            0: batch,
            1: sequence_len,
        },
        "past_key_values": [
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
        ],
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[f"past_key_values.{i}.key" for i in range(num_hidden_layers)],
        *[f"past_key_values.{i}.value" for i in range(num_hidden_layers)],
    ]
    output_names = [
        "logits",
        *[f"present.{i}.key" for i in range(num_hidden_layers)],
        *[f"present.{i}.value" for i in range(num_hidden_layers)],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=make_dynamic_cache(
            [
                (
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                )
                for _ in range(num_hidden_layers)
            ]
        ),
    )

    return example_inputs, dynamic_shapes, input_names, output_names


# Obtain example inputs and dynamic axes
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)


In [6]:
print("input_names:", input_names)
print("output_names:", output_names)

input_names: ['input_ids', 'attention_mask', 'position_ids', 'past_key_values.0.key', 'past_key_values.1.key', 'past_key_values.2.key', 'past_key_values.3.key', 'past_key_values.4.key', 'past_key_values.5.key', 'past_key_values.6.key', 'past_key_values.7.key', 'past_key_values.8.key', 'past_key_values.9.key', 'past_key_values.10.key', 'past_key_values.11.key', 'past_key_values.12.key', 'past_key_values.13.key', 'past_key_values.14.key', 'past_key_values.15.key', 'past_key_values.16.key', 'past_key_values.17.key', 'past_key_values.18.key', 'past_key_values.19.key', 'past_key_values.20.key', 'past_key_values.21.key', 'past_key_values.22.key', 'past_key_values.23.key', 'past_key_values.24.key', 'past_key_values.25.key', 'past_key_values.0.value', 'past_key_values.1.value', 'past_key_values.2.value', 'past_key_values.3.value', 'past_key_values.4.value', 'past_key_values.5.value', 'past_key_values.6.value', 'past_key_values.7.value', 'past_key_values.8.value', 'past_key_values.9.value', 'pa

Now, define a wrapper for the transformers model so that it takes `torch.Tensor`s as inputs and returns `torch.Tensor`s as outputs. By keeping the model signature simple, we make the job of understanding the model IO easier for `torch.export` especially when there is dynamic shapes involved.

In [7]:
class TextGenerationModelWrapper(torch.nn.Module):
    """A wrapper around a Hugging Face model to adjust the forward method for ONNX export."""

    def __init__(self, model: transformers.PreTrainedModel):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
    ):
        hf_output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=True,
        )
        return hf_output.logits, hf_output.past_key_values


# Wrap the model to adjust the forward method for ONNX export
model = TextGenerationModelWrapper(model)

### Export the model

With everything ready, we can now call `torch.onnx.export` to export the model. Internally, `torch.onnx.export` calls `torch.export` to obtain the model graph, then translates the model to ONNX using ONNX Script and `onnx-ir`.

Set `opset_version` to 23 if you want to get the ONNX standard Attention and RotaryEmbedding ops. Set it to 20 (as of Oct 2025) if you want to do fusion for ONNX Runtime. As we improve the fusion logic, we expect opset 23 and higher to be the generally supported and recommended opsets very soon.

In [8]:
onnx_program = torch.onnx.export(
    model,
    (),
    kwargs=example_kwargs,
    input_names=input_names,
    output_names=output_names,
    dynamic_shapes=dynamic_shapes,
    opset_version=20,  # Set to 20 for ORT fusion rules
    dynamo=True,
    # report=True,  # Uncomment to get a report of the export
)

print("✅ Export successful")

W1017 18:57:12.658000 41808 site-packages/torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `TextGenerationModelWrapper([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `TextGenerationModelWrapper([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅




Applied 474 of general pattern rewrite rules.
✅ Export successful


Now, we can call ONNX Script optimizer to optimize for ONNX Runtime

In [9]:
from onnxscript.rewriter.ort_fusions import optimize_for_ort

print("Optimize the model with ONNX Runtime custom ops...")

# Get the onnx_ir Model with onnx_program.model. The fusion is done inplace.
_, count = optimize_for_ort(onnx_program.model)
print(f"Applied fusions: {count}")

Optimize the model with ONNX Runtime custom ops...
Applied 1 of general pattern rewrite rules.
Applied fusions: {'erf_gelu': 0, 'rms_normalization': 157, 'skip_layer_normalization': 0, 'skip_rms_normalization': 52, 'rotary_embedding': 52, 'cos_sin_cache': 52, 'partial_rotary_embedding': 0, 'sdpa': 19, 'gqa': 19, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_scale': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 0}


Save the model to disk:

In [10]:
import pathlib

model_dir = pathlib.Path(f"models/{MODEL_NAME}_ort")
os.makedirs(model_dir, exist_ok=True)

In [11]:
# Use the ONNXProgram.save method to save the model. Specifying external_data=True
# will save the model weights in external files, which is required for models > 2GB
path = model_dir / f"{MODEL_NAME}.onnx"
onnx_program.save(path, external_data=True)

print(f"🧠 Model saved to {path}")

🧠 Model saved to models/gemma-3-1b-it_ort/gemma-3-1b-it.onnx


## Run with ONNX Runtime GenAI

First load the model:

In [12]:
import onnxruntime_genai as og

model_name = MODEL_ID.split("/")[-1]
model_path = f"models/{model_name}_ort"

# Create tokenizer files if they don't exist
if not (model_dir / "tokenizer_config.json").exists():
    print("Downloading tokenizer...")
    from transformers import AutoTokenizer

    AutoTokenizer.from_pretrained(MODEL_ID).save_pretrained(model_dir)

model = og.Model(str(model_dir))
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()

# Set the max length to something sensible by default,
# since otherwise it will be set to the entire context length
search_options = {}
search_options["max_length"] = 2048
search_options["batch_size"] = 1

Now generate!

In [13]:
text = "Why is the sky blue?"

# Generate the prompt by applying the chat template
prompt = tokenizer.apply_chat_template(
    messages=f"""[{{"role": "user", "content": "{text}"}}]""",
    add_generation_prompt=True,
)

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

print("Output: ", end="", flush=True)

try:
    generator.append_tokens(input_tokens)
    while not generator.is_done():
        generator.generate_next_token()

        new_token = generator.get_next_tokens()[0]
        print(tokenizer_stream.decode(new_token), end="", flush=True)
except KeyboardInterrupt:
    print("  --control+c pressed, aborting generation--")

print()


Output: The sky is blue due to a phenomenon called **Rayleigh scattering**. Here's a breakdown of how it works:

* **Sunlight is made of all colors:** White sunlight is actually a mixture of all the colors of the rainbow – red, orange, yellow, green, blue, indigo, and violet.

* **Entering the atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny air molecules (mostly nitrogen and oxygen).

* **Scattering of light:** This collision causes the light to scatter in different directions.▁▁Rayleigh scattering is a type of scattering that's more effective at shorter wavelengths of light.

* **Blue light is scattered more:** Blue and violet light have shorter wavelengths than other colors like red and orange.▁▁Because of this, they are scattered much more strongly by the air molecules.▁▁It's like throwing a small pebble (blue light) at a bumpy surface – it bounces in all directions.

* **Why we see blue, not violet:** While violet light is scattered even *more* than

## Summary

Congratulations! You have successfully completed the journey of exporting a Hugging Face transformers model to ONNX and running it with ONNX Runtime GenAI. Here's what you accomplished:

### Key Steps Completed:
1. **Model Preparation**: Modified the Gemma-3 model to be torch.export friendly by implementing custom attention functions that avoid non-exportable operations like `vmap`
2. **Export Setup**: Created proper example inputs with dynamic shapes and defined input/output names for the ONNX export process
3. **Model Wrapping**: Wrapped the transformers model to have a clean tensor-based interface suitable for ONNX export
4. **ONNX Export**: Successfully exported the model using `torch.onnx.export` with opset version 20 for optimal ONNX Runtime compatibility
5. **Optimization**: Applied ONNX Runtime-specific fusion optimizations to improve performance
6. **Integration**: Loaded the exported model in ONNX Runtime GenAI and demonstrated text generation capabilities

### Next Steps:
- For even better performance, consider quantizing the model using tools like Olive
