Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ struct ExampleRunner {
using ElementAcc = typename Gemm::ElementAccumulator;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementC = typename CollectiveEpilogue::ElementOutput;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is wrong. Why change this? Did you check this is correct if dtype of C and output is different?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this issue can be detected by pre-ci checks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It assums C and D has same dtype here. I also think we need to support more dtype combinations.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change it if you assume C and D is the same? We should add a static_assert if it only works under that assumption.

using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
Expand Down Expand Up @@ -375,7 +375,7 @@ int main(int argc, const char** argv)
// aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more
// complex epilogue examples.
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

// FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch
// policy/architecture) and defines the epilogue arguments.
Expand Down
6 changes: 3 additions & 3 deletions examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ struct ExampleRunner {

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;

using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
Expand All @@ -199,7 +198,8 @@ struct ExampleRunner {

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
using ElementAccumulator = ElementOutput;
using ElementAccumulator = ElementAccumulator;
using ElementC = typename CollectiveEpilogue::ElementOutput;

using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
Expand Down Expand Up @@ -585,7 +585,7 @@ int main(int argc, const char** argv)
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
Expand Down
24 changes: 20 additions & 4 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class CollectiveEpilogue<
using DispatchPolicy = IntelXeXMX16Group;
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
using ElementC = typename FusionCallbacks::ElementSource;
using ElementAccumulator = ElementC_;
using StrideC = StrideC_;
using InternalStrideC = cute::remove_pointer_t<StrideC>;
Expand All @@ -115,7 +115,7 @@ class CollectiveEpilogue<
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;

static_assert(cute::is_same_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
"Only Linear Combination Epilogue is supported for Grouped GEMM at the moment.");

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand Down Expand Up @@ -372,6 +372,7 @@ class CollectiveEpilogue<
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
auto trC_frag = recast<Array<typename TiledMma::ValTypeC, FragmentSize>>(trC);
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
Expand Down Expand Up @@ -421,15 +422,24 @@ class CollectiveEpilogue<
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );

constexpr bool is_same_dtype_accum_and_output = std::is_same_v<typename TiledMma::ValTypeC, ElementC>;

auto synchronize = [&] () {};
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < FragsM; epi_m++) {

if (is_C_load_needed) {
//cordinates for C and D are the same
if constexpr (is_same_dtype_accum_and_output) {
//cordinates for C and D are the same
copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC);
} else {
Tensor trC_ori = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC_ori);
auto trC_ori_frag = recast<Array<ElementC, FragmentSize>>(trC_ori);
*(trC_frag.data()) = cutlass::NumericArrayConverter<typename TiledMma::ValTypeC, ElementC, FragmentSize>{}(*(trC_ori_frag.data()));
}
}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
Expand All @@ -438,7 +448,13 @@ class CollectiveEpilogue<

CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
if constexpr (is_same_dtype_accum_and_output) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
} else {
// align dtypes firstly
auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
trD_compute_frag(epi_v) = cutlass::NumericArrayConverter<ElementCompute, ElementOutput, FragmentSize>{}(tmp);
}
}
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);

Expand Down
24 changes: 20 additions & 4 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CollectiveEpilogue<
using DispatchPolicy = IntelXeXMX16;
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementC = ElementC_;
using ElementC = typename FusionCallbacks::ElementSource;;
using ElementAccumulator = ElementC_;
using StrideC = StrideC_;
using ElementD = ElementD_;
Expand Down Expand Up @@ -350,6 +350,7 @@ class CollectiveEpilogue<
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
auto trC_frag = recast<Array<typename TiledMma::ValTypeC, FragmentSize>>(trC);
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
Expand Down Expand Up @@ -398,7 +399,9 @@ class CollectiveEpilogue<
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );


constexpr bool is_same_dtype_accum_and_output = std::is_same_v<typename TiledMma::ValTypeC, ElementC>;

auto synchronize = [&] () {};
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
Expand All @@ -407,7 +410,14 @@ class CollectiveEpilogue<
cst_callbacks.begin_loop(epi_m, epi_n);

if (is_C_load_needed) {
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
if constexpr (is_same_dtype_accum_and_output) {
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
} else {
Tensor trC_ori = make_tensor<ElementC>(Shape<Int<FragmentSize>>{});
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC_ori);
auto trC_ori_frag = recast<Array<ElementC, FragmentSize>>(trC_ori);
*(trC_frag.data()) = cutlass::NumericArrayConverter<typename TiledMma::ValTypeC, ElementC, FragmentSize>{}(*(trC_ori_frag.data()));
}
}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
Expand All @@ -416,7 +426,13 @@ class CollectiveEpilogue<

CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
if constexpr (is_same_dtype_accum_and_output) {
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
} else {
// align dtypes firstly
auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
trD_compute_frag(epi_v) = cutlass::NumericArrayConverter<ElementCompute, ElementOutput, FragmentSize>{}(tmp);
}
}
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);

Expand Down
11 changes: 6 additions & 5 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ template <
class ElementOutput_,
class ElementCompute_,
class ElementAux,
class ElementSource,
class ElementSource_,
class ElementScalar,
int AlignmentAux,
FloatRoundStyle RoundStyle,
Expand All @@ -355,28 +355,29 @@ struct FusionCallbacks<
epilogue::IntelXeXMX16,
fusion::LinCombDeEltAct<
GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_,
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
CopyOpG2R
> : XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput_,
ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
ElementCompute_, ElementAux, ElementSource_, ElementScalar, RoundStyle
> {

using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;

using Impl =
XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput,
ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
ElementCompute, ElementAux, ElementSource_, ElementScalar, RoundStyle
>;
using Operation =
fusion::LinCombDeEltAct<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle
>;

struct Arguments {
Expand Down
4 changes: 2 additions & 2 deletions test/unit/gemm/device/default_gemm_group_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration<

using TiledMma = typename CollectiveMainloop::TiledMma;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
Copy link

@sanchitintel sanchitintel Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, while the BF16 A, B and FP32 C with BF16 D/output case is supported in the main branch, the epilogue::fusion::LinearCombination API usage at this line is non-intuitive because the main branch is using a hacky way that deviates from the intended/documented use of this API, since its first template parameter is intended to be the output dtype.

Currently, the unwritten/implicit contract for this code in the current main branch seems to be:

  1. intuitively thinking of it as computing D = alpha * Accum + beta * C in Float,
  2. and then setting the correct ElementD parameter in the cutlass::epilogue::collective::CollectiveEpilogue can be thought of as converting to the correct output dtype (which is ElementOutput in this file).

It seems that when this PR would be ready, it will rectify the API usage of cutlass::epilogue::fusion::LinearCombination in this repo.

Thanks!

using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;

using FusionCallBacks = epilogue::fusion::FusionCallbacks<
epilogue::IntelXeXMX16Group,
Expand All @@ -101,7 +101,7 @@ struct DefaultGemmGroupConfiguration<
TileShape, Shape<_1, _1, _1>,
epilogue::collective::EpilogueTileAuto,
float, float,
float, LayoutC, 1,
ElementOutput, LayoutC, 1,
ElementOutput, LayoutC, 1,
epilogue::IntelXeXMX16Group,
EpilogueOp
Expand Down
Loading