Skip to content

[mlir][nvgpu] Improve tensormap.descriptor Type Verifier #77904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 5, 2024

Conversation

grypp
Copy link
Member

@grypp grypp commented Jan 12, 2024

This PR improves the verifier for the nvgpu.tensormap.descriptor type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.

See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7

@llvmbot
Copy link
Member

llvmbot commented Jan 12, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

This PR improves the verifier for the nvgpu.tensormap.descriptor type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.

See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7


Full diff: https://github.com/llvm/llvm-project/pull/77904.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+4)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+16)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+7-10)
  • (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+20)
  • (modified) mlir/test/Dialect/NVGPU/tmaload-transform.mlir (+18-18)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 2888fed2779575..cc41e17a8f59cf 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -27,6 +27,10 @@ constexpr int kWarpSize = 32;
 constexpr int kWgmmaSizeM = 64;
 /// Maximum tensor dimension that TMA supports
 constexpr int kMaxTMATensorDimension = 5;
+/// Maximum any dimension for TMA
+constexpr unsigned kMaxTMADimension = 256;
+/// Last dimension of 2D+ TMA must be 128 bytes
+constexpr unsigned kMaxTMALastdimByte = 128;
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index b0a4ed1cc2697c..cbbdbb5c948e36 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -355,6 +355,22 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
   if (!descMemref.hasStaticShape())
     return op->emitError() << "the tensor map descriptor must be static shaped";
 
+  for (auto dim : descMemref.getShape()) {
+    if (dim <= 0 || dim > kMaxTMADimension) {
+      return op->emitError() << "the tensor map descriptor must not have zero "
+                                "dimension";
+    }
+  }
+  if (descMemref.getRank() > 1) {
+    unsigned lastDimensionByte =
+        descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
+    if (lastDimensionByte != kMaxTMALastdimByte)
+      return op->emitError() << "the tensormap descriptor must have last "
+                                "dimension of "
+                             << kMaxTMALastdimByte << " bytes but it is "
+                             << lastDimensionByte << " bytes";
+  }
+
   // No verification if memref type is not provided
   if (!memrefType.has_value())
     return std::nullopt;
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..3f859a2c2be884 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -753,10 +753,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
 }
 
 !lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-
-!shmemlhs = memref<128x64xf16,3>
-!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
 
 module @mymodule {
   // Dynamic Shared memory
@@ -765,17 +762,17 @@ module @mymodule {
   func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
     %c0 = arith.constant 0 : index
     %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
-    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
-    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
-    %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
-    %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
+    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64],  strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
+    %rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
+    %rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1]  : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
-    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
     // CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
     // CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
     // CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
-    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     return
   }
 }
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index e1949fcfad7ad6..4c070e9a0fad35 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc,  %buffer1: memref<128xf32,3>, %buffer2: memr
   nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
   return
 }
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
+  %descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
+  return
+}
+// -----
+
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc,  %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index ab6483151a63f2..5f3074cad926c9 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -3,8 +3,8 @@
 // RUN:     -test-transform-dialect-erase-schedule \
 // RUN: | FileCheck %s
 
-memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
 // CHECK-LABEL: func.func @main()
 func.func @main() {
@@ -12,26 +12,26 @@ func.func @main() {
   %c128 = arith.constant 128 : index
 
   %0 = gpu.wait async
-  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
-  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
+  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
 
-  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
   //      CHECK: %[[c64:.*]] = arith.constant 64 : index
   //      CHECK: %[[c8:.*]] = arith.constant 8 : index
   //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
-  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
   //      CHECK: %[[c8_2:.*]] = arith.constant 8 : index
   //      CHECK: %[[c128_2:.*]] = arith.constant 128 : index
   //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
   // CHECK: gpu.launch
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
             threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
-    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
-    %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
+    %out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
     //      CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,15 +45,15 @@ func.func @main() {
     //
     //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
-    // CHECK-SAME:     : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:        swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:     -> memref<64x8xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:     -> memref<64x32xf32, #gpu.address_space<workgroup>>
     //
     //      CHECK:   %[[c0_8:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
-    // CHECK-SAME:     : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:         swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:    -> memref<8x128xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:    -> memref<8x32xf32, #gpu.address_space<workgroup>>
     //
     //      CHECK:   %[[c6144:.*]] = arith.constant 6144 : index
     //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
     //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
 
     /// Both copies are matched and end up in the same async group.
-    linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
-    linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
 
     gpu.terminator
   }

grypp added 2 commits January 19, 2024 13:25
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.

See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
@grypp grypp force-pushed the verify-tma-descriptor branch from 3612abb to 51b47fa Compare January 19, 2024 12:31
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know this part that well, seems like a good tightening of the verification.

@grypp grypp merged commit f33a0a4 into llvm:main Feb 5, 2024
@grypp grypp deleted the verify-tma-descriptor branch February 5, 2024 09:32
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type.
The descriptor contains information for TMA, and the compile-time check
ensures its restrictions, such as the last memory dimension being
128-byte. This prevents runtime crashes.

See cuda driver for more explanation:

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants