Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#20 from mthreads/operators
Browse files Browse the repository at this point in the history
Operators
  • Loading branch information
caizhi-mt authored and mt-robot committed Aug 4, 2023
2 parents fa1d95c + 3d50ac0 commit 76797b7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 63 deletions.
48 changes: 24 additions & 24 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,30 +711,30 @@ __global__ void VectorizedElementwiseKernel(
kps::IndexType main_offset,
int read_lens,
Functor func) {
//kps::IndexType data_offset =
// static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
//kps::IndexType stride =
// static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
//for (; data_offset < main_offset; data_offset += stride) {
// VectorizedElementwiseKernelImpl<OutT,
// Functor,
// Arity,
// NumOuts,
// VecSize,
// false>(
// ins, outs, data_offset, read_lens * BLOCK_NUM_X, read_lens, func);
//}

//kps::IndexType remain = numel - data_offset;
//if (remain > 0) {
// VectorizedElementwiseKernelImpl<OutT,
// Functor,
// Arity,
// NumOuts,
// VecSize,
// true>(
// ins, outs, data_offset, static_cast<int>(remain), read_lens, func);
//}
kps::IndexType data_offset =
static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
kps::IndexType stride =
static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<OutT,
Functor,
Arity,
NumOuts,
VecSize,
false>(
ins, outs, data_offset, read_lens * BLOCK_NUM_X, read_lens, func);
}

kps::IndexType remain = numel - data_offset;
if (remain > 0) {
VectorizedElementwiseKernelImpl<OutT,
Functor,
Arity,
NumOuts,
VecSize,
true>(
ins, outs, data_offset, static_cast<int>(remain), read_lens, func);
}
}

template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
Expand Down
42 changes: 3 additions & 39 deletions paddle/phi/kernels/kps/elementwise_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,57 +69,21 @@ void Float32Bfloat16OrFloat16AddCudaFunctor(const Context& dev_ctx,
}
}

// TODO(MTAI): The following code is temporary, which is just a demo for MUSA.
// It will be removed later.
using muTensor = ::musa::dnn::Tensor;
using BINARY_MODE = ::musa::dnn::Binary::Mode;
muTensor CreateMUTensor(const DenseTensor& tensor) {
muTensor mu_tensor;
mu_tensor.SetNdInfo(tensor.dims().size(), tensor.dims().Get());
switch (tensor.dtype()) {
case DataType::FLOAT32:
mu_tensor.SetType(muTensor::Type::FLOAT);
break;
case DataType::INT32:
mu_tensor.SetType(muTensor::Type::INT32);
break;
case DataType::INT64:
mu_tensor.SetType(muTensor::Type::INT64);
break;
default:
std::cerr << "=========mismatch dtype in add kernel=====\n";
throw;
}
mu_tensor.SetAddr(tensor.data());
return mu_tensor;
}

template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_MUSA)
if (x.dtype() == phi::DataType::FLOAT32 &&
(y.dtype() == phi::DataType::BFLOAT16 ||
y.dtype() == phi::DataType::FLOAT16)) {
using Type = DataTypeToCppType<phi::DataType::FLOAT32>::type;
Float32Bfloat16OrFloat16AddCudaFunctor<Type, Context>(dev_ctx, x, y, out);
} else {
#endif
// AddCudaFunctor<T, Context>(dev_ctx, x, y, -1, out);
dev_ctx.template Alloc<T>(out);
using muHandle = ::musa::dnn::Handle;
::musa::dnn::Handle h;
muTensor musa_self = CreateMUTensor(x);
muTensor musa_other = CreateMUTensor(y);
muTensor musa_out = CreateMUTensor(*out);

::musa::dnn::Binary binary_op;
binary_op.SetMode(BINARY_MODE::ADD);
binary_op.Run(h, musa_out, musa_self, musa_other);

#ifdef PADDLE_WITH_CUDA
AddCudaFunctor<T, Context>(dev_ctx, x, y, -1, out);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_MUSA)
}
#endif
}
Expand Down

0 comments on commit 76797b7

Please sign in to comment.