-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC #161861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/skatrak/flang-generic-07-dealloc-point
Are you sure you want to change the base?
[MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC #161861
Conversation
This patch moves tablegen definitions that could be used for all kinds of heap allocations out of `omp.target_allocmem` and into a new `OpenMP_HeapAllocClause` that can be reused. Descriptions are updated to follow the format of most other operations and the custom verifier for `omp.target_allocmem` is removed as it only made a redundant check on its result type.
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) ChangesThis patch moves tablegen definitions that could be used for all kinds of heap allocations out of Descriptions are updated to follow the format of most other operations and the custom verifier for Full diff: https://github.com/llvm/llvm-project/pull/161861.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1eda5e4bc1618..3b6ecceb1dfb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -20,6 +20,7 @@
#define OPENMP_CLAUSES
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -547,6 +548,58 @@ class OpenMP_HasDeviceAddrClauseSkip<
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold heap allocation information.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HeapAllocClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let traits = [
+ MemoryEffects<[MemAlloc<DefaultResource>]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$in_type,
+ OptionalAttr<StrAttr>:$uniq_name,
+ OptionalAttr<StrAttr>:$bindc_name,
+ Variadic<IntLikeType>:$typeparams,
+ Variadic<IntLikeType>:$shape
+ );
+
+ // The custom parser doesn't parse `uniq_name` and `bindc_name`. This is
+ // handled by the attr-dict, which must be present in the operation's
+ // `assemblyFormat`.
+ let reqAssemblyFormat = [{
+ custom<HeapAllocClause>($in_type, $typeparams, type($typeparams), $shape,
+ type($shape))
+ }];
+
+ let extraClassDeclaration = [{
+ mlir::Type getAllocatedType() { return getInTypeAttr().getValue(); }
+ }];
+
+ let description = [{
+ The `in_type` is the type of the object for which memory is being allocated.
+ For arrays, this can be a static or dynamic array type.
+
+ The optional `uniq_name` is a unique name for the allocated memory.
+
+ The optional `bindc_name` is a name used for C interoperability.
+
+ The `typeparams` are runtime type parameters for polymorphic or
+ parameterized types. These are typically integer values that define aspects
+ of a type not fixed at compile time.
+
+ The `shape` holds runtime shape operands for dynamic arrays. Each operand is
+ an integer value representing the extent of a specific dimension.
+ }];
+}
+
+def OpenMP_HeapAllocClause : OpenMP_HeapAllocClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [5.4.7] `inclusive` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9003fb2ef7959..8b206f58c7733 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2128,59 +2128,45 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clause
// TargetAllocMemOp
//===----------------------------------------------------------------------===//
-def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
- [MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
+def TargetAllocMemOp : OpenMP_Op<"target_allocmem", traits = [
+ AttrSizedOperandSegments
+ ], clauses = [
+ OpenMP_HeapAllocClause
+ ]> {
let summary = "allocate storage on an openmp device for an object of a given type";
let description = [{
- Allocates memory on the specified OpenMP device for an object of the given type.
- Returns an integer value representing the device pointer to the allocated memory.
- The memory is uninitialized after allocation. Operations must be paired with
- `omp.target_freemem` to avoid memory leaks.
-
- * `$device`: The integer ID of the OpenMP device where the memory will be allocated.
- * `$in_type`: The type of the object for which memory is being allocated.
- For arrays, this can be a static or dynamic array type.
- * `$uniq_name`: An optional unique name for the allocated memory.
- * `$bindc_name`: An optional name used for C interoperability.
- * `$typeparams`: Runtime type parameters for polymorphic or parameterized types.
- These are typically integer values that define aspects of a type not fixed at compile time.
- * `$shape`: Runtime shape operands for dynamic arrays.
- Each operand is an integer value representing the extent of a specific dimension.
-
- ```mlir
- // Allocate a static 3x3 integer vector on device 0
- %device_0 = arith.constant 0 : i32
- %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
- // ... use %ptr_static ...
- omp.target_freemem %device_0, %ptr_static : i32, i64
-
- // Allocate a dynamic 2D Fortran array (fir.array) on device 1
- %device_1 = arith.constant 1 : i32
- %rows = arith.constant 10 : index
- %cols = arith.constant 20 : index
- %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
- // ... use %ptr_dynamic ...
- omp.target_freemem %device_1, %ptr_dynamic : i32, i64
- ```
- }];
+ Allocates memory on the specified OpenMP device for an object of the given
+ type. Returns an integer value representing the device pointer to the
+ allocated memory. The memory is uninitialized after allocation. Operations
+ must be paired with `omp.target_freemem` to avoid memory leaks.
- let arguments = (ins
- Arg<AnyInteger>:$device,
- TypeAttr:$in_type,
- OptionalAttr<StrAttr>:$uniq_name,
- OptionalAttr<StrAttr>:$bindc_name,
- Variadic<IntLikeType>:$typeparams,
- Variadic<IntLikeType>:$shape
- );
- let results = (outs I64);
+ ```mlir
+ // Allocate a static 3x3 integer vector on device 0
+ %device_0 = arith.constant 0 : i32
+ %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
+ // ... use %ptr_static ...
+ omp.target_freemem %device_0, %ptr_static : i32, i64
+
+ // Allocate a dynamic 2D Fortran array (fir.array) on device 1
+ %device_1 = arith.constant 1 : i32
+ %rows = arith.constant 10 : index
+ %cols = arith.constant 20 : index
+ %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
+ // ... use %ptr_dynamic ...
+ omp.target_freemem %device_1, %ptr_dynamic : i32, i64
+ ```
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
+ The `device` is an integer ID of the OpenMP device where the memory will be
+ allocated.
+ }] # clausesDescription;
- let extraClassDeclaration = [{
- mlir::Type getAllocatedType();
- }];
+ let arguments = !con((ins Arg<AnyInteger>:$device), clausesArgs);
+ let results = (outs I64);
+
+ // Override inherited assembly format to include `device`.
+ let assemblyFormat = " $device `:` type($device) `,` "
+ # clausesReqAssemblyFormat # " attr-dict";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8640c4ba0b757..fabb1b8c173a2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -797,6 +797,58 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for Heap Alloc Clause
+//===----------------------------------------------------------------------===//
+
+/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+static ParseResult parseHeapAllocClause(
+ OpAsmParser &parser, TypeAttr &inTypeAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &typeparams,
+ SmallVectorImpl<Type> &typeparamsTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &shape,
+ SmallVectorImpl<Type> &shapeTypes) {
+ mlir::Type inType;
+ if (parser.parseType(inType))
+ return mlir::failure();
+ inTypeAttr = TypeAttr::get(inType);
+
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
+ return failure();
+ }
+
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(shape, OpAsmParser::Delimiter::None))
+ return failure();
+
+ // TODO: This overrides the actual types of the operands, which might cause
+ // issues when they don't match. At the moment this is done in place of
+ // making the corresponding operand type `Variadic<Index>` because index
+ // types are lowered to I64 prior to LLVM IR translation.
+ shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
+ }
+
+ return success();
+}
+
+static void printHeapAllocClause(OpAsmPrinter &p, Operation *op,
+ TypeAttr inType, ValueRange typeparams,
+ TypeRange typeparamsTypes, ValueRange shape,
+ TypeRange shapeTypes) {
+ p << inType;
+ if (!typeparams.empty()) {
+ p << '(' << typeparams << " : " << typeparamsTypes << ')';
+ }
+ for (auto sh : shape) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -4109,107 +4161,6 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// TargetAllocMemOp
-//===----------------------------------------------------------------------===//
-
-mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
- return getInTypeAttr().getValue();
-}
-
-/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
-/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
-/// attr-dict-without-keyword
-static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
- auto &builder = parser.getBuilder();
- bool hasOperands = false;
- std::int32_t typeparamsSize = 0;
-
- // Parse device number as a new operand
- mlir::OpAsmParser::UnresolvedOperand deviceOperand;
- mlir::Type deviceType;
- if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
- return mlir::failure();
- if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
- return mlir::failure();
- if (parser.parseComma())
- return mlir::failure();
-
- mlir::Type intype;
- if (parser.parseType(intype))
- return mlir::failure();
- result.addAttribute("in_type", mlir::TypeAttr::get(intype));
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
- llvm::SmallVector<mlir::Type> typeVec;
- if (!parser.parseOptionalLParen()) {
- // parse the LEN params of the derived type. (<params> : <types>)
- if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
- parser.parseColonTypeList(typeVec) || parser.parseRParen())
- return mlir::failure();
- typeparamsSize = operands.size();
- hasOperands = true;
- }
- std::int32_t shapeSize = 0;
- if (!parser.parseOptionalComma()) {
- // parse size to scale by, vector of n dimensions of type index
- if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
- return mlir::failure();
- shapeSize = operands.size() - typeparamsSize;
- auto idxTy = builder.getIndexType();
- for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
- typeVec.push_back(idxTy);
- hasOperands = true;
- }
- if (hasOperands &&
- parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
- result.operands))
- return mlir::failure();
-
- mlir::Type restype = builder.getIntegerType(64);
- if (!restype) {
- parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
- return mlir::failure();
- }
- llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
- result.addAttribute("operandSegmentSizes",
- builder.getDenseI32ArrayAttr(segmentSizes));
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.addTypeToList(restype, result.types))
- return mlir::failure();
- return mlir::success();
-}
-
-mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
- return parseTargetAllocMemOp(parser, result);
-}
-
-void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
- p << " ";
- p.printOperand(getDevice());
- p << " : ";
- p << getDevice().getType();
- p << ", ";
- p << getInType();
- if (!getTypeparams().empty()) {
- p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
- }
- for (auto sh : getShape()) {
- p << ", ";
- p.printOperand(sh);
- }
- p.printOptionalAttrDict((*this)->getAttrs(),
- {"in_type", "operandSegmentSizes"});
-}
-
-llvm::LogicalResult omp::TargetAllocMemOp::verify() {
- mlir::Type outType = getType();
- if (!mlir::dyn_cast<IntegerType>(outType))
- return emitOpError("must be a integer type");
- return mlir::success();
-}
-
//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..0cc4b522db466 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3139,3 +3139,17 @@ func.func @invalid_workdistribute() -> () {
}
return
}
+
+// -----
+func.func @target_allocmem_invalid_uniq_name(%device : i32) -> () {
+// expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.target_allocmem %device : i32, i64 {uniq_name=2}
+ return
+}
+
+// -----
+func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
+// expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
+ %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index cbd863f88fd1f..9e7287178ff66 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3321,3 +3321,27 @@ func.func @omp_workdistribute() {
}
return
}
+
+// CHECK-LABEL: func.func @omp_target_allocmem(
+// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
+func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, i64
+ %0 = omp.target_allocmem %device : i32, i64
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"}
+ %1 = omp.target_allocmem %device : i32, vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"}
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32)
+ %2 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32)
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr, %[[X]], %[[Y]]
+ %3 = omp.target_allocmem %device : i32, !llvm.ptr, %x, %y
+ // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]]
+ %4 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y
+ return
+}
+
+// CHECK-LABEL: func.func @omp_target_freemem(
+// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
+func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+ // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
+ omp.target_freemem %device, %ptr : i32, i64
+ return
+}
|
// The custom parser doesn't parse `uniq_name` and `bindc_name`. This is | ||
// handled by the attr-dict, which must be present in the operation's | ||
// `assemblyFormat`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to not parse it in parseHeapAllocClause
?
This patch moves tablegen definitions that could be used for all kinds of heap allocations out of
omp.target_allocmem
and into a newOpenMP_HeapAllocClause
that can be reused.Descriptions are updated to follow the format of most other operations and the custom verifier for
omp.target_allocmem
is removed as it only made a redundant check on its result type.