Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions benchmarks/gemm/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ struct BenchmarkRunnerGemm {

using CollectiveMainloop = typename Gemm::GemmKernel::CollectiveMainloop;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementMma = CollectiveMainloop::TiledMma::ValTypeA;
using ElementMma = typename CollectiveMainloop::TiledMma::ValTypeA;

using ElementScale = ScaleType<CollectiveMainloop>::type;
using ElementZero = ZeroType<CollectiveMainloop>::type;
using StrideS = ScaleStride<CollectiveMainloop>::type;
using StrideZ = ZeroStride<CollectiveMainloop>::type;
using ElementScale = typename ScaleType<CollectiveMainloop>::type;
using ElementZero = typename ZeroType<CollectiveMainloop>::type;
using StrideS = typename ScaleStride<CollectiveMainloop>::type;
using StrideZ = typename ZeroStride<CollectiveMainloop>::type;

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
Expand Down Expand Up @@ -453,7 +453,10 @@ struct BenchmarkRunnerGemm {
}

bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
auto [M, N, K, L] = problem_size;
auto& M = cute::get<0>(problem_size);
auto& N = cute::get<1>(problem_size);
auto& K = cute::get<2>(problem_size);
auto& L = cute::get<3>(problem_size);

TensorRef ref_C(block_C[0].get(), LayoutC::packed({M, N}));
TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));
Expand Down
15 changes: 9 additions & 6 deletions cmake/FindDPCPP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_library(DPCPP::DPCPP INTERFACE IMPORTED)

set(DPCPP_FLAGS "-fsycl;")
set(DPCPP_COMPILE_ONLY_FLAGS "")
set(DPCPP_LINK_ONLY_FLAGS "")

if(NOT "${DPCPP_SYCL_TARGET}" STREQUAL "")
list(APPEND DPCPP_FLAGS "-fsycl-targets=${DPCPP_SYCL_TARGET};")
Expand All @@ -63,10 +64,10 @@ if("${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_pvc" OR
"${DPCPP_SYCL_TARGET}" STREQUAL "spir64" OR
"${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_bmg_g21")
if ((CMAKE_CXX_COMPILER_ID MATCHES "IntelLLVM" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 2025.2) OR CUTLASS_SYCL_BUILTIN_ENABLE)
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
else()
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
endif()
list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
endif()
if(DPCPP_DISABLE_ITT_FOR_CUTLASS)
list(APPEND DPCPP_FLAGS "-fno-sycl-instrument-device-code")
endif()
Expand All @@ -76,14 +77,16 @@ endif()
if(UNIX)
set_target_properties(DPCPP::DPCPP PROPERTIES
INTERFACE_COMPILE_OPTIONS "${DPCPP_FLAGS};${DPCPP_COMPILE_ONLY_FLAGS}"
INTERFACE_LINK_OPTIONS "${DPCPP_FLAGS}"
INTERFACE_LINK_OPTIONS "${DPCPP_FLAGS};${DPCPP_LINK_ONLY_FLAGS}"
INTERFACE_LINK_LIBRARIES ${DPCPP_LIB_DIR}
INTERFACE_INCLUDE_DIRECTORIES "${DPCPP_BIN_DIR}/../include/sycl;${DPCPP_BIN_DIR}/../include")
message(STATUS "DPCPP INCLUDE DIR: ${DPCPP_BIN_DIR}/../include/sycl;${DPCPP_BIN_DIR}/../include")
message(STATUS "Using DPCPP flags: ${DPCPP_FLAGS};${DPCPP_COMPILE_ONLY_FLAGS}")
message(STATUS "Using DPCPP compile flags: ${DPCPP_FLAGS};${DPCPP_COMPILE_ONLY_FLAGS}")
message(STATUS "Using DPCPP link flags: ${DPCPP_FLAGS};${DPCPP_LINK_ONLY_FLAGS}")
else()
set_target_properties(DPCPP::DPCPP PROPERTIES
INTERFACE_COMPILE_OPTIONS "${DPCPP_FLAGS};${DPCPP_COMPILE_ONLY_FLAGS}"
INTERFACE_LINK_OPTIONS "${DPCPP_FLAGS};${DPCPP_LINK_ONLY_FLAGS}"
INTERFACE_LINK_LIBRARIES ${DPCPP_LIB_DIR}
INTERFACE_INCLUDE_DIRECTORIES "${DPCPP_BIN_DIR}/../include/sycl")
endif()
Expand All @@ -105,7 +108,7 @@ function(add_sycl_to_target)
)
get_target_property(target_type ${CUTLASS_ADD_SYCL_TARGET} TYPE)
if (NOT target_type STREQUAL "OBJECT_LIBRARY")
target_link_options(${CUTLASS_ADD_SYCL_TARGET} PUBLIC ${DPCPP_FLAGS})
target_link_options(${CUTLASS_ADD_SYCL_TARGET} PUBLIC ${DPCPP_FLAGS} ${DPCPP_LINK_ONLY_FLAGS})
endif()
endfunction()

Expand Down
11 changes: 10 additions & 1 deletion cmake/googletest.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ FetchContent_Declare(

FetchContent_MakeAvailable(googletest)

if (CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM")
if (TARGET gtest)
# Ignore unsupported warning flags on IntelLLVM
target_compile_options(gtest PRIVATE -Wno-unknown-warning-option)
# Show -Winline warnings, but don’t let them become errors
target_compile_options(gtest PRIVATE -Wno-error=inline)
endif()
endif()

if (MSVC)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
endif()
endif()
6 changes: 5 additions & 1 deletion examples/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,11 @@ struct ExampleRunner {

/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
auto [M, N, K, L] = ProblemShapeType{options.m, options.n, options.k, options.l};
auto problem_shape = ProblemShapeType{options.m, options.n, options.k, options.l};
auto& M = cute::get<0>(problem_shape);
auto& N = cute::get<1>(problem_shape);
auto& K = cute::get<2>(problem_shape);
auto& L = cute::get<3>(problem_shape);

auto zero_elements_packed_along_k = get<0>(StrideZero{});
const int scale_k = cute::ceil_div(options.k, options.g);
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/gemm/collective/xe_mma_mixed_input.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ struct CollectiveMma<
}();

if constexpr (ModeScale) {
return Params{tiled_copy_a, tiled_copy_b, tiled_copy_scale, {}, args.group_size};
return Params{tiled_copy_a, tiled_copy_b, {tiled_copy_scale}, {}, args.group_size};
} else {
auto ptr_Z = [&]() {
if constexpr (sizeof_bits_v<NonVoidElementZero> < 8) {
Expand All @@ -353,7 +353,7 @@ struct CollectiveMma<
}
}();

return Params{tiled_copy_a, tiled_copy_b, tiled_copy_scale, tiled_copy_zero, args.group_size};
return Params{tiled_copy_a, tiled_copy_b, {tiled_copy_scale}, {tiled_copy_zero}, args.group_size};
}
}
}
Expand Down
1 change: 1 addition & 0 deletions test/unit/cute/intel_xe/mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n,

#define CUTLASS_ENABLE_DEBUG_PRINTS (0)

#undef LOG_THREAD
#define LOG_THREAD (16)

#if CUTLASS_ENABLE_DEBUG_PRINTS
Expand Down
1 change: 1 addition & 0 deletions test/unit/cute/intel_xe/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ using namespace syclcompat::experimental;

#define CUTLASS_ENABLE_DEBUG_PRINTS (0)
#define LOG_GROUP (0)
#undef LOG_THREAD
#define LOG_THREAD (0)

template <class atype, class btype, class ctype>
Expand Down
Loading