New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpenMP][mlir] Add translation for if
in omp.teams
#69404
Conversation
This patch adds translation for `if` clause on `teams` construct in OpenMP Dialect.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Shraiysh (shraiysh) ChangesThis patch adds translation for Full diff: https://github.com/llvm/llvm-project/pull/69404.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e3dc68a1b8b7d72..025f102f0fbf809 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -666,7 +666,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
- if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
+ if (!op.getAllocatorsVars().empty() || op.getReductions())
return op.emitError("unhandled clauses for translation to LLVM IR");
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
@@ -689,9 +689,13 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
if (Value threadLimitVar = op.getThreadLimit())
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+ llvm::Value *ifExpr = nullptr;
+ if (Value ifExprVar = op.getIfExpr())
+ ifExpr = moduleTranslation.lookupValue(ifExprVar);
+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
- ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
+ ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));
return bodyGenStatus;
}
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 87ef90223ed704a..cc12dff6df1e376 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -235,3 +235,51 @@ llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUp
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
// CHECK: call void @duringTeams()
// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @teams_if
+// CHECK-SAME: (i1 [[ARG:.+]])
+llvm.func @teams_if(%arg : i1) {
+ // CHECK-NEXT: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[NUM_TEAMS_UPPER:%.+]] = select i1 [[ARG]], i32 0, i32 1
+ // CHECK: [[NUM_TEAMS_LOWER:%.+]] = select i1 [[ARG]], i32 0, i32 1
+ // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+ // CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}})
+ omp.teams if(%arg) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams()
+ llvm.call @afterTeams() : () -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @teams_if_with_num_teams
+// CHECK-SAME: (i1 [[CONDITION:.+]], i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
+llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
+ // CHECK: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[NUM_TEAMS_UPPER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_UPPER]], i32 1
+ // CHECK: [[NUM_TEAMS_LOWER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_LOWER]], i32 1
+ // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER_NEW]], i32 [[NUM_TEAMS_UPPER_NEW]], i32 [[THREAD_LIMIT]])
+ // CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}})
+ omp.teams if(%condition) num_teams(%numTeamsLower: i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams()
+ llvm.call @afterTeams() : () -> ()
+ llvm.return
+}
|
// CHECK: [[NUM_TEAMS_UPPER:%.+]] = select i1 [[ARG]], i32 0, i32 1 | ||
// CHECK: [[NUM_TEAMS_LOWER:%.+]] = select i1 [[ARG]], i32 0, i32 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it 0
if ARG
is true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zero does not mean zero teams. The runtime call interprets zero as "not-specified". So, in that case, an "implementation-defined" number of teams are created.
When the condition is true, and there is no num_teams clause, the number of teams should be whatever they would be without the if clause - which in this case falls to the implementation defined number of teams. For LLVM OpenMP library, internally, it sets number of teams to one if zero is passed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment in the test regarding this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
This patch adds translation for
if
clause onteams
construct in OpenMP Dialect.