Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions benchmarks/gemm/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,16 @@ namespace cutlass::benchmark {

///////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(SYCL_INTEL_TARGET)
template <class T, int Stages = 0>
template <class T>
static constexpr auto is_mixed_dtype = false;

#if defined(SYCL_INTEL_TARGET)
template <int Stages>
static constexpr auto is_mixed_dtype<cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<Stages>> = true;
#else
template <class T, int Stages = 0>
static constexpr auto is_mixed_dtype = false;
#endif

template <class T, class = void>
// ScaleType
template <class, class = void>
struct ScaleType {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same. And this should be a type alias.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you marks this as resolved. you didn't address it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we talked before, i will make this PR as draft because peter's PR may cover my improvement. i will re-check this PR after peter's PR merged.

using type = int;
};
Expand All @@ -81,7 +79,8 @@ struct ScaleType<T, cute::void_t<typename T::ElementScale>> {
using type = typename T::ElementScale;
};

template <class T, class = void>
// ZeroType
template <class, class = void>
struct ZeroType {
using type = int;
};
Expand All @@ -90,7 +89,8 @@ struct ZeroType<T, cute::void_t<typename T::ElementZero>> {
using type = typename T::ElementZero;
};

template <class T, class = void>
// ScaleStride
template <class, class = void>
struct ScaleStride {
using type = int;
};
Expand All @@ -99,7 +99,8 @@ struct ScaleStride<T, cute::void_t<typename T::StrideScale>> {
using type = typename T::StrideScale;
};

template <class T, class = void>
// ZeroStride
template <class, class = void>
struct ZeroStride {
using type = int;
};
Expand Down
24 changes: 12 additions & 12 deletions benchmarks/gemm/benchmarks_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand All @@ -373,8 +373,8 @@ using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::Mixed
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand All @@ -389,8 +389,8 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand All @@ -405,8 +405,8 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand All @@ -421,8 +421,8 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand All @@ -437,8 +437,8 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
cutlass::epilogue::fusion::LinearCombination<int, int,
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
2
>;

Expand Down
108 changes: 30 additions & 78 deletions examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,86 +253,38 @@ struct ExampleRunner {
// Methods
//

bool verify(const Options &options) {

//
// Compute reference output (default gemm kernel w/ ElementA == ElementB)
//

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

constexpr int PipelineStages = 3;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;

using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U32x8x16_ST_N,
void, void>;

// Mainloop
using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
ElementMMA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementMMA,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA, void, void, cute::identity, // A
GmemTiledCopyB, void, void, cute::identity // B
>;

using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloopRef,
CollectiveEpilogueRef
>;

using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;

typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{block_A_dq.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D}
};

// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
auto [M, N, K, L] = problem_size;

cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));

cutlass::reference::device::GemmComplex(
{M, N, K},
alpha,
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
beta,
ref_C,
ref_D,
ElementAccumulator(0),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);

// CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue
syclcompat::wait();

// compare_reference
ElementOutput const epsilon(1e-2f);
ElementOutput const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
}

template <class Element>
Expand Down Expand Up @@ -462,7 +414,7 @@ struct ExampleRunner {
syclcompat::wait();

// Verify that the result is correct
bool passed = verify(options);
bool passed = verify(problem_size, options.alpha, options.beta);
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;

if(!passed) return cutlass::Status::kErrorInternal;
Expand Down
108 changes: 30 additions & 78 deletions examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,86 +242,38 @@ struct ExampleRunner {
// Methods
//

bool verify(const Options &options) {

//
// Compute reference output (default gemm kernel w/ ElementA == ElementB)
//

using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

constexpr int PipelineStages = 3;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;

using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue<
EpilogueDispatchPolicy,
TileShape,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutC>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutD>,
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
XE_2D_U16x8x16_ST_N,
void, void>;

// Mainloop
using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma<
GEMMDispatchPolicy,
TileShape,
ElementMMA,
cutlass::gemm::TagToStrideA_t<LayoutA>,
ElementMMA,
cutlass::gemm::TagToStrideB_t<LayoutB>,
TiledMma,
GmemTiledCopyA, void, void, cute::identity, // A
GmemTiledCopyB, void, void, cute::identity // B
>;

using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloopRef,
CollectiveEpilogueRef
>;

using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;

typename GemmRef::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{block_A_dq.get(), stride_A, block_B_dq.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D}
};

// Run the gemm where the scaling is performed outside of the kernel.
GemmRef gemm_ref;
size_t workspace_size = GemmRef::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm_ref.run());
bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
auto [M, N, K, L] = problem_size;

cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K}));
cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N}));
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));

cutlass::reference::device::GemmComplex(
{M, N, K},
alpha,
ref_A,
cutlass::ComplexTransform::kNone,
ref_B,
cutlass::ComplexTransform::kNone,
beta,
ref_C,
ref_D,
ElementAccumulator(0),
L, // batch_count
M * K, // batch_stride_A
K * N, // batch_stride_B
M * N, // batch_stride_C
M * N // batch_stride_D
);

// CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue
syclcompat::wait();

// compare_reference
ElementOutput const epsilon(1e-2f);
ElementOutput const non_zero_floor(1e-4f);
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
return passed;
return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
}

template <class Element>
Expand Down Expand Up @@ -553,7 +505,7 @@ struct ExampleRunner {
syclcompat::wait();

// Verify that the result is correct
bool passed = verify(options);
bool passed = verify(problem_size, options.alpha, options.beta);
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;

if(!passed) return cutlass::Status::kErrorInternal;
Expand Down
Loading
Loading