From 1eccb260d6d32f2870f8056580e02dec7f7fd19f Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 12 Sep 2025 11:26:40 +0100 Subject: [PATCH] [MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC This patch moves tablegen definitions that could be used for all kinds of heap allocations out of `omp.target_allocmem` and into a new `OpenMP_HeapAllocClause` that can be reused. Descriptions are updated to follow the format of most other operations and the custom verifier for `omp.target_allocmem` is removed as it only made a redundant check on its result type. --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 ++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 80 ++++----- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 153 ++++++------------ mlir/test/Dialect/OpenMP/invalid.mlir | 14 ++ mlir/test/Dialect/OpenMP/ops.mlir | 24 +++ 5 files changed, 176 insertions(+), 148 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 1eda5e4bc1618..3b6ecceb1dfb3 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -20,6 +20,7 @@ #define OPENMP_CLAUSES include "mlir/Dialect/OpenMP/OpenMPOpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// @@ -547,6 +548,58 @@ class OpenMP_HasDeviceAddrClauseSkip< def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>; +//===----------------------------------------------------------------------===// +// Not in the spec: Clause-like structure to hold heap allocation information. +//===----------------------------------------------------------------------===// + +class OpenMP_HeapAllocClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let traits = [ + MemoryEffects<[MemAlloc]> + ]; + + let arguments = (ins + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + + // The custom parser doesn't parse `uniq_name` and `bindc_name`. This is + // handled by the attr-dict, which must be present in the operation's + // `assemblyFormat`. + let reqAssemblyFormat = [{ + custom($in_type, $typeparams, type($typeparams), $shape, + type($shape)) + }]; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType() { return getInTypeAttr().getValue(); } + }]; + + let description = [{ + The `in_type` is the type of the object for which memory is being allocated. + For arrays, this can be a static or dynamic array type. + + The optional `uniq_name` is a unique name for the allocated memory. + + The optional `bindc_name` is a name used for C interoperability. + + The `typeparams` are runtime type parameters for polymorphic or + parameterized types. These are typically integer values that define aspects + of a type not fixed at compile time. + + The `shape` holds runtime shape operands for dynamic arrays. Each operand is + an integer value representing the extent of a specific dimension. + }]; +} + +def OpenMP_HeapAllocClause : OpenMP_HeapAllocClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [5.4.7] `inclusive` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 9003fb2ef7959..8b206f58c7733 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2128,59 +2128,45 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clause // TargetAllocMemOp //===----------------------------------------------------------------------===// -def TargetAllocMemOp : OpenMP_Op<"target_allocmem", - [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { +def TargetAllocMemOp : OpenMP_Op<"target_allocmem", traits = [ + AttrSizedOperandSegments + ], clauses = [ + OpenMP_HeapAllocClause + ]> { let summary = "allocate storage on an openmp device for an object of a given type"; let description = [{ - Allocates memory on the specified OpenMP device for an object of the given type. - Returns an integer value representing the device pointer to the allocated memory. - The memory is uninitialized after allocation. Operations must be paired with - `omp.target_freemem` to avoid memory leaks. - - * `$device`: The integer ID of the OpenMP device where the memory will be allocated. - * `$in_type`: The type of the object for which memory is being allocated. - For arrays, this can be a static or dynamic array type. - * `$uniq_name`: An optional unique name for the allocated memory. - * `$bindc_name`: An optional name used for C interoperability. - * `$typeparams`: Runtime type parameters for polymorphic or parameterized types. - These are typically integer values that define aspects of a type not fixed at compile time. - * `$shape`: Runtime shape operands for dynamic arrays. - Each operand is an integer value representing the extent of a specific dimension. - - ```mlir - // Allocate a static 3x3 integer vector on device 0 - %device_0 = arith.constant 0 : i32 - %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32> - // ... use %ptr_static ... - omp.target_freemem %device_0, %ptr_static : i32, i64 - - // Allocate a dynamic 2D Fortran array (fir.array) on device 1 - %device_1 = arith.constant 1 : i32 - %rows = arith.constant 10 : index - %cols = arith.constant 20 : index - %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array, %rows, %cols : index, index - // ... use %ptr_dynamic ... - omp.target_freemem %device_1, %ptr_dynamic : i32, i64 - ``` - }]; + Allocates memory on the specified OpenMP device for an object of the given + type. Returns an integer value representing the device pointer to the + allocated memory. The memory is uninitialized after allocation. Operations + must be paired with `omp.target_freemem` to avoid memory leaks. - let arguments = (ins - Arg:$device, - TypeAttr:$in_type, - OptionalAttr:$uniq_name, - OptionalAttr:$bindc_name, - Variadic:$typeparams, - Variadic:$shape - ); - let results = (outs I64); + ```mlir + // Allocate a static 3x3 integer vector on device 0 + %device_0 = arith.constant 0 : i32 + %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32> + // ... use %ptr_static ... + omp.target_freemem %device_0, %ptr_static : i32, i64 + + // Allocate a dynamic 2D Fortran array (fir.array) on device 1 + %device_1 = arith.constant 1 : i32 + %rows = arith.constant 10 : index + %cols = arith.constant 20 : index + %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array, %rows, %cols : index, index + // ... use %ptr_dynamic ... + omp.target_freemem %device_1, %ptr_dynamic : i32, i64 + ``` - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; + The `device` is an integer ID of the OpenMP device where the memory will be + allocated. + }] # clausesDescription; - let extraClassDeclaration = [{ - mlir::Type getAllocatedType(); - }]; + let arguments = !con((ins Arg:$device), clausesArgs); + let results = (outs I64); + + // Override inherited assembly format to include `device`. + let assemblyFormat = " $device `:` type($device) `,` " + # clausesReqAssemblyFormat # " attr-dict"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 8640c4ba0b757..fabb1b8c173a2 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -797,6 +797,58 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op, p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType); } +//===----------------------------------------------------------------------===// +// Parser and printer for Heap Alloc Clause +//===----------------------------------------------------------------------===// + +/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +static ParseResult parseHeapAllocClause( + OpAsmParser &parser, TypeAttr &inTypeAttr, + SmallVectorImpl &typeparams, + SmallVectorImpl &typeparamsTypes, + SmallVectorImpl &shape, + SmallVectorImpl &shapeTypes) { + mlir::Type inType; + if (parser.parseType(inType)) + return mlir::failure(); + inTypeAttr = TypeAttr::get(inType); + + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen()) + return failure(); + } + + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(shape, OpAsmParser::Delimiter::None)) + return failure(); + + // TODO: This overrides the actual types of the operands, which might cause + // issues when they don't match. At the moment this is done in place of + // making the corresponding operand type `Variadic` because index + // types are lowered to I64 prior to LLVM IR translation. + shapeTypes.append(shape.size(), IndexType::get(parser.getContext())); + } + + return success(); +} + +static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, + TypeAttr inType, ValueRange typeparams, + TypeRange typeparamsTypes, ValueRange shape, + TypeRange shapeTypes) { + p << inType; + if (!typeparams.empty()) { + p << '(' << typeparams << " : " << typeparamsTypes << ')'; + } + for (auto sh : shape) { + p << ", "; + p.printOperand(sh); + } +} + //===----------------------------------------------------------------------===// // Parsers for operations including clauses that define entry block arguments. //===----------------------------------------------------------------------===// @@ -4109,107 +4161,6 @@ LogicalResult AllocateDirOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// TargetAllocMemOp -//===----------------------------------------------------------------------===// - -mlir::Type omp::TargetAllocMemOp::getAllocatedType() { - return getInTypeAttr().getValue(); -} - -/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, -/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? -/// attr-dict-without-keyword -static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - auto &builder = parser.getBuilder(); - bool hasOperands = false; - std::int32_t typeparamsSize = 0; - - // Parse device number as a new operand - mlir::OpAsmParser::UnresolvedOperand deviceOperand; - mlir::Type deviceType; - if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) - return mlir::failure(); - if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) - return mlir::failure(); - if (parser.parseComma()) - return mlir::failure(); - - mlir::Type intype; - if (parser.parseType(intype)) - return mlir::failure(); - result.addAttribute("in_type", mlir::TypeAttr::get(intype)); - llvm::SmallVector operands; - llvm::SmallVector typeVec; - if (!parser.parseOptionalLParen()) { - // parse the LEN params of the derived type. ( : ) - if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || - parser.parseColonTypeList(typeVec) || parser.parseRParen()) - return mlir::failure(); - typeparamsSize = operands.size(); - hasOperands = true; - } - std::int32_t shapeSize = 0; - if (!parser.parseOptionalComma()) { - // parse size to scale by, vector of n dimensions of type index - if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) - return mlir::failure(); - shapeSize = operands.size() - typeparamsSize; - auto idxTy = builder.getIndexType(); - for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) - typeVec.push_back(idxTy); - hasOperands = true; - } - if (hasOperands && - parser.resolveOperands(operands, typeVec, parser.getNameLoc(), - result.operands)) - return mlir::failure(); - - mlir::Type restype = builder.getIntegerType(64); - if (!restype) { - parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; - return mlir::failure(); - } - llvm::SmallVector segmentSizes{1, typeparamsSize, shapeSize}; - result.addAttribute("operandSegmentSizes", - builder.getDenseI32ArrayAttr(segmentSizes)); - if (parser.parseOptionalAttrDict(result.attributes) || - parser.addTypeToList(restype, result.types)) - return mlir::failure(); - return mlir::success(); -} - -mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseTargetAllocMemOp(parser, result); -} - -void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { - p << " "; - p.printOperand(getDevice()); - p << " : "; - p << getDevice().getType(); - p << ", "; - p << getInType(); - if (!getTypeparams().empty()) { - p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; - } - for (auto sh : getShape()) { - p << ", "; - p.printOperand(sh); - } - p.printOptionalAttrDict((*this)->getAttrs(), - {"in_type", "operandSegmentSizes"}); -} - -llvm::LogicalResult omp::TargetAllocMemOp::verify() { - mlir::Type outType = getType(); - if (!mlir::dyn_cast(outType)) - return emitOpError("must be a integer type"); - return mlir::success(); -} - //===----------------------------------------------------------------------===// // WorkdistributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index af24d969064ab..0cc4b522db466 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3139,3 +3139,17 @@ func.func @invalid_workdistribute() -> () { } return } + +// ----- +func.func @target_allocmem_invalid_uniq_name(%device : i32) -> () { +// expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}} + %0 = omp.target_allocmem %device : i32, i64 {uniq_name=2} + return +} + +// ----- +func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () { +// expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}} + %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2} + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index cbd863f88fd1f..9e7287178ff66 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3321,3 +3321,27 @@ func.func @omp_workdistribute() { } return } + +// CHECK-LABEL: func.func @omp_target_allocmem( +// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) { +func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) { + // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, i64 + %0 = omp.target_allocmem %device : i32, i64 + // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} + %1 = omp.target_allocmem %device : i32, vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} + // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) + %2 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32) + // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr, %[[X]], %[[Y]] + %3 = omp.target_allocmem %device : i32, !llvm.ptr, %x, %y + // CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] + %4 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y + return +} + +// CHECK-LABEL: func.func @omp_target_freemem( +// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) { +func.func @omp_target_freemem(%device : i32, %ptr : i64) { + // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64 + omp.target_freemem %device, %ptr : i32, i64 + return +}