|
57 | 57 |
|
58 | 58 | func.func private @printMemrefF32(memref<*xf32>)
|
59 | 59 |
|
60 |
| -memref.global "private" @dynamicShmem : memref<0xf16, 3> {alignment = 16 : i64} |
61 | 60 | memref.global "private" @accShmem : memref<0xf32, 3> {alignment = 16 : i64}
|
62 | 61 |
|
63 | 62 | func.func @main() {
|
@@ -148,12 +147,11 @@ func.func @main() {
|
148 | 147 | %c57344 = arith.constant 57344 : index
|
149 | 148 | %c40960 = arith.constant 40960 : index
|
150 | 149 |
|
151 |
| - %tidx = gpu.thread_id x |
152 |
| - %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> |
153 |
| - %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3> |
154 |
| - %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3> |
155 |
| - %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> |
| 150 | + %tidx = gpu.thread_id x |
156 | 151 | %dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
|
| 152 | + %lhsShmem = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x128x64xf16, #gpu.address_space<workgroup>> |
| 153 | + %rhsShmem = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x64x128xf16, #gpu.address_space<workgroup>> |
| 154 | + |
157 | 155 | // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
|
158 | 156 | %barrier = nvgpu.mbarrier.create -> !barrierType
|
159 | 157 | %cnd = arith.cmpi eq, %tidx, %c0 : index
|
@@ -202,11 +200,11 @@ func.func @main() {
|
202 | 200 | // TMA wait
|
203 | 201 | %phase_c0 = arith.constant 0 : i1
|
204 | 202 | nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
|
205 |
| - %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3> |
206 |
| - %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3> |
| 203 | + %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, #gpu.address_space<workgroup>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>> |
| 204 | + %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, #gpu.address_space<workgroup>> to memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>> |
207 | 205 | // Descriptor WGMMA
|
208 |
| - %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, 3>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>> |
209 |
| - %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, 3>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>> |
| 206 | + %dA = nvgpu.warpgroup.generate.descriptor %lhsSlice, %descA : memref<128x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>, !lhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, 3>> |
| 207 | + %dB = nvgpu.warpgroup.generate.descriptor %rhsSlice, %descB : memref<64x128xf16, strided<[128, 1], offset: ?>, #gpu.address_space<workgroup>>, !rhsTensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<64x128xf16, 3>> |
210 | 208 | // Perform WGMMA 128x128x64
|
211 | 209 | %md = nvgpu.warpgroup.mma %dA, %dB, %mc {transposeB} : <tensor = memref<128x64xf16,3>>, <tensor = memref<64x128xf16,3>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
|
212 | 210 | scf.yield %md : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
|
@@ -271,7 +269,7 @@ func.func @main() {
|
271 | 269 | vector.print str "Correct Results :"
|
272 | 270 | vector.print %correctCount : i32
|
273 | 271 | vector.print str "Incorrect Results :"
|
274 |
| - vector.print %errorCount : i32 |
| 272 | + vector.print %errorCount : i32 |
275 | 273 |
|
276 | 274 | return
|
277 | 275 | }
|
0 commit comments