Skip to content

Commit

Permalink
fix scale type in alpha and beta (PaddlePaddle#48887)
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO committed Dec 9, 2022
1 parent 9896b29 commit b561b32
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -1530,21 +1530,16 @@ class CublasFusedMLP {
beta16 =
add_residual ? static_cast<half>(1.0) : static_cast<half>(0.0);

void *alpha = nullptr, *beta = nullptr;
void *alpha = &alpha32, *beta = &beta32;
if (std::is_same<T, double>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, float>::value) {
alpha = &alpha64;
beta = &beta64;
} else if (std::is_same<T, phi::dtype::float16>::value) {
}

if (std::is_same<T, phi::dtype::float16>::value &&
FLAGS_gemm_use_half_precision_compute_type) {
alpha = &alpha16;
beta = &beta16;
} else {
PADDLE_ENFORCE_EQ(true,
false,
platform::errors::InvalidArgument(
"Only support double, float, half data type. "));
}

const auto *x_data = x->data<T>();
Expand Down

0 comments on commit b561b32

Please sign in to comment.