From 87641e5af1f48f1cc25981a39223f36d7dc9f226 Mon Sep 17 00:00:00 2001 From: aacostadiaz Date: Thu, 4 Apr 2024 17:19:57 +0100 Subject: [PATCH] Remove the sgemm_nt_1_sycl PoC (#15) * Remove sgemm_nt_1 PoC * Fix build issues * Fix code style format * Remove ENABLE_NVPTX flag * Update include/cute/util/debug.hpp Co-authored-by: Mehdi Goli * Cosmetic --------- Co-authored-by: Mehdi Goli --- examples/cute/tutorial/CMakeLists.txt | 47 +- examples/cute/tutorial/sgemm_nt_1_sycl.cpp | 446 ------------------ include/cute/arch/copy_sm75.hpp | 7 +- include/cute/arch/copy_sm80.hpp | 12 +- include/cute/arch/mma_sm80.hpp | 123 ++--- include/cute/config.hpp | 9 - include/cute/util/debug.hpp | 14 +- include/cutlass/arch/arch.h | 6 +- include/cutlass/arch/memory.h | 48 +- include/cutlass/barrier.h | 2 - include/cutlass/detail/helper_macros.hpp | 12 - tools/util/include/cutlass/util/GPU_Clock.hpp | 4 + 12 files changed, 123 insertions(+), 607 deletions(-) delete mode 100644 examples/cute/tutorial/sgemm_nt_1_sycl.cpp diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index 2389d419d..c91d98109 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -27,34 +27,27 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if (CUTLASS_ENABLE_SYCL) - cutlass_example_add_executable( - sgemm_nt_1_sycl - sgemm_nt_1_sycl.cpp - ) -else () - cutlass_example_add_executable( - sgemm_1 - sgemm_1.cu - ) +cutlass_example_add_executable( + sgemm_1 + sgemm_1.cu +) - cutlass_example_add_executable( - sgemm_2 - sgemm_2.cu - ) +cutlass_example_add_executable( + sgemm_2 + sgemm_2.cu +) - cutlass_example_add_executable( - sgemm_sm70 - sgemm_sm70.cu - ) +cutlass_example_add_executable( + sgemm_sm70 + sgemm_sm70.cu +) - cutlass_example_add_executable( - sgemm_sm80 - sgemm_sm80.cu - ) +cutlass_example_add_executable( + sgemm_sm80 + sgemm_sm80.cu +) - cutlass_example_add_executable( - tiled_copy - tiled_copy.cu - ) -endif () +cutlass_example_add_executable( + tiled_copy + tiled_copy.cu +) diff --git a/examples/cute/tutorial/sgemm_nt_1_sycl.cpp b/examples/cute/tutorial/sgemm_nt_1_sycl.cpp deleted file mode 100644 index 56e316098..000000000 --- a/examples/cute/tutorial/sgemm_nt_1_sycl.cpp +++ /dev/null @@ -1,446 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#include -#include - -#include -#include -#include - -#include "cutlass/util/print_error.hpp" -#include "cutlass/util/GPU_Clock.hpp" -#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -# include "cutlass/util/cublas_wrappers.hpp" -#endif - - -template -void gemm_device(MShape M, NShape N, KShape K, - TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, - TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, - TC * C, CStride dC, CBlockLayout , CThreadLayout tC, - Alpha alpha, Beta beta) -{ - using namespace cute; - using X = Underscore; - - // Preconditions - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - - CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); - CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); - - //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M - //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N - CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K - - // Shared memory buffers - auto smemA = syclcompat::local_mem]>(); - auto smemB = syclcompat::local_mem]>(); - auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) - auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) - - // Represent the full tensors - auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) - auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) - auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) - - // Get the appropriate blocks for this thread block -- - // potential for thread block locality - auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) - auto blk_coord = make_coord(syclcompat::work_group_id::x(), syclcompat::work_group_id::y(), _); // (m,n,k) - - auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - // - // Partition the copying of A and B tiles across the threads - // - - // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB - // Default is a raked partition, but can be changed with Step parameter - - auto tAgA = local_partition(gA, tA, syclcompat::local_id::x()); // (THR_M,THR_K,k) - auto tAsA = local_partition(sA, tA, syclcompat::local_id::x()); // (THR_M,THR_K) - - auto tBgB = local_partition(gB, tB, syclcompat::local_id::x()); // (THR_N,THR_K,k) - auto tBsB = local_partition(sB, tB, syclcompat::local_id::x()); // (THR_N,THR_K) - - // - // Define C accumulators and A/B partitioning - // - - // TUTORIAL: Example of partitioning via projections of tC - - // Partition sA (M,K) by the rows of tC - auto tCsA = local_partition(sA, tC, syclcompat::local_id::x(), Step<_1, X>{}); // (THR_M,BLK_K) - // Partition sB (N,K) by the cols of tC - auto tCsB = local_partition(sB, tC, syclcompat::local_id::x(), Step< X,_1>{}); // (THR_N,BLK_K) - // Partition gC (M,N) by the tile of tC - auto tCgC = local_partition(gC, tC, syclcompat::local_id::x(), Step<_1,_1>{}); // (THR_M,THR_N) - - // Allocate the accumulators -- same size as the projected data - auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) - - // Clear the accumulators - clear(tCrC); - -#if 0 - if(thread0()) { - print("mA\n"); - print(mA.shape()); print("\n"); print(mA.stride()); - print("\n\ngA\n"); - print(gA.shape()); print("\n"); print(gA.stride()); - print("\n\ntAgA\n"); - print(tAgA.shape()); print("\n"); print(tAgA.stride()); - print("\n\nsA\n"); - print(sA.shape()); print("\n"); print(sA.stride()); - print("\n\ntAsA\n"); - print(tAsA.shape()); print("\n"); print(tAsA.stride()); - print("\n\n"); - } -#endif - -#if 0 - if(thread0()) { - print("mB\n"); - print(mB.shape()); print("\n"); print(mB.stride()); - print("\n\ngB\n"); - print(gB.shape()); print("\n"); print(gB.stride()); - print("\n\ntBgB\n"); - print(tBgB.shape()); print("\n"); print(tBgB.stride()); - print("\n\nsB\n"); - print(sB.shape()); print("\n"); print(sB.stride()); - print("\n\ntBsB\n"); - print(tBsB.shape()); print("\n"); print(tBsB.stride()); - print("\n\n"); - } -#endif - -#if 0 - if(thread0()) { - print("mC\n"); - print(mC.shape()); print("\n"); print(mC.stride()); - print("\n\ngC\n"); - print(gC.shape()); print("\n"); print(gC.stride()); - print("\n\ntCsA\n"); - print(tCsA.shape()); print("\n"); print(tCsA.stride()); - print("\n\ntCsB\n"); - print(tCsB.shape()); print("\n"); print(tCsB.stride()); - print("\n\ntCgC\n"); - print(tCgC.shape()); print("\n"); print(tCgC.stride()); - print("\n\ntCrC\n"); - print(tCrC.shape()); print("\n"); print(tCrC.stride()); - print("\n\n"); - } -#endif - -#if 1 - - // TUTORIAL: Example of a very simple compute loop - // Data is read from global to shared memory via the tA|tB partitioning - // gemm(.) operates on the shared memory directly via the tC partitioning - - auto k_max = size<2>(tAgA); - - for (int k = 0; k < k_max; ++k) - { - // Copy gmem to smem - copy(tAgA(_,_,k), tAsA); - copy(tBgB(_,_,k), tBsB); - - // In case copy uses cp.async, make sure that the cp.async - // instructions are ordered with respect to other cp.async - // instructions (fence), then wait on all the outstanding copy - // operations (wait<0>()). __syncthreads() alone does not do - // this. - // - // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. - // This is equivalent to cp.async.commit_group followed by - // cp.async_wait_group 0. This should make the first - // cp_async_fence() (which also issues cp.async.commit_group) - // redundant. The tutorial works as-is, so we'll leave the - // redundant fence in for now and study its removal later. - cp_async_fence(); - cp_async_wait<0>(); - - syclcompat::wg_barrier(); - - // Compute gemm on smem - gemm(tCsA, tCsB, tCrC); - - syclcompat::wg_barrier(); - } - -#endif - - // - // Epilogue - // - - axpby(alpha, tCrC, beta, tCgC); -} - - -template -void -gemm(sycl::queue q, int m, int n, int k, - Alpha alpha, - TA const* A, int ldA, - TB const* B, int ldB, - Beta beta, - TC * C, int ldC) -{ - using namespace cute; - - // Define shapes (dynamic) - auto M = int(m); - auto N = int(n); - auto K = int(k); - - // Define strides (mixed) - auto dA = make_stride(Int<1>{}, ldA); - auto dB = make_stride(Int<1>{}, ldB); - auto dC = make_stride(Int<1>{}, ldC); - - // Define block sizes (static) - auto bM = Int<128>{}; - auto bN = Int<128>{}; - auto bK = Int< 8>{}; - - // Define the block layouts (static) - auto sA = make_layout(make_shape(bM,bK)); - auto sB = make_layout(make_shape(bN,bK)); - auto sC = make_layout(make_shape(bM,bN)); - - // Define the thread layouts (static) - auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); - - const auto block = syclcompat::dim3(size(tC)); - const auto grid = syclcompat::dim3(ceil_div(size(M), size(bM)), - ceil_div(size(N), size(bN))); - - syclcompat::launch< - gemm_device< - int, int, int, - TA, decltype(dA), decltype(sA), decltype(tA), - TB, decltype(dB), decltype(sB), decltype(tB), - TC, decltype(dC), decltype(sC), decltype(tC), - Alpha, Beta - > - >(grid, block, q, M, N, K, - A, dA, sA, tA, - B, dB, sB, tB, - C, dC, sC, tC, - alpha, beta); -} - -#include -#include -#include -#include - -void test_gemm(int m, int n, int k) -{ - auto q = sycl::queue { sycl::gpu_selector_v } ; - - std::cout << "M = " << m << std::endl; - std::cout << "N = " << n << std::endl; - std::cout << "K = " << k << std::endl; - - using TA = float; - using TB = float; - using TC = float; - using TI = float; - - thrust::host_vector h_A(m*k); - thrust::host_vector h_B(n*k); - thrust::host_vector h_C(m*n); - - for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); - for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); - for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); - - auto d_A = sycl::malloc_device(m*k, q); - auto d_B = sycl::malloc_device(n*k, q); - auto d_C = sycl::malloc_device(m*n, q); - - q.memcpy(d_A, h_A.data(), m*k * sizeof(TA)).wait(); - q.memcpy(d_B, h_B.data(), n*k * sizeof(TB)).wait(); - q.memcpy(d_C, h_C.data(), m*n * sizeof(TC)).wait(); - - TI alpha = 1.0; - TI beta = 0.0; - - double tflops = (2.0*m*n*k) * 1e-12; - - const int timing_iterations = 100; - GPU_Clock timer; - -#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 - // - // cuBLas - // - - cublasHandle_t handle; - cublasCreate(&handle); - - thrust::device_vector dc_A = h_A; - thrust::device_vector dc_B = h_B; - thrust::device_vector dc_C = h_C; - - // Run once - blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, - m, n, k, - &alpha, - dc_A.data().get(), m, - dc_B.data().get(), n, - &beta, - dc_C.data().get(), m); - CUTE_CHECK_LAST(); - - thrust::host_vector cublas_result = dc_C; - - // Timing iterations - timer.start(); - for (int i = 0; i < timing_iterations; ++i) { - blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, - m, n, k, - &alpha, - dc_A.data().get(), m, - dc_B.data().get(), n, - &beta, - dc_C.data().get(), m); - } - double cublas_time = timer.seconds() / timing_iterations; - CUTE_CHECK_LAST(); - printf("CUBLAS_GEMM: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cublas_time, cublas_time*1000); - -#else - - std::cout << "Verification by comparison with cuBLAS is disabled, " - "either because the CMake option CUTLASS_ENABLE_CUBLAS " - "was explicitly set to OFF, or because CMake could not find cuBLAS. " - "If you would like to enable verification with cuBLAS, " - "please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, " - "rerun CMake, and recompile this example.\n"; - -#endif // CUTLASS_ENABLE_CUBLAS - - // - // CuTe - // - - // Run once (and check) - - gemm(q, m, n, k, - alpha, - d_A, m, - d_B, n, - beta, - d_C, m); - CUTE_CHECK_LAST(); - q.wait_and_throw(); - - // Timing iterations - - timer.start(); - for (int i = 0; i < timing_iterations; ++i) { - gemm(q, m, n, k, - alpha, - d_A, m, - d_B, n, - beta, - d_C, m); - } - - q.wait(); - - double cute_time = timer.seconds() / timing_iterations; - CUTE_CHECK_LAST(); - printf("SYCL_CUTE_GEMM: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); - - std::vector cute_result(m*n); - q.memcpy(cute_result.data(), d_C, m*n * sizeof(TC)).wait(); - -#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 - printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100); - - auto host_matrix_to_const_column_major_cute_tensor = - [](const auto& X, int num_rows, int num_cols, int LDX) { - const auto shape = cute::Shape{num_rows, num_cols}; - const auto strides = cute::Stride{1, LDX}; - return cute::make_tensor(X.data(), cute::make_layout(shape, strides)); - }; - - const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m); - // B^T is k x n, so B is n x k. - const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n); - const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m); - const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m); - print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view); - -#endif // CUTLASS_ENABLE_CUBLAS -} - - -int main(int argc, char** argv) -{ - int m = 5120; - if (argc >= 2) - sscanf(argv[1], "%d", &m); - - int n = 5120; - if (argc >= 3) - sscanf(argv[2], "%d", &n); - - int k = 4096; - if (argc >= 4) - sscanf(argv[3], "%d", &k); - - test_gemm(m, n, k); - - return 0; -} diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp index c244abce0..30d86b7ba 100644 --- a/include/cute/arch/copy_sm75.hpp +++ b/include/cute/arch/copy_sm75.hpp @@ -56,11 +56,8 @@ #define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED) #endif -#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 -#endif - -#if ((CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 750) +#if (CUTE_ARCH_LDSM_SM75_ENABLED) && \ + ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750) || (defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 750)) #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 #endif diff --git a/include/cute/arch/copy_sm80.hpp b/include/cute/arch/copy_sm80.hpp index ae488b53d..9b7ab1168 100644 --- a/include/cute/arch/copy_sm80.hpp +++ b/include/cute/arch/copy_sm80.hpp @@ -57,7 +57,7 @@ struct SM80_CP_ASYNC_CACHEALWAYS copy(TS const& gmem_src, TD & smem_dst) { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) TS const* gmem_ptr = &gmem_src; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" @@ -84,7 +84,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL copy(TS const& gmem_src, TD & smem_dst) { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) TS const* gmem_ptr = &gmem_src; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" @@ -112,7 +112,7 @@ struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL TD & smem_dst, bool pred) { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) TS const* gmem_ptr = &gmem_src; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); int src_size = pred ? sizeof(TS) : 0; @@ -142,7 +142,7 @@ struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL TD & smem_dst, bool pred) { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) TS const* gmem_ptr = &gmem_src; uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); int src_size = pred ? sizeof(TS) : 0; @@ -164,7 +164,7 @@ CUTE_HOST_DEVICE void cp_async_fence() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) asm volatile("cp.async.commit_group;\n" ::); #endif } @@ -177,7 +177,7 @@ CUTE_HOST_DEVICE void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) if constexpr (N == 0) { asm volatile("cp.async.wait_all;\n" ::); } else { diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 4aff3fe56..c217c95f3 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -36,7 +36,8 @@ #include // Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) || (defined(__SYCL_CUDA_ARCH__) && (__SYCL_CUDA_ARCH__ >= 800)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) || \ + (defined(__SYCL_CUDA_ARCH__) && (__SYCL_CUDA_ARCH__ >= 800)) # define CUTE_ARCH_MMA_SM80_ENABLED #if (__CUDA_ARCH__ <= 900) @@ -69,7 +70,7 @@ struct SM80_16x8x8_F16F16F16F16_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," @@ -102,7 +103,7 @@ struct SM80_16x8x16_F16F16F16F16_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," @@ -135,7 +136,7 @@ struct SM80_16x8x8_F32F16F16F32_TN uint32_t const& b0, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," @@ -168,7 +169,7 @@ struct SM80_16x8x16_F32F16F16F32_TN uint32_t const& b0, uint32_t const& b1, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," @@ -201,7 +202,7 @@ struct SM80_16x8x8_F32BF16BF16F32_TN uint32_t const& b0, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," @@ -234,7 +235,7 @@ struct SM80_16x8x16_F32BF16BF16F32_TN uint32_t const& b0, uint32_t const& b1, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," @@ -267,7 +268,7 @@ struct SM80_16x8x4_F32TF32TF32F32_TN uint32_t const& b0, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " "{%0, %1, %2, %3}," @@ -300,7 +301,7 @@ struct SM80_16x8x8_F32TF32TF32F32_TN uint32_t const& b0, uint32_t const& b1, float const & c0, float const & c1, float const & c2, float const & c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " "{%0, %1, %2, %3}," @@ -333,7 +334,7 @@ struct SM80_8x8x4_F64F64F64F64_TN double const& b0, double const& c0, double const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " "{%0, %1}," @@ -470,7 +471,7 @@ struct SM80_8x8x16_S32S8S8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " "{%0, %1}," @@ -503,7 +504,7 @@ struct SM80_8x8x16_S32S8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " "{%0, %1}," @@ -536,7 +537,7 @@ struct SM80_16x8x16_S32S8S8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " "{%0, %1, %2, %3}," @@ -569,7 +570,7 @@ struct SM80_16x8x16_S32S8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " "{%0, %1, %2, %3}," @@ -602,7 +603,7 @@ struct SM80_16x8x32_S32S8S8S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " "{%0, %1, %2, %3}," @@ -635,7 +636,7 @@ struct SM80_16x8x32_S32S8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " "{%0, %1, %2, %3}," @@ -668,7 +669,7 @@ struct SM80_8x8x16_S32S8U8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " "{%0, %1}," @@ -701,7 +702,7 @@ struct SM80_8x8x16_S32S8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " "{%0, %1}," @@ -734,7 +735,7 @@ struct SM80_16x8x16_S32S8U8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " "{%0, %1, %2, %3}," @@ -767,7 +768,7 @@ struct SM80_16x8x16_S32S8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " "{%0, %1, %2, %3}," @@ -800,7 +801,7 @@ struct SM80_16x8x32_S32S8U8S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " "{%0, %1, %2, %3}," @@ -833,7 +834,7 @@ struct SM80_16x8x32_S32S8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " "{%0, %1, %2, %3}," @@ -866,7 +867,7 @@ struct SM80_8x8x16_S32U8S8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " "{%0, %1}," @@ -899,7 +900,7 @@ struct SM80_8x8x16_S32U8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " "{%0, %1}," @@ -932,7 +933,7 @@ struct SM80_16x8x16_S32U8S8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " "{%0, %1, %2, %3}," @@ -965,7 +966,7 @@ struct SM80_16x8x16_S32U8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " "{%0, %1, %2, %3}," @@ -998,7 +999,7 @@ struct SM80_16x8x32_S32U8S8S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " "{%0, %1, %2, %3}," @@ -1031,7 +1032,7 @@ struct SM80_16x8x32_S32U8S8S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " "{%0, %1, %2, %3}," @@ -1064,7 +1065,7 @@ struct SM80_8x8x16_S32U8U8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " "{%0, %1}," @@ -1097,7 +1098,7 @@ struct SM80_8x8x16_S32U8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " "{%0, %1}," @@ -1130,7 +1131,7 @@ struct SM80_16x8x16_S32U8U8S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " "{%0, %1, %2, %3}," @@ -1163,7 +1164,7 @@ struct SM80_16x8x16_S32U8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " "{%0, %1, %2, %3}," @@ -1196,7 +1197,7 @@ struct SM80_16x8x32_S32U8U8S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " "{%0, %1, %2, %3}," @@ -1229,7 +1230,7 @@ struct SM80_16x8x32_S32U8U8S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " "{%0, %1, %2, %3}," @@ -1262,7 +1263,7 @@ struct SM80_8x8x32_S32S4S4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " "{%0, %1}," @@ -1295,7 +1296,7 @@ struct SM80_8x8x32_S32S4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " "{%0, %1}," @@ -1328,7 +1329,7 @@ struct SM80_16x8x32_S32S4S4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " "{%0, %1, %2, %3}," @@ -1361,7 +1362,7 @@ struct SM80_16x8x32_S32S4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1394,7 +1395,7 @@ struct SM80_16x8x64_S32S4S4S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " "{%0, %1, %2, %3}," @@ -1427,7 +1428,7 @@ struct SM80_16x8x64_S32S4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1460,7 +1461,7 @@ struct SM80_8x8x32_S32S4U4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " "{%0, %1}," @@ -1493,7 +1494,7 @@ struct SM80_8x8x32_S32S4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " "{%0, %1}," @@ -1526,7 +1527,7 @@ struct SM80_16x8x32_S32S4U4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " "{%0, %1, %2, %3}," @@ -1559,7 +1560,7 @@ struct SM80_16x8x32_S32S4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1592,7 +1593,7 @@ struct SM80_16x8x64_S32S4U4S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " "{%0, %1, %2, %3}," @@ -1625,7 +1626,7 @@ struct SM80_16x8x64_S32S4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1658,7 +1659,7 @@ struct SM80_8x8x32_S32U4S4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " "{%0, %1}," @@ -1691,7 +1692,7 @@ struct SM80_8x8x32_S32U4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " "{%0, %1}," @@ -1724,7 +1725,7 @@ struct SM80_16x8x32_S32U4S4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " "{%0, %1, %2, %3}," @@ -1757,7 +1758,7 @@ struct SM80_16x8x32_S32U4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1790,7 +1791,7 @@ struct SM80_16x8x64_S32U4S4S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " "{%0, %1, %2, %3}," @@ -1823,7 +1824,7 @@ struct SM80_16x8x64_S32U4S4S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1856,7 +1857,7 @@ struct SM80_8x8x32_S32U4U4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " "{%0, %1}," @@ -1889,7 +1890,7 @@ struct SM80_8x8x32_S32U4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " "{%0, %1}," @@ -1922,7 +1923,7 @@ struct SM80_16x8x32_S32U4U4S32_TN uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " "{%0, %1, %2, %3}," @@ -1955,7 +1956,7 @@ struct SM80_16x8x32_S32U4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " "{%0, %1, %2, %3}," @@ -1988,7 +1989,7 @@ struct SM80_16x8x64_S32U4U4S32_TN uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " "{%0, %1, %2, %3}," @@ -2021,7 +2022,7 @@ struct SM80_16x8x64_S32U4U4S32_TN_SATURATE uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " "{%0, %1, %2, %3}," @@ -2056,7 +2057,7 @@ struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1}," @@ -2089,7 +2090,7 @@ struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," @@ -2122,7 +2123,7 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC uint32_t const& b0, uint32_t const& b1, uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) { -#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) && defined(ENABLE_NVPTX) +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " "{%0, %1, %2, %3}," diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 750cdf798..4cf38929e 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -44,15 +44,6 @@ # define CUTE_HOST inline #endif // CUTE_HOST_DEVICE, CUTE_DEVICE -#if defined(CUTLASS_ENABLE_SYCL) -// the flag ENABLE_NVPTX should be set to 1 for SYCL Nvidia backend and CUDA backend. However, this flag will be set to 0 for SYCL backend on non-Nvidia devices -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -# define ENABLE_NVPTX 1 -#endif -#else -# define ENABLE_NVPTX 1 -#endif - #if defined(__CUDACC_RTC__) # define CUTE_HOST_RTC CUTE_HOST_DEVICE #else diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 5e421ec3b..35ce315c9 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -128,14 +128,9 @@ block(int bid) { #if defined(CUTLASS_ENABLE_SYCL) using sycl::ext::oneapi::experimental::this_nd_item; - return (this_nd_item<3>.get_linear_id()==bid); + return (this_nd_item<3>().get_group_linear_id()==bid); #elif defined(__CUDA_ARCH__) return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; -#elif defined(CUTLASS_ENABLE_SYCL) - using namespace syclcompat; - return (work_group_id::x() + - work_group_id::y() * global_range::x() + - work_group_id::z() * global_range::x() * global_range::y() == bid); #else return true; #endif @@ -147,14 +142,9 @@ thread(int tid, int bid) { #if defined(CUTLASS_ENABLE_SYCL) using sycl::ext::oneapi::experimental::this_nd_item; - return (this_nd_item<3>.get_linear_id()==bid); + return (this_nd_item<3>().get_global_linear_id()==bid); #elif defined(__CUDA_ARCH__) return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); -#elif defined(CUTLASS_ENABLE_SYCL) - using namespace syclcompat; - return (local_id::x() + - local_id::y() * work_group_range::x() + - local_id::z() * work_group_range::x() * work_group_range::y() == tid) && block(bid); #else return true; #endif diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 6ee097eab..6c7941735 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -47,7 +47,7 @@ namespace arch { CUTLASS_DEVICE int LaneId() { int ret; -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); #endif return ret; @@ -57,7 +57,7 @@ int LaneId() { CUTLASS_DEVICE int SmId() { int ret; -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); #endif return ret; @@ -100,7 +100,7 @@ struct Sm90 { /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { -#if (defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__)) && defined(ENABLE_NVPTX) +#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) asm volatile (" brkpt;\n"); #endif } diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index a3beec223..7a7c7bb98 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -81,7 +81,7 @@ struct global_load(&D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -119,7 +119,7 @@ struct global_load(&D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -152,7 +152,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -181,7 +181,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -206,7 +206,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -233,7 +233,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -256,7 +256,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -282,7 +282,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -304,7 +304,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -330,7 +330,7 @@ struct global_load(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -377,7 +377,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint4 const *data = reinterpret_cast(&D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -405,7 +405,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint4 const *data = reinterpret_cast(&D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -426,7 +426,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint4 const &data = reinterpret_cast(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -444,7 +444,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint2 const &data = reinterpret_cast(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -462,7 +462,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint32_t const &data = reinterpret_cast(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -480,7 +480,7 @@ struct global_store { CUTLASS_DEVICE global_store(AccessType const &D, void *ptr, bool pred_guard) { uint16_t const &data = reinterpret_cast(D); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile( "{\n" " .reg .pred p;\n" @@ -513,7 +513,7 @@ void shared_load(void *dst, uint32_t ptr); template <> CUTLASS_DEVICE void shared_load<2>(void *dst, uint32_t ptr) { -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("ld.shared.u16 %0, [%1];\n" : "=h"(*reinterpret_cast(dst)) : "r"(ptr)); @@ -524,7 +524,7 @@ void shared_load<2>(void *dst, uint32_t ptr) { template <> CUTLASS_DEVICE void shared_load<4>(void *dst, uint32_t ptr) { -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("ld.shared.u32 %0, [%1];\n" : "=r"(*reinterpret_cast(dst)) : "r"(ptr)); @@ -536,7 +536,7 @@ template <> CUTLASS_DEVICE void shared_load<8>(void *dst, uint32_t ptr) { uint2 *dst_u64 = reinterpret_cast(dst); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" : "=r"(dst_u64->x), @@ -550,7 +550,7 @@ template <> CUTLASS_DEVICE void shared_load<16>(void *dst, uint32_t ptr) { uint4 *dst_u128 = reinterpret_cast(dst); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst_u128->x), @@ -573,7 +573,7 @@ void shared_store(uint32_t ptr, void const *src); template <> CUTLASS_DEVICE void shared_store<2>(uint32_t ptr, void const *src) { -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("st.shared.u16 [%0], %1;\n" : : "r"(ptr), @@ -586,7 +586,7 @@ void shared_store<2>(uint32_t ptr, void const *src) { template <> CUTLASS_DEVICE void shared_store<4>(uint32_t ptr, void const *src) { -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(ptr), @@ -600,7 +600,7 @@ template <> CUTLASS_DEVICE void shared_store<8>(uint32_t ptr, void const *src) { uint2 const *dst_u64 = reinterpret_cast(src); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" : : "r"(ptr), @@ -615,7 +615,7 @@ template <> CUTLASS_DEVICE void shared_store<16>(uint32_t ptr, void const *src) { uint4 const *dst_u128 = reinterpret_cast(src); -#if defined(ENABLE_NVPTX) +#if !defined(CUTLASS_ENABLE_SYCL) || defined(__SYCL_CUDA_ARCH__) asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr), diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 1e688ebb5..9b2362a9c 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -120,10 +120,8 @@ struct GenericBarrier { // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data // that was weakly-written by other threads prior to the last syncthreads) -#if defined(ENABLE_NVPTX) asm volatile ("fence.acq_rel.gpu;\n"); asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); -#endif #else threadfence(); atomicAdd(ptr, val); diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index ac64b9adc..2c755b865 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -159,18 +159,6 @@ namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTLASS_ENABLE_SYCL) -// the flag ENABLE_NVPTX should be set to 1 for SYCL Nvidia backend and CUDA backend. However, this flag will be set to 0 for SYCL backend on non-Nvidia devices -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) -// the flag ENABLE_NVPTX should be set to 1 for SYCL Nvidia backend and CUDA backend. However, this flag will be set to 0 for SYCL backend on non-Nvidia devices -# define ENABLE_NVPTX 1 -#endif -#else -# define ENABLE_NVPTX 1 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - }; // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/GPU_Clock.hpp b/tools/util/include/cutlass/util/GPU_Clock.hpp index 72cfec3d5..76f0c86df 100644 --- a/tools/util/include/cutlass/util/GPU_Clock.hpp +++ b/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -33,6 +33,10 @@ #include +#ifdef CUTLASS_ENABLE_SYCL +#include +#endif + struct GPU_Clock { GPU_Clock() {