Skip to content

Commit 3ae5f27

Browse files
authored
[ROCDL] Added LDS barrier ops to ROCDL (gfx1250) (#171810)
Added `ds.atomic.barrier.arrive.rtn.b64` and `ds.atomic.async.barrier.arrive.b64` to ROCDL. These are parts of the LDS memory barrier concept in GFX1250. Also added alias analysis to `global/flat` data prefetch ops. Extended rocdl tests.
1 parent 7a43921 commit 3ae5f27

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,25 +1192,81 @@ def ROCDL_RawBufferAtomicCmpSwap :
11921192
// Memory prefetch intrinsics
11931193

11941194
def ROCDL_GlobalPrefetchOp :
1195-
ROCDL_IntrOp<"global.prefetch", [], [], [], 0, 0, 0, 0, [1], ["scope"]>,
1196-
Arguments<(ins Arg<LLVM_PointerInAddressSpace<1>, "", []>:$ptr, I32Attr:$scope)> {
1195+
ROCDL_IntrOp<"global.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
1196+
dag args = (ins Arg<LLVM_PointerInAddressSpace<1>, "", [MemRead]>:$ptr,
1197+
I32Attr:$scope);
1198+
let arguments = !con(args, baseArgs);
11971199
let description = [{
11981200
Prefetches 1 byte of data per lane from global memory into the WGP-cache or L2-cache.
11991201
Available on gfx1250+.
12001202
}];
12011203
let results = (outs);
12021204
let assemblyFormat = "$ptr `,` `scope` $scope attr-dict `:` qualified(type($ptr))";
1205+
let extraClassDefinition = [{
1206+
SmallVector<Value> $cppClass::getAccessedOperands() {
1207+
return {getPtr()};
1208+
}
1209+
}];
12031210
}
12041211

12051212
def ROCDL_FlatPrefetchOp :
1206-
ROCDL_IntrOp<"flat.prefetch", [], [], [], 0, 0, 0, 0, [1], ["scope"]>,
1207-
Arguments<(ins Arg<LLVM_PointerInAddressSpace<0>, "", []>:$ptr, I32Attr:$scope)> {
1213+
ROCDL_IntrOp<"flat.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
1214+
dag args = (ins Arg<LLVM_PointerInAddressSpace<0>, "", [MemRead]>:$ptr,
1215+
I32Attr:$scope);
1216+
let arguments = !con(args, baseArgs);
12081217
let description = [{
12091218
Prefetches 1 byte of data per lane using flat-memory addresses into the WGP-cache or L2-cache.
12101219
Available on gfx1250+.
12111220
}];
12121221
let results = (outs);
12131222
let assemblyFormat = "$ptr `,` `scope` $scope attr-dict `:` qualified(type($ptr))";
1223+
let extraClassDefinition = [{
1224+
SmallVector<Value> $cppClass::getAccessedOperands() {
1225+
return {getPtr()};
1226+
}
1227+
}];
1228+
}
1229+
1230+
//===---------------------------------------------------------------------===//
1231+
// Atomic barrier intrinsic (LDS memory barriers).
1232+
1233+
def ROCDL_DsAtomicBarrierArriveRtnOp :
1234+
ROCDL_IntrOp<"ds.atomic.barrier.arrive.rtn.b64", [], [], [], 1, 0, 1, 0, [], []> {
1235+
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead, MemWrite]>:$barrierPtr,
1236+
I64:$val);
1237+
let arguments = !con(args, baseArgs);
1238+
let description = [{
1239+
Waits on a given DS barrier and decrements its pending count by a given value. Note, the barrier state
1240+
is given as a 64-bit structure containing pending count, phase and init count. The op returns the old
1241+
barrier state. The op is executed as an ordinary LDS operations and it is ordered with other LDS operations.
1242+
Thus, check DSCNT to determine when this instruction has executed.
1243+
Available on gfx1250+.
1244+
}];
1245+
let results = (outs I64:$res);
1246+
let assemblyFormat = "$barrierPtr `,` $val attr-dict `:` qualified(type($barrierPtr)) `,` type($val) `->` type($res)";
1247+
let extraClassDefinition = [{
1248+
SmallVector<Value> $cppClass::getAccessedOperands() {
1249+
return {getBarrierPtr()};
1250+
}
1251+
}];
1252+
}
1253+
1254+
def ROCDL_DsAtomicAsyncBarrierArriveOp :
1255+
ROCDL_IntrOp<"ds.atomic.async.barrier.arrive.b64", [], [], [], 0, 0, 1, 0, [], []> {
1256+
dag args = (ins Arg<ROCDLBufferLDS, "", [MemWrite]>:$barrierPtr);
1257+
let arguments = !con(args, baseArgs);
1258+
let description = [{
1259+
Waits on a given DS barrier and decrements pending count by -1.
1260+
Stays in order with ASYNC loads to LDS, and uses ASYNCcnt to track its completion.
1261+
Available on gfx1250+.
1262+
}];
1263+
let results = (outs);
1264+
let assemblyFormat = "$barrierPtr attr-dict `:` qualified(type($barrierPtr))";
1265+
let extraClassDefinition = [{
1266+
SmallVector<Value> $cppClass::getAccessedOperands() {
1267+
return {getBarrierPtr()};
1268+
}
1269+
}];
12141270
}
12151271

12161272
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,15 @@ llvm.func @rocdl.flat.prefetch(%ptr : !llvm.ptr) {
966966
llvm.return
967967
}
968968

969+
llvm.func @rocdl.atomic.barriers.arrive(%ptr : !llvm.ptr<3>, %val : i64) {
970+
// CHECK-LABEL: rocdl.atomic.barriers.arrive
971+
// CHECK: rocdl.ds.atomic.async.barrier.arrive.b64 %{{.*}} : !llvm.ptr<3>
972+
// CHECK: %{{.*}} = rocdl.ds.atomic.barrier.arrive.rtn.b64 %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i64
973+
rocdl.ds.atomic.async.barrier.arrive.b64 %ptr : !llvm.ptr<3>
974+
%res = rocdl.ds.atomic.barrier.arrive.rtn.b64 %ptr, %val : !llvm.ptr<3>, i64 -> i64
975+
llvm.return
976+
}
977+
969978
// -----
970979

971980
llvm.func @rocdl.raw.buffer.f32(%rsrc : vector<4xi32>,

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,15 @@ llvm.func @rocdl.flat.prefetch(%ptr : !llvm.ptr) {
14281428
llvm.return
14291429
}
14301430

1431+
llvm.func @rocdl.atomic.barriers.arrive(%ptr : !llvm.ptr<3>, %val : i64) {
1432+
// CHECK-LABEL: rocdl.atomic.barriers.arrive
1433+
// CHECK: call void @llvm.amdgcn.ds.atomic.async.barrier.arrive.b64(ptr addrspace(3) %{{.*}})
1434+
// CHECK: %{{.*}} = call i64 @llvm.amdgcn.ds.atomic.barrier.arrive.rtn.b64(ptr addrspace(3) %{{.*}}, i64 %{{.*}})
1435+
rocdl.ds.atomic.async.barrier.arrive.b64 %ptr : !llvm.ptr<3>
1436+
%res = rocdl.ds.atomic.barrier.arrive.rtn.b64 %ptr, %val : !llvm.ptr<3>, i64 -> i64
1437+
llvm.return
1438+
}
1439+
14311440
llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>,
14321441
%arg3: vector<12xi32>, %arg5: vector<16xi32>,
14331442
%arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> {

0 commit comments

Comments
 (0)