-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][cuda][NFC] Use NVVM barrier op with reduction #167940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesSimplify the lowering by using the barrier op from NVVM updated in #167036 Full diff: https://github.com/llvm/llvm-project/pull/167940.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
index 323d1ef78e65d..f67129dfa6730 100644
--- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -1080,42 +1080,39 @@ void CUDAIntrinsicLibrary::genSyncThreads(
mlir::Value
CUDAIntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
- mlir::MLIRContext *context = builder.getContext();
- mlir::Type i32 = builder.getI32Type();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {i32});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- mlir::Value arg = builder.createConvert(loc, i32, args[0]);
- return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::AND),
+ arg)
+ .getResult(0);
}
// SYNCTHREADS_COUNT
mlir::Value
CUDAIntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
- mlir::MLIRContext *context = builder.getContext();
- mlir::Type i32 = builder.getI32Type();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {i32});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- mlir::Value arg = builder.createConvert(loc, i32, args[0]);
- return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::POPC),
+ arg)
+ .getResult(0);
}
// SYNCTHREADS_OR
mlir::Value
CUDAIntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
- mlir::MLIRContext *context = builder.getContext();
- mlir::Type i32 = builder.getI32Type();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {i32});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- mlir::Value arg = builder.createConvert(loc, i32, args[0]);
- return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0);
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::OR),
+ arg)
+ .getResult(0);
}
// SYNCWARP
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 3a255afd59263..ef15bf8d7726d 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -103,24 +103,24 @@ end
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: nvvm.barrier0
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = nvvm.barrier <and> %c1{{.*}} -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CONV]])
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = nvvm.barrier <and> %[[CONV]] -> i32
+! CHECK: %{{.*}} = nvvm.barrier <popc> %c1{{.*}} -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CONV]]) fastmath<contract> : (i32) -> i32
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = nvvm.barrier <popc> %[[CONV]] -> i32
+! CHECK: %{{.*}} = nvvm.barrier <or> %c1{{.*}} -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32
-! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CONV]]) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = nvvm.barrier <or> %[[CONV]] -> i32
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
@@ -214,9 +214,9 @@ end
! CHECK: cuf.kernel
! CHECK: nvvm.barrier0
! CHECK: nvvm.bar.warp.sync %c1{{.*}} : i32
-! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
-! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
-! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+! CHECK: nvvm.barrier <and> %c1{{.*}} -> i32
+! CHECK: nvvm.barrier <popc> %c1{{.*}} -> i32
+! CHECK: nvvm.barrier <or> %c1{{.*}} -> i32
attributes(device) subroutine testMatch()
integer :: a, ipred, mask, v32
|
razvanlupusoru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvement! Thank you!
vzakhari
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Simplify the lowering by using the barrier op from NVVM updated in #167036