Copyright (c) Microsoft Corporation. All rights reserved.  
Licensed under the MIT License.

# Accelerating LLaMA-2 Inference with ONNX Runtime

In this tutorial, you will export, optimize, and run the LLaMA-2 model using ONNX Runtime.

## Prerequisites

0. Use a machine with at least 64GB of memory. Exporting LLaMA-2 requires a significant amount of memory because of the model's size.

1. Install [Anaconda](https://www.anaconda.com/distribution/). Once installed, create a `conda` environment named `llama2` by running the following in your terminal (outside of this notebook).

```console
$ conda create -n llama2 python=3.9
$ conda activate llama2
```

If you don't have Jupyter installed to run this notebook, here is how you can install it and connect it to your new `conda` environment (run in your terminal outside of this notebook).
```console
$ pip install jupyterlab
$ conda install ipykernel
$ conda install -c conda-forge ipywidgets
$ ipython kernel install --user --name llama2
$ jupyter-lab
```

Once you have this notebook open in Jupyter, you can select the `llama2` environment that you created as the kernel for this notebook.

2. Select the `torch` package for your environment. For this notebook, you need to install `torch` with CUDA enabled for your installed CUDA version.

First, you need to identify your installed CUDA version. Run the following command in your terminal (outside of this notebook).

```console
$ nvidia-smi
```

A table should print that shows the status of your GPUs. Your CUDA version is located in the top right of the table. For this notebook, CUDA 11.8 is used as the version.

Once you have identified your CUDA version, you need to visit the [PyTorch website](https://pytorch.org/get-started/locally/) to get the download instructions for your specific `torch` package. For CUDA version 11.8 and lower, you can select `CUDA 11.8` on the PyTorch website. For CUDA version 12.0 and higher, you can select `CUDA 12.1` on the PyTorch website.

Important: You must select the `Preview (Nightly)` PyTorch build option. Otherwise, the export may fail.

In [1]:
import sys

# Uninstall existing torch
!{sys.executable} -m pip uninstall -y torch

# Example installation command for CUDA 11.8
!{sys.executable} -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118

[0mLooking in indexes: https://download.pytorch.org/whl/nightly/cu118
Collecting torch
  Using cached https://download.pytorch.org/whl/nightly/cu118/torch-2.2.0.dev20231111%2Bcu118-cp39-cp39-linux_x86_64.whl (2532.6 MB)
Collecting filelock (from torch)
  Using cached https://download.pytorch.org/whl/nightly/filelock-3.9.0-py3-none-any.whl (9.7 kB)
Collecting sympy (from torch)
  Using cached https://download.pytorch.org/whl/nightly/sympy-1.11.1-py3-none-any.whl (6.5 MB)
Collecting networkx (from torch)
  Using cached https://download.pytorch.org/whl/nightly/networkx-3.0rc1-py3-none-any.whl (2.0 MB)
Collecting fsspec (from torch)
  Using cached https://download.pytorch.org/whl/nightly/fsspec-2023.4.0-py3-none-any.whl (153 kB)
Collecting pytorch-triton==2.1.0+6e4932cda8 (from torch)
  Using cached https://download.pytorch.org/whl/nightly/pytorch_triton-2.1.0%2B6e4932cda8-cp39-cp39-linux_x86_64.whl (125.4 MB)
Collecting mpmath>=0.19 (from sympy->torch)
  Using cached https://download.pyt

3. Install the packages from the `requirements-*.txt` file that [fits your scenario](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama). Since this notebook is showing inference with CUDA, you can use `requirements-cuda.txt`.

In [2]:
# Replace requirements-*.txt filename with the one for your scenario
!wget https://raw.githubusercontent.com/microsoft/onnxruntime/main/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt

--2023-11-12 08:59:28--  https://raw.githubusercontent.com/microsoft/onnxruntime/main/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 234 [text/plain]
Saving to: ‘requirements-cuda.txt’


2023-11-12 08:59:28 (24.4 MB/s) - ‘requirements-cuda.txt’ saved [234/234]



Once you have downloaded the desired `requirements-*.txt` file, you need to download the common `requirements.txt` file that contains all shared packages across the `requirements-*.txt` files.

In [3]:
!wget https://raw.githubusercontent.com/microsoft/onnxruntime/main/onnxruntime/python/tools/transformers/models/llama/requirements.txt

--2023-11-12 08:59:33--  https://raw.githubusercontent.com/microsoft/onnxruntime/main/onnxruntime/python/tools/transformers/models/llama/requirements.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 138 [text/plain]
Saving to: ‘requirements.txt’


2023-11-12 08:59:34 (11.3 MB/s) - ‘requirements.txt’ saved [138/138]



Once downloaded, you can install the packages from the `requirements-*.txt` file for your scenario. It will use `requirements.txt` to download the shared packages.

In [4]:
!{sys.executable} -m pip install -r requirements-cuda.txt

Collecting git+https://github.com/huggingface/optimum.git (from -r requirements.txt (line 1))
  Cloning https://github.com/huggingface/optimum.git to /tmp/pip-req-build-6v3xkch4
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/optimum.git /tmp/pip-req-build-6v3xkch4
  Resolved https://github.com/huggingface/optimum.git to commit 832f3b292b501c7ab920ecc7913de3c6b7894d60
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting transformers>=4.33.2 (from -r requirements.txt (line 2))
  Using cached transformers-4.35.0-py3-none-any.whl.metadata (123 kB)
Collecting onnx>=1.14.0 (from -r requirements.txt (line 4))
  Using cached onnx-1.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (15 kB)
Collecting datasets>=2.8.0 (from -r requirements.txt (line 5))
  Using cached datasets-2.14.6-py3-none-any.whl.metadata (

4. Create an account on Hugging Face to access the LLaMA-2 models. Once you have created your account, you can apply for access to one of the models [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). Once you apply for access to one model and accept Meta's license, you will have access to [all LLaMA-2 models](https://huggingface.co/meta-llama/) in Hugging Face.

5. Once you have access, you will need to install Hugging Face's CLI interface in order to download the model and authenticate your account with Hugging Face.

In [5]:
# Install CLI interface
!{sys.executable} -m pip install huggingface_hub



Before you authenticate, make sure you set up a [user access token](https://huggingface.co/docs/hub/security-tokens) in your Hugging Face account. You can then run the following command and enter your token to authenticate.

In [6]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

If the above does not work, you can enter the following command in your terminal (outside of this notebook).

```console
$ huggingface-cli login
```

6. Verify that you have the following versions.
- Transformers: v4.33.2 or higher
  - Without this version, ONNX Runtime optimizations may not apply and you will miss performance benefits.
- Protobuf: v3.20.2
  - Without this version, you will not be able to optimize the exported ONNX model and will likely see a `Segmentation fault`.

To see what versions you have installed, you can run the following.

In [7]:
!{sys.executable} -m pip show transformers

Name: transformers
Version: 4.35.0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /data/kvaishnavi/anaconda3/envs/llama2/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: optimum


In [8]:
!{sys.executable} -m pip show protobuf

Name: protobuf
Version: 3.20.2
Summary: Protocol Buffers
Home-page: https://developers.google.com/protocol-buffers/
Author: 
Author-email: 
License: BSD-3-Clause
Location: /data/kvaishnavi/anaconda3/envs/llama2/lib/python3.9/site-packages
Requires: 
Required-by: onnx, onnxruntime-gpu


If there is a version mismatch, please uninstall the existing version and install the following versions.

In [None]:
# If `transformers` version is wrong
!{sys.executable} -m pip uninstall -y transformers
!{sys.executable} -m pip install transformers==4.33.2

In [None]:
# If `protobuf` version is wrong
!{sys.executable} -m pip uninstall -y protobuf
!{sys.executable} -m pip install protobuf==3.20.2

## 1. Export + Optimize + Quantize LLaMA-2

Now that all prerequisites have been completed, you are ready to export LLaMA-2 to an optimized (and quantized, if requested) ONNX model. Before you begin, let's define a cache directory to store downloaded files.

In [9]:
cache_dir = "./cache_dir"

To run all of these steps in one command, there is a `convert_to_onnx` script in ONNX Runtime for LLaMA-2.

In [10]:
# List all flag options with convert_to_onnx
import sys
!{sys.executable} -m onnxruntime.transformers.models.llama.convert_to_onnx --help

usage: convert_to_onnx.py [-h] -m MODEL_NAME [-i INPUT] [-o OUTPUT]
                          [-p {fp32,fp16,int8,int4}] [-e {cpu,cuda,rocm}] [-r]
                          [--use_gqa] [--no_merged]
                          [-q {blockwise,smooth_quant,quantize_dynamic}]
                          [--block_size BLOCK_SIZE]
                          [--smooth_quant_alpha SMOOTH_QUANT_ALPHA]
                          [--smooth_quant_dataset SMOOTH_QUANT_DATASET]
                          [--pad_max PAD_MAX]
                          [--calibration_sampling_size CALIBRATION_SAMPLING_SIZE]
                          [--nc_workspace NC_WORKSPACE]
                          [--quantize_embedding_layer]
                          [--quantize_per_channel] [--quantize_reduce_range]
                          [-v] [-d] [--cache_dir CACHE_DIR]

optional arguments:
  -h, --help            show this help message and exit
  -m MODEL_NAME, --model_name MODEL_NAME
                        Model name in Hugg

Here is what each of the main flags do.
- `-m/--model_name`: corresponds to the model name in Hugging Face
- `--output`: folder to store the exported ONNX model in
- `--precision`: precision you want the final exported ONNX model to be in
- `--execution_provider`: the execution provider to run the model with
- `--quantization_method`: the method by which to quantize the model. For INT4, the quantization method is called `blockwise`.
- `--use_gqa`: replace MultiHeadAttention with GroupQueryAttention. This replacement can only happen for FP16 CUDA and INT4 CUDA. This flag must also be used if `num_attention_heads != num_key_value_heads` in your model. You can determine this by running the below code.

In [11]:
from transformers import LlamaConfig

model_name = "meta-llama/Llama-2-7b-hf"  # Replace with your model name
config = LlamaConfig.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)



Downloading (…)lve/main/config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

In [12]:
f"Does {model_name} require GroupQueryAttention? {config.num_attention_heads != config.num_key_value_heads}"

'Does meta-llama/Llama-2-7b-hf require GroupQueryAttention? False'

If you are unsure if your machine satisfies the memory requirements, you should run `convert_to_onnx` outside of this notebook so that the notebook does not crash. Later in this notebook, you can load the model back.

There are a [wide range of scenarios](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama) for which you can export LLaMA-2. Here are some of the common options and the export commands for them.

FP16 CUDA (with GroupQueryAttention)

```console
// Model will be stored at ./llama2-7b-fp16-gqa/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx
$ python -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16-gqa --precision fp16 --execution_provider cuda --use_gqa
```

INT4 CPU (with FP32 inputs/outputs)
```console
// Model will be stored at ./llama2-7b-int4-cpu/rank_0_Llama-2-7b-hf_decoder_merged_model_int4.onnx
$ python -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --execution_provider cpu --quantization_method blockwise
```

INT4 CUDA (with FP16 inputs/outputs)
```console
// Model will be stored at ./llama2-7b-int4-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_int4.onnx
$ python -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --quantization_method blockwise --use_gqa
```

In [13]:
# Add/change/remove flags to run below command in your terminal for your scenario
!{sys.executable} -m onnxruntime.transformers.models.llama.convert_to_onnx -m "meta-llama/Llama-2-7b-hf" --cache_dir "./cache_dir" --output "./llama2-7b-fp16-gqa" --precision "fp16" --execution_provider "cuda" --use_gqa

PyTorch Version:2.2.0.dev20231111+cu118
Transformers Version:4.35.0
OnnxRuntime Version:1.16.2
Arguments: Namespace(model_name='meta-llama/Llama-2-7b-hf', input='.', output='./llama2-7b-fp16-gqa', precision=<Precision.FLOAT16: 'fp16'>, execution_provider='cuda', reexport=False, use_gqa=True, no_merged=False, quantization_method='', block_size=32, smooth_quant_alpha=0.8, smooth_quant_dataset='NeelNanda/pile-10k', pad_max=196, calibration_sampling_size=8, nc_workspace='./nc_workspace', quantize_embedding_layer=False, quantize_per_channel=False, quantize_reduce_range=False, verbose=False, use_dynamo_export=False, cache_dir='./cache_dir')
world_size: 1
Downloading (…)fetensors.index.json: 100%|█| 26.8k/26.8k [00:00<00:00, 14.3MB/s]
Downloading shards:   0%|                                 | 0/2 [00:00<?, ?it/s]
Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s][A
Downloading (…)of-00002.safetensors:   0%|  | 31.5M/9.98G [00:00<00:37, 266MB/s][A
Downloading (

You can also pick up an already exported + optimized ONNX model from [Microsoft's repository](https://github.com/microsoft/Llama-2-Onnx/tree/main-CUDA_CPU).

## 2. Run LLaMA-2 End-to-End

Now that the model is exported to ONNX, you can run it end-to-end. For this example, you can use the LLaMA-2 7B FP16 CUDA model with GQA.

Let's first import the necessary libraries.

In [14]:
from transformers import LlamaConfig, LlamaTokenizer
import numpy as np
import onnxruntime as ort
import torch

Next, you can set the main settings that you want to run your end-to-end scenario with.

In [15]:
# Change the below settings to your desired scenario
model_name = "meta-llama/Llama-2-7b-hf"  # Model name in Hugging Face
onnx_model_path = "./llama2-7b-fp16-gqa/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx"  # Path to exported ONNX model on disk
use_fp16 = True  # True when KV cache inputs/outputs are in float16
use_buffer_share = True  # True when --use_gqa was passed during export

prompt = ["ONNX Runtime is ", "I want to book a vacation to Hawaii. First, I need to ", "A good workout routine is ", "How are astronauts launched into space? "] # List of prompts to use
max_length = 64  # max(prompt length + generation length)

device_id = 0
device = torch.device(f"cuda:{device_id}")  # Change to torch.device("cpu") if running on CPU

With your main settings finalized, you can import the model's configuration and tokenizer.

In [16]:
config = LlamaConfig.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)



Downloading (…)okenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Now you need to prepare your inputs for the model to understand. This involves several steps.
1. Tokenize the prompts so that the model can understand the inputs.
2. Pad the tokenized output so that each batch entry has the same length.
3. Pre-allocate on-device memory to store the inputs and outputs.

   a. In a typical transformer model, the present KV cache outputs are passed as the past KV cache inputs in the next iteration. But for most models, you need to pre-allocate separate on-device memory for the past KV cache inputs and present KV cache outputs. With GroupQueryAttention, you can share this memory. This allows you to pass the present KV cache outputs directly to the past KV cache inputs.

   b. Because the on-device memory can be shared in GroupQueryAttention, you need to pre-allocate enough so that the model can run at any prompt length + generation length. Therefore, you should pre-allocate the on-device KV cache memory to have enough memory to hold the largest sequence length that the model can produce.

In [17]:
def get_initial_inputs_and_outputs(config, tokenizer, prompt, device, use_fp16, use_buffer_share):
    tokenizer.pad_token = "[PAD]"  # Set pad token for tokenizer
    encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
    torch_dtype = torch.float16 if use_fp16 else torch.float32

    # Move inputs from tokenizer to on-device memory
    input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    inputs = {
        "input_ids": input_ids.contiguous(),
        "attention_mask": attention_mask.contiguous(),
        "position_ids": position_ids.contiguous(),
    }

    # Pre-allocate on-device memory for past_key_values (past KV cache)
    # Share on-device memory if use_buffer_share is True
    batch_size, sequence_length = input_ids.shape
    max_sequence_length = config.max_position_embeddings
    num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_attention_heads
    for i in range(config.num_hidden_layers):
        past_key = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
        past_value = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
        inputs.update({
            f"past_key_values.{i}.key": past_key.contiguous(),
            f"past_key_values.{i}.value": past_value.contiguous()
        })
    
    # Pre-allocate on-device memory for logits
    logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
    outputs = {
        "logits": logits.contiguous()
    }

    # Pre-allocate on-device memory for present KV cache if use_buffer_share is False
    if not use_buffer_share:
        for i in range(config.num_hidden_layers):
            present_key = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
            present_value = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
            outputs.update({
                f"present.{i}.key": present_key.contiguous(),
                f"present.{i}.value": present_value.contiguous()
            })

    return inputs, outputs

In [18]:
inputs, outputs = get_initial_inputs_and_outputs(config, tokenizer, prompt, device, use_fp16, use_buffer_share)

Once the on-device memory has been allocated, you can load the LLaMA-2 ONNX model.

In [19]:
sess_options = ort.SessionOptions()
ep = ("CUDAExecutionProvider", {"device_id": device_id})  # change to ep = "CPUExecutionProvider" for CPU
model = ort.InferenceSession(onnx_model_path, sess_options=sess_options, providers=[ep])

2023-11-12 09:18:02.339043617 [W:onnxruntime:, session_state.cc:1162 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-11-12 09:18:02.339077992 [W:onnxruntime:, session_state.cc:1164 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.


Now that the ONNX model is loaded, you will need a way to bind the inputs and outputs to the model.

In [20]:
def apply_io_binding(model, inputs, outputs, use_fp16, use_buffer_share):
    # Check that all model inputs will be provided
    model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
    user_inputs = set(inputs.keys())
    missing_inputs = model_inputs - user_inputs
    if len(missing_inputs):
        print(f"The following model inputs are missing: {missing_inputs}")
        raise Exception("There are missing inputs to the model. Please add them and try again.")

    # Remove unnecessary inputs from model inputs
    unnecessary_inputs = user_inputs - model_inputs
    if len(unnecessary_inputs):
        for unnecessary_input in unnecessary_inputs:
            print(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
            del inputs[unnecessary_input]

    # Bind inputs/outputs to IO binding
    io_binding = model.io_binding()
    device = None
    pt_to_np = {
        "torch.int64": np.int64,
        "torch.float32": np.float32,
        "torch.float16": np.float16
    }

    for k, v in inputs.items():
        io_binding.bind_input(
            name=k,
            device_type=v.device.type,
            device_id=0 if v.device.type == "cpu" else v.device.index,
            element_type=pt_to_np[repr(v.dtype)],
            shape=tuple(v.shape),
            buffer_ptr=v.data_ptr()
        )
        device = v.device

    for output in model.get_outputs():
        name = output.name
        if use_buffer_share and "present" in name:
            # Bind KV cache outputs to KV cache inputs
            v = inputs[name.replace("present", "past_key_values")]
            io_binding.bind_output(
                name=name,
                device_type=v.device.type,
                device_id=v.device.index,
                element_type=np.float16,
                shape=tuple(v.shape),
                buffer_ptr=v.data_ptr()
            )
        else:
            v = outputs[name]
            io_binding.bind_output(
                name=name,
                device_type=device.type,
                device_id=0 if device.type == "cpu" else device.index,
                element_type=(np.float16 if use_fp16 else np.float32),
                shape=tuple(v.shape),
                buffer_ptr=v.data_ptr()
            )

    return io_binding

You are almost ready to run inference with your ONNX model. You need to store the token ids that are generated and keep track of each batch entry to see whether it has completed generation or not.

In [21]:
all_token_ids = inputs["input_ids"].clone()  # store prompt token ids + generated token ids for transcription at the end
batch_size, sequence_length = all_token_ids.shape
max_sequence_length = config.max_position_embeddings
num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_attention_heads

current_length = sequence_length  # keep track of current length (prompt length + generation length)
has_eos = torch.zeros(batch_size, device=device, dtype=torch.bool)  # keep track of each batch entry's status and whether it has reached end-of-sequence (EOS) or not

Now you can run inference and generate tokens.

In [22]:
while current_length <= max_length:
    # Run inference
    io_binding = apply_io_binding(model, inputs, outputs, use_fp16, use_buffer_share)
    io_binding.synchronize_inputs()
    model.run_with_iobinding(io_binding)
    io_binding.synchronize_outputs()

    # Sample/choose next token with argmax (greedy search)
    if outputs["logits"].shape[1] > 1:
        prompt_end_indices = inputs["attention_mask"].sum(1) - 1
        idxs = prompt_end_indices.unsqueeze(dim=1).repeat(1, config.vocab_size).view(batch_size, 1, config.vocab_size)
        next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
    else:
        next_token_logits = outputs["logits"][:, -1, :]
    next_tokens = torch.argmax(next_token_logits, dim=-1)

    # Check if we previously reached EOS token id or if generated token id is EOS token id
    has_eos = has_eos | next_tokens == tokenizer.eos_token_id

    # Determine which new tokens to add to list of all token ids
    # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
    tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
    all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)

    # Return early if:
    # 1) all batch entries have reached EOS token id or 
    # 2) we have reached the max length of a batch entry (prompt length + generation length) or
    # 3) max sequence length that the model can support
    current_length += 1
    if torch.all(has_eos) or current_length > max_length or current_length > max_sequence_length:
        break

    # Update inputs for next inference run
    inputs["input_ids"] = tokens_to_add
    inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
    inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1)

    # Set logits to zeros for next inference run and re-use memory buffer
    if outputs["logits"].shape[1] != 1:
        outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
    outputs["logits"].zero_()

    # If buffer sharing is off, pass the present KV cache from previous iteration as the past KV cache for next iteration
    if not use_buffer_share:
        for i in range(config.num_hidden_layers):
            inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
            inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]

        new_sequence_length = inputs["attention_mask"].shape[1]
        for i in range(config.num_hidden_layers):
            present_key = torch.zeros(batch_size, num_heads, new_sequence_length, head_size, device=device, dtype=torch_dtype)
            present_value = torch.zeros(batch_size, num_heads, new_sequence_length, head_size, device=device, dtype=torch_dtype)
            outputs.update({
                f"present.{i}.key": present_key.contiguous(),
                f"present.{i}.value": present_value.contiguous()
            })

Once generation is complete, you can batch decode all of the token ids to see what the model produced.

In [23]:
tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)

['ONNX Runtime is 100% open source and free to use.\nONNX Runtime is a cross-platform runtime that can be used to run ONNX models on a variety of platforms.\nONNX Runtime is a cross-',
 "I want to book a vacation to Hawaii. First, I need to 1) find a good travel agent, 2) find a good hotel, and 3) find a good flight.\nI've been to Hawaii before, so I know what I like. I'm looking for",
 'A good workout routine is 30 minutes of cardio and 30 minutes of strength training.\nA good workout routine is 30 minutes of cardio and 30 minutes of strength training. This is the best way to get in shape',
 'How are astronauts launched into space? 1. How are astronauts launched into space? 2. How do astronauts get to the moon? 3. How do astronauts get to the moon? 4. How do astronauts get to']

Congratulations! You have successfully run an end-to-end example using LLaMA-2 in ONNX Runtime. For your convenience, the above code blocks to run the end-to-end example are combined into one code block.

In [None]:
from transformers import LlamaConfig, LlamaTokenizer
import numpy as np
import onnxruntime as ort
import torch


def get_initial_inputs_and_outputs(config, tokenizer, prompt, device, use_fp16, use_buffer_share):
    tokenizer.pad_token = "[PAD]"
    encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
    torch_dtype = torch.float16 if use_fp16 else torch.float32

    input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)
    
    inputs = {
        "input_ids": input_ids.contiguous(),
        "attention_mask": attention_mask.contiguous(),
        "position_ids": position_ids.contiguous(),
    }

    batch_size, sequence_length = input_ids.shape
    max_sequence_length = config.max_position_embeddings
    num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_attention_heads
    for i in range(config.num_hidden_layers):
        past_key = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
        past_value = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
        inputs.update({
            f"past_key_values.{i}.key": past_key.contiguous(),
            f"past_key_values.{i}.value": past_value.contiguous()
        })

    logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
    outputs = {
        "logits": logits.contiguous()
    }
    if not use_buffer_share:
        for i in range(config.num_hidden_layers):
            present_key = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
            present_value = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
            outputs.update({
                f"present.{i}.key": present_key.contiguous(),
                f"present.{i}.value": present_value.contiguous()
            })

    return inputs, outputs


def apply_io_binding(model, inputs, outputs, use_fp16, use_buffer_share):
    # Check that all model inputs will be provided
    model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
    user_inputs = set(inputs.keys())
    missing_inputs = model_inputs - user_inputs
    if len(missing_inputs):
        print(f"The following model inputs are missing: {missing_inputs}")
        raise Exception("There are missing inputs to the model. Please add them and try again.")

    # Remove unnecessary inputs from model inputs
    unnecessary_inputs = user_inputs - model_inputs
    if len(unnecessary_inputs):
        for unnecessary_input in unnecessary_inputs:
            print(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
            del inputs[unnecessary_input]

    # Bind inputs/outputs to IO binding
    io_binding = model.io_binding()
    device = None

    for k, v in inputs.items():
        io_binding.bind_input(
            name=k,
            device_type=v.device.type,
            device_id=0 if v.device.type == "cpu" else v.device.index,
            element_type=pt_to_np[repr(v.dtype)],
            shape=tuple(v.shape),
            buffer_ptr=v.data_ptr()
        )
        device = v.device

    for output in model.get_outputs():
        name = output.name
        if use_buffer_share and "present" in name:
            # Bind KV cache outputs to KV cache inputs
            v = inputs[name.replace("present", "past_key_values")]
            io_binding.bind_output(
                name=name,
                device_type=v.device.type,
                device_id=v.device.index,
                element_type=np.float16,
                shape=tuple(v.shape),
                buffer_ptr=v.data_ptr()
            )
        else:
            v = outputs[name]
            io_binding.bind_output(
                name=name,
                device_type=device.type,
                device_id=0 if device.type == "cpu" else device.index,
                element_type=(np.float16 if use_fp16 else np.float32),
                shape=tuple(v.shape),
                buffer_ptr=v.data_ptr()
            )

    return io_binding

def main():
    # User settings
    model_name = "meta-llama/Llama-2-7b-hf"
    onnx_model_path = "./llama2-7b-fp16-gqa/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx"
    use_fp16 = True  # True when KV cache inputs/outputs are in float16
    use_buffer_share = True  # True when --use_gqa was passed during export

    prompt = ["ONNX Runtime is ", "I want to book a vacation to Hawaii. First, I need to ", "A good workout routine is ", "How are astronauts launched into space? "]
    max_length = 64  # max(prompt length + generation length)

    device_id = 0
    device = torch.device(f"cuda:{device_id}")  # Change to torch.device("cpu") if running on CPU

    config = LlamaConfig.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)
    tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)
    torch_dtype = torch.float16 if use_fp16 else torch.float32

    # Get model and its initial inputs/outputs
    inputs, outputs = get_initial_inputs_and_outputs(config, tokenizer, prompt, device, use_fp16, use_buffer_share)

    sess_options = ort.SessionOptions()
    ep = ("CUDAExecutionProvider", {"device_id": device_id})  # change to ep = "CPUExecutionProvider" for CPU
    model = ort.InferenceSession(onnx_model_path, sess_options=sess_options, providers=[ep])

    all_token_ids = inputs["input_ids"].clone()
    batch_size, sequence_length = all_token_ids.shape
    num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_attention_heads

    current_length = sequence_length
    has_eos = torch.zeros(batch_size, device=device, dtype=torch.bool)

    while current_length <= max_length:
        # Run inference
        io_binding = apply_io_binding(model, inputs, outputs, use_fp16, use_buffer_share)
        io_binding.synchronize_inputs()
        model.run_with_iobinding(io_binding)
        io_binding.synchronize_outputs()

        # Sample with argmax (greedy search)
        if outputs["logits"].shape[1] > 1:
            prompt_end_indices = inputs["attention_mask"].sum(1) - 1
            idxs = prompt_end_indices.unsqueeze(dim=1).repeat(1, config.vocab_size).view(batch_size, 1, config.vocab_size)
            next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
        else:
            next_token_logits = outputs["logits"][:, -1, :]
        next_tokens = torch.argmax(next_token_logits, dim=-1)

        # Check if we previously reached EOS token id or if generated token id is EOS token id
        has_eos = has_eos | next_tokens == tokenizer.eos_token_id

        # Determine which new tokens to add to list of all token ids
        # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
        tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
        all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)

        # Return early if all batch entries have reached EOS token id
        current_length += 1
        if torch.all(has_eos) or current_length > max_length:
            break

        # Update inputs for next inference run
        inputs["input_ids"] = tokens_to_add
        inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
        inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1)

        # Set logits to zeros for next inference run and re-use memory buffer
        if outputs["logits"].shape[1] != 1:
            outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
        outputs["logits"].zero_()

        if not use_buffer_share:
            for i in range(config.num_hidden_layers):
                inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
                inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]

            new_sequence_length = inputs["attention_mask"].shape[1]
            for i in range(config.num_hidden_layers):
                present_key = torch.zeros(batch_size, num_heads, new_sequence_length, head_size, device=device, dtype=torch_dtype)
                present_value = torch.zeros(batch_size, num_heads, new_sequence_length, head_size, device=device, dtype=torch_dtype)
                outputs.update({
                    f"present.{i}.key": present_key.contiguous(),
                    f"present.{i}.value": present_value.contiguous()
                })

    # Batch decoding at end of generation
    print(tokenizer.batch_decode(all_token_ids, skip_special_tokens=True))

pt_to_np = {
    "torch.int64": np.int64,
    "torch.float32": np.float32,
    "torch.float16": np.float16
}
main()
    

Note that you can also disable buffer sharing when using GroupQueryAttention. Inference will still work but performance will be worse.

Without buffer sharing, you can also run your model using [Hugging Face's Optimum](https://github.com/huggingface/optimum). Here's how you can use Optimum.

In [None]:
from transformers import LlamaConfig, LlamaTokenizer
from optimum.onnxruntime import ORTModelForCausalLM
import torch

# User settings
model_name = "meta-llama/Llama-2-7b-hf"
onnx_model_dir = "./llama2-7b-fp16-gqa/"
cache_dir = "./cache_dir"

device_id = 0
device = torch.device(f"cuda:{device_id}")  # Change to torch.device("cpu") if running on CPU

ep = "CUDAExecutionProvider"  # change to CPUExecutionProvider if running on CPU
ep_options = {"device_id": device_id}

prompt = ["ONNX Runtime is ", "I want to book a vacation to Hawaii. First, I need to ", "A good workout routine is ", "How are astronauts launched into space? "]
max_length = 64  # max(prompt length + generation length)

config = LlamaConfig.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)
config.save_pretrained(onnx_model_dir)  # Save config file in ONNX model directory
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True, cache_dir=cache_dir)
tokenizer.pad_token = "[PAD]"

model = ORTModelForCausalLM.from_pretrained(
    onnx_model_dir,
    use_auth_token=True,
    use_io_binding=True,
    provider=ep,
    provider_options={"device_id": device_id}  # comment out if running on CPU
)
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

print("-------------")
generate_ids = model.generate(**inputs, do_sample=False, max_length=max_length)
transcription = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
print(transcription)
print("-------------")