Skip to content

[BUG] Unable to convert output dtype from FP32 to BF16 in Group GEMM epilogue #500

@sanchitintel

Description

@sanchitintel

Which component has the problem?

CUTLASS C++

Bug Report

Describe the bug
With BF16 A, B matrices, try computing Group GEMM. Use output dtype as BFloat16 (convert in epilogue).
Did not work with epilogue created directly with cutlass::epilogue::collective::CollectiveEpilogue.

Steps/Code to reproduce bug
Please apply this small diff & compile Group GEMM example (the same file)

diff --git a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
index bdda0536..860cbf0c 100644
--- a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
+++ b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
@@ -92,7 +92,7 @@ using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementA = bfloat16_t; // <- data type of elements in input matrix A
using ElementB = bfloat16_t; // <- data type of elements in input matrix B
-using ElementOutput = float; // <- data type of elements in output matrix D
+using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D

///////////////////////////////////////////////////////////////////////////////////////////////////

@@ -198,8 +198,8 @@ struct ExampleRunner {
using LayoutD = typename Gemm::LayoutD;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;

  • using ElementOutput = typename CollectiveEpilogue::ElementOutput;
  • using ElementAccumulator = ElementOutput;
  • using ElementOutput = bfloat16_t;

  • using ElementAccumulator = float_t;

    using StrideA = typename Gemm::GemmKernel::InternalStrideA;
    using StrideB = typename Gemm::GemmKernel::InternalStrideB;
    @@ -361,7 +361,7 @@ void initialize(const Options &options) {
    std::vector<ElementA *> ptr_A_host(options.groups);
    std::vector<ElementB *> ptr_B_host(options.groups);
    std::vector<ElementC *> ptr_C_host(options.groups);

  • std::vector<ElementC *> ptr_D_host(options.groups);
  • std::vector<ElementOutput *> ptr_D_host(options.groups);
    std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
    std::vector<ElementAccumulator *> ptr_beta_host(options.groups);

@@ -599,7 +599,7 @@ int main(int argc, const char** argv)
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,

  •      XE_2D_U32x8x16_ST_N,
    
  •      XE_2D_U16x8x16_ST_N,
         void, void>;
    

Expected behavior
Dtype conversion of MMA output should be supported in epilogue

Environment details (please complete the following information):
PVC GPU

Additional context
Main branch

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions