Skip to content

Commit

Permalink
LLaMA Model Optimization (microsoft#18021)
Browse files Browse the repository at this point in the history
### Description
This PR contains fusion-level and kernel-level optimizations for [Meta's
LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/).

Some of the added optimizations include:

- SimplifiedLayerNorm changes
  - Fusions for multiple variants
- SkipSimplifiedLayerNorm changes
  - Kernel support for CPU
- Rotary embeddings (previously did not exist)
  - Fusions for multiple variants
  - CPU and CUDA kernels
  - Supports interleaving and non-interleaving in the same kernels
  - Optimized cache that requires half of its originally exported sizes
- Reduced from `(max_sequence_length, head_size)` to
`(max_sequence_length, head_size / 2)`
- Multi-head attention
  - Support for 2D and 3D attention masks
- Group query attention (for FP16 CUDA and INT4 CUDA)
  - Integration with flash attention v2 and past-present buffer sharing
- Removes need for `attention_mask` input as it is supported in the
kernel
- 4 bit quantization
  - `block_size` parameter is available for customizing
- Support the new changes for [Microsoft
version](https://github.com/microsoft/Llama-2-Onnx)
- Support combinations of the below variants (ex: export ORT version and
run with Optimum)

Supported variants of LLaMA-2 include:
- [ORT
version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama)
- Produces one ONNX file that is already optimized (and quantized if
requested)
  - Integrates with Optimum
- [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx)
  - Already exported and available off-the-shelf
  - Faster versions of those models will be uploaded there soon
- [Hugging Face version](https://huggingface.co/meta-llama)
  - Models that end with `-hf`
- Some older and current versions of
[`transformers`](https://github.com/huggingface/transformers) and
[`optimum`](https://github.com/huggingface/optimum) that export the
model to ONNX differently
- Note that while some older versions are supported, it is recommended
to use the latest package versions.

### Usage

To use the optimizations, please see `README.md` for details. Please
note the various `requirements.txt` files for the package versions
recommended in order to use these changes.

To run the ORT transformer optimizer separately, run the script as
follows:
```
$ cd onnxruntime/onnxruntime/python/tools/transformers/
$ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type gpt2 --num_heads <number of attention heads> --hidden_size <attention hidden size> --use_external_data_format --opt_level 0
```

### Motivation and Context
This PR helps the following issues:
- microsoft#14997
- microsoft#16254
- microsoft#17681
- microsoft#17925
- microsoft/onnxruntime-inference-examples#320

This PR uses changes from the following PRs:
- pytorch/pytorch#104468
- pytorch/pytorch#109759
- microsoft#17020
- microsoft#17674
- microsoft#17890
- microsoft#17920
- huggingface/transformers#26162
- huggingface/optimum#1257
- huggingface/optimum#1289
- huggingface/optimum#1462

### New TorchDynamo Exporter (experimental stage)

This PR uses changes from the following issues and PRs to begin
supporting the [new TorchDynamo
exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter):
- huggingface/transformers#26307
- pytorch/pytorch#104903
- pytorch/pytorch#105040
- microsoft/onnxscript#847
- microsoft/onnxscript#862
- microsoft/onnxscript#493
  • Loading branch information
kunal-vaishnavi authored and kleiti committed Mar 22, 2024
1 parent e68bb7b commit 07bfbbb
Show file tree
Hide file tree
Showing 49 changed files with 5,897 additions and 563 deletions.
51 changes: 50 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RemovePadding">com.microsoft.RemovePadding</a>
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.RotaryEmbedding">com.microsoft.RotaryEmbedding</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
Expand Down Expand Up @@ -2834,7 +2835,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
<dd>Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
Expand Down Expand Up @@ -4796,6 +4797,54 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
that are multiplied to query and key before the inner product of query and key is taken.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>

#### Inputs

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
</dl>


### <a name="com.microsoft.SampleOp"></a><a name="com.microsoft.sampleop">**com.microsoft.SampleOp**</a>

Sample echo operator.
Expand Down
3 changes: 3 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -866,6 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
#include <iostream>

using onnxruntime::concurrency::ThreadPool;

Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
Expand All @@ -216,13 +217,21 @@ Status CheckInputs(const T* query,
} else if (mask_dims[0] == static_cast<int64_t>(3) * static_cast<int64_t>(batch_size) + static_cast<int64_t>(2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 3 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(sequence_length) &&
mask_dims[2] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}

if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
"Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}

Expand Down Expand Up @@ -257,7 +266,6 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
Expand Down
115 changes: 115 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using namespace onnxruntime::contrib::rotary_embedding_helper;

namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
RotaryEmbedding,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding<float>);

template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault<float>("scale", 1.0);
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

template <typename T>
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* position_ids = context->Input<Tensor>(1);
const Tensor* cos_cache = context->Input<Tensor>(2);
const Tensor* sin_cache = context->Input<Tensor>(3);

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
position_ids,
cos_cache,
sin_cache,
&parameters));

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}

const T* input_src = input->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
T* output_dest = output->MutableData<T>();

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

const int loop_len = batch_size * sequence_length * num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(pos_ids_data[0]) + s
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < head_size; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_head_size;
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
}
});

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class RotaryEmbedding final : public OpKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

protected:
float scale;
bool interleaved;
};

} // namespace contrib
} // namespace onnxruntime

0 comments on commit 07bfbbb

Please sign in to comment.