Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom ops Rotary #738

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
<<<<<<< HEAD
xadupre marked this conversation as resolved.
Show resolved Hide resolved
#include "cuda/rotary.h"
=======
#include "cuda/scatter_nd_of_shape.h"
>>>>>>> f5055466d5376059c2ea74e3cea46e16a537bc0d
#include "cuda/transpose_cast.h"
#endif

Expand All @@ -33,6 +37,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16

Expand All @@ -42,6 +47,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
Expand Down
83 changes: 83 additions & 0 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "rotary_impl.cuh"
#include "ortx_common.h"

namespace contrib {

/**
* Y = Rotary(X) is equivalent to if side == LEFT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = X[...,N/2:]
* Y[...,N/2:] = -X[...,:N/2]
*
* And the opposite if side == RIGHT.
xadupre marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename T>
struct Rotary {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
std::string empty;
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do LeftRotary and RightRotary so that the compute function becomes stateless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is done later here: https://github.com/microsoft/onnxruntime-extensions/pull/738/files#diff-643fdcb552aafbcb9a86bc8d48cab50dea3723b20b3105880e081b2e5b1b8da9R24. I hope the compiler is able to remove the unnecessary code at compilation time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes more reasonable to split them in the beginning instead of relying on compiler. If they are two ops, we define two ops. No need to make a mega op to support several cases; this will make future improvement harder and error-prone.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the only place I did that. There is only one op from ort point of view and it is Rotary but we have different implementation depending the argument value, the same goes with the input types. Rotary(X, side=LEFT) = Neg(Rotary(X, side=RIGHT)). Do you want to have distinct onnx names as well?

if (side == "left") {
side_ = RotarySide::LEFT;
}
else if (side == "right") {
side_ = RotarySide::RIGHT;
}
else {
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
}

return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is split related to the formula above?

Copy link
Member Author

@xadupre xadupre Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Rotary replaces something like:

x1, x2 = Split(X, split, axis=-1)
Y = Concat(-x2, x1, axis=-1)

When I implemented this kernel, I found it was equal splits on llama but I wondered if it could be not equal splits so I chose to leave this parameter so that I could still use the optimized and let the kernel fails when the splits are not equal. Then I would know not equal splits must be implemented. I can remove if you know this case never happens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure the exact behavior is described in comments for this op. I don't see the use and explanation of split in the equations. In addition to math symbols, you can use ONNX sub-graph to describe them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can assume even split and throw otherwise. With split, you will need to introduce a sub-graph to compute the right split from the shape of X. That will slows down ORT..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The formula remains the same, whether or not the splits are equal or not.

const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split,
wschin marked this conversation as resolved.
Show resolved Hide resolved
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

auto shape_split = split.Shape();
if (shape_split.size() != 1 || shape_split[0] != 2) {
wschin marked this conversation as resolved.
Show resolved Hide resolved
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
}
const int64_t* split_data = split.Data();
if (split_data[0] != split_data[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
xadupre marked this conversation as resolved.
Show resolved Hide resolved
}
if (split_data[0] * 2 != input_shape[input_shape.size()-1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
xadupre marked this conversation as resolved.
Show resolved Hide resolved
}

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
static_cast<int>(input_shape[input_shape.size()-1]),
xadupre marked this conversation as resolved.
Show resolved Hide resolved
input_data,
split_data,
output_data,
side_);
return {};
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 1) // split
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
RotarySide side_;
};

} // namespace contrib
81 changes: 81 additions & 0 deletions operators/cuda/rotary_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "rotary_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
xadupre marked this conversation as resolved.
Show resolved Hide resolved
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= half_N)
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
if (side == RotarySide::RIGHT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
}

template <typename T>
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) {
if (input_length == 0)
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);

const int num_threads_per_block = 256;
const int num_elements_per_thread =
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;

switch (side) {
case RotarySide::LEFT:
RotaryKernel<TT, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
case RotarySide::RIGHT:
RotaryKernel<TT, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}

template <>
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
const ortc::MFloat16* input_data, const int64_t* split_data,
ortc::MFloat16* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}
15 changes: 15 additions & 0 deletions operators/cuda/rotary_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

enum class RotarySide : int {
LEFT = 1,
RIGHT = 2,
};

template <typename T>
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* split_data, T* output_data, RotarySide side);
61 changes: 61 additions & 0 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,67 @@ def test_masked_scatternd_of_shape_standalone_cuda_big(self):
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)

def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"Rotary",
["X", "splits"],
["Y"],
domain="ai.onnx.contrib",
side=side,
)
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None, None]),
helper.make_tensor_value_info("splits", TensorProto.INT64, [2]),
],
[helper.make_tensor_value_info("Y", itype, [None, None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)

expected = x.copy()
half = x.shape[-1] // 2
if side == "left":
expected[:, :, :, :half] = x[:, :, :, half:]
expected[:, :, :, half:] = -x[:, :, :, :half]
else:
expected[:, :, :, :half] = -x[:, :, :, half:]
expected[:, :, :, half:] = x[:, :, :, :half]

feeds = dict(X=x, splits=splits)
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds)[0]
assert_almost_equal(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_rotary_cuda(self):
self._rotary_cuda(TensorProto.FLOAT, "left")
self._rotary_cuda(TensorProto.FLOAT, "right")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT16, "right")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_bigger_rotary_cuda(self):
sh = (2, 2, 1024, 8)
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)

def _transpose_cast_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
Expand Down
Loading