-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] SM_90 integratation test of TMA 128x64xf16
and 64x64xf16
with 128b Swizzling
#65954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir ChangesThe #65953 added a test
The program tests the loaded data for Matrix-B.Patch is 64.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65954.diff 1 Files Affected:
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir new file mode 100644 index 000000000000000..b3c55c87405c507 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir @@ -0,0 +1,245 @@ +// RUN: mlir-opt %s --convert-nvgpu-to-nvvm \ +// RUN: -convert-linalg-to-loops \ +// RUN: -canonicalize -cse \ +// RUN: -gpu-kernel-outlining \ +// RUN: -canonicalize -cse \ +// RUN: -convert-vector-to-scf \ +// RUN: -canonicalize -cse \ +// RUN: -lower-affine \ +// RUN: -canonicalize -cse \ +// RUN: -convert-scf-to-cf \ +// RUN: -canonicalize -cse \ +// RUN: -convert-nvvm-to-llvm \ +// RUN: -canonicalize -cse \ +// RUN: -convert-nvgpu-to-nvvm \ +// RUN: -canonicalize -cse \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -canonicalize -cse \ +// RUN: -convert-math-to-llvm \ +// RUN: -canonicalize -cse \ +// RUN: -lower-affine \ +// RUN: -convert-index-to-llvm=index-bitwidth=32 \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm='use-opaque-pointers=1' \ +// RUN: -convert-func-to-llvm \ +// RUN: -canonicalize -cse \ +// RUN: -expand-strided-metadata --nvvm-attach-target="module=main_kernel features=+ptx80 chip=sm_90 O=3" \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-index-to-llvm{index-bitwidth=32},canonicalize,cse))' \ +// RUN: | mlir-opt --gpu-to-llvm --gpu-module-to-binary -canonicalize -cse -reconcile-unrealized-casts \ +// RUN: | mlir-translate --mlir-to-llvmir -o %t.ll +// RUN: clang %t.ll -O3 %mlir_cuda_runtime %mlir_runner_utils -o %t.exe +// RUN: LD_LIBRARY_PATH=%shlibdir %t.exe | FileCheck %s + +// This does 3 TMA loads with 128B Swizzling : +// TMA Load: Matrix-A[0:128][0:64] +// TMA Load: Matrix-B[0:64][0:64] +// TMA Load: Matrix-B[64:128][0:64] + +// Test swizzling with TMA load +// 128B Swizzle Each numbered cell is 16 byte +// |-------------------------------| +// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | +// | 1 | 0 | 3 | 2 | 5 | 4 | 7 | 6 | +// | 2 | 3 | 0 | 1 | 6 | 7 | 4 | 5 | +// | 3 | 2 | 1 | 0 | 7 | 6 | 5 | 4 | +// | 4 | 5 | 6 | 7 | 0 | 1 | 2 | 3 | +// | 5 | 4 | 7 | 6 | 1 | 0 | 3 | 2 | +// | 6 | 7 | 4 | 5 | 2 | 3 | 0 | 1 | +// |-------------------------------| +// | ... pattern repeats ... | +// |-------------------------------| + + +!barrierType = !nvgpu.mbarrier.barrier> +!tokenType = !nvgpu.mbarrier.token + +!lhs = memref<128x64xf16> +!shmemlhs = memref<128x64xf16, 3> +!lhsTensorMap = !nvgpu.tensormap.descriptor + +!rhs = memref<128x64xf16> +!shmemrhs = memref<128x64xf16, 3> +!rhsTensorMap = !nvgpu.tensormap.descriptor + +module @mymod { + func.func private @printMemrefF32(memref<*xf32>) + memref.global "private" @bufferLhsGlobal : !shmemlhs + memref.global "private" @bufferRhsGlobal : !shmemrhs + llvm.func @printf(!llvm.ptr, ...) -> i32 + func.func @main() { + %c32768 = arith.constant 32768 : index + %c-1_i32 = arith.constant -1 : i32 + %c10000000 = arith.constant 10000000 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c8 = arith.constant 8 : index + + // Step 1. Allocate host data and initialize it. + %lhs = memref.alloc() : !lhs + %rhs = memref.alloc() : !rhs + %lhs32 = memref.alloc() : memref<128x64xf32> + %rhs32 = memref.alloc() : memref<64x128xf32> + scf.for %i = %c0 to %c64 step %c1 { + scf.for %j = %c0 to %c128 step %c1 { + %v0 = arith.muli %i, %c128 : index + %v00 = arith.addi %v0, %j : index + %v01 = arith.divui %v00, %c8 : index + %v2 = arith.index_cast %v01 : index to i32 + %vR = arith.sitofp %v2 : i32 to f16 + memref.store %vR, %rhs[%i, %j] : !rhs + %vR32 = arith.extf %vR : f16 to f32 + memref.store %vR32, %rhs32[%i, %j] : memref<64x128xf32> + } + } + scf.for %j = %c0 to %c128 step %c1 { + scf.for %i = %c0 to %c64 step %c1 { + %b0 = arith.muli %j, %c64 : index + %b00 = arith.addi %b0, %i : index + %b01 = arith.divui %b00, %c8 : index + %v1 = arith.index_cast %b01 : index to i32 + %vL = arith.sitofp %v1 : i32 to f16 + memref.store %vL, %lhs[%j, %i] : !lhs + %vL32 = arith.extf %vL : f16 to f32 + memref.store %vL32, %lhs32[%j, %i] : memref<128x64xf32> + } + } + + // Step 2. Print on the host + %lhs32_unranked = memref.cast %lhs32 : memref<128x64xf32> to memref<*xf32> + call @printMemrefF32(%lhs32_unranked) : (memref<*xf32>) -> () + %rhs32_unranked = memref.cast %rhs32 : memref<64x128xf32> to memref<*xf32> + call @printMemrefF32(%rhs32_unranked) : (memref<*xf32>) -> () + + // Step 3. Copy host to device + %0 = gpu.wait async + %d_glbmem_lhs, %asyncToken = gpu.alloc async [%0] () : !lhs + %d_glbmem_rhs, %asyncToken_2 = gpu.alloc async [%0] () : !rhs + %1 = gpu.memcpy async [%0] %d_glbmem_lhs, %lhs : !lhs, !lhs + %2 = gpu.memcpy async [%0] %d_glbmem_rhs, %rhs : !rhs, !rhs + + // Step 4. Create TMA tensor descriptor + %d_lhs_unranked = memref.cast %d_glbmem_lhs :!lhs to memref<*xf16> + %d_rhs_unranked = memref.cast %d_glbmem_rhs :!rhs to memref<*xf16> + + %d_lhsTensorMap = nvgpu.tma.create.descriptor %d_lhs_unranked box[%c128, %c64] : memref<*xf16> -> !lhsTensorMap + %d_rhsTensorMap = nvgpu.tma.create.descriptor %d_rhs_unranked box[%c64, %c64] : memref<*xf16> -> !rhsTensorMap + + // Step 5. Launch a GPU kernel + gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c128, %arg10 = %c1, %arg11 = %c1) { + %5 = gpu.block_dim x + %6 = gpu.thread_id x + %lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs + %rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs + %rhsShmem2 = memref.subview %rhsShmem[%c32, %c0][%c32, %c128][%c1, %c1] : !shmemrhs to memref, 3> + + // Step 6. Initialize the mbarrier + %9 = nvgpu.mbarrier.create -> !barrierType + nvgpu.mbarrier.init %9, %5 : !barrierType + %10 = arith.cmpi eq, %6, %c0 : index + + + // Step 7. First thread does TMA load + scf.if %10 { + gpu.printf "[GPU] TMA SIZE %d\0A" %c32768 : index + nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9 to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs + nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9 to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs + nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9 to %rhsShmem2 : !rhsTensorMap, !barrierType -> memref, 3> + nvgpu.mbarrier.arrive.expect_tx %9, %c32768 : !barrierType + } else { + nvgpu.mbarrier.arrive.expect_tx %9, %c0 : !barrierType + } + + // Step 8. Wait until TMA is done + nvgpu.mbarrier.try_wait.parity %9, %c0, %c10000000 : !barrierType + + // Step 9. Print loaded data in 128b swizzled + scf.if %10 { + gpu.printf "===--- Matrix B ---=== %d \n" %c-1_i32 : i32 + scf.for %ii = %c0 to %c64 step %c1 { + scf.for %j = %c0 to %c128 step %c1 { + %lhs0 = memref.load %rhsShmem[%ii, %j] : !shmemrhs + %lhs032 = arith.extf %lhs0: f16 to f32 + gpu.printf "%.0f, " %lhs032 : f32 + } + gpu.printf "%d\n" %c-1_i32 : i32 + } + gpu.printf "===----------------=== %d \n" %c-1_i32 : i32 + } + gpu.terminator + } + return + } +} + + +// CHECK: [GPU] TMA SIZE 32768 +// CHECK: ===--- Matrix B ---=== -1 +// CHECK: 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 17, 17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 16, 19, 19, 19, 19, 19, 19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 18, 21, 21, 21, 21, 21, 21, 21, 21, 20, 20, 20, 20, 20, 20, 20, 20, 23, 23, 23, 23, 23, 23, 23, 23, 22, 22, 22, 22, 22, 22, 22, 22, -1 +// CHECK: 17, 17, 17, 17, 17, 17, 17, 17, 16, 16, 16, 16, 16, 16, 16, 16, 19, 19, 19, 19, 19, 19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 18, 21, 21, 21, 21, 21, 21, 21, 21, 20, 20, 20, 20, 20, 20, 20, 20, 23, 23, 23, 23, 23, 23, 23, 23, 22, 22, 22, 22, 22, 22, 22, 22, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, -1 +// CHECK: 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 51, 51, 51, 51, 51, 51, 51, 51, 50, 50, 50, 50, 50, 50, 50, 50, 49, 49, 49, 49, 49, 49, 49, 49, 48, 48, 48, 48, 48, 48, 48, 48, 55, 55, 55, 55, 55, 55, 55, 55, 54, 54, 54, 54, 54, 54, 54, 54, 53, 53, 53, 53, 53, 53, 53, 53, 52, 52, 52, 52, 52, 52, 52, 52, -1 +// CHECK: 51, 51, 51, 51, 51, 51, 51, 51, 50, 50, 50, 50, 50, 50, 50, 50, 49, 49, 49, 49, 49, 49, 49, 49, 48, 48, 48, 48, 48, 48, 48, 48, 55, 55, 55, 55, 55, 55, 55, 55, 54, 54, 54, 54, 54, 54, 54, 54, 53, 53, 53, 53, 53, 53, 53, 53, 52, 52, 52, 52, 52, 52, 52, 52, 68, 68, 68, 68, 68, 68, 68, 68, 69, 69, 69, 69, 69, 69, 69, 69, 70, 70, 70, 70, 70, 70, 70, 70, 71, 71, 71, 71, 71, 71, 71, 71, 64, 64, 64, 64, 64, 64, 64, 64, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 67, 67, 67, 67, 67, 67, 67, 67, -1 +// CHECK: 68, 68, 68, 68, 68, 68, 68, 68, 69, 69, 69, 69, 69, 69, 69, 69, 70, 70, 70, 70, 70, 70, 70, 70, 71, 71, 71, 71, 71, 71, 71, 71, 64, 64, 64, 64, 64, 64, 64, 64, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 67, 67, 67, 67, 67, 67, 67, 67, 85, 85, 85, 85, 85, 85, 85, 85, 84, 84, 84, 84, 84, 84, 84, 84, 87, 87, 87, 87, 87, 87, 87, 87, 86, 86, 86, 86, 86, 86, 86, 86, 81, 81, 81, 81, 81, 81, 81, 81, 80, 80, 80, 80, 80, 80, 80, 80, 83, 83, 83, 83, 83, 83, 83, 83, 82, 82, 82, 82, 82, 82, 82, 82, -1 +// CHECK: 85, 85, 85, 85, 85, 85, 85, 85, 84, 84, 84, 84, 84, 84, 84, 84, 87, 87, 87, 87, 87, 87, 87, 87, 86, 86, 86, 86, 86, 86, 86, 86, 81, 81, 81, 81, 81, 81, 81, 81, 80, 80, 80, 80, 80, 80, 80, 80, 83, 83, 83, 83, 83, 83, 83, 83, 82, 82, 82, 82, 82, 82, 82, 82, 102, 102, 102, 102, 102, 102, 102, 102, 103, 103, 103, 103, 103, 103, 103, 103, 100, 100, 100, 100, 100, 100, 100, 100, 101, 101, 101, 101, 101, 101, 101, 101, 98, 98, 98, 98, 98, 98, 98, 98, 99, 99, 99, 99, 99, 99, 99, 99, 96, 96, 96, 96, 96, 96, 96, 96, 97, 97, 97, 97, 97, 97, 97, 97, -1 +// CHECK: 102, 102, 102, 102, 102, 102, 102, 102, 103, 103, 103, 103, 103, 103, 103, 103, 100, 100, 100, 100, 100, 100, 100, 100, 101, 101, 101, 101, 101, 101, 101, 101, 98, 98, 98, 98, 98, 98, 98, 98, 99, 99, 99, 99, 99, 99, 99, 99, 96, 96, 96, 96, 96, 96, 96, 96, 97, 97, 97, 97, 97, 97, 97, 97, 119, 119, 119, 119, 119, 119, 119, 119, 118, 118, 118, 118, 118, 118, 118, 118, 117, 117, 117, 117, 117, 117, 117, 117, 116, 116, 116, 116, 116, 116, 116, 116, 115, 115, 115, 115, 115, 115, 115, 115, 114, 114, 114, 114, 114, 114, 114, 114, 113, 113, 113, 113, 113, 113, 113, 113, 112, 112, 112, 112, 112, 112, 112, 112, -1 +// CHECK: 119, 119, 119, 119, 119, 119, 119, 119, 118, 118, 118, 118, 118, 118, 118, 118, 117, 117, 117, 117, 117, 117, 117, 117, 116, 116, 116, 116, 116, 116, 116, 116, 115, 115, 115, 115, 115, 115, 115, 115, 114, 114, 114, 114, 114, 114, 114, 114, 113, 113, 113, 113, 113, 113, 113, 113, 112, 112, 112, 112, 112, 112, 112, 112, 128, 128, 128, 128, 128, 128, 128, 128, 129, 129, 129, 129, 129, 129, 129, 129, 130, 130, 130, 130, 130, 130, 130, 130, 131, 131, 131, 131, 131, 131, 131, 131, 132, 132, 132, 132, 132, 132, 132, 132, 133, 133, 133, 133, 133, 133, 133, 133, 134, 134, 134, 134, 134, 134, 134, 134, 135, 135, 135, 135, 135, 135, 135, 135, -1 +// CHECK: 128, 128, 128, 128, 128, 128, 128, 128, 129, 129, 129, 129, 129, 129, 129, 129, 130, 130, 130, 130, 130, 130, 130, 130, 131, 131, 131, 131, 131, 131, 131, 131, 132, 132, 132, 132, 132, 132, 132, 132, 133, 133, 133, 133, 133, 133, 133, 133, 134, 134, 134, 134, 134, 134, 134, 134, 135, 135, 135, 135, 135, 135, 135, 135, 145, 145, 145, 145, 145, 145, 145, 145, 144, 144, 144, 144, 144, 144, 144, 144, 147, 147, 147, 147, 147, 147, 147, 147, 146, 146, 146, 146, 146, 146, 146, 146, 149, 149, 149, 149, 149, 149, 149, 149, 148, 148, 148, 148, 148, 148, 148, 148, 151, 151, 151, 151, 151, 151, 151, 151, 150, 150, 150, 150, 150, 150, 150, 150, -1 +// CHECK: 145, 145, 145, 145, 145, 145, 145, 145, 144, 144, 144, 144, 144, 144, 144, 144, 147, 147, 147, 147, 147, 147, 147, 147, 146, 146, 146, 146, 146, 146, 146, 146, 149, 149, 149, 149, 149, 149, 149, 149, 148, 148, 148, 148, 148, 148, 148, 148, 151, 151, 151, 151, 151, 151, 151, 151, 150, 150, 150, 150, 150, 150, 150, 150, 162, 162, 162, 162, 162, 162, 162, 162, 163, 163, 163, 163, 163, 163, 163, 163, 160, 160, 160, 160, 160, 160, 160, 160, 161, 161, 161, 161, 161, 161, 161, 161, 166, 166, 166, 166, 166, 166, 166, 166, 167, 167, 167, 167, 167, 167, 167, 167, 164, 164, 164, 164, 164, 164, 164, 164, 165, 165, 165, 165, 165, 165, 165, 165, -1 +// CHECK: 162, 162, 162, 162, 162, 162, 162, 162, 163, 163, 163, 163, 163, 163, 163, 163, 160, 160, 160, 160, 160, 160, 160, 160, 161, 161, 161, 161, 161, 161, 161, 161, 166, 166, 166, 166, 166, 166, 166, 166, 167, 167, 167, 167, 167, 167, 167, 167, 164, 164, 164, 164, 164, 164, 164, 164, 165, 165, 165, 165, 165, 165, 165, 165, 179, 179, 179, 179, 179, 179, 179, 179, 178, 178, 178, 178, 178, 178, 178, 178, 177, 177, 177, 177, 177, 177, 177, 177, 176, 176, 176, 176, 176, 176, 176, 176, 183, 183, 183, 183, 183, 183, 183, 183, 182, 182, 182, 182, 182, 182, 182, 182, 181, 181, 181, 181, 181, 181, 181, 181, 180, 180, 180, 180, 180, 180, 180, 180, -1 +// CHECK: 179, 179, 179, 179, 179, 179, 179, 179, 178, 178, 178, 178, 178, 178, 178, 178, 177, 177, 177, 177, 177, 177, 177, 177, 176, 176, 176, 176, 176, 176, 176, 176, 183, 183, 183, 183, 183, 183, 183, 183, 182, 182, 182, 182, 182, 182, 182, 182, 181, 181, 181, 181, 181, 181, 181, 181, 180, 180, 180, 180, 180, 180, 180, 180, 196, 196, 196, 196, 196, 196, 196, 196, 197, 197, 197, 197, 197, 197, 197, 197, 198, 198, 198, 198, 198, 198, 198, 198, 199, 199, 199, 199, 199, 199, 199, 199, 192, 192, 192, 192, 192, 192, 192, 192, 193, 193, 193, 193, 193, 193, 193, 193, 194, 194, 194, 194, 194, 194, 194, 194, 195, 195, 195, 195, 195, 195, 195, 195, -1 +// CHECK: 196, 196, 196, 196, 196, 196, 196, 196, 197, 197, 197, 197, 197, 197, 197, 197, 198, 198, 198, 198, 198, 198, 198, 198, 199, 199, 199, 199, 199, 199, 199, 199, 192, 192, 192, 192, 192, 192, 192, 192, 193, 193, 193, 193, 193, 193, 193, 193, 194, 194, 194, 194, 194, 194, 194, 194, 195, 195, 195, 195, 195, 195, 195, 195, 213, 213, 213, 213, 213, 213, 213, 213, 212, 212, 212, 212, 212, 212, 212, 212, 215, 215, 215, 215, 215, 215, 215, 215, 214, 214, 214, 214, 214, 214, 214, 214, 209, 209, 209, 209, 209, 209, 209, 209, 208, 208, 208, 208, 208, 208, 208, 208, 211, 211, 211, 211, 211, 211, 211, 211, 210, 210, 210, 210, 210, 210, 210, 210, -1 +// CHECK: 213, 213, 213, 213, 213, 213, 213, 213, 212, 212, 212, 212, 212, 212, 212, 212, 215, 215, 215, 215, 215, 215, 215, 215, 214, 214, 214, 214, 214, 214, 214, 214, 209, 209, 209, 209, 209, 209, 209, 209, 208, 208, 208, 208, 208, 208, 208, 208, 211, 211, 211, 211, 211, 211, 211, 211, 210, 210, 210, 210, 210, 210, 210, 210, 230, 230, 230, 230, 230, 230, 230, 230, 231, 231, 231, 231, 231, 231, 231, 231, 228, 228, 228, 228, 228, 228, 228, 228, 229, 229, 229, 229, 229, 229, 229, 229, 226, 226, 226, 226, 226, 226, 226, 226, 227, 227, 227, 227, 227, 227, 227, 227, 224, 224, 224, 224, 224, 224, 224, 224, 225, 225, 225, 225, 225, 225, 225, 225, -1 +// CHECK: 230, 230, 230, 230, 230, 230, 230, 230, 231, 231, 231, 231, 231, 231, 231, 231, 228, 228, 228, 228, 228, 228, 228, 228... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ith 128b Swizzling The llvm#65953 added a test `128x64xf16` that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling: ``` TMA Load: Matrix-A[0:128][0:64] TMA Load: Matrix-B[0:64][0:64] TMA Load: Matrix-B[64:128][0:64] ``` The program tests the loaded data for Matrix-B.
92426af
to
2689497
Compare
…ith 128b Swizzling (llvm#65954) The llvm#65953 added a test `128x64xf16` that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling: ``` TMA Load: Matrix-A[0:128][0:64] TMA Load: Matrix-B[0:64][0:64] TMA Load: Matrix-B[64:128][0:64] ``` The program tests the loaded data for Matrix-B.
// RUN: -convert-arith-to-llvm \ | ||
// RUN: -finalize-memref-to-llvm='use-opaque-pointers=1' \ | ||
// RUN: -convert-func-to-llvm \ | ||
// RUN: -canonicalize -cse \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pipeline really seems messy: there seems like a clear abuse of "canonicalize, cse", and the overall ordering isn't clear to me (why is convert-nvvm-to-llvm
before convert-nvgpu-to-nvvm
? Why is convert-scf-to-cf
duplicated?)
Also I would think that the convert-to-llvm
should cover most of of the LLVM conversion.
, can you please clean this up?
and make this a single mlir-opt
invocation being a --pass-pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for bringing it to my attention. Indeed there were redundancies in the pass pipeline. I clean them up in #67416
The #65953 added a test
128x64xf16
that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling:The program tests the loaded data for Matrix-B.