diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index 376b8b6e1bb681..d5d46b181bc65d 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -194,8 +194,8 @@ objects, the creation and manipulation of LLVM dialect types is thread-safe. MLIR does not support module-scoped named type declarations, e.g. `%s = type {i32, i32}` in LLVM IR. Instead, types must be fully specified at each use, except for recursive types where only the first reference to a named type needs -to be fully specified. MLIR [type aliases](../LangRef.md/#type-aliases) can be used -to achieve more compact syntax. +to be fully specified. MLIR [type aliases](../LangRef.md/#type-aliases) can be +used to achieve more compact syntax. The general syntax of LLVM dialect types is `!llvm.`, followed by a type kind identifier (e.g., `ptr` for pointer or `struct` for structure) and by an @@ -257,17 +257,24 @@ the element type, which can be either compatible built-in or LLVM dialect types. Pointer types specify an address in memory. -Pointer types are parametric types parameterized by the element type and the -address space. The address space is an integer, but this choice may be -reconsidered if MLIR implements named address spaces. Their syntax is as -follows: +Both opaque and type-parameterized pointer types are supported. +[Opaque pointers](https://llvm.org/docs/OpaquePointers.html) do not indicate the +type of the data pointed to, and are intended to simplify LLVM IR by encoding +behavior relevant to the pointee type into operations rather than into types. +Non-opaque pointer types carry the pointee type as a type parameter. Both kinds +of pointers may be additionally parameterized by an address space. The address +space is an integer, but this choice may be reconsidered if MLIR implements +named address spaces. The syntax of pointer types is as follows: ``` - llvm-ptr-type ::= `!llvm.ptr<` type (`,` integer-literal)? `>` + llvm-ptr-type ::= `!llvm.ptr` (`<` integer-literal `>`)? + | `!llvm.ptr<` type (`,` integer-literal)? `>` ``` -where the optional integer literal corresponds to the memory space. Both cases -are represented by `LLVMPointerType` internally. +where the former case is the opaque pointer type and the latter case is the +non-opaque pointer type; the optional group containing the integer literal +corresponds to the memory space. All cases are represented by `LLVMPointerType` +internally. #### Array Types diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index a3554f79f253ed..38a163435e5a9f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -97,10 +97,11 @@ def LLVM_AnyPointer : Type()">, // on the element type. class LLVM_PointerTo : Type< And<[LLVM_AnyPointer.predicate, - SubstLeaves< - "$_self", - "$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()", - pointee.predicate>]>, + Or<[CPred<"$_self.cast<::mlir::LLVM::LLVMPointerType>().isOpaque()">, + SubstLeaves< + "$_self", + "$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()", + pointee.predicate>]>]>, "LLVM pointer to " # pointee.summary>; // Type constraints accepting LLVM pointer type to integer of a specific width. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index e5b0fc577a63dd..ff1b57f52085ae 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -293,12 +293,16 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase { // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase { let arguments = (ins AnyInteger:$arraySize, - OptionalAttr:$alignment); - let results = (outs Res]>:$res); + OptionalAttr:$alignment, + OptionalAttr:$elem_type); + let results = (outs Res]>:$res); string llvmBuilder = [{ auto addrSpace = $_resultType->getPointerAddressSpace(); - auto *inst = builder.CreateAlloca( - $_resultType->getPointerElementType(), addrSpace, $arraySize); + llvm::Type *elementType = moduleTranslation.convertType( + $elem_type ? *$elem_type + : op.getType().cast().getElementType()); + auto *inst = builder.CreateAlloca(elementType, addrSpace, $arraySize); }] # setAlignmentCode # [{ $res = inst; }]; @@ -307,17 +311,20 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase { "unsigned":$alignment), [{ if (alignment == 0) - return build($_builder, $_state, resultType, arraySize, IntegerAttr()); + return build($_builder, $_state, resultType, arraySize, IntegerAttr(), + TypeAttr()); build($_builder, $_state, resultType, arraySize, - $_builder.getI64IntegerAttr(alignment)); + $_builder.getI64IntegerAttr(alignment), TypeAttr()); }]>]; let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$indices, - I32ElementsAttr:$structIndices); + I32ElementsAttr:$structIndices, + OptionalAttr:$elem_type); let results = (outs LLVM_ScalarOrVectorOf:$res); let skipDefaultBuilders = 1; let builders = [ @@ -337,16 +344,18 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { else indices.push_back(builder.getInt32(structIndex)); } - $res = builder.CreateGEP( - $base->getType()->getPointerElementType(), $base, indices); + Type baseElementType = op.getSourceElementType(); + llvm::Type *elementType = moduleTranslation.convertType(baseElementType); + $res = builder.CreateGEP(elementType, $base, indices); }]; let assemblyFormat = [{ $base `[` custom($indices, $structIndices) `]` attr-dict - `:` functional-type(operands, results) + `:` functional-type(operands, results) (`,` $elem_type^)? }]; let extraClassDeclaration = [{ constexpr static int kDynamicIndex = std::numeric_limits::min(); + Type getSourceElementType(); }]; let hasFolder = 1; let hasVerifier = 1; @@ -361,8 +370,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); string llvmBuilder = [{ - auto *inst = builder.CreateLoad( - $addr->getType()->getPointerElementType(), $addr, $volatile_); + auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_); }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode @@ -375,11 +383,13 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal), [{ auto type = addr.getType().cast().getElementType(); + assert(type && "must provide explicit element type to the constructor " + "when the pointer type is opaque"); build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal); }]>, OpBuilder<(ins "Type":$t, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, - CArg<"bool", "false">:$isNonTemporal)>]; + CArg<"bool", "false">:$isNonTemporal)>,]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1592,13 +1602,14 @@ def LLVM_vector_reduce_fmul : LLVM_VectorReductionAcc<"fmul">; def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> { let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); - let results = (outs LLVM_Type:$res); + let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); - llvm::Type *ElemTy = $data->getType()->getPointerElementType(); + llvm::Type *ElemTy = moduleTranslation.convertType( + getVectorElementType(op.getType())); llvm::Align align = dl.getABITypeAlign(ElemTy); $res = mb.CreateColumnMajorLoad( ElemTy, $data, align, $stride, $isVolatile, $rows, @@ -1617,15 +1628,17 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> { /// columns - Number of columns in matrix (must be a constant) /// stride - Space between columns def LLVM_MatrixColumnMajorStoreOp : LLVM_Op<"intr.matrix.column.major.store"> { - let arguments = (ins LLVM_Type:$matrix, LLVM_Type:$data, LLVM_Type:$stride, - I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns); + let arguments = (ins LLVM_AnyVector:$matrix, LLVM_Type:$data, + LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows, + I32Attr:$columns); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); const llvm::DataLayout &dl = builder.GetInsertBlock()->getModule()->getDataLayout(); + Type elementType = getVectorElementType(op.getMatrix().getType()); llvm::Align align = dl.getABITypeAlign( - $data->getType()->getPointerElementType()); + moduleTranslation.convertType(elementType)); mb.CreateColumnMajorStore( $matrix, $data, align, $stride, $isVolatile, $rows, $columns); @@ -1681,14 +1694,12 @@ def LLVM_GetActiveLaneMaskOp def LLVM_MaskedLoadOp : LLVM_Op<"intr.masked.load"> { let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); - let results = (outs LLVM_Type:$res); - let builders = [LLVM_OneResultOpBuilder]; + let results = (outs LLVM_AnyVector:$res); string llvmBuilder = [{ - llvm::Type *Ty = $data->getType()->getPointerElementType(); $res = $pass_thru.empty() ? builder.CreateMaskedLoad( - Ty, $data, llvm::Align($alignment), $mask) : + $_resultType, $data, llvm::Align($alignment), $mask) : builder.CreateMaskedLoad( - Ty, $data, llvm::Align($alignment), $mask, $pass_thru[0]); + $_resultType, $data, llvm::Align($alignment), $mask, $pass_thru[0]); }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; @@ -1709,19 +1720,15 @@ def LLVM_MaskedStoreOp : LLVM_Op<"intr.masked.store"> { /// Create a call to Masked Gather intrinsic. def LLVM_masked_gather : LLVM_Op<"intr.masked.gather"> { - let arguments = (ins LLVM_Type:$ptrs, LLVM_Type:$mask, + let arguments = (ins LLVM_AnyVector:$ptrs, LLVM_Type:$mask, Variadic:$pass_thru, I32Attr:$alignment); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; string llvmBuilder = [{ - llvm::VectorType *PtrVecTy = cast($ptrs->getType()); - llvm::Type *Ty = llvm::VectorType::get( - PtrVecTy->getElementType()->getPointerElementType(), - PtrVecTy->getElementCount()); $res = $pass_thru.empty() ? builder.CreateMaskedGather( - Ty, $ptrs, llvm::Align($alignment), $mask) : + $_resultType, $ptrs, llvm::Align($alignment), $mask) : builder.CreateMaskedGather( - Ty, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); + $_resultType, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]); }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index cc2e30a821e63a..9c975289c18c08 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -176,36 +176,49 @@ class LLVMFunctionType //===----------------------------------------------------------------------===// /// LLVM dialect pointer type. This type typically represents a reference to an -/// object in memory. It is parameterized by the element type and the address -/// space. +/// object in memory. Pointers may be opaque or parameterized by the element +/// type. Both opaque and non-opaque pointers are additionally parameterized by +/// the address space. class LLVMPointerType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; - using Base::getChecked; /// Checks if the given type can have a pointer type pointing to it. static bool isValidElementType(Type type); /// Gets or creates an instance of LLVM dialect pointer type pointing to an /// object of `pointee` type in the given address space. The pointer type is - /// created in the same context as `pointee`. + /// created in the same context as `pointee`. If the pointee is not provided, + /// creates an opaque pointer in the given context and address space. + static LLVMPointerType get(MLIRContext *context, unsigned addressSpace = 0); static LLVMPointerType get(Type pointee, unsigned addressSpace = 0); static LLVMPointerType getChecked(function_ref emitError, Type pointee, unsigned addressSpace = 0); + static LLVMPointerType + getChecked(function_ref emitError, MLIRContext *context, + unsigned addressSpace = 0); - /// Returns the pointed-to type. + /// Returns the pointed-to type. It may be null if the pointer is opaque. Type getElementType() const; + /// Returns `true` if this type is the opaque pointer type, i.e., it has no + /// pointed-to type. + bool isOpaque() const; + /// Returns the address space of the pointer. unsigned getAddressSpace() const; /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type pointee, unsigned); + static LogicalResult verify(function_ref emitError, + MLIRContext *context, unsigned) { + return success(); + } /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a /// DataLayout instance and query it instead. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3fc042d17c5726..393d53c83e69ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -42,6 +42,7 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage; static constexpr const char kVolatileAttrName[] = "volatile_"; static constexpr const char kNonTemporalAttrName[] = "nontemporal"; +static constexpr const char kElemTypeAttrName[] = "elem_type"; #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" @@ -180,11 +181,13 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::AllocaOp. +// Printing, parsing and verification for LLVM::AllocaOp. //===----------------------------------------------------------------------===// void AllocaOp::print(OpAsmPrinter &p) { - auto elemTy = getType().cast().getElementType(); + Type elemTy = getType().cast().getElementType(); + if (!elemTy) + elemTy = *getElemType(); auto funcTy = FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); @@ -193,7 +196,8 @@ void AllocaOp::print(OpAsmPrinter &p) { if (getAlignment().hasValue() && *getAlignment() != 0) p.printOptionalAttrDict((*this)->getAttrs()); else - p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"}); + p.printOptionalAttrDict((*this)->getAttrs(), + {"alignment", kElemTypeAttrName}); p << " : " << funcTy; } @@ -232,10 +236,37 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) return failure(); + Type resultType = funcType.getResult(0); + if (auto ptrResultType = resultType.dyn_cast()) { + if (ptrResultType.isOpaque()) + result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); + } + result.addTypes({funcType.getResult(0)}); return success(); } +/// Checks that the elemental type is present in either the pointer type or +/// the attribute, but not both. +static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, + Optional ptrElementType) { + if (ptrType.isOpaque() && !ptrElementType.hasValue()) { + return op->emitOpError() << "expected '" << kElemTypeAttrName + << "' attribute if opaque pointer type is used"; + } + if (!ptrType.isOpaque() && ptrElementType.hasValue()) { + return op->emitOpError() + << "unexpected '" << kElemTypeAttrName + << "' attribute when non-opaque pointer type is used"; + } + return success(); +} + +LogicalResult AllocaOp::verify() { + return verifyOpaquePtr(getOperation(), getType().cast(), + getElemType()); +} + //===----------------------------------------------------------------------===// // LLVM::BrOp //===----------------------------------------------------------------------===// @@ -402,20 +433,12 @@ static void recordStructIndices(Type type, unsigned currentIndex, /// provided, it is populated with sizes of the indexed structs for bounds /// verification purposes. static void -findKnownStructIndices(Type baseGEPType, SmallVectorImpl &indices, +findKnownStructIndices(Type sourceElementType, + SmallVectorImpl &indices, SmallVectorImpl *structSizes = nullptr) { - Type type = baseGEPType; - if (auto vectorType = type.dyn_cast()) - type = vectorType.getElementType(); - if (auto scalableVectorType = type.dyn_cast()) - type = scalableVectorType.getElementType(); - if (auto fixedVectorType = type.dyn_cast()) - type = fixedVectorType.getElementType(); - - Type pointeeType = type.cast().getElementType(); SmallPtrSet visited; - recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes, - visited); + recordStructIndices(sourceElementType, /*currentIndex=*/1, indices, + structSizes, visited); } void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, @@ -426,6 +449,17 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, attributes); } +/// Returns the elemental type of any LLVM-compatible vector type or self. +static Type extractVectorElementType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType(); + if (auto scalableVectorType = type.dyn_cast()) + return scalableVectorType.getElementType(); + if (auto fixedVectorType = type.dyn_cast()) + return fixedVectorType.getElementType(); + return type; +} + void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value basePtr, ValueRange indices, ArrayRef structIndices, @@ -434,7 +468,10 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, SmallVector updatedStructIndices(structIndices.begin(), structIndices.end()); SmallVector structRelatedPositions; - findKnownStructIndices(basePtr.getType(), structRelatedPositions); + auto ptrType = + extractVectorElementType(basePtr.getType()).cast(); + assert(!ptrType.isOpaque() && "expected non-opaque pointer"); + findKnownStructIndices(ptrType.getElementType(), structRelatedPositions); SmallVector operandsToErase; for (unsigned pos : structRelatedPositions) { @@ -523,9 +560,15 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, } LogicalResult LLVM::GEPOp::verify() { + if (failed(verifyOpaquePtr( + getOperation(), + extractVectorElementType(getType()).cast(), + getElemType()))) + return failure(); + SmallVector indices; SmallVector structSizes; - findKnownStructIndices(getBase().getType(), indices, &structSizes); + findKnownStructIndices(getSourceElementType(), indices, &structSizes); DenseIntElementsAttr structIndices = getStructIndices(); for (unsigned i : llvm::seq(0, indices.size())) { unsigned index = indices[i]; @@ -544,6 +587,15 @@ LogicalResult LLVM::GEPOp::verify() { return success(); } +Type LLVM::GEPOp::getSourceElementType() { + if (Optional elemType = getElemType()) + return *elemType; + + return extractVectorElementType(getBase().getType()) + .cast() + .getElementType(); +} + //===----------------------------------------------------------------------===// // Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -637,22 +689,28 @@ void LoadOp::print(OpAsmPrinter &p) { if (getVolatile_()) p << "volatile "; p << getAddr(); - p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); + p.printOptionalAttrDict((*this)->getAttrs(), + {kVolatileAttrName, kElemTypeAttrName}); p << " : " << getAddr().getType(); + if (getAddr().getType().cast().isOpaque()) + p << " -> " << getType(); } -// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return -// the resulting type wrapped in MLIR, or nullptr on error. -static Type getLoadStoreElementType(OpAsmParser &parser, Type type, - SMLoc trailingTypeLoc) { +// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return +// the resulting type if any, null type if opaque pointers are used, and None +// if the given type is not the pointer type. +static Optional getLoadStoreElementType(OpAsmParser &parser, Type type, + SMLoc trailingTypeLoc) { auto llvmTy = type.dyn_cast(); - if (!llvmTy) - return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), - nullptr; + if (!llvmTy) { + parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); + return llvm::None; + } return llvmTy.getElementType(); } // ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type +// (`->` type)? ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand addr; Type type; @@ -667,9 +725,19 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { parser.resolveOperand(addr, type, result.operands)) return failure(); - Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); + Optional elemTy = + getLoadStoreElementType(parser, type, trailingTypeLoc); + if (!elemTy) + return failure(); + if (*elemTy) { + result.addTypes(*elemTy); + return success(); + } - result.addTypes(elemTy); + Type trailingType; + if (parser.parseArrow() || parser.parseType(trailingType)) + return failure(); + result.addTypes(trailingType); return success(); } @@ -698,11 +766,14 @@ void StoreOp::print(OpAsmPrinter &p) { p << "volatile "; p << getValue() << ", " << getAddr(); p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); - p << " : " << getAddr().getType(); + p << " : "; + if (getAddr().getType().cast().isOpaque()) + p << getValue().getType() << ", "; + p << getAddr().getType(); } // ::= `llvm.store` `volatile` ssa-use `,` ssa-use -// attribute-dict? `:` type +// attribute-dict? `:` type (`,` type)? ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand addr, value; Type type; @@ -717,11 +788,20 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); - Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); - if (!elemTy) - return failure(); + Type operandType; + if (succeeded(parser.parseOptionalComma())) { + operandType = type; + if (parser.parseType(type)) + return failure(); + } else { + Optional maybeOperandType = + getLoadStoreElementType(parser, type, trailingTypeLoc); + if (!maybeOperandType) + return failure(); + operandType = *maybeOperandType; + } - if (parser.resolveOperand(value, elemTy, result.operands) || + if (parser.resolveOperand(value, operandType, result.operands) || parser.resolveOperand(addr, type, result.operands)) return failure(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 1c91007897c094..04015efea3e9f5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -141,6 +141,12 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { printer << getTypeKeyword(type); if (auto ptrType = type.dyn_cast()) { + if (ptrType.isOpaque()) { + if (ptrType.getAddressSpace() != 0) + printer << '<' << ptrType.getAddressSpace() << '>'; + return; + } + printer << '<'; dispatchPrint(printer, ptrType.getElementType()); if (ptrType.getAddressSpace() != 0) @@ -215,13 +221,27 @@ static LLVMFunctionType parseFunctionType(AsmParser &parser) { /// Parses an LLVM dialect pointer type. /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` +/// | `ptr` (`<` integer `>`)? static LLVMPointerType parsePointerType(AsmParser &parser) { SMLoc loc = parser.getCurrentLocation(); Type elementType; - if (parser.parseLess() || dispatchParse(parser, elementType)) - return LLVMPointerType(); + if (parser.parseOptionalLess()) { + return parser.getChecked(loc, parser.getContext(), + /*addressSpace=*/0); + } unsigned addressSpace = 0; + OptionalParseResult opr = parser.parseOptionalInteger(addressSpace); + if (opr.hasValue()) { + if (failed(*opr) || parser.parseGreater()) + return LLVMPointerType(); + return parser.getChecked(loc, parser.getContext(), + addressSpace); + } + + if (dispatchParse(parser, elementType)) + return LLVMPointerType(); + if (succeeded(parser.parseOptionalComma()) && failed(parser.parseInteger(addressSpace))) return LLVMPointerType(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 2698a8e9a9a528..101e255f99efda 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -156,6 +156,8 @@ LLVMFunctionType::verify(function_ref emitError, //===----------------------------------------------------------------------===// bool LLVMPointerType::isValidElementType(Type type) { + if (!type) + return true; return isCompatibleOuterType(type) ? !type.isa() @@ -163,10 +165,16 @@ bool LLVMPointerType::isValidElementType(Type type) { } LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { - assert(pointee && "expected non-null subtype"); + assert(pointee && "expected non-null subtype, pass the context instead if " + "the opaque pointer type is desired"); return Base::get(pointee.getContext(), pointee, addressSpace); } +LLVMPointerType LLVMPointerType::get(MLIRContext *context, + unsigned addressSpace) { + return Base::get(context, Type(), addressSpace); +} + LLVMPointerType LLVMPointerType::getChecked(function_ref emitError, Type pointee, unsigned addressSpace) { @@ -174,8 +182,16 @@ LLVMPointerType::getChecked(function_ref emitError, addressSpace); } +LLVMPointerType +LLVMPointerType::getChecked(function_ref emitError, + MLIRContext *context, unsigned addressSpace) { + return Base::getChecked(emitError, context, Type(), addressSpace); +} + Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; } +bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; } + unsigned LLVMPointerType::getAddressSpace() const { return getImpl()->addressSpace; } @@ -722,9 +738,13 @@ static bool isCompatibleImpl(Type type, SetVector &callstack) { .Case([&](auto vecType) { return vecType.getRank() == 1 && isCompatible(vecType.getElementType()); }) + .Case([&](auto pointerType) { + if (pointerType.isOpaque()) + return true; + return isCompatible(pointerType.getElementType()); + }) // clang-format off .Case< - LLVMPointerType, LLVMFixedVectorType, LLVMScalableVectorType, LLVMArrayType diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h index 7e89c6c7a9a569..7ebfe8e69f8cc8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -382,7 +382,8 @@ struct LLVMFunctionTypeStorage : public TypeStorage { //===----------------------------------------------------------------------===// /// Storage type for LLVM dialect pointer types. These are uniqued by a pair of -/// element type and address space. +/// element type and address space. The element type may be null indicating that +/// the pointer is opaque. struct LLVMPointerTypeStorage : public TypeStorage { using KeyTy = std::tuple; diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index 2cd98691e0c3b7..a65757a1bbd5f4 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -106,6 +106,8 @@ class TypeToLLVMIRTranslatorImpl { /// Translates the given pointer type. llvm::Type *translate(LLVM::LLVMPointerType type) { + if (type.isOpaque()) + return llvm::PointerType::get(context, type.getAddressSpace()); return llvm::PointerType::get(translateType(type.getElementType()), type.getAddressSpace()); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index cc4ae43b7ae3ef..c260568b17888b 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -104,6 +104,20 @@ func @alloca_non_integer_alignment() { // ----- +func @alloca_opaque_ptr_no_type(%sz : i64) { + // expected-error@below {{expected 'elem_type' attribute if opaque pointer type is used}} + "llvm.alloca"(%sz) : (i64) -> !llvm.ptr +} + +// ----- + +func @alloca_ptr_type_attr_non_opaque_ptr(%sz : i64) { + // expected-error@below {{unexpected 'elem_type' attribute when non-opaque pointer type is used}} + "llvm.alloca"(%sz) { elem_type = i32 } : (i64) -> !llvm.ptr +} + +// ----- + func @gep_missing_input_result_type(%pos : i64, %base : !llvm.ptr) { // expected-error@+1 {{2 operands present, but expected 0}} llvm.getelementptr %base[%pos] : () -> () @@ -160,6 +174,13 @@ func @store_non_ptr_type(%foo : f32, %bar : f32) { // ----- +func @store_malformed_elem_type(%foo: !llvm.ptr, %bar: f32) { + // expected-error@+1 {{expected non-function type}} + llvm.store %bar, %foo : !llvm.ptr, "f32" +} + +// ----- + func @call_non_function_type(%callee : !llvm.func, %arg : i8) { // expected-error@+1 {{expected function type}} llvm.call %callee(%arg) : !llvm.func diff --git a/mlir/test/Dialect/LLVMIR/opaque-ptr.mlir b/mlir/test/Dialect/LLVMIR/opaque-ptr.mlir new file mode 100644 index 00000000000000..373931c747fc3e --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/opaque-ptr.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: @opaque_ptr_load +llvm.func @opaque_ptr_load(%arg0: !llvm.ptr) -> i32 { + // CHECK: = llvm.load %{{.*}} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 : !llvm.ptr -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @opaque_ptr_store +llvm.func @opaque_ptr_store(%arg0: i32, %arg1: !llvm.ptr){ + // CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr + llvm.store %arg0, %arg1 : i32, !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @opaque_ptr_ptr_store +llvm.func @opaque_ptr_ptr_store(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr + llvm.store %arg0, %arg1 : !llvm.ptr, !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @opaque_ptr_alloca +llvm.func @opaque_ptr_alloca(%arg0: i32) -> !llvm.ptr { + // CHECK: llvm.alloca %{{.*}} x f32 : (i32) -> !llvm.ptr + %0 = llvm.alloca %arg0 x f32 : (i32) -> !llvm.ptr + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_gep +llvm.func @opaque_ptr_gep(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr { + // CHECK: = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %0 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_gep_struct +llvm.func @opaque_ptr_gep_struct(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr { + // CHECK: = llvm.getelementptr %{{.*}}[%{{.*}}, 0, 1] + // CHECK: : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(struct<(f32, f64)>, struct<(i32, i64)>)> + %0 = llvm.getelementptr %arg0[%arg1, 0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(struct<(f32, f64)>, struct<(i32, i64)>)> + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_matrix_load_store +llvm.func @opaque_ptr_matrix_load_store(%ptr: !llvm.ptr, %stride: i64) -> vector<48 x f32> { + // CHECK: = llvm.intr.matrix.column.major.load + // CHECK: vector<48xf32> from !llvm.ptr stride i64 + %0 = llvm.intr.matrix.column.major.load %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : + vector<48 x f32> from !llvm.ptr stride i64 + // CHECK: llvm.intr.matrix.column.major.store + // CHECK: vector<48xf32> to !llvm.ptr stride i64 + llvm.intr.matrix.column.major.store %0, %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : + vector<48 x f32> to !llvm.ptr stride i64 + llvm.return %0 : vector<48 x f32> +} + +// CHECK-LABEL: @opaque_ptr_masked_load +llvm.func @opaque_ptr_masked_load(%arg0: !llvm.ptr, %arg1: vector<7xi1>) -> vector<7xf32> { + // CHECK: = llvm.intr.masked.load + // CHECK: (!llvm.ptr, vector<7xi1>) -> vector<7xf32> + %0 = llvm.intr.masked.load %arg0, %arg1 { alignment = 1: i32} : + (!llvm.ptr, vector<7xi1>) -> vector<7xf32> + llvm.return %0 : vector<7 x f32> +} + +// CHECK-LABEL: @opaque_ptr_gather +llvm.func @opaque_ptr_gather(%M: !llvm.vec<7 x ptr>, %mask: vector<7xi1>) -> vector<7xf32> { + // CHECK: = llvm.intr.masked.gather + // CHECK: (!llvm.vec<7 x ptr>, vector<7xi1>) -> vector<7xf32> + %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} : + (!llvm.vec<7 x ptr>, vector<7xi1>) -> vector<7xf32> + llvm.return %a : vector<7xf32> +} diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir index 6e1a571d0ca1e3..d67d2b6c4f71a8 100644 --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -73,6 +73,10 @@ func @ptr() { "some.op"() : () -> !llvm.ptr // CHECK: !llvm.ptr, 9> "some.op"() : () -> !llvm.ptr, 9> + // CHECK: !llvm.ptr + "some.op"() : () -> !llvm.ptr + // CHECK: !llvm.ptr<42> + "some.op"() : () -> !llvm.ptr<42> return } diff --git a/mlir/test/Target/LLVMIR/opaque-ptr.mlir b/mlir/test/Target/LLVMIR/opaque-ptr.mlir new file mode 100644 index 00000000000000..9af79cff01ef94 --- /dev/null +++ b/mlir/test/Target/LLVMIR/opaque-ptr.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-translate -mlir-to-llvmir -opaque-pointers %s | FileCheck %s + +// CHECK-LABEL: @opaque_ptr_load +llvm.func @opaque_ptr_load(%arg0: !llvm.ptr) -> i32 { + // CHECK: load i32, ptr %{{.*}} + %0 = llvm.load %arg0 : !llvm.ptr -> i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @opaque_ptr_store +llvm.func @opaque_ptr_store(%arg0: i32, %arg1: !llvm.ptr){ + // CHECK: store i32 %{{.*}}, ptr %{{.*}} + llvm.store %arg0, %arg1 : i32, !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @opaque_ptr_ptr_store +llvm.func @opaque_ptr_ptr_store(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + // CHECK: store ptr %{{.*}}, ptr %{{.*}} + llvm.store %arg0, %arg1 : !llvm.ptr, !llvm.ptr + llvm.return +} + +// CHECK-LABEL: @opaque_ptr_alloca +llvm.func @opaque_ptr_alloca(%arg0: i32) -> !llvm.ptr { + // CHECK: alloca float, i32 %{{.*}} + %0 = llvm.alloca %arg0 x f32 : (i32) -> !llvm.ptr + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_gep +llvm.func @opaque_ptr_gep(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr { + // CHECK: getelementptr float, ptr %{{.*}}, i32 %{{.*}} + %0 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_gep_struct +llvm.func @opaque_ptr_gep_struct(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr { + // CHECK: getelementptr { { float, double }, { i32, i64 } }, ptr %{{.*}}, i32 0, i32 1 + %0 = llvm.getelementptr %arg0[%arg1, 0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(struct<(f32, f64)>, struct<(i32, i64)>)> + llvm.return %0 : !llvm.ptr +} + +// CHECK-LABEL: @opaque_ptr_matrix_load_store +llvm.func @opaque_ptr_matrix_load_store(%ptr: !llvm.ptr, %stride: i64) -> vector<48 x f32> { + // CHECK: call <48 x float> @llvm.matrix.column.major.load.v48f32.i64 + // CHECK: (ptr {{.*}}, i64 %{{.*}} + %0 = llvm.intr.matrix.column.major.load %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : + vector<48 x f32> from !llvm.ptr stride i64 + // CHECK: call void @llvm.matrix.column.major.store.v48f32.i64 + // CHECK: <48 x float> %{{.*}}, ptr {{.*}}, i64 + llvm.intr.matrix.column.major.store %0, %ptr, + { isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : + vector<48 x f32> to !llvm.ptr stride i64 + llvm.return %0 : vector<48 x f32> +} + +// CHECK-LABEL: @opaque_ptr_masked_load +llvm.func @opaque_ptr_masked_load(%arg0: !llvm.ptr, %arg1: vector<7xi1>) -> vector<7xf32> { + // CHECK: call <7 x float> @llvm.masked.load.v7f32.p0(ptr + %0 = llvm.intr.masked.load %arg0, %arg1 { alignment = 1: i32} : + (!llvm.ptr, vector<7xi1>) -> vector<7xf32> + llvm.return %0 : vector<7 x f32> +} + +// CHECK-LABEL: @opaque_ptr_gather +llvm.func @opaque_ptr_gather(%M: !llvm.vec<7 x ptr>, %mask: vector<7xi1>) -> vector<7xf32> { + // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0(<7 x ptr> {{.*}}, i32 + %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} : + (!llvm.vec<7 x ptr>, vector<7xi1>) -> vector<7xf32> + llvm.return %a : vector<7xf32> +}