diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index e3dc68a1b8b7d..025f102f0fbf8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -666,7 +666,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; LogicalResult bodyGenStatus = success(); - if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions()) + if (!op.getAllocatorsVars().empty() || op.getReductions()) return op.emitError("unhandled clauses for translation to LLVM IR"); auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { @@ -689,9 +689,13 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, if (Value threadLimitVar = op.getThreadLimit()) threadLimit = moduleTranslation.lookupValue(threadLimitVar); + llvm::Value *ifExpr = nullptr; + if (Value ifExprVar = op.getIfExpr()) + ifExpr = moduleTranslation.lookupValue(ifExprVar); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams( - ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit)); + ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr)); return bodyGenStatus; } diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir index 87ef90223ed70..973d2ed395eb0 100644 --- a/mlir/test/Target/LLVMIR/openmp-teams.mlir +++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir @@ -235,3 +235,54 @@ llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUp // CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}}) // CHECK: call void @duringTeams() // CHECK: ret void + +// ----- + +llvm.func @beforeTeams() +llvm.func @duringTeams() +llvm.func @afterTeams() + +// CHECK-LABEL: @teams_if +// CHECK-SAME: (i1 [[ARG:.+]]) +llvm.func @teams_if(%arg : i1) { + // CHECK-NEXT: call void @beforeTeams() + llvm.call @beforeTeams() : () -> () + // If the condition is true, then the value of bounds is zero - which basically means "implementation-defined". + // The runtime sees zero and sets a default value of number of teams. This behavior is according to the standard. + // The same is true for `thread_limit`. + // CHECK: [[NUM_TEAMS_UPPER:%.+]] = select i1 [[ARG]], i32 0, i32 1 + // CHECK: [[NUM_TEAMS_LOWER:%.+]] = select i1 [[ARG]], i32 0, i32 1 + // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0) + // CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}}) + omp.teams if(%arg) { + llvm.call @duringTeams() : () -> () + omp.terminator + } + // CHECK: call void @afterTeams() + llvm.call @afterTeams() : () -> () + llvm.return +} + +// ----- + +llvm.func @beforeTeams() +llvm.func @duringTeams() +llvm.func @afterTeams() + +// CHECK-LABEL: @teams_if_with_num_teams +// CHECK-SAME: (i1 [[CONDITION:.+]], i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]]) +llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) { + // CHECK: call void @beforeTeams() + llvm.call @beforeTeams() : () -> () + // CHECK: [[NUM_TEAMS_UPPER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_UPPER]], i32 1 + // CHECK: [[NUM_TEAMS_LOWER_NEW:%.+]] = select i1 [[CONDITION]], i32 [[NUM_TEAMS_LOWER]], i32 1 + // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.+}}, i32 {{.+}}, i32 [[NUM_TEAMS_LOWER_NEW]], i32 [[NUM_TEAMS_UPPER_NEW]], i32 [[THREAD_LIMIT]]) + // CHECK: call void {{.+}} @__kmpc_fork_teams({{.+}}) + omp.teams if(%condition) num_teams(%numTeamsLower: i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) { + llvm.call @duringTeams() : () -> () + omp.terminator + } + // CHECK: call void @afterTeams() + llvm.call @afterTeams() : () -> () + llvm.return +}