From 078762410554afb426e2ce9f8d37c3330081a73f Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 08:38:40 +0800 Subject: [PATCH 1/9] enable mixed dtype benchmark --- .../input_sglang_gemm_mixed_dtype.in | 18 + benchmarks/gemm/CMakeLists.txt | 3 + benchmarks/gemm/benchmark_runner.hpp | 398 ++++++++++++++++-- benchmarks/gemm/benchmarks_sycl.hpp | 146 +++++++ benchmarks/gemm/gemm_configuration_sycl.hpp | 89 ++++ examples/common/sycl_common.hpp | 112 +---- .../00_bmg_gemm_with_sycl_queue.cpp | 6 +- .../epilogue/collective/xe_epilogue.hpp | 6 +- .../cutlass/util/mixed_dtype_utils.hpp | 141 +++++++ 9 files changed, 766 insertions(+), 153 deletions(-) create mode 100755 benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in diff --git a/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in b/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in new file mode 100755 index 0000000000..4177851cd3 --- /dev/null +++ b/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in @@ -0,0 +1,18 @@ +############################################################################# +### Benchmarks for required shapes for second SGLang release ### +############################################################################# + +# The data type " FP16U4FP16F16FP16S4 " in the benchmark are: A, B, C, Mma, Scale, Zero + + +# int4 +PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 + +# int8 +PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 diff --git a/benchmarks/gemm/CMakeLists.txt b/benchmarks/gemm/CMakeLists.txt index 1833646822..84a98ec419 100644 --- a/benchmarks/gemm/CMakeLists.txt +++ b/benchmarks/gemm/CMakeLists.txt @@ -35,6 +35,8 @@ set(CONFIG_FILE_INTEL_SGLANG_SPLITK --config_file=${CMAKE_SOURCE_DIR}/benchmarks set(CONFIG_FILE_CUDA --config_file=${CMAKE_SOURCE_DIR}/benchmarks/device/ampere/input_files/input_gemm.in) +set(CONFIG_FILE_INTEL_MIXED_DTYPE --config_file=${CMAKE_SOURCE_DIR}/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in) + cutlass_benchmark_add_suite(cutlass_benchmarks_gemm) if(CUTLASS_ENABLE_SYCL) @@ -51,6 +53,7 @@ cutlass_benchmark_add_executable( CONFIG_FILE_INTEL_PYTORCH CONFIG_FILE_INTEL_SGLANG CONFIG_FILE_INTEL_SGLANG_SPLITK + CONFIG_FILE_INTEL_MIXED_DTYPE ) else() diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 0dfb588f14..9cb1c2a5b9 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -49,6 +49,7 @@ #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_silu.h" +#include "cutlass/util/mixed_dtype_utils.hpp" #include "../common.hpp" @@ -60,30 +61,47 @@ namespace cutlass::benchmark { /////////////////////////////////////////////////////////////////////////////////////////////////// -/// Helper to initialize a block of device data -template -bool initialize_block( - cutlass::DeviceAllocation& block, - uint64_t seed=2023) { - - Element scope_max, scope_min; - int bits_input = cutlass::sizeof_bits::value; - - if (bits_input == 1) { - scope_max = Element(2); - scope_min = Element(0); - } else if (bits_input <= 8) { - scope_max = Element(2); - scope_min = Element(-2); - } else { - scope_max = Element(8); - scope_min = Element(-8); - } +template +static constexpr auto is_mixed_dtype = false; - reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, scope_max, scope_min, 0); - return true; -} +template +static constexpr auto is_mixed_dtype> = true; + +template +struct ScaleType { + using type = int; +}; +template +struct ScaleType> { + using type = typename T::ElementScale; +}; + +template +struct ZeroType { + using type = int; +}; +template +struct ZeroType> { + using type = typename T::ElementZero; +}; + +template +struct ScaleStride { + using type = int; +}; +template +struct ScaleStride> { + using type = typename T::StrideScale; +}; + +template +struct ZeroStride { + using type = int; +}; +template +struct ZeroStride> { + using type = typename T::StrideZero; +}; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -150,6 +168,15 @@ struct BenchmarkRunnerGemm { using ElementB = typename Gemm::ElementB; using ElementAcc = typename Gemm::ElementAccumulator; + using CollectiveMainloop = typename Gemm::GemmKernel::CollectiveMainloop; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementMma = CollectiveMainloop::TiledMma::ValTypeA; + + using ElementScale = ScaleType::type; + using ElementZero = ZeroType::type; + using StrideS = ScaleStride::type; + using StrideZ = ZeroStride::type; + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; @@ -168,7 +195,7 @@ struct BenchmarkRunnerGemm { using FusionDeEltMul = cutlass::epilogue::fusion::LinCombDeEltAct; using FusionLinComb = epilogue::fusion::LinearCombination< - ElementOutput, ElementCompute, ElementAccumulator, ElementAccumulator, + ElementAccumulator, ElementCompute, ElementAccumulator, ElementAccumulator, FloatRoundStyle::round_to_nearest>; // Epilogue used in ampere/gemm_configuration.hpp @@ -201,6 +228,10 @@ struct BenchmarkRunnerGemm { StrideC stride_C; StrideD stride_D; + StrideS stride_S; + StrideZ stride_Z; + + uint64_t seed; std::vector> block_A; @@ -210,20 +241,267 @@ struct BenchmarkRunnerGemm { DeviceAllocation block_ref_D; std::vector> block_Aux; + cutlass::DeviceAllocation block_scale; + cutlass::DeviceAllocation block_zero; + + DeviceAllocation block_A_verify; + DeviceAllocation block_B_verify; + BenchmarkRunnerGemm() : seed(0) {}; // // Methods // + template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout, + class ZeroLayout> + static auto dequantize_A(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + ZeroLayout const zero_layout, + int const group_size) { + if constexpr (std::is_same_v) { + return dq_buffer; + } + + std::vector dst(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(dst.data(), (uint8_t*)dq_buffer, dst.size()); + + std::vector src(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(src.data(), (uint8_t*)q_buffer, src.size()); + + std::vector scale(size(scale_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(scale.data(), (uint8_t*)scale_buffer, scale.size()); + + std::vector zero(size(zero_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(zero.data(), (uint8_t*)zero_buffer, zero.size()); + + syclcompat::wait(); + + auto dst_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(dst.data())), select<1, 0, 2>(operand_layout)); + + auto src_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + return make_tensor(cute::subbyte_iterator(src.data()), operand_layout); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(src.data())), select<1, 0, 2>(operand_layout)); + } + }(); + + auto scale_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(scale.data())), scale_layout); + + auto zero_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + auto flatten_tensor = flatten(make_tensor(cute::subbyte_iterator(zero.data()), zero_layout)); + static_assert(rank(flatten_tensor.layout()) == 4); + return make_tensor(flatten_tensor.data(), select<1, 0, 2, 3>(flatten_tensor.layout())); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(zero.data())), zero_layout); + } + }(); + + auto M = size<1>(src_tensor); + auto K = size<0>(src_tensor); + auto L = size<2>(src_tensor); + + static constexpr bool is_qnt = cutlass::platform::numeric_limits::is_integer; + + for (int l = 0; l < L; l++) { + for (int k= 0; k < K; k++) { + for (int m = 0; m < M; m++) { + auto src_data = [&]() { + if constexpr (is_qnt) { + if constexpr (sizeof_bits_v >= 8) { + return src_tensor(k, m, l); + } else { + return src_tensor(k, m, l).get(); + } + } else { + using ret_type = cute::conditional_t >= 8, ElementZero, int8_t>; + if constexpr (sizeof_bits_v >= 8) { + return (ret_type)(src_tensor(k, m, l)); + } else { + return (ret_type)(src_tensor(k, m, l).get()); + } + } + }(); + + auto scale_data = scale_tensor(m, k / group_size, l); + + using ret_type = cute::conditional_t >= 8, ElementZero, int8_t>; + ret_type zero_data = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return zero_tensor(m, k / group_size, l); + } else { + auto zero_elements_packed_along_k = get<0>(zero_tensor.shape()); + return (ret_type)(zero_tensor((k / group_size) % zero_elements_packed_along_k, m, k / group_size / zero_elements_packed_along_k, l).get()); + } + }(); + + if constexpr (is_qnt) { + dst_tensor(k, m, l) = ((int)(src_data / scale_data)) + zero_data; + } else { + dst_tensor(k, m, l) = (src_data - zero_data) * scale_data; + } + } + } + } + + cutlass::device_memory::copy_to_device(dq_buffer, (DequantizedElement*)(raw_pointer_cast(dst_tensor.data())), dst_tensor.size()); + syclcompat::wait(); + return dq_buffer; + } + + template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout, + class ZeroLayout> + static auto dequantize_B(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + ZeroLayout const zero_layout, + int const group_size) { + std::vector dst(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(dst.data(), (uint8_t*)dq_buffer, dst.size()); + + std::vector src(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(src.data(), (uint8_t*)q_buffer, src.size()); + + std::vector scale(size(scale_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(scale.data(), (uint8_t*)scale_buffer, scale.size()); + + std::vector zero(size(zero_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(zero.data(), (uint8_t*)zero_buffer, zero.size()); + + syclcompat::wait(); + + auto dst_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(dst.data())), operand_layout); + + auto src_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + return make_tensor(cute::subbyte_iterator(src.data()), operand_layout); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(src.data())), operand_layout); + } + }(); + + auto scale_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(scale.data())), scale_layout); + + auto zero_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + auto flatten_tensor = flatten(make_tensor(cute::subbyte_iterator(zero.data()), zero_layout)); + static_assert(rank(flatten_tensor.layout()) == 4); + return make_tensor(flatten_tensor.data(), select<1, 0, 2, 3>(flatten_tensor.layout())); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(zero.data())), zero_layout); + } + }(); + + auto N = size<0>(src_tensor); + auto K = size<1>(src_tensor); + auto L = size<2>(src_tensor); + + for (int l = 0; l < L; l++) { + for (int k= 0; k < K; k++) { + for (int n = 0; n < N; n++) { + using ret_type = cute::conditional_t >= 8, ElementZero, int8_t>; + ret_type a = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return (ret_type)(src_tensor(n, k, l)); + } else { + return (ret_type)(src_tensor(n, k, l).get()); + }}(); + + ret_type b = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return (ret_type)(zero_tensor(n, k / group_size, l)); + } else { + auto k_packed = get<0>(zero_tensor.shape()); + return (ret_type)(zero_tensor((k / group_size) % k_packed, n, k / group_size / k_packed, l).get()); + } + }(); + + dst_tensor(n, k, l) = ((ElementScale)(a - b)) * scale_tensor(n, k / group_size, l); + } + } + } + + cutlass::device_memory::copy_to_device(dq_buffer, (DequantizedElement*)(raw_pointer_cast(dst_tensor.data())), dst_tensor.size()); + syclcompat::wait(); + return dq_buffer; + } + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; - TensorRef ref_A(block_A[0].get(), LayoutA::packed({M, K})); - TensorRef ref_B(block_B[0].get(), LayoutB::packed({K, N})); TensorRef ref_C(block_C[0].get(), LayoutC::packed({M, N})); TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + auto [ptr_A, ptr_B] = [&]() { + if constexpr (!is_mixed_dtype) { + return make_tuple(block_A[0].get(), block_B[0].get()); + } else { + static constexpr bool IsAQuant = cutlass::platform::numeric_limits::is_integer + ^ cutlass::platform::numeric_limits::is_integer; + static constexpr bool IsBQuant = cutlass::platform::numeric_limits::is_integer + ^ cutlass::platform::numeric_limits::is_integer; + + static constexpr bool IsATransformed = CollectiveMainloop::IsATransformed; + auto dq_mn_size = IsATransformed ? M : N; + + auto shape_ab = cute::make_shape(dq_mn_size, K, L); + auto shape_scale = cute::make_shape(dq_mn_size, K / 128, L); + static constexpr auto k_packed = CollectiveMainloop::zero_elements_packed_along_k; + auto shape_zero = [&]() { + if constexpr (is_tuple_v(stride_Z))>>) { + return cute::make_shape(dq_mn_size, cute::make_shape(k_packed, + cute::max(1, K / 128 / k_packed)), L); + } else { + return shape_scale; + } + }(); + + auto ptr_A = [&]() { + if constexpr (IsAQuant) { + return dequantize_A(block_A_verify.get(), block_A[0].get(), make_layout(shape_ab, stride_A), block_scale.get(), + block_zero.get(), make_layout(shape_scale, stride_S), make_layout(shape_zero, stride_Z), 128); + } else { + return block_A_verify.get(); + } + }(); + + auto ptr_B = [&]() { + if constexpr (IsBQuant) { + return dequantize_B(block_B_verify.get(), block_B[0].get(), make_layout(shape_ab, stride_B), block_scale.get(), + block_zero.get(), make_layout(shape_scale, stride_S), make_layout(shape_zero, stride_Z), 128); + } else { + return block_B_verify.get(); + } + }(); + + return make_tuple(ptr_A, ptr_B); + } + }(); + + TensorRef ref_A(ptr_A, LayoutA::packed({M, K})); + TensorRef ref_B(ptr_B, LayoutB::packed({K, N})); + reference::device::GemmComplex( {M, N, K}, alpha, @@ -283,10 +561,40 @@ struct BenchmarkRunnerGemm { std::size_t size_A = cute::cosize(make_layout(cute::make_shape(M, K, L), stride_A)); std::size_t size_B = cute::cosize(make_layout(cute::make_shape(N, K, L), stride_B)); std::size_t size_C = cute::cosize(make_layout(cute::make_shape(M, N, L), stride_C)); - std::size_t mem_occupied_ABC = (size_A * sizeof(ElementA)) + (size_B * sizeof(ElementB)) + - (size_C * sizeof(ElementC)); + std::size_t mem_occupied_ABC = ((size_A * sizeof_bits_v) + (size_B * sizeof_bits_v) + + (size_C * sizeof_bits_v)) / sizeof_bits_v; count = std::ceil(static_cast(cutlass::get_llc_size()) / static_cast(mem_occupied_ABC)) + 1; + if constexpr (is_mixed_dtype) { + static constexpr bool IsATransformed = CollectiveMainloop::IsATransformed; + + auto dq_mn_size = IsATransformed ? M : N; + auto scale_k = K / 128; + + static constexpr auto k_packed = CollectiveMainloop::zero_elements_packed_along_k; + static constexpr auto is_tuple_z = is_tuple_v(StrideZ{}))>>; + + auto shape_scale = cute::make_shape(dq_mn_size, scale_k, L); + + stride_S = cutlass::make_cute_packed_stride(StrideS{}, shape_scale); + stride_Z = [&]() { + if constexpr (is_tuple_z) { + return make_stride(Int{}, make_stride(_1{}, int64_t(k_packed * dq_mn_size)), int64_t(dq_mn_size * scale_k)); + } else { + return stride_S; + } + }(); + + block_A_verify.reset(size_A); + block_B_verify.reset(size_B); + + block_scale.reset(static_cast(scale_k) * L * dq_mn_size); + block_zero.reset(static_cast(scale_k) * L * dq_mn_size); + + initialize_block(block_scale, seed, ElementScale(1), ElementScale(4)); + initialize_block(block_zero, seed); + } + for(int i=0; i < count; i++) { block_A.emplace_back(); block_B.emplace_back(); @@ -301,8 +609,13 @@ struct BenchmarkRunnerGemm { block_A[i].reset(size_A); block_B[i].reset(size_B); block_C[i].reset(size_C); - initialize_block(block_A[i], seed + i); - initialize_block(block_B[i], seed + i); + if (is_mixed_dtype && i == 0) { + initialize_mixed_dtype_block(block_A[i], block_A_verify, seed + i); + initialize_mixed_dtype_block(block_B[i], block_B_verify, seed + i); + } else { + initialize_block(block_A[i], seed + i); + initialize_block(block_B[i], seed + i); + } initialize_block(block_C[i], seed + i); if constexpr (epi_is_deeltactmul) { block_Aux[i].reset(size_C); @@ -325,8 +638,15 @@ struct BenchmarkRunnerGemm { typename Gemm::GemmKernel::Arguments arguments = GemmConfiguration::defaultArguments(); arguments.mode = gemm::GemmUniversalMode::kGemm; arguments.problem_shape = problem_size; - arguments.mainloop = {block_A[0].get(), stride_A, block_B[0].get(), stride_B}; - arguments.epilogue = {{options.alpha, options.beta}, block_C[0].get(), stride_C, block_D.get(), stride_D}; + + if constexpr (!is_mixed_dtype) { + arguments.mainloop = {block_A[0].get(), stride_A, block_B[0].get(), stride_B}; + } else { + arguments.mainloop = {block_A[0].get(), stride_A, block_B[0].get(), stride_B, block_scale.get(), + stride_S, block_zero.get(), stride_Z, 128}; + } + + arguments.epilogue = {{ElementAcc(options.alpha), ElementAcc(options.beta)}, block_C[0].get(), stride_C, block_D.get(), stride_D}; arguments.hw_info = hw_info; if constexpr(epi_is_deeltactmul){ @@ -394,10 +714,10 @@ struct BenchmarkRunnerGemm { auto gflop = 2.0 * options.m * options.n * options.k * options.l * 1e-9; auto mega_bytes_transferred = static_cast( - options.m * options.k * sizeof(ElementA) + - options.k * options.n * sizeof(ElementB) + - (options.beta != 0 ? 2 : 1) * options.m * options.n * sizeof(ElementC) - ) * 1e-6 * options.l; + options.m * options.k * sizeof_bits_v + + options.k * options.n * sizeof_bits_v + + (options.beta != 0 ? 2 : 1) * options.m * options.n * sizeof_bits_v + ) * 1e-6 * options.l / sizeof_bits_v; initialize_counters(state); int32_t counter = 1; @@ -408,9 +728,13 @@ struct BenchmarkRunnerGemm { gemm::GemmUniversalMode::kGemm, problem_size, {block_A[input_num].get(), stride_A, block_B[input_num].get(), stride_B}, - {{options.alpha, options.beta}, block_C[input_num].get(), stride_C, block_D.get(), stride_D}, + {{ElementAcc(options.alpha), ElementAcc(options.beta)}, block_C[input_num].get(), stride_C, block_D.get(), stride_D}, hw_info }; + if constexpr (is_mixed_dtype) { + arguments.mainloop = {block_A[input_num].get(), stride_A, block_B[input_num].get(), stride_B, block_scale.get(), + stride_S, block_zero.get(), stride_Z, 128}; + } if constexpr(epi_is_deeltactmul){ arguments.epilogue.thread.aux_ptr = block_Aux[input_num].get(); arguments.epilogue.thread.dAux = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); diff --git a/benchmarks/gemm/benchmarks_sycl.hpp b/benchmarks/gemm/benchmarks_sycl.hpp index fbdbfb4277..ee28c1cfce 100644 --- a/benchmarks/gemm/benchmarks_sycl.hpp +++ b/benchmarks/gemm/benchmarks_sycl.hpp @@ -314,6 +314,134 @@ using PvcGemmFP16FP16FP32_SplitK_RCR_5 = cutlass::gemm::device::GemmConfiguratio XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T >; +using PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::int8_t, cutlass::layout::RowMajor, + cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::uint4_t, cutlass::layout::ColumnMajor, + cutlass::int8_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int4_t, cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::int8_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + +using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::MixedPrecisionGemmConfiguration< + cutlass::arch::IntelXe, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::int8_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + cutlass::half_t, cute::Stride<_1, int64_t, int64_t>, + cutlass::int8_t, cute::Stride<_1, int64_t, int64_t>, + Shape<_32, _128, _32>, Scheduler::Gemm, + typename TiledMMAHelper, Layout>, + Layout, Stride<_4, _1, _0>>>::TiledMMA, + XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, + cutlass::epilogue::fusion::LinearCombination, + 2 + >; + CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_RCR_5); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_RCR_7); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_RCR_9); @@ -322,6 +450,16 @@ CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_RCR_7_mul); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_RCR_8_silu); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmFP16FP16FP32_SplitK_RCR_5); +// Below are MixedPrecisionGemm, the data type are A, B, C, Mma, Scale, Zero +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1); + using PvcGemmBF16BF16FP32_SplitK_RRR_1 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelXe, cutlass::bfloat16_t, cutlass::layout::RowMajor, @@ -381,6 +519,14 @@ static void register_gemm_benchmarks() { CUTLASS_BENCHMARK(PvcGemmFP16FP16FP32_RCR_7_mul); CUTLASS_BENCHMARK(PvcGemmFP16FP16FP32_RCR_8_silu); CUTLASS_BENCHMARK(PvcGemmFP16FP16FP32_SplitK_RCR_5); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmFP16U4FP16F16FP16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmBF16U4BF16BF16BF16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1); + CUTLASS_BENCHMARK(PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1); // CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RCR_Linear); // CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RCR_Linear_MoE); diff --git a/benchmarks/gemm/gemm_configuration_sycl.hpp b/benchmarks/gemm/gemm_configuration_sycl.hpp index a58439c72d..0576e33382 100644 --- a/benchmarks/gemm/gemm_configuration_sycl.hpp +++ b/benchmarks/gemm/gemm_configuration_sycl.hpp @@ -66,6 +66,19 @@ struct GemmConfiguration { static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); }; +template< + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, class ElementC, typename LayoutC, + class ElementScale, typename StrideS, + class ElementZero, typename StrideZ, + class TileShape, Scheduler TileScheduler, + class TiledMma, class GmemTiledCopyA, class GmemTiledCopyB, + class GmemTiledCopyC, class EpilogueOp, int Stages = 3> +struct MixedPrecisionGemmConfiguration{ + static_assert(sizeof(ElementA) == 0, "No valid MixedPrecisionGemmConfiguration configuration exists."); +}; + ///////////////////////////////////////////////////////////////////////// // bfloat16 @@ -159,4 +172,80 @@ struct GemmConfiguration< } }; +template +struct MixedPrecisionGemmConfiguration< + arch::IntelXe, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementScale, StrideS, + ElementZero, StrideZ, + TileShape, TileScheduler, TiledMma, + GmemTiledCopyA, GmemTiledCopyB, + GmemTiledCopyC, EpilogueOp, Stages> +{ + using LayoutD = LayoutC; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16MixedPrecision; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using ElementAccumulator = typename TiledMma::ValTypeD; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementC, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + GmemTiledCopyC, + void, void>; + + static constexpr bool IsAQuant = cutlass::platform::numeric_limits::is_integer + ^ cutlass::platform::numeric_limits::is_integer; + static constexpr bool IsBQuant = cutlass::platform::numeric_limits::is_integer + ^ cutlass::platform::numeric_limits::is_integer; + + using CollectiveMainloop = collective::CollectiveMma, ElementA>, + cutlass::gemm::TagToStrideA_t, + cute::conditional_t, ElementB>, + cutlass::gemm::TagToStrideB_t, TiledMma, + GmemTiledCopyA, void, void, cute::identity, GmemTiledCopyB, void, void, + cute::identity>; + + using GemmKernel = kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue>; + + using Gemm = device::GemmUniversalAdapter; + + constexpr static typename GemmKernel::Arguments defaultArguments() { + using StreamKMode = + cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode; + if constexpr (TileScheduler == Scheduler::Gemm) { + return {}; + } else if constexpr (TileScheduler == Scheduler::GemmStreamK) { + typename GemmKernel::Arguments arguments{}; + arguments.scheduler = {1, StreamKMode::StreamK}; + return arguments; + } else { + static_assert(TileScheduler == Scheduler::GemmSplitK); + typename GemmKernel::Arguments arguments{}; + arguments.scheduler = {2, StreamKMode::SplitK}; + return arguments; + } + } +}; + } // namespace cutlass::gemm::device diff --git a/examples/common/sycl_common.hpp b/examples/common/sycl_common.hpp index 8f441408a5..9593e37d27 100644 --- a/examples/common/sycl_common.hpp +++ b/examples/common/sycl_common.hpp @@ -34,117 +34,7 @@ #include "cutlass/cutlass.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/reference/device/sycl_tensor_fill.h" - -template -static constexpr auto is_signed_v = cute::is_signed::value; - -template -static constexpr auto digits = std::numeric_limits::digits > 0 ? std::numeric_limits::digits : cute::numeric_limits::digits; - -template -auto max_for_test = T(cute::sizeof_bits_v >= 8 ? 1 << cute::ceil_div(digits , 4) : cutlass::platform::numeric_limits::max() / 2); - -/// Helper to initialize a block of device data -template -bool initialize_block(Element* block, std::size_t size, uint64_t seed, Args_t&&... args) { - - static_assert(sizeof...(Args_t) == 0 || sizeof...(Args_t) == 2); - - Element scope_max; - Element scope_min; - - if constexpr ( sizeof...(Args_t) == 2) { - auto tuple_args = std::forward_as_tuple(std::forward(args)...); - scope_min = std::get<0>(tuple_args); - scope_max = std::get<1>(tuple_args); - } else { - scope_max = max_for_test; - scope_min = is_signed_v ? Element(-scope_max) : Element(1); - } - - cutlass::reference::device::BlockFillRandomUniform( - block, size, seed, scope_max, scope_min, 0); - - syclcompat::wait(); - return true; -} - -template -bool initialize_block(cutlass::DeviceAllocation& block, uint64_t seed, Args_t&&... args) { - return initialize_block(block.get(), block.size(), seed, args...); -} - -template -void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, - cutlass::DeviceAllocation& block_device_dq, - uint64_t seed, - Args_t&&... args) { - static_assert(cute::sizeof_bits_v >= 8); - static_assert(sizeof...(Args_t) == 0 || sizeof...(Args_t) == 2); - - T1 scope_max; - T1 scope_min; - - if constexpr ( sizeof...(Args_t) == 2) { - auto tuple_args = std::forward_as_tuple(std::forward(args)...); - scope_min = std::get<0>(tuple_args); - scope_max = std::get<1>(tuple_args); - } else { - scope_max = max_for_test; - scope_min = is_signed_v ? T1(-scope_max) : T1(1); - } - - std::uniform_int_distribution<> dist(scope_min, scope_max); - - std::ranlux24_base rng(std::random_device{}()); - rng.seed(seed); - - if constexpr (cute::sizeof_bits_v >= 8) { - auto block_host = std::vector(block_device.size()); - auto block_host_dq = std::vector(block_device.size()); - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i]); - } - - block_device.copy_from_host(block_host.data()); - block_device_dq.copy_from_host(block_host_dq.data()); - } else { - static constexpr auto array_size = 1024; - - cute::array_subbyte block_host{}; - auto block_host_dq = std::vector(array_size); - - for (int i = 0; i < block_host.size(); ++i) { - block_host[i] = static_cast(dist(rng)); - block_host_dq[i] = static_cast(block_host[i].get()); - } - - static constexpr auto elements_per_byte = cute::sizeof_bits_v / cute::sizeof_bits_v; - - int loop_cnt = block_device.size() / array_size; - for (int i = 0; i < loop_cnt; i++) { - cutlass::device_memory::copy_to_device(((uint8_t*)(block_device.get())) + (i * array_size) / elements_per_byte, - (uint8_t*)(raw_pointer_cast(block_host.begin())), - array_size / elements_per_byte); - cutlass::device_memory::copy_to_device(block_device_dq.get() + i * array_size, - block_host_dq.data(), - array_size); - } - - auto tail_size = block_device.size() % array_size; - if (tail_size) { - cutlass::device_memory::copy_to_device(((uint8_t*)block_device.get()) + (loop_cnt * array_size) / elements_per_byte, - (uint8_t*)(raw_pointer_cast(block_host.begin())), - tail_size / elements_per_byte); - cutlass::device_memory::copy_to_device(block_device_dq.get() + loop_cnt * array_size, - block_host_dq.data(), - tail_size); - } - } - - syclcompat::wait(); -} +#include "cutlass/util/mixed_dtype_utils.hpp" template inline diff --git a/examples/sycl/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp b/examples/sycl/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp index 093dabaf8e..a9c3f246b8 100644 --- a/examples/sycl/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp +++ b/examples/sycl/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp @@ -235,9 +235,9 @@ struct ExampleRunner { stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - initialize_block(mem.block_A, M * K * L, seed + 2023); - initialize_block(mem.block_B, N * K * L, seed + 2022); - initialize_block(mem.block_C, M * N * L, seed + 2021); + cutlass::initialize_block(mem.block_A, M * K * L, seed + 2023); + cutlass::initialize_block(mem.block_B, N * K * L, seed + 2022); + cutlass::initialize_block(mem.block_C, M * N * L, seed + 2021); } cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index de6e524a73..003f5de776 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -343,6 +343,9 @@ class CollectiveEpilogue< // Tile the output tensor per SG and select tile for the current SG Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); + Tensor tCgC = thread_xe_load_c.partition_S(gD); + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); Tensor tCgD = thread_xe_store_d.partition_D(gD); @@ -404,8 +407,7 @@ class CollectiveEpilogue< cst_callbacks.begin_loop(epi_m, epi_n); if (is_C_load_needed) { - //cordinates for C and D are the same - copy(params.xe_load_c, tCgD(_, epi_m, epi_n), trC); + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index 3b0f884b47..86ea19915a 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -513,6 +513,147 @@ void reorder_tensor( cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); } + +template +static constexpr auto is_signed_v = cute::is_signed::value; + +template +static constexpr auto digits = std::numeric_limits::digits > 0 ? std::numeric_limits::digits : cute::numeric_limits::digits; + +template +auto max_for_test = T(cute::sizeof_bits_v >= 8 ? 1 << cute::ceil_div(digits , 4) : cutlass::platform::numeric_limits::max() / 2); + +/// Helper to initialize a block of device data +template +bool initialize_block(Element* block, std::size_t size, uint64_t seed, Args_t&&... args) { + static_assert(sizeof...(Args_t) == 0 || sizeof...(Args_t) == 2); + + Element scope_max; + Element scope_min; + + if constexpr ( sizeof...(Args_t) == 2) { + auto tuple_args = std::forward_as_tuple(std::forward(args)...); + scope_min = std::get<0>(tuple_args); + scope_max = std::get<1>(tuple_args); + } else { + scope_max = max_for_test; + scope_min = is_signed_v ? Element(-scope_max) : Element(1); + } + + if constexpr (cute::sizeof_bits_v >= 8) { + cutlass::reference::device::BlockFillRandomUniform(block, size, seed, scope_max, scope_min, 0); + } else { + std::uniform_int_distribution<> dist(scope_min, scope_max); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + } + + static constexpr auto elements_per_byte = cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = size / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device(((uint8_t*)(block)) + (i * array_size) / elements_per_byte, + (uint8_t*)(raw_pointer_cast(block_host.begin())), + array_size / elements_per_byte); + } + + auto tail_size = size % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device(((uint8_t*)block) + (loop_cnt * array_size) / elements_per_byte, + (uint8_t*)(raw_pointer_cast(block_host.begin())), + tail_size / elements_per_byte); + } + } + + syclcompat::wait(); + return true; +} + +template +bool initialize_block(cutlass::DeviceAllocation& block, uint64_t seed, Args_t&&... args) { + return initialize_block(block.get(), block.size(), seed, args...); +} + +template +void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, + cutlass::DeviceAllocation& block_device_dq, + uint64_t seed, + Args_t&&... args) { + static_assert(cute::sizeof_bits_v >= 8); + static_assert(sizeof...(Args_t) == 0 || sizeof...(Args_t) == 2); + + T1 scope_max; + T1 scope_min; + + if constexpr ( sizeof...(Args_t) == 2) { + auto tuple_args = std::forward_as_tuple(std::forward(args)...); + scope_min = std::get<0>(tuple_args); + scope_max = std::get<1>(tuple_args); + } else { + scope_max = max_for_test; + scope_min = is_signed_v ? T1(-scope_max) : T1(1); + } + + std::uniform_int_distribution<> dist(scope_min, scope_max); + + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + if constexpr (cute::sizeof_bits_v >= 8) { + auto block_host = std::vector(block_device.size()); + auto block_host_dq = std::vector(block_device.size()); + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_dq.copy_from_host(block_host_dq.data()); + } else { + static constexpr auto array_size = 1024; + + cute::array_subbyte block_host{}; + auto block_host_dq = std::vector(array_size); + + for (int i = 0; i < block_host.size(); ++i) { + block_host[i] = static_cast(dist(rng)); + block_host_dq[i] = static_cast(block_host[i].get()); + } + + static constexpr auto elements_per_byte = cute::sizeof_bits_v / cute::sizeof_bits_v; + + int loop_cnt = block_device.size() / array_size; + for (int i = 0; i < loop_cnt; i++) { + cutlass::device_memory::copy_to_device(((uint8_t*)(block_device.get())) + (i * array_size) / elements_per_byte, + (uint8_t*)(raw_pointer_cast(block_host.begin())), + array_size / elements_per_byte); + cutlass::device_memory::copy_to_device(block_device_dq.get() + i * array_size, + block_host_dq.data(), + array_size); + } + + auto tail_size = block_device.size() % array_size; + if (tail_size) { + cutlass::device_memory::copy_to_device(((uint8_t*)block_device.get()) + (loop_cnt * array_size) / elements_per_byte, + (uint8_t*)(raw_pointer_cast(block_host.begin())), + tail_size / elements_per_byte); + cutlass::device_memory::copy_to_device(block_device_dq.get() + loop_cnt * array_size, + block_host_dq.data(), + tail_size); + } + } + + syclcompat::wait(); +} + #undef CUDA_CHECK } // namespace cutlass From 07e1d2b53887e3de078ad3bdc2fe33be4631911a Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 11:43:46 +0800 Subject: [PATCH 2/9] update --- benchmarks/gemm/benchmark_runner.hpp | 2 ++ tools/util/include/cutlass/util/mixed_dtype_utils.hpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 9cb1c2a5b9..c8b9cb9de7 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -61,11 +61,13 @@ namespace cutlass::benchmark { /////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUTLASS_ENABLE_SYCL) template static constexpr auto is_mixed_dtype = false; template static constexpr auto is_mixed_dtype> = true; +#endif template struct ScaleType { diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index 86ea19915a..5082181e4a 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -513,6 +513,7 @@ void reorder_tensor( cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); } +#ifdef CUTLASS_ENABLE_SYCL template static constexpr auto is_signed_v = cute::is_signed::value; @@ -653,6 +654,7 @@ void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, syclcompat::wait(); } +#endif #undef CUDA_CHECK From 408af57d50b9c5ac6f519ce88918a7a0d3b36e35 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 13:12:15 +0800 Subject: [PATCH 3/9] Update benchmark_runner.hpp --- benchmarks/gemm/benchmark_runner.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index c8b9cb9de7..25a706bf5d 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -61,7 +61,7 @@ namespace cutlass::benchmark { /////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ENABLE_SYCL) +#if defined(SYCL_INTEL_TARGET) template static constexpr auto is_mixed_dtype = false; From 34e205909058d2f4eeb76d292cbf7e235ebf8ecf Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 13:13:12 +0800 Subject: [PATCH 4/9] Update mixed_dtype_utils.hpp --- tools/util/include/cutlass/util/mixed_dtype_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index 5082181e4a..5c1ce88302 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -513,7 +513,7 @@ void reorder_tensor( cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); } -#ifdef CUTLASS_ENABLE_SYCL +#if defined(SYCL_INTEL_TARGET) template static constexpr auto is_signed_v = cute::is_signed::value; From cf38651cc070a7c82dd096ab1ee0c9c604b08565 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 13:40:05 +0800 Subject: [PATCH 5/9] update --- benchmarks/gemm/benchmark_runner.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 25a706bf5d..1d5e29a843 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -67,6 +67,9 @@ static constexpr auto is_mixed_dtype = false; template static constexpr auto is_mixed_dtype> = true; +#else +template +static constexpr auto is_mixed_dtype = false; #endif template From b3da8393ec4701cbc26d2deb73aa2de1773c0c4a Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 14:10:15 +0800 Subject: [PATCH 6/9] update --- benchmarks/gemm/benchmark_runner.hpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 1d5e29a843..96ebaabbe7 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -614,9 +614,14 @@ struct BenchmarkRunnerGemm { block_A[i].reset(size_A); block_B[i].reset(size_B); block_C[i].reset(size_C); - if (is_mixed_dtype && i == 0) { - initialize_mixed_dtype_block(block_A[i], block_A_verify, seed + i); - initialize_mixed_dtype_block(block_B[i], block_B_verify, seed + i); + if constexpr (is_mixed_dtype) { + if (i == 0) { + initialize_mixed_dtype_block(block_A[i], block_A_verify, seed + i); + initialize_mixed_dtype_block(block_B[i], block_B_verify, seed + i); + } else { + initialize_block(block_A[i], seed + i); + initialize_block(block_B[i], seed + i); + } } else { initialize_block(block_A[i], seed + i); initialize_block(block_B[i], seed + i); From 90123849d813f3d032640bad64f97f6edbdc715e Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 14:40:34 +0800 Subject: [PATCH 7/9] Update mixed_dtype_utils.hpp --- tools/util/include/cutlass/util/mixed_dtype_utils.hpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index 5c1ce88302..a7f49d604a 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -513,8 +513,6 @@ void reorder_tensor( cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); } -#if defined(SYCL_INTEL_TARGET) - template static constexpr auto is_signed_v = cute::is_signed::value; @@ -654,7 +652,6 @@ void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, syclcompat::wait(); } -#endif #undef CUDA_CHECK From 22a90dc9ebb31e9a9e0474372cb340a9ae76d6f0 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Thu, 24 Jul 2025 15:11:31 +0800 Subject: [PATCH 8/9] update --- tools/util/include/cutlass/util/mixed_dtype_utils.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp index a7f49d604a..2b3018bcd4 100644 --- a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -519,8 +519,11 @@ static constexpr auto is_signed_v = cute::is_signed::value; template static constexpr auto digits = std::numeric_limits::digits > 0 ? std::numeric_limits::digits : cute::numeric_limits::digits; +template +auto max_for_test = T(1 << cute::ceil_div(digits , 4)); + template -auto max_for_test = T(cute::sizeof_bits_v >= 8 ? 1 << cute::ceil_div(digits , 4) : cutlass::platform::numeric_limits::max() / 2); +auto max_for_test < 8>> = T(cutlass::platform::numeric_limits::max() / 2); /// Helper to initialize a block of device data template From b894b6dfb5aadbfcdcbd5e57248fd534ee0a8dc3 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Mon, 28 Jul 2025 10:04:53 +0800 Subject: [PATCH 9/9] Update input_sglang_gemm_mixed_dtype.in --- .../device/bmg/input_files/input_sglang_gemm_mixed_dtype.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in b/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in index 4177851cd3..559660a569 100755 --- a/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in +++ b/benchmarks/device/bmg/input_files/input_sglang_gemm_mixed_dtype.in @@ -14,5 +14,5 @@ PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 # int8 -PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 -PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 --bm_name=mixed_dtype_int4 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 --bm_name=mixed_dtype_int8 --m=32 --k=4096 --n=14336 --l=1 +PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 --bm_name=mixed_dtype_int8 --m=32 --k=4096 --n=14336 --l=1