Skip to content

[add] Add operator gemmfastgelu for ROCM#12101

Merged
PeixuanZuo merged 15 commits intomasterfrom
peixuanzuo/add_gemmfastgelu
Jul 13, 2022
Merged

[add] Add operator gemmfastgelu for ROCM#12101
PeixuanZuo merged 15 commits intomasterfrom
peixuanzuo/add_gemmfastgelu

Conversation

@PeixuanZuo
Copy link
Copy Markdown
Contributor

Description: Describe your changes.

Add a new operator GemmFastGelu on ROCM EP. Fused by MatMul + FastGelu.
It's a base implementation and the performance is the same as sequential execution of MatMul + FastGelu.
Will add composable_kernel(https://github.com/ROCmSoftwarePlatform/composable_kernel) implementation for fused GemmFastGelu.

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

@PeixuanZuo PeixuanZuo requested review from pengwa and zhangyaobit July 6, 2022 07:25
@PeixuanZuo PeixuanZuo force-pushed the peixuanzuo/add_gemmfastgelu branch from 75480b6 to a47b599 Compare July 7, 2022 07:44
template <typename T>
GemmFastGelu<T>::GemmFastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
const TransformerOptions* options = TransformerOptions::GetInstance();
tuning_ = options->IsTuningEnabled();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see tuning_ is used anywhere. Suggest to remove it right now, and you can add it when you add tuning logic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a variable used in the future, delete in this pr.

} else if (CanUseStridedBatchedGemm(X->Shape(), W->Shape(),
transa, transb, stride_A, stride_B, stride_C, batch_count)) {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(RocblasHandle(),
transB,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment


#if defined(USE_ROCM)
static void RunGemmFastGeluGpuTest(const std::vector<float>& input_data, const std::vector<float>& weight_data,
const std::vector<float>& bias_data, const std::vector<float>& output_data,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment

@PeixuanZuo PeixuanZuo force-pushed the peixuanzuo/add_gemmfastgelu branch from ac0ee68 to 9fd5e4d Compare July 12, 2022 03:59
@zhangyaobit zhangyaobit self-requested a review July 12, 2022 19:25
zhangyaobit
zhangyaobit previously approved these changes Jul 12, 2022
@PeixuanZuo PeixuanZuo merged commit 5579d81 into master Jul 13, 2022
@PeixuanZuo PeixuanZuo deleted the peixuanzuo/add_gemmfastgelu branch July 13, 2022 07:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants