From 19c68baffcb2f4ba84c88c18947dddc304a3d1ec Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Wed, 25 Sep 2024 08:42:23 -0700 Subject: [PATCH 1/2] [SYCL][Matrix] Fix bfloat16 component type matrix muladd Signed-off-by: Sidorov, Dmitry --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index d3d57f24c56e6..7b967cb8a050c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -530,8 +530,8 @@ joint_matrix_mad( else D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #else - if constexpr (std::is_same::value && - std::is_same::value && + if constexpr (std::is_same::value && + std::is_same::value && std::is_same::value) { constexpr uint32_t MatrixOperand = static_cast( __spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL); From a0488cbe112e9168ac64412e882e329f9d40ed9e Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 26 Sep 2024 03:25:27 -0700 Subject: [PATCH 2/2] Add function for matrix operands Signed-off-by: Sidorov, Dmitry --- .../oneapi/matrix/matrix-unified-utils.hpp | 24 +++++++++++++ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 34 +++---------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index 3c8fef515fae0..0510d71b0b564 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -82,6 +82,30 @@ inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv( } } +#ifdef __SPIRV_USE_COOPERATIVE_MATRIX +template +constexpr uint32_t CalculateMatrixOperand() { + if constexpr (std::is_same::value && + std::is_same::value && + std::is_same::value) + return static_cast( + __spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL); + if constexpr (std::is_signed::value && std::is_unsigned::value) + return static_cast( + __spv::MatrixOperands::MatrixASignedComponentsKHR); + if constexpr (std::is_unsigned::value && std::is_signed::value) + return static_cast( + __spv::MatrixOperands::MatrixBSignedComponentsKHR); + if constexpr (std::is_signed::value && std::is_signed::value) { + return static_cast( + __spv::MatrixOperands::MatrixASignedComponentsKHR) + + static_cast( + __spv::MatrixOperands::MatrixBSignedComponentsKHR); + } + return 0; +} +#endif // __SPIRV_USE_COOPERATIVE_MATRIX + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 7b967cb8a050c..c8d2918b6b105 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -530,36 +530,10 @@ joint_matrix_mad( else D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #else - if constexpr (std::is_same::value && - std::is_same::value && - std::is_same::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixAAndBBFloat16ComponentsINTEL); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_signed::value && - std::is_unsigned::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixASignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_unsigned::value && - std::is_signed::value) { - constexpr uint32_t MatrixOperand = static_cast( - __spv::MatrixOperands::MatrixBSignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else if constexpr (std::is_signed::value && std::is_signed::value) { - constexpr uint32_t MatrixOperand = - static_cast( - __spv::MatrixOperands::MatrixASignedComponentsKHR) + - static_cast( - __spv::MatrixOperands::MatrixBSignedComponentsKHR); - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, - MatrixOperand); - } else { - D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm); - } + constexpr uint32_t MatrixOperand = + sycl::detail::CalculateMatrixOperand(); + D.spvm = __spirv_CooperativeMatrixMulAddKHR(A.spvm, B.spvm, C.spvm, + MatrixOperand); #endif // __SPIRV_USE_COOPERATIVE_MATRIX #endif // defined(__NVPTX__) #else