-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][nvgpu] Improve tensormap.descriptor
Type Verifier
#77904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis PR improves the verifier for the See cuda driver for more explanation: Full diff: https://github.com/llvm/llvm-project/pull/77904.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 2888fed2779575..cc41e17a8f59cf 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -27,6 +27,10 @@ constexpr int kWarpSize = 32;
constexpr int kWgmmaSizeM = 64;
/// Maximum tensor dimension that TMA supports
constexpr int kMaxTMATensorDimension = 5;
+/// Maximum any dimension for TMA
+constexpr unsigned kMaxTMADimension = 256;
+/// Last dimension of 2D+ TMA must be 128 bytes
+constexpr unsigned kMaxTMALastdimByte = 128;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index b0a4ed1cc2697c..cbbdbb5c948e36 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -355,6 +355,22 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
if (!descMemref.hasStaticShape())
return op->emitError() << "the tensor map descriptor must be static shaped";
+ for (auto dim : descMemref.getShape()) {
+ if (dim <= 0 || dim > kMaxTMADimension) {
+ return op->emitError() << "the tensor map descriptor must not have zero "
+ "dimension";
+ }
+ }
+ if (descMemref.getRank() > 1) {
+ unsigned lastDimensionByte =
+ descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
+ if (lastDimensionByte != kMaxTMALastdimByte)
+ return op->emitError() << "the tensormap descriptor must have last "
+ "dimension of "
+ << kMaxTMALastdimByte << " bytes but it is "
+ << lastDimensionByte << " bytes";
+ }
+
// No verification if memref type is not provided
if (!memrefType.has_value())
return std::nullopt;
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..3f859a2c2be884 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -753,10 +753,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
}
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-
-!shmemlhs = memref<128x64xf16,3>
-!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
module @mymodule {
// Dynamic Shared memory
@@ -765,17 +762,17 @@ module @mymodule {
func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
%c0 = arith.constant 0 : index
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128], strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
- %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
- %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
+ %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+ %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64], strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
+ %rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
+ %rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1] : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
- nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+ nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
- nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+ nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
return
}
}
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index e1949fcfad7ad6..4c070e9a0fad35 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memr
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
return
}
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
+ // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
+ %descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
+ return
+}
+// -----
+
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index ab6483151a63f2..5f3074cad926c9 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -3,8 +3,8 @@
// RUN: -test-transform-dialect-erase-schedule \
// RUN: | FileCheck %s
-memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
// CHECK-LABEL: func.func @main()
func.func @main() {
@@ -12,26 +12,26 @@ func.func @main() {
%c128 = arith.constant 128 : index
%0 = gpu.wait async
- %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
- %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+ %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
+ %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
- // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+ // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
// CHECK: %[[c64:.*]] = arith.constant 64 : index
// CHECK: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
- // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
// CHECK: %[[c128_2:.*]] = arith.constant 128 : index
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: gpu.launch
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
- // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
- // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
- %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
- %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+ // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+ // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
+ %out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+ %out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
// CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
// CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,15 +45,15 @@ func.func @main() {
//
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
// CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
- // CHECK-SAME: : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
- // CHECK-SAME: -> memref<64x8xf32, #gpu.address_space<workgroup>>
+ // CHECK-SAME: -> memref<64x32xf32, #gpu.address_space<workgroup>>
//
// CHECK: %[[c0_8:.*]] = arith.constant 0 : index
// CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
- // CHECK-SAME: : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
- // CHECK-SAME: -> memref<8x128xf32, #gpu.address_space<workgroup>>
+ // CHECK-SAME: -> memref<8x32xf32, #gpu.address_space<workgroup>>
//
// CHECK: %[[c6144:.*]] = arith.constant 6144 : index
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
// CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
/// Both copies are matched and end up in the same async group.
- linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
- linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+ linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
+ linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
gpu.terminator
}
|
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes. See cuda driver for more explanation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
3612abb
to
51b47fa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know this part that well, seems like a good tightening of the verification.
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes. See cuda driver for more explanation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
This PR improves the verifier for the
nvgpu.tensormap.descriptor
type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7