From 3c05d0cbaa97565b5f52f8977eda374e73749c98 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 3 Nov 2023 16:46:41 +0000 Subject: [PATCH] - Replace for loop to copy restults with std::memcpy which showed slightly a better performance. --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 7f9f9b1219cf4..57a052d8fb9f3 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -344,15 +344,13 @@ void joint_matrix_mad_hip( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 4; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 4 * sizeof(float)); } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_f32_32x32x8f16( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 16; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 16 * sizeof(float)); } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { @@ -360,23 +358,20 @@ void joint_matrix_mad_hip( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 4; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 4 * sizeof(float)); } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 16; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 16 * sizeof(float)); } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { auto result = __builtin_amdgcn_mfma_f64_16x16x4f64( A.wi_marray[0], B.wi_marray[0], *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 4; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 4 * sizeof(double)); } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { @@ -384,15 +379,13 @@ void joint_matrix_mad_hip( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 4; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 4 * sizeof(int32_t)); } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_i32_32x32x8i8( *reinterpret_cast(&A.wi_marray), *reinterpret_cast(&B.wi_marray), *reinterpret_cast(&C.wi_marray), 0, 0, 0); - for (int i = 0; i < 16; ++i) - D.wi_marray[i] = result[i]; + std::memcpy(&D.wi_marray, &result, 16 * sizeof(int32_t)); } } }