Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][OpenMP] Added translation for omp.teams to LLVM IR #68042

Merged
merged 5 commits into from
Oct 3, 2023

Conversation

shraiysh
Copy link
Member

@shraiysh shraiysh commented Oct 2, 2023

This patch adds translation from omp.teams operation to LLVM IR using OpenMPIRBuilder. The clauses are not handled in this patch.

shraiysh and others added 2 commits October 2, 2023 16:43
This patch adds translation from `omp.teams` operation to LLVM IR using
OpenMPIRBuilder.

The clauses are not handled in this patch.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 2, 2023

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Changes

This patch adds translation from omp.teams operation to LLVM IR using OpenMPIRBuilder. The clauses are not handled in this patch.


Full diff: https://github.com/llvm/llvm-project/pull/68042.diff

2 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21)
  • (added) mlir/test/Target/LLVMIR/openmp-teams.mlir (+136)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8f7f1963b3e5a4f..b9643be40e13c01 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -661,6 +661,24 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
+static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  LogicalResult bodyGenStatus = success();
+  if(op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() || op.getThreadLimit() || !op.getAllocatorsVars().empty() || op.getReductions()) {
+    return op.emitError("unhandled clauses for translation to LLVM IR");
+  }
+  auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP){
+    LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(moduleTranslation, allocaIP);
+    builder.restoreIP(codegenIP);
+    convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder, moduleTranslation, bodyGenStatus);
+  };
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
+  return bodyGenStatus;
+}
+
 /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
@@ -2406,6 +2424,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::SingleOp op) {
         return convertOmpSingle(op, builder, moduleTranslation);
       })
+      .Case([&](omp::TeamsOp op) {
+        return convertOmpTeams(op, builder, moduleTranslation);
+      })
       .Case([&](omp::TaskOp op) {
         return convertOmpTaskOp(op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
new file mode 100644
index 000000000000000..c9005fca94a7c20
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -0,0 +1,136 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @foo()
+
+// CHECK-LABEL: @omp_teams_simple
+// CHECK: call void {{.*}} @__kmpc_fork_teams(ptr @{{.+}}, i32 0, ptr [[wrapperfn:.+]])
+// CHECK: ret void
+llvm.func @omp_teams_simple() {
+    omp.teams {
+        llvm.call @foo() : () -> ()
+        omp.terminator
+    }
+    llvm.return
+}
+
+// CHECK: define internal void @[[outlinedfn:.+]]()
+// CHECK:   call void @foo()
+// CHECK:   ret void
+// CHECK: define void [[wrapperfn]](ptr %[[global_tid:.+]], ptr %[[bound_tid:.+]])
+// CHECK:   call void @[[outlinedfn]]
+// CHECK:   ret void
+
+// -----
+
+llvm.func @foo(i32) -> ()
+
+// CHECK-LABEL: @omp_teams_shared_simple
+// CHECK-SAME: (i32 [[arg0:%.+]])
+// CHECK: [[structArg:%.+]] = alloca { i32 }
+// CHECK: br
+// CHECK: [[gep:%.+]] = getelementptr { i32 }, ptr [[structArg]], i32 0, i32 0
+// CHECK: store i32 [[arg0]], ptr [[gep]]
+// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[wrapperfn:.+]], ptr [[structArg]])
+// CHECK: ret void
+llvm.func @omp_teams_shared_simple(%arg0: i32) {
+    omp.teams {
+        llvm.call @foo(%arg0) : (i32) -> ()
+        omp.terminator
+    }
+    llvm.return
+}
+
+// CHECK: define internal void [[outlinedfn:@.+]](ptr [[structArg:%.+]])
+// CHECK:   [[gep:%.+]] = getelementptr { i32 }, ptr [[structArg]], i32 0, i32 0
+// CHECK:   [[loadgep:%.+]] = load i32, ptr [[gep]]
+// CHECK:   call void @foo(i32 [[loadgep]])
+// CHECK:   ret void
+// CHECK: define void [[wrapperfn]](ptr [[global_tid:.+]], ptr [[bound_tid:.+]], ptr [[structArg:.+]])
+// CHECK:   call void [[outlinedfn]](ptr [[structArg]])
+// CHECK:   ret void
+
+// -----
+
+llvm.func @my_alloca_fn() -> !llvm.ptr<i32>
+llvm.func @foo(i32, f32, !llvm.ptr<i32>, f128, !llvm.ptr<i32>, i32) -> ()
+llvm.func @bar()
+
+// CHECK-LABEL: @omp_teams_branching_shared
+// CHECK-SAME: (i1 [[condition:%.+]], i32 [[arg0:%.+]], float [[arg1:%.+]], ptr [[arg2:%.+]], fp128 [[arg3:%.+]])
+
+// Checking that the allocation for struct argument happens in the alloca block.
+// CHECK: [[structArg:%.+]] = alloca { i1, i32, float, ptr, fp128, ptr, i32 }
+// CHECK: [[allocated:%.+]] = call ptr @my_alloca_fn()
+// CHECK: [[loaded:%.+]] = load i32, ptr [[allocated]]
+// CHECK: br label
+
+// Checking that the shared values are stored properly in the struct arg.
+// CHECK: [[conditionPtr:%.+]] = getelementptr {{.+}}, ptr [[structArg]]
+// CHECK: store i1 [[condition]], ptr [[conditionPtr]]
+// CHECK: [[arg0ptr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 1
+// CHECK: store i32 [[arg0]], ptr [[arg0ptr]]
+// CHECK: [[arg1ptr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 2
+// CHECK: store float [[arg1]], ptr [[arg1ptr]]
+// CHECK: [[arg2ptr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 3
+// CHECK: store ptr [[arg2]], ptr [[arg2ptr]]
+// CHECK: [[arg3ptr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 4
+// CHECK: store fp128 [[arg3]], ptr [[arg3ptr]]
+// CHECK: [[allocatedPtr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 5
+// CHECK: store ptr [[allocated]], ptr [[allocatedPtr]]
+// CHECK: [[loadedPtr:%.+]] = getelementptr {{.+}}, ptr [[structArg]], i32 0, i32 6
+// CHECK: store i32 [[loaded]], ptr [[loadedPtr]]
+
+// Runtime call.
+// CHECK: call void {{.+}} @__kmpc_fork_teams(ptr @{{.+}}, i32 1, ptr [[wrapperfn:@.+]], ptr [[structArg]])
+// CHECK: br label
+// CHECK: call void @bar()
+// CHECK: ret void
+llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %arg2: !llvm.ptr<i32>, %arg3: f128) {
+    %allocated = llvm.call @my_alloca_fn(): () -> !llvm.ptr<i32>
+    %loaded = llvm.load %allocated : !llvm.ptr<i32>
+    llvm.br ^codegenBlock
+^codegenBlock:
+    omp.teams {
+        llvm.cond_br %condition, ^true_block, ^false_block
+    ^true_block:
+        llvm.call @foo(%arg0, %arg1, %arg2, %arg3, %allocated, %loaded) : (i32, f32, !llvm.ptr<i32>, f128, !llvm.ptr<i32>, i32) -> ()
+        llvm.br ^exit
+    ^false_block:
+        llvm.br ^exit
+    ^exit:
+        omp.terminator
+    }
+    llvm.call @bar() : () -> ()
+    llvm.return
+}
+
+// Check the outlined function.
+// CHECK: define internal void [[outlinedfn:@.+]](ptr [[data:%.+]])
+// CHECK:   [[conditionPtr:%.+]] = getelementptr {{.+}}, ptr [[data]]
+// CHECK:   [[condition:%.+]] = load i1, ptr [[conditionPtr]]
+// CHECK:   [[arg0ptr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 1
+// CHECK:   [[arg0:%.+]] = load i32, ptr [[arg0ptr]]
+// CHECK:   [[arg1ptr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 2
+// CHECK:   [[arg1:%.+]] = load float, ptr [[arg1ptr]]
+// CHECK:   [[arg2ptr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 3
+// CHECK:   [[arg2:%.+]] = load ptr, ptr [[arg2ptr]]
+// CHECK:   [[arg3ptr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 4
+// CHECK:   [[arg3:%.+]] = load fp128, ptr [[arg3ptr]]
+// CHECK:   [[allocatedPtr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 5
+// CHECK:   [[allocated:%.+]] = load ptr, ptr [[allocatedPtr]]
+// CHECK:   [[loadedPtr:%.+]] = getelementptr {{.+}}, ptr [[data]], i32 0, i32 6
+// CHECK:   [[loaded:%.+]] = load i32, ptr [[loadedPtr]]
+// CHECK:   br label
+
+// CHECK:   br i1 [[condition]], label %[[true:.+]], label %[[false:.+]]
+// CHECK: [[false]]:
+// CHECK-NEXT: br label
+// CHECK: [[true]]:
+// CHECK:   call void @foo(i32 [[arg0]], float [[arg1]], ptr [[arg2]], fp128 [[arg3]], ptr [[allocated]], i32 [[loaded]])
+// CHECK-NEXT: br label
+// CHECK: ret void
+
+// Check the wrapper function
+// CHECK: define void [[wrapperfn]](ptr [[globalTID:%.+]], ptr [[boundTID:%.+]], ptr [[data:%.+]])
+// CHECK:   call void [[outlinedfn]](ptr [[data]])
+// CHECK:   ret void

@github-actions
Copy link

github-actions bot commented Oct 2, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Will the wrapper function stay or be removed?

llvm.func @foo()

// CHECK-LABEL: @omp_teams_simple
// CHECK: call void {{.*}} @__kmpc_fork_teams(ptr @{{.+}}, i32 0, ptr [[wrapperfn:.+]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you convert the captured variables (eg. wrapperfn) to caps? This is to distinguish them easily from code.

@shraiysh
Copy link
Member Author

shraiysh commented Oct 3, 2023

Will the wrapper function stay or be removed?

I would like to get it removed, because that is unnecessary (I did not realize this earlier while submitting the patch). But because nobody is actively reviewing #67723, I do not want to delay progress for the construct. So, I will keep updating that PR with updates to testcases while I wait on reviews. The functions should not need any further changes there.

@shraiysh shraiysh merged commit f7d4f86 into llvm:main Oct 3, 2023
3 checks passed
@shraiysh shraiysh deleted the teams-lowering branch October 3, 2023 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants