Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,10 @@ void initialize(const Options &options) {
}

/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true)
typename Gemm::Arguments args_from_options(const Options &options,
const cutlass::KernelHardwareInfo& hw_info,
bool host_problem_shapes_available = true,
bool use_nullptr_c = false)
{
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
Expand Down Expand Up @@ -458,7 +461,7 @@ void initialize(const Options &options) {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info,
{1, RasterOrderOptions::AlongN}
};
Expand All @@ -468,7 +471,7 @@ void initialize(const Options &options) {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
{fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info,
{1, RasterOrderOptions::AlongN}
};
Expand All @@ -477,13 +480,16 @@ void initialize(const Options &options) {
return arguments;
}

cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) {
cutlass::Status run(const Options& options,
const cutlass::KernelHardwareInfo& hw_info,
bool host_problem_shapes_available = true,
bool use_nullptr_c = false) {
allocate(options);
initialize(options);

Gemm gemm_op;

auto arguments = args_from_options(options, hw_info, host_problem_shapes_available);
auto arguments = args_from_options(options, hw_info, host_problem_shapes_available, use_nullptr_c);

size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Expand Down Expand Up @@ -530,26 +536,8 @@ void initialize(const Options &options) {

};

int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

template<bool use_nullptr_c=false>
void launcher(Options& options) {
//
// Run examples
//
Expand Down Expand Up @@ -584,8 +572,11 @@ int main(int argc, const char** argv)
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueOp = cute::conditional_t<use_nullptr_c,
cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest, false>,
cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest, true>>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;
Expand Down Expand Up @@ -626,7 +617,38 @@ int main(int argc, const char** argv)

ExampleRunner<Gemm> runner;

CUTLASS_CHECK(runner.run(options, hw_info));
CUTLASS_CHECK(runner.run(options, hw_info, true, /* use_nullptr_c = */use_nullptr_c));
}


int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
if (options.beta == 0.f) {
// the reference kernel doesn't accept nullptr for C, so we only test for nullptr ptr_C epilogue arg
// when beta is 0.
std::cout << "\n\nUse a nullptr as argument ptr_C of the group GEMM epilogue colective\n\n";
launcher<true>(options);
std::cout << "\n\nPass actual ptr_C as an argument to the group GEMM epilogue colective\n\n";
}
launcher<false>(options);

return 0;

}
8 changes: 5 additions & 3 deletions include/cutlass/epilogue/collective/builders/xe_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ namespace detail {
template <
class ElementD,
class ElementCompute,
class ElementC
class ElementC,
cutlass::FloatRoundStyle RoundStyle_,
bool supportSource_
>
struct FusionOpInfo<cutlass::epilogue::fusion::LinearCombination<
ElementD, ElementCompute, ElementC, ElementCompute
ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle_, supportSource_
>> {
constexpr static bool HasBuilder = true;

Expand All @@ -63,7 +65,7 @@ namespace detail {
class>
using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks<
DispatchPolicy,
cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute>,
cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle_, supportSource_>,
TileShape_MNK,
EpilogueTile
>;
Expand Down
7 changes: 4 additions & 3 deletions include/cutlass/epilogue/collective/xe_array_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ class CollectiveEpilogue<
using ElementScalar = typename FusionCallbacks::ElementScalar;
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;

static_assert(cute::is_same_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle>>,
static_assert(cute::is_any_of_v<typename FusionCallbacks::Operation,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, false>,
fusion::LinearCombination<ElementAccumulator, ElementCompute, ElementSource, ElementScalar, RoundStyle, true>>,
"Only Linear Combination Epilogue is supported for Grouped GEMM at the moment.");

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand All @@ -139,7 +140,7 @@ class CollectiveEpilogue<
Layout<CopyThreadShape>{},
make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))));
private:
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && FusionCallbacks::Operation::IsSourceSupported;
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;

public:
Expand Down
5 changes: 3 additions & 2 deletions include/cutlass/epilogue/fusion/operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,13 @@ template<
class ElementCompute_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest,
bool supportSource_ = true
>
struct LinearCombination
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_> {
using ElementSource = ElementSource_;
static constexpr bool IsSourceSupported = true;
static constexpr bool IsSourceSupported = supportSource_;
};

// D = activation(alpha * acc + beta * C)
Expand Down
14 changes: 8 additions & 6 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ template <
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
class EpilogueTile_,
bool supportSource_
>
struct FusionCallbacks<
epilogue::IntelXeXMX16,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_, supportSource_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinearCombination<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
Expand All @@ -77,7 +78,7 @@ struct FusionCallbacks<
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;
using ElementScalar = ElementScalar_;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource_, ElementScalar, RoundStyle_>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource_, ElementScalar, RoundStyle_, supportSource_>;

struct Arguments {
ElementScalar alpha = ElementScalar(1);
Expand Down Expand Up @@ -730,11 +731,12 @@ template <
class ElementScalar_,
FloatRoundStyle RoundStyle_,
class CtaTileShapeMNK_,
class EpilogueTile_
class EpilogueTile_,
bool supportSource_
>
struct FusionCallbacks<
epilogue::IntelXeXMX16Group,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>,
fusion::LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_, supportSource_>,
CtaTileShapeMNK_,
EpilogueTile_
> : Sm90LinearCombinationPtrArray<typename cutlass::detail::get_unpacked_element_type<ElementOutput_>::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
Expand All @@ -744,7 +746,7 @@ struct FusionCallbacks<
using ElementCompute = ElementCompute_;
using ElementSource = ElementSource_;
using ElementScalar = ElementScalar_;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle_>;
using Operation = fusion::LinearCombination<ElementOutput, ElementCompute, ElementSource, ElementScalar, RoundStyle_, supportSource_>;

struct Arguments {
ElementScalar alpha = ElementScalar(1);
Expand Down