Skip to content

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 11, 2023

The #65953 added a test 128x64xf16 that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling:

    TMA Load: Matrix-A[0:128][0:64]
    TMA Load: Matrix-B[0:64][0:64]
    TMA Load: Matrix-B[64:128][0:64]

The program tests the loaded data for Matrix-B.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir

Changes

The #65953 added a test 128x64xf16 that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling:

    TMA Load: Matrix-A[0:128][0:64]
    TMA Load: Matrix-B[0:64][0:64]
    TMA Load: Matrix-B[64:128][0:64]

The program tests the loaded data for Matrix-B.

Patch is 64.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65954.diff

1 Files Affected:

  • (added) mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir (+245)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir
new file mode 100644
index 000000000000000..b3c55c87405c507
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload_64_64_swizzle128b.mlir
@@ -0,0 +1,245 @@
+// RUN: mlir-opt %s --convert-nvgpu-to-nvvm \
+// RUN:         -convert-linalg-to-loops \
+// RUN:         -canonicalize -cse \
+// RUN:         -gpu-kernel-outlining \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-vector-to-scf  \
+// RUN:         -canonicalize -cse \
+// RUN:         -lower-affine \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-scf-to-cf \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-nvvm-to-llvm \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-nvgpu-to-nvvm \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-scf-to-cf  \
+// RUN:         -convert-vector-to-llvm \
+// RUN:         -canonicalize -cse \
+// RUN:         -convert-math-to-llvm \
+// RUN:         -canonicalize -cse \
+// RUN:         -lower-affine \
+// RUN:         -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:         -convert-arith-to-llvm \
+// RUN:         -finalize-memref-to-llvm='use-opaque-pointers=1' \
+// RUN:         -convert-func-to-llvm \
+// RUN:         -canonicalize -cse \
+// RUN:         -expand-strided-metadata --nvvm-attach-target="module=main_kernel features=+ptx80 chip=sm_90 O=3" \
+// RUN:  | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-index-to-llvm{index-bitwidth=32},canonicalize,cse))' \
+// RUN:  | mlir-opt --gpu-to-llvm --gpu-module-to-binary -canonicalize -cse -reconcile-unrealized-casts \
+// RUN:  | mlir-translate --mlir-to-llvmir -o %t.ll 
+// RUN: clang %t.ll -O3 %mlir_cuda_runtime %mlir_runner_utils -o %t.exe
+// RUN: LD_LIBRARY_PATH=%shlibdir %t.exe | FileCheck %s
+
+// This does 3 TMA loads with 128B Swizzling : 
+//    TMA Load: Matrix-A[0:128][0:64]
+//    TMA Load: Matrix-B[0:64][0:64]
+//    TMA Load: Matrix-B[64:128][0:64]
+
+// Test swizzling with TMA load
+// 128B Swizzle Each numbered cell is 16 byte 
+// |-------------------------------|
+// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 
+// | 1 | 0 | 3 | 2 | 5 | 4 | 7 | 6 |
+// | 2 | 3 | 0 | 1 | 6 | 7 | 4 | 5 |
+// | 3 | 2 | 1 | 0 | 7 | 6 | 5 | 4 | 
+// | 4 | 5 | 6 | 7 | 0 | 1 | 2 | 3 |
+// | 5 | 4 | 7 | 6 | 1 | 0 | 3 | 2 | 
+// | 6 | 7 | 4 | 5 | 2 | 3 | 0 | 1 |
+// |-------------------------------| 
+// | ... pattern repeats ...       |
+// |-------------------------------|
+
+
+!barrierType = !nvgpu.mbarrier.barrier>
+!tokenType = !nvgpu.mbarrier.token
+
+!lhs = memref<128x64xf16>
+!shmemlhs = memref<128x64xf16, 3>
+!lhsTensorMap = !nvgpu.tensormap.descriptor
+
+!rhs = memref<128x64xf16>
+!shmemrhs = memref<128x64xf16, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor
+
+module @mymod {
+  func.func private @printMemrefF32(memref<*xf32>)
+  memref.global "private" @bufferLhsGlobal : !shmemlhs
+  memref.global "private" @bufferRhsGlobal : !shmemrhs
+  llvm.func @printf(!llvm.ptr, ...) -> i32
+  func.func @main() {
+    %c32768 = arith.constant 32768 : index
+    %c-1_i32 = arith.constant -1 : i32
+    %c10000000 = arith.constant 10000000 : index
+    %c64 = arith.constant 64 : index
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c8 = arith.constant 8 : index
+    
+    // Step 1. Allocate host data and initialize it.
+    %lhs = memref.alloc() : !lhs
+    %rhs = memref.alloc() : !rhs
+    %lhs32 = memref.alloc() : memref<128x64xf32>
+    %rhs32 = memref.alloc() : memref<64x128xf32>
+    scf.for %i = %c0 to %c64 step %c1 {
+      scf.for %j = %c0 to %c128 step %c1 {
+        %v0 = arith.muli %i, %c128 : index
+        %v00 = arith.addi %v0, %j : index     
+        %v01 = arith.divui %v00, %c8 : index 
+        %v2 = arith.index_cast %v01 : index to i32
+        %vR = arith.sitofp %v2 : i32 to f16      
+        memref.store %vR, %rhs[%i, %j] : !rhs      
+        %vR32 = arith.extf %vR : f16 to f32      
+        memref.store %vR32, %rhs32[%i, %j] : memref<64x128xf32>      
+      }
+    }
+    scf.for %j = %c0 to %c128 step %c1 {
+      scf.for %i = %c0 to %c64 step %c1 {    
+        %b0 = arith.muli %j, %c64 : index
+        %b00 = arith.addi %b0, %i : index
+        %b01 = arith.divui %b00, %c8 : index
+        %v1 = arith.index_cast %b01 : index to i32
+        %vL = arith.sitofp %v1 : i32 to f16
+        memref.store %vL, %lhs[%j, %i] : !lhs
+        %vL32 = arith.extf %vL : f16 to f32
+        memref.store %vL32, %lhs32[%j, %i] : memref<128x64xf32>
+      }
+    }
+
+    // Step 2. Print on the host
+    %lhs32_unranked = memref.cast %lhs32 : memref<128x64xf32> to memref<*xf32>
+    call @printMemrefF32(%lhs32_unranked) : (memref<*xf32>) -> ()
+    %rhs32_unranked = memref.cast %rhs32 : memref<64x128xf32> to memref<*xf32>
+    call @printMemrefF32(%rhs32_unranked) : (memref<*xf32>) -> ()
+    
+    // Step 3. Copy host to device
+    %0 = gpu.wait async
+    %d_glbmem_lhs, %asyncToken = gpu.alloc async [%0] () : !lhs
+    %d_glbmem_rhs, %asyncToken_2 = gpu.alloc async [%0] () : !rhs
+    %1 = gpu.memcpy async [%0] %d_glbmem_lhs, %lhs : !lhs, !lhs
+    %2 = gpu.memcpy async [%0] %d_glbmem_rhs, %rhs : !rhs, !rhs
+    
+    // Step 4. Create TMA tensor descriptor
+    %d_lhs_unranked = memref.cast %d_glbmem_lhs :!lhs  to memref<*xf16>
+    %d_rhs_unranked = memref.cast %d_glbmem_rhs :!rhs  to memref<*xf16>
+
+    %d_lhsTensorMap = nvgpu.tma.create.descriptor %d_lhs_unranked box[%c128, %c64] : memref<*xf16> -> !lhsTensorMap
+    %d_rhsTensorMap = nvgpu.tma.create.descriptor %d_rhs_unranked box[%c64, %c64] : memref<*xf16> -> !rhsTensorMap
+
+    // Step 5. Launch a GPU kernel
+    gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c128, %arg10 = %c1, %arg11 = %c1) {
+      %5 = gpu.block_dim  x
+      %6 = gpu.thread_id  x
+      %lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs
+      %rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs
+      %rhsShmem2 = memref.subview %rhsShmem[%c32, %c0][%c32, %c128][%c1, %c1] : !shmemrhs to memref, 3>
+    
+      // Step 6. Initialize the mbarrier
+      %9 = nvgpu.mbarrier.create -> !barrierType
+      nvgpu.mbarrier.init %9, %5 : !barrierType
+      %10 = arith.cmpi eq, %6, %c0 : index
+      
+      
+      // Step 7. First thread does TMA load
+      scf.if %10 {
+        gpu.printf "[GPU] TMA SIZE %d\0A" %c32768 : index
+        nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9 to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+        nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9 to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+        nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9 to %rhsShmem2 : !rhsTensorMap, !barrierType -> memref, 3>
+        nvgpu.mbarrier.arrive.expect_tx %9, %c32768 : !barrierType
+      } else {
+        nvgpu.mbarrier.arrive.expect_tx %9, %c0 : !barrierType
+      }
+
+      // Step 8. Wait until TMA is done
+      nvgpu.mbarrier.try_wait.parity %9, %c0, %c10000000 : !barrierType
+
+      // Step 9. Print loaded data in 128b swizzled
+      scf.if %10 {        
+        gpu.printf "===--- Matrix B ---=== %d \n" %c-1_i32 : i32
+        scf.for %ii = %c0 to %c64 step %c1 {
+          scf.for %j = %c0 to %c128 step %c1 {
+            %lhs0 = memref.load %rhsShmem[%ii, %j] : !shmemrhs
+            %lhs032 = arith.extf %lhs0: f16 to f32
+            gpu.printf "%.0f,   " %lhs032 : f32
+          }
+          gpu.printf "%d\n" %c-1_i32 : i32
+        }
+        gpu.printf "===----------------=== %d \n" %c-1_i32 : i32
+      }
+      gpu.terminator
+    }
+    return
+  }
+}
+
+
+// CHECK: [GPU] TMA SIZE 32768
+// CHECK: ===--- Matrix B ---=== -1 
+// CHECK: 0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   1,   1,   1,   1,   1,   1,   2,   2,   2,   2,   2,   2,   2,   2,   3,   3,   3,   3,   3,   3,   3,   3,   4,   4,   4,   4,   4,   4,   4,   4,   5,   5,   5,   5,   5,   5,   5,   5,   6,   6,   6,   6,   6,   6,   6,   6,   7,   7,   7,   7,   7,   7,   7,   7,   17,   17,   17,   17,   17,   17,   17,   17,   16,   16,   16,   16,   16,   16,   16,   16,   19,   19,   19,   19,   19,   19,   19,   19,   18,   18,   18,   18,   18,   18,   18,   18,   21,   21,   21,   21,   21,   21,   21,   21,   20,   20,   20,   20,   20,   20,   20,   20,   23,   23,   23,   23,   23,   23,   23,   23,   22,   22,   22,   22,   22,   22,   22,   22,   -1
+// CHECK: 17,   17,   17,   17,   17,   17,   17,   17,   16,   16,   16,   16,   16,   16,   16,   16,   19,   19,   19,   19,   19,   19,   19,   19,   18,   18,   18,   18,   18,   18,   18,   18,   21,   21,   21,   21,   21,   21,   21,   21,   20,   20,   20,   20,   20,   20,   20,   20,   23,   23,   23,   23,   23,   23,   23,   23,   22,   22,   22,   22,   22,   22,   22,   22,   34,   34,   34,   34,   34,   34,   34,   34,   35,   35,   35,   35,   35,   35,   35,   35,   32,   32,   32,   32,   32,   32,   32,   32,   33,   33,   33,   33,   33,   33,   33,   33,   38,   38,   38,   38,   38,   38,   38,   38,   39,   39,   39,   39,   39,   39,   39,   39,   36,   36,   36,   36,   36,   36,   36,   36,   37,   37,   37,   37,   37,   37,   37,   37,   -1
+// CHECK: 34,   34,   34,   34,   34,   34,   34,   34,   35,   35,   35,   35,   35,   35,   35,   35,   32,   32,   32,   32,   32,   32,   32,   32,   33,   33,   33,   33,   33,   33,   33,   33,   38,   38,   38,   38,   38,   38,   38,   38,   39,   39,   39,   39,   39,   39,   39,   39,   36,   36,   36,   36,   36,   36,   36,   36,   37,   37,   37,   37,   37,   37,   37,   37,   51,   51,   51,   51,   51,   51,   51,   51,   50,   50,   50,   50,   50,   50,   50,   50,   49,   49,   49,   49,   49,   49,   49,   49,   48,   48,   48,   48,   48,   48,   48,   48,   55,   55,   55,   55,   55,   55,   55,   55,   54,   54,   54,   54,   54,   54,   54,   54,   53,   53,   53,   53,   53,   53,   53,   53,   52,   52,   52,   52,   52,   52,   52,   52,   -1
+// CHECK: 51,   51,   51,   51,   51,   51,   51,   51,   50,   50,   50,   50,   50,   50,   50,   50,   49,   49,   49,   49,   49,   49,   49,   49,   48,   48,   48,   48,   48,   48,   48,   48,   55,   55,   55,   55,   55,   55,   55,   55,   54,   54,   54,   54,   54,   54,   54,   54,   53,   53,   53,   53,   53,   53,   53,   53,   52,   52,   52,   52,   52,   52,   52,   52,   68,   68,   68,   68,   68,   68,   68,   68,   69,   69,   69,   69,   69,   69,   69,   69,   70,   70,   70,   70,   70,   70,   70,   70,   71,   71,   71,   71,   71,   71,   71,   71,   64,   64,   64,   64,   64,   64,   64,   64,   65,   65,   65,   65,   65,   65,   65,   65,   66,   66,   66,   66,   66,   66,   66,   66,   67,   67,   67,   67,   67,   67,   67,   67,   -1
+// CHECK: 68,   68,   68,   68,   68,   68,   68,   68,   69,   69,   69,   69,   69,   69,   69,   69,   70,   70,   70,   70,   70,   70,   70,   70,   71,   71,   71,   71,   71,   71,   71,   71,   64,   64,   64,   64,   64,   64,   64,   64,   65,   65,   65,   65,   65,   65,   65,   65,   66,   66,   66,   66,   66,   66,   66,   66,   67,   67,   67,   67,   67,   67,   67,   67,   85,   85,   85,   85,   85,   85,   85,   85,   84,   84,   84,   84,   84,   84,   84,   84,   87,   87,   87,   87,   87,   87,   87,   87,   86,   86,   86,   86,   86,   86,   86,   86,   81,   81,   81,   81,   81,   81,   81,   81,   80,   80,   80,   80,   80,   80,   80,   80,   83,   83,   83,   83,   83,   83,   83,   83,   82,   82,   82,   82,   82,   82,   82,   82,   -1
+// CHECK: 85,   85,   85,   85,   85,   85,   85,   85,   84,   84,   84,   84,   84,   84,   84,   84,   87,   87,   87,   87,   87,   87,   87,   87,   86,   86,   86,   86,   86,   86,   86,   86,   81,   81,   81,   81,   81,   81,   81,   81,   80,   80,   80,   80,   80,   80,   80,   80,   83,   83,   83,   83,   83,   83,   83,   83,   82,   82,   82,   82,   82,   82,   82,   82,   102,   102,   102,   102,   102,   102,   102,   102,   103,   103,   103,   103,   103,   103,   103,   103,   100,   100,   100,   100,   100,   100,   100,   100,   101,   101,   101,   101,   101,   101,   101,   101,   98,   98,   98,   98,   98,   98,   98,   98,   99,   99,   99,   99,   99,   99,   99,   99,   96,   96,   96,   96,   96,   96,   96,   96,   97,   97,   97,   97,   97,   97,   97,   97,   -1
+// CHECK: 102,   102,   102,   102,   102,   102,   102,   102,   103,   103,   103,   103,   103,   103,   103,   103,   100,   100,   100,   100,   100,   100,   100,   100,   101,   101,   101,   101,   101,   101,   101,   101,   98,   98,   98,   98,   98,   98,   98,   98,   99,   99,   99,   99,   99,   99,   99,   99,   96,   96,   96,   96,   96,   96,   96,   96,   97,   97,   97,   97,   97,   97,   97,   97,   119,   119,   119,   119,   119,   119,   119,   119,   118,   118,   118,   118,   118,   118,   118,   118,   117,   117,   117,   117,   117,   117,   117,   117,   116,   116,   116,   116,   116,   116,   116,   116,   115,   115,   115,   115,   115,   115,   115,   115,   114,   114,   114,   114,   114,   114,   114,   114,   113,   113,   113,   113,   113,   113,   113,   113,   112,   112,   112,   112,   112,   112,   112,   112,   -1
+// CHECK: 119,   119,   119,   119,   119,   119,   119,   119,   118,   118,   118,   118,   118,   118,   118,   118,   117,   117,   117,   117,   117,   117,   117,   117,   116,   116,   116,   116,   116,   116,   116,   116,   115,   115,   115,   115,   115,   115,   115,   115,   114,   114,   114,   114,   114,   114,   114,   114,   113,   113,   113,   113,   113,   113,   113,   113,   112,   112,   112,   112,   112,   112,   112,   112,   128,   128,   128,   128,   128,   128,   128,   128,   129,   129,   129,   129,   129,   129,   129,   129,   130,   130,   130,   130,   130,   130,   130,   130,   131,   131,   131,   131,   131,   131,   131,   131,   132,   132,   132,   132,   132,   132,   132,   132,   133,   133,   133,   133,   133,   133,   133,   133,   134,   134,   134,   134,   134,   134,   134,   134,   135,   135,   135,   135,   135,   135,   135,   135,   -1
+// CHECK: 128,   128,   128,   128,   128,   128,   128,   128,   129,   129,   129,   129,   129,   129,   129,   129,   130,   130,   130,   130,   130,   130,   130,   130,   131,   131,   131,   131,   131,   131,   131,   131,   132,   132,   132,   132,   132,   132,   132,   132,   133,   133,   133,   133,   133,   133,   133,   133,   134,   134,   134,   134,   134,   134,   134,   134,   135,   135,   135,   135,   135,   135,   135,   135,   145,   145,   145,   145,   145,   145,   145,   145,   144,   144,   144,   144,   144,   144,   144,   144,   147,   147,   147,   147,   147,   147,   147,   147,   146,   146,   146,   146,   146,   146,   146,   146,   149,   149,   149,   149,   149,   149,   149,   149,   148,   148,   148,   148,   148,   148,   148,   148,   151,   151,   151,   151,   151,   151,   151,   151,   150,   150,   150,   150,   150,   150,   150,   150,   -1
+// CHECK: 145,   145,   145,   145,   145,   145,   145,   145,   144,   144,   144,   144,   144,   144,   144,   144,   147,   147,   147,   147,   147,   147,   147,   147,   146,   146,   146,   146,   146,   146,   146,   146,   149,   149,   149,   149,   149,   149,   149,   149,   148,   148,   148,   148,   148,   148,   148,   148,   151,   151,   151,   151,   151,   151,   151,   151,   150,   150,   150,   150,   150,   150,   150,   150,   162,   162,   162,   162,   162,   162,   162,   162,   163,   163,   163,   163,   163,   163,   163,   163,   160,   160,   160,   160,   160,   160,   160,   160,   161,   161,   161,   161,   161,   161,   161,   161,   166,   166,   166,   166,   166,   166,   166,   166,   167,   167,   167,   167,   167,   167,   167,   167,   164,   164,   164,   164,   164,   164,   164,   164,   165,   165,   165,   165,   165,   165,   165,   165,   -1
+// CHECK: 162,   162,   162,   162,   162,   162,   162,   162,   163,   163,   163,   163,   163,   163,   163,   163,   160,   160,   160,   160,   160,   160,   160,   160,   161,   161,   161,   161,   161,   161,   161,   161,   166,   166,   166,   166,   166,   166,   166,   166,   167,   167,   167,   167,   167,   167,   167,   167,   164,   164,   164,   164,   164,   164,   164,   164,   165,   165,   165,   165,   165,   165,   165,   165,   179,   179,   179,   179,   179,   179,   179,   179,   178,   178,   178,   178,   178,   178,   178,   178,   177,   177,   177,   177,   177,   177,   177,   177,   176,   176,   176,   176,   176,   176,   176,   176,   183,   183,   183,   183,   183,   183,   183,   183,   182,   182,   182,   182,   182,   182,   182,   182,   181,   181,   181,   181,   181,   181,   181,   181,   180,   180,   180,   180,   180,   180,   180,   180,   -1
+// CHECK: 179,   179,   179,   179,   179,   179,   179,   179,   178,   178,   178,   178,   178,   178,   178,   178,   177,   177,   177,   177,   177,   177,   177,   177,   176,   176,   176,   176,   176,   176,   176,   176,   183,   183,   183,   183,   183,   183,   183,   183,   182,   182,   182,   182,   182,   182,   182,   182,   181,   181,   181,   181,   181,   181,   181,   181,   180,   180,   180,   180,   180,   180,   180,   180,   196,   196,   196,   196,   196,   196,   196,   196,   197,   197,   197,   197,   197,   197,   197,   197,   198,   198,   198,   198,   198,   198,   198,   198,   199,   199,   199,   199,   199,   199,   199,   199,   192,   192,   192,   192,   192,   192,   192,   192,   193,   193,   193,   193,   193,   193,   193,   193,   194,   194,   194,   194,   194,   194,   194,   194,   195,   195,   195,   195,   195,   195,   195,   195,   -1
+// CHECK: 196,   196,   196,   196,   196,   196,   196,   196,   197,   197,   197,   197,   197,   197,   197,   197,   198,   198,   198,   198,   198,   198,   198,   198,   199,   199,   199,   199,   199,   199,   199,   199,   192,   192,   192,   192,   192,   192,   192,   192,   193,   193,   193,   193,   193,   193,   193,   193,   194,   194,   194,   194,   194,   194,   194,   194,   195,   195,   195,   195,   195,   195,   195,   195,   213,   213,   213,   213,   213,   213,   213,   213,   212,   212,   212,   212,   212,   212,   212,   212,   215,   215,   215,   215,   215,   215,   215,   215,   214,   214,   214,   214,   214,   214,   214,   214,   209,   209,   209,   209,   209,   209,   209,   209,   208,   208,   208,   208,   208,   208,   208,   208,   211,   211,   211,   211,   211,   211,   211,   211,   210,   210,   210,   210,   210,   210,   210,   210,   -1
+// CHECK: 213,   213,   213,   213,   213,   213,   213,   213,   212,   212,   212,   212,   212,   212,   212,   212,   215,   215,   215,   215,   215,   215,   215,   215,   214,   214,   214,   214,   214,   214,   214,   214,   209,   209,   209,   209,   209,   209,   209,   209,   208,   208,   208,   208,   208,   208,   208,   208,   211,   211,   211,   211,   211,   211,   211,   211,   210,   210,   210,   210,   210,   210,   210,   210,   230,   230,   230,   230,   230,   230,   230,   230,   231,   231,   231,   231,   231,   231,   231,   231,   228,   228,   228,   228,   228,   228,   228,   228,   229,   229,   229,   229,   229,   229,   229,   229,   226,   226,   226,   226,   226,   226,   226,   226,   227,   227,   227,   227,   227,   227,   227,   227,   224,   224,   224,   224,   224,   224,   224,   224,   225,   225,   225,   225,   225,   225,   225,   225,   -1
+// CHECK: 230,   230,   230,   230,   230,   230,   230,   230,   231,   231,   231,   231,   231,   231,   231,   231,   228,   228,   228,   228,   228,   228,   228,   228...

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…ith 128b Swizzling

The llvm#65953 added a test `128x64xf16` that does a single TMA load. This PR adds more complex test that does 2 additional TMA loads with 128B Swizzling:
```
    TMA Load: Matrix-A[0:128][0:64]
    TMA Load: Matrix-B[0:64][0:64]
    TMA Load: Matrix-B[64:128][0:64]
```

The program tests the loaded data for Matrix-B.
@grypp grypp force-pushed the ex_tma2_swizzle_128 branch from 92426af to 2689497 Compare September 15, 2023 07:30
@grypp grypp merged commit 9420fc4 into llvm:main Sep 15, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…ith 128b Swizzling (llvm#65954)

The llvm#65953 added a test `128x64xf16` that does a single TMA load. This
PR adds more complex test that does 2 additional TMA loads with 128B
Swizzling:
```
    TMA Load: Matrix-A[0:128][0:64]
    TMA Load: Matrix-B[0:64][0:64]
    TMA Load: Matrix-B[64:128][0:64]
```

The program tests the loaded data for Matrix-B.
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm='use-opaque-pointers=1' \
// RUN: -convert-func-to-llvm \
// RUN: -canonicalize -cse \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pipeline really seems messy: there seems like a clear abuse of "canonicalize, cse", and the overall ordering isn't clear to me (why is convert-nvvm-to-llvm before convert-nvgpu-to-nvvm? Why is convert-scf-to-cf duplicated?)

Also I would think that the convert-to-llvm should cover most of of the LLVM conversion.

, can you please clean this up?

and make this a single mlir-opt invocation being a --pass-pipeline

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing it to my attention. Indeed there were redundancies in the pass pipeline. I clean them up in #67416

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants