Skip to content

Commit a579278

Browse files
[mlir][nvgpu] Fix nvgpu integration test (#154748)
Fix nvgpu mlir file integration test. This PR fixes the bug by removing memref.get_global and then using memref.view.
1 parent c535fc9 commit a579278

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757

5858
func.func private @printMemrefF32(memref<*xf32>)
5959

60-
memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
6160
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
6261

6362
func.func @main() {
@@ -148,12 +147,11 @@ func.func @main() {
148147
%c57344 = arith.constant 57344 : index
149148
%c40960 = arith.constant 40960 : index
150149

151-
%tidx = gpu.thread_id x
152-
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
153-
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
154-
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
155-
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
150+
%tidx = gpu.thread_id x
156151
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
152+
%lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>>
153+
%rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>>
154+
157155
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
158156
%barrier = nvgpu.mbarrier.create -> !barrierType
159157
%cnd = arith.cmpi eq, %tidx, %c0 : index
@@ -202,11 +200,11 @@ func.func @main() {
202200
// TMA wait
203201
%phase_c0 = arith.constant 0 : i1
204202
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
205-
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
206-
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
203+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
204+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
207205
// Descriptor WGMMA
208-
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
209-
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
206+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
207+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
210208
// Perform WGMMA 128x128x64
211209
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
212210
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
@@ -271,7 +269,7 @@ func.func @main() {
271269
vector.print str "Correct Results :"
272270
vector.print %correctCount : i32
273271
vector.print str "Incorrect Results :"
274-
vector.print %errorCount : i32
272+
vector.print %errorCount : i32
275273

276274
return
277275
}

mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757

5858
func.func private @printMemrefF32(memref<*xf32>)
5959

60-
memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64}
6160
memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
6261

6362
func.func @main() {
@@ -149,11 +148,10 @@ func.func @main() {
149148
%c40960 = arith.constant 40960 : index
150149

151150
%tidx = gpu.thread_id x
152-
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
153-
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
154-
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
155-
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
156151
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
152+
%lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>>
153+
%rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>>
154+
157155
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
158156
%barrier = nvgpu.mbarrier.create -> !barrierType
159157

@@ -210,11 +208,11 @@ func.func @main() {
210208
// TMA wait
211209
%phase_c0 = arith.constant 0 : i1
212210
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
213-
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
214-
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
211+
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
212+
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>
215213
// Descriptor WGMMA
216-
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
217-
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
214+
%dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>>
215+
%dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>>
218216
// Perform WGMMA 128x128x64
219217
%md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
220218
scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>

0 commit comments

Comments
 (0)