Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom ops Rotary #738

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
#include "cuda/replace_zero.h"
#include "cuda/rotary.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
#endif
Expand Down Expand Up @@ -36,6 +37,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16

Expand All @@ -48,6 +50,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
Expand Down
81 changes: 81 additions & 0 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "rotary_impl.cuh"
#include "ortx_common.h"

namespace contrib {

/**
* Y = Rotary(X) is equivalent to if side == LEFT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = X[...,N/2:]
* Y[...,N/2:] = -X[...,:N/2]
*
* And the opposite if side == RIGHT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = -X[...,N/2:]
* Y[...,N/2:] = X[...,:N/2]
*/
template <typename T>
struct Rotary {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
std::string empty;
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do LeftRotary and RightRotary so that the compute function becomes stateless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is done later here: https://github.com/microsoft/onnxruntime-extensions/pull/738/files#diff-643fdcb552aafbcb9a86bc8d48cab50dea3723b20b3105880e081b2e5b1b8da9R24. I hope the compiler is able to remove the unnecessary code at compilation time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes more reasonable to split them in the beginning instead of relying on compiler. If they are two ops, we define two ops. No need to make a mega op to support several cases; this will make future improvement harder and error-prone.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the only place I did that. There is only one op from ort point of view and it is Rotary but we have different implementation depending the argument value, the same goes with the input types. Rotary(X, side=LEFT) = Neg(Rotary(X, side=RIGHT)). Do you want to have distinct onnx names as well?

if (side == "left") {
side_ = RotarySide::LEFT;
} else if (side == "right") {
side_ = RotarySide::RIGHT;
} else {
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
}

return {};
}

OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split, ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

auto shape_split = split.Shape();
if (shape_split.size() != 1 || shape_split[0] != 2) {
wschin marked this conversation as resolved.
Show resolved Hide resolved
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
}
const int64_t* split_data = split.Data();
if (split_data[0] != split_data[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split is allowed."};
}
if (split_data[0] != split_data[1] != input_shape[input_shape.size() - 1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
xadupre marked this conversation as resolved.
Show resolved Hide resolved
}

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), input_length,
static_cast<int>(input_shape[input_shape.size() - 1]), input_data, split_data, output_data,
side_);
return {};
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 1) // split
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
RotarySide side_;
};

} // namespace contrib
81 changes: 81 additions & 0 deletions operators/cuda/rotary_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "rotary_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
xadupre marked this conversation as resolved.
Show resolved Hide resolved
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= half_N)
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
if (side == RotarySide::RIGHT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
}

template <typename T>
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) {
if (input_length == 0)
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);

const int num_threads_per_block = 256;
const int num_elements_per_thread =
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;

switch (side) {
case RotarySide::LEFT:
RotaryKernel<TT, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
case RotarySide::RIGHT:
RotaryKernel<TT, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}

template <>
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
const ortc::MFloat16* input_data, const int64_t* split_data,
ortc::MFloat16* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}
15 changes: 15 additions & 0 deletions operators/cuda/rotary_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

enum class RotarySide : int {
LEFT = 1,
RIGHT = 2,
};

template <typename T>
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* split_data, T* output_data, RotarySide side);
61 changes: 61 additions & 0 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,67 @@ def test_masked_scatternd_of_shape_standalone_cuda_big(self):
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)

def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"Rotary",
["X", "splits"],
["Y"],
domain="ai.onnx.contrib",
side=side,
)
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None, None]),
helper.make_tensor_value_info("splits", TensorProto.INT64, [2]),
],
[helper.make_tensor_value_info("Y", itype, [None, None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)

expected = x.copy()
half = x.shape[-1] // 2
if side == "left":
expected[:, :, :, :half] = x[:, :, :, half:]
expected[:, :, :, half:] = -x[:, :, :, :half]
else:
expected[:, :, :, :half] = -x[:, :, :, half:]
expected[:, :, :, half:] = x[:, :, :, :half]

feeds = dict(X=x, splits=splits)
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds)[0]
assert_almost_equal(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_rotary_cuda(self):
self._rotary_cuda(TensorProto.FLOAT, "left")
self._rotary_cuda(TensorProto.FLOAT, "right")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT16, "right")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_bigger_rotary_cuda(self):
sh = (2, 2, 1024, 8)
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)

def _transpose_cast_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
Expand Down
Loading