diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index bfce904a18d4f..32670a78a40f9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -11,6 +11,7 @@ include "mlir/Dialect/LLVMIR/LLVMDialect.td" include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -23,6 +24,41 @@ class LLVM_Attr + ]> { + let summary = "LLVM address space"; + let description = [{ + The `address_space` attribute represents an LLVM address space. It takes an + unsigned integer parameter that specifies the address space number. + + Different address spaces in LLVM can have different properties: + - Address space 0 is the default/generic address space + - Other address spaces may have specific semantics (e.g., shared memory, + constant memory, etc.) depending on the target architecture + + Example: + + ```mlir + // Address space 0 (default) + #llvm.address_space<0> + + // Address space 1 (e.g., global memory on some targets) + #llvm.address_space<1> + + // Address space 3 (e.g., shared memory on some GPU targets) + #llvm.address_space<3> + ``` + }]; + let parameters = (ins "unsigned":$addressSpace); + let assemblyFormat = "`<` $addressSpace `>`"; +} + //===----------------------------------------------------------------------===// // CConvAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h index 1ceeb7e4ba2a5..fafccf304e1b4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_LLVMIR_LLVMATTRS_H_ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 17561f79d135a..a1506497dc85c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -28,6 +28,7 @@ namespace mlir { class AsmParser; class AsmPrinter; +class DataLayout; namespace LLVM { class LLVMDialect; @@ -111,6 +112,15 @@ bool isCompatibleFloatingPointType(Type type); /// dialect pointers and LLVM dialect scalable vector types. bool isCompatibleVectorType(Type type); +/// Returns `true` if the given type is a loadable type compatible with the LLVM +/// dialect. +bool isLoadableType(Type type); + +/// Returns true if the given type is supported by atomic operations. All +/// integer, float, and pointer types with a power-of-two bitsize and a minimal +/// size of 8 bits are supported. +bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout); + /// Returns the element count of any LLVM-compatible vector type. llvm::ElementCount getVectorNumElements(Type type); diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt index fa4914b179b7a..388d735843c4e 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt @@ -7,8 +7,6 @@ mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr) add_mlir_dialect_tablegen_target(MLIRPtrOpsAttributesIncGen) set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td) -mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs) mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs) add_mlir_dialect_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h index 3e6754c6bec99..4d65c8d807cb9 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h @@ -17,8 +17,12 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" +#include +#include + namespace mlir { class Operation; +class DataLayout; namespace ptr { enum class AtomicBinOp : uint32_t; enum class AtomicOrdering : uint32_t; @@ -27,6 +31,4 @@ enum class AtomicOrdering : uint32_t; #include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc" -#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc" - #endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td index 0171b9ca2e5dc..5231231564fb1 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td @@ -43,6 +43,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { /*args=*/ (ins "::mlir::Type":$type, "::mlir::ptr::AtomicOrdering":$ordering, "std::optional":$alignment, + "const ::mlir::DataLayout *":$dataLayout, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) >, InterfaceMethod< @@ -58,6 +59,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { /*args=*/ (ins "::mlir::Type":$type, "::mlir::ptr::AtomicOrdering":$ordering, "std::optional":$alignment, + "const ::mlir::DataLayout *":$dataLayout, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) >, InterfaceMethod< @@ -74,6 +76,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { "::mlir::Type":$type, "::mlir::ptr::AtomicOrdering":$ordering, "std::optional":$alignment, + "const ::mlir::DataLayout *":$dataLayout, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) >, InterfaceMethod< @@ -91,6 +94,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { "::mlir::ptr::AtomicOrdering":$successOrdering, "::mlir::ptr::AtomicOrdering":$failureOrdering, "std::optional":$alignment, + "const ::mlir::DataLayout *":$dataLayout, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h index dc0a3ffd4ae33..bb01ceaaeea54 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -19,10 +19,9 @@ #include "llvm/Support/TypeSize.h" #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" +#include "mlir/Dialect/Ptr/IR/PtrEnums.h" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" -#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc" - #endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h new file mode 100644 index 0000000000000..2e98df8654b71 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.h @@ -0,0 +1,21 @@ +//===- PtrEnums.h - `ptr` dialect enums -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the `ptr` dialect enums. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_PTRENUMS_H +#define MLIR_DIALECT_PTR_IR_PTRENUMS_H + +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_PTRENUMS_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index e4670cb1a9622..05b66ace902ad 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -25,6 +25,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" @@ -45,6 +46,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerNVVMDialectTranslation(registry); registerOpenACCDialectTranslation(registry); registerOpenMPDialectTranslation(registry); + registerPtrDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h new file mode 100644 index 0000000000000..5c81762ba1a8a --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===- PtrToLLVMIRTranslation.h - `ptr` to LLVM IR --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for `ptr` dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the `ptr` dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerPtrDialectTranslation(DialectRegistry ®istry); + +/// Register the `ptr` dialect and the translation from it in the registry +/// associated with the given context. +void registerPtrDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ff55f17315cfd..ec581ac7277e3 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces + MLIRPtrMemorySpaceInterfaces MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 9d7a23f028cb0..e268e8f36de01 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrEnums.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -50,6 +52,87 @@ void LLVMDialect::registerAttributes() { >(); } +//===----------------------------------------------------------------------===// +// AddressSpaceAttr +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an LLVM type that can be loaded or stored. +static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, + std::optional alignment, + const ::mlir::DataLayout *dataLayout, + function_ref emitError) { + if (!isLoadableType(type)) { + if (emitError) + emitError() << "type must be LLVM type with size, but got " << type; + return false; + } + if (ordering == ptr::AtomicOrdering::not_atomic) + return true; + + // To check atomic validity we need a datalayout. + if (!dataLayout) { + if (emitError) + emitError() << "expected a valid data layout"; + return false; + } + if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) { + if (emitError) + emitError() << "unsupported type " << type << " for atomic access"; + return false; + } + return true; +} + +bool AddressSpaceAttr::isValidLoad( + Type type, ptr::AtomicOrdering ordering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, + function_ref emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidStore( + Type type, ptr::AtomicOrdering ordering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, + function_ref emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidAtomicOp( + ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, + std::optional alignment, const ::mlir::DataLayout *dataLayout, + function_ref emitError) const { + // TODO: update this method once `ptr.atomic_rmw` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAtomicXchg( + Type type, ptr::AtomicOrdering successOrdering, + ptr::AtomicOrdering failureOrdering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, + function_ref emitError) const { + // TODO: update this method once `ptr.atomic_cmpxchg` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref emitError) const { + // TODO: update this method once the `ptr.addrspace_cast` op is added to the + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref emitError) const { + // TODO: update this method once the int-cast ops are added to the `ptr` + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + //===----------------------------------------------------------------------===// // AliasScopeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index a0b755bc63736..ef2707089a45c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -853,8 +853,8 @@ void LoadOp::getEffects( /// Returns true if the given type is supported by atomic operations. All /// integer, float, and pointer types with a power-of-two bitsize and a minimal /// size of 8 bits are supported. -static bool isTypeCompatibleWithAtomicOp(Type type, - const DataLayout &dataLayout) { +bool LLVM::isTypeCompatibleWithAtomicOp(Type type, + const DataLayout &dataLayout) { if (!isa(type)) if (!isCompatibleFloatingPointType(type)) return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 78b44116bb4fa..297640cdd49d0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -24,7 +24,9 @@ using namespace mlir::LLVM; /// prints it as usual. static void dispatchPrint(AsmPrinter &printer, Type type) { if (isCompatibleType(type) && - !llvm::isa(type)) + !(llvm::isa(type) || + (llvm::isa(type) && + !llvm::isa(type)))) return mlir::LLVM::detail::printType(type, printer); printer.printType(type); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index fee2d3ed62930..2dd0132a65bc4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -13,6 +13,7 @@ #include "TypeDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" @@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const { // Utility functions. //===----------------------------------------------------------------------===// +/// Check whether type is a compatible ptr type. These are pointer-like types +/// with no element type, no metadata, and using the LLVM AddressSpaceAttr +/// memory space. +static bool isCompatiblePtrType(Type type) { + auto ptrTy = dyn_cast(type); + if (!ptrTy) + return false; + return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr && + isa(ptrTy.getMemorySpace()); +} + bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off if (llvm::isa< @@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { if (auto vecType = llvm::dyn_cast(type)) return vecType.getRank() == 1; - return false; + return isCompatiblePtrType(type); } static bool isCompatibleImpl(Type type, DenseSet &compatibleTypes) { @@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet &compatibleTypes) { LLVMX86AMXType >([](Type) { return true; }) // clang-format on + .Case( + [](Type type) { return isCompatiblePtrType(type); }) .Default([](Type) { return false; }); if (!result) @@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) { return LLVMDialect::isCompatibleType(type); } +bool mlir::LLVM::isLoadableType(Type type) { + return /*LLVM_PrimitiveType*/ ( + LLVM::isCompatibleOuterType(type) && + !isa(type)) && + /*LLVM_OpaqueStruct*/ + !(isa(type) && + cast(type).isOpaque()) && + /*LLVM_AnyTargetExt*/ + !(isa(type) && + !cast(type).supportsMemOps()); +} + bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { return llvm::isa(type); @@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { if (auto intType = llvm::dyn_cast(elementType)) return intType.isSignless(); return llvm::isa(elementType); + Float80Type, Float128Type, LLVMPointerType>(elementType) || + isCompatiblePtrType(elementType); } return false; } diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 497468b9391db..bd1e655fc6b5e 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -1,3 +1,22 @@ +set(LLVM_OPTIONAL_SOURCES + MemorySpaceInterfaces.cpp + PtrAttrs.cpp + PtrTypes.cpp + PtrDialect.cpp +) + +add_mlir_dialect_library( + MLIRPtrMemorySpaceInterfaces + MemorySpaceInterfaces.cpp + + DEPENDS + MLIRPtrOpsEnumsGen + MLIRPtrMemorySpaceInterfacesIncGen + LINK_LIBS + PUBLIC + MLIRIR +) + add_mlir_dialect_library( MLIRPtrDialect PtrAttrs.cpp @@ -15,4 +34,5 @@ add_mlir_dialect_library( MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces MLIRViewLikeInterface + MLIRPtrMemorySpaceInterfaces ) diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp new file mode 100644 index 0000000000000..059e67ffb9f66 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp @@ -0,0 +1,15 @@ +//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ptr dialect memory space interfaces. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index dd4e906536cfc..ac3bcd6cea87e 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -23,19 +23,21 @@ constexpr const static unsigned kBitsInByte = 8; bool GenericSpaceAttr::isValidLoad( Type type, ptr::AtomicOrdering ordering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return true; } bool GenericSpaceAttr::isValidStore( Type type, ptr::AtomicOrdering ordering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return true; } bool GenericSpaceAttr::isValidAtomicOp( ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, - std::optional alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return true; } @@ -43,6 +45,7 @@ bool GenericSpaceAttr::isValidAtomicOp( bool GenericSpaceAttr::isValidAtomicXchg( Type type, ptr::AtomicOrdering successOrdering, ptr::AtomicOrdering failureOrdering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return true; } diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index bf87f83afc273..d5976b9a41ff6 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -139,8 +139,9 @@ void LoadOp::getEffects( LogicalResult LoadOp::verify() { auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(), - emitDiag)) + &dataLayout, emitDiag)) return failure(); if (failed(verifyAlignment(getAlignment(), emitDiag))) return failure(); @@ -181,8 +182,9 @@ void StoreOp::getEffects( LogicalResult StoreOp::verify() { auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(), - emitDiag)) + &dataLayout, emitDiag)) return failure(); if (failed(verifyAlignment(getAlignment(), emitDiag))) return failure(); @@ -268,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional layout) { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" -#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc" - -#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" - #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index d39b35526daf9..a73a78d17c134 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -59,6 +59,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation + MLIRPtrToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation MLIRXeVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index 86c731a1074c3..a102c4323075b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(NVVM) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(Ptr) add_subdirectory(SPIRV) add_subdirectory(VCIX) add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000000000..f94410d1f8a78 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_translation_library(MLIRPtrToLLVMIRTranslation + PtrToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPtrDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp new file mode 100644 index 0000000000000..7b89ec8fcbffb --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -0,0 +1,66 @@ +//===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR `ptr` dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +using namespace mlir; +using namespace mlir::ptr; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the `ptr` dialect to LLVM IR. +class PtrDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + // Translation for ptr dialect operations to LLVM IR is currently + // unimplemented. + return op->emitError("Translation for ptr dialect operations to LLVM IR is " + "not implemented."); + } + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + // Translation for ptr dialect operations to LLVM IR is currently + // unimplemented. + return op->emitError("Translation for ptr dialect operations to LLVM IR is " + "not implemented."); + } +}; +} // namespace + +void mlir::registerPtrDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerPtrDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerPtrDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index e4ba478f1d3b5..ddd5946ce5d63 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/TypeToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" @@ -71,7 +72,7 @@ class TypeToLLVMIRTranslatorImpl { }) .Case( + LLVM::LLVMTargetExtType, PtrLikeTypeInterface>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); @@ -149,6 +150,14 @@ class TypeToLLVMIRTranslatorImpl { type.getIntParams()); } + /// Translates the given ptr type. + llvm::Type *translate(PtrLikeTypeInterface type) { + auto memSpace = dyn_cast(type.getMemorySpace()); + assert(memSpace && "expected pointer with the LLVM address space"); + assert(!type.hasPtrMetadata() && "expected pointer without metadata"); + return llvm::PointerType::get(context, memSpace.getAddressSpace()); + } + /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { diff --git a/mlir/test/Dialect/LLVMIR/ptr.mlir b/mlir/test/Dialect/LLVMIR/ptr.mlir new file mode 100644 index 0000000000000..3c208ae9d3211 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/ptr.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s --verify-roundtrip + +// Check that LLVM ops accept ptr values. +llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr, %arg1: !llvm.struct<(!ptr.ptr<#llvm.address_space<3>>)>) { + %1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#llvm.address_space<1>> + llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr + llvm.store %arg1, %arg0 : !llvm.struct<(!ptr.ptr<#llvm.address_space<3>>)>, !llvm.ptr + llvm.return +} diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index 5702097567fc7..0c34ae43bf6be 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -38,3 +38,19 @@ func.func @store_const(%arg0: !ptr.ptr<#test.const_memory_space>, %arg1: i64) { ptr.store %arg1, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#test.const_memory_space> return } + +// ----- + +func.func @llvm_load(%arg0: !ptr.ptr<#llvm.address_space<1>>) -> (memref) { + // expected-error@+1 {{type must be LLVM type with size, but got 'memref'}} + %0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<1>> -> memref + return %0 : memref +} + +// ----- + +func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref) { + // expected-error@+1 {{type must be LLVM type with size, but got 'memref'}} + ptr.store %arg1, %arg0 : memref, !ptr.ptr<#llvm.address_space<1>> + return +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index dc89489be9efc..3f3ad05c46acc 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -42,3 +42,17 @@ func.func @store_ops(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: f32, %arg2: i64 ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#ptr.generic_space> return } + +/// Test load operations with llvm.address_space memory space +func.func @llvm_load(%arg0: !ptr.ptr<#llvm.address_space<1>>) -> (f32, i32) { + %0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<1>> -> f32 + %1 = ptr.load volatile %arg0 atomic acquire alignment = 4 : !ptr.ptr<#llvm.address_space<1>> -> i32 + return %0, %1 : f32, i32 +} + +/// Test store operations with llvm.address_space memory space +func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<2>>, %arg1: f32, %arg2: i64) { + ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<2>> + ptr.store %arg2, %arg0 atomic release alignment = 8 : i64, !ptr.ptr<#llvm.address_space<2>> + return +} diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir new file mode 100644 index 0000000000000..c1620cb9ed313 --- /dev/null +++ b/mlir/test/Target/LLVMIR/ptr.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: declare ptr @llvm_ptr_address_space(ptr addrspace(1), ptr addrspace(3)) +llvm.func @llvm_ptr_address_space(!ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<3>>) -> !ptr.ptr<#llvm.address_space<0>> + +// CHECK-LABEL: define void @llvm_ops_with_ptr_values +// CHECK-SAME: (ptr %[[ARG:.*]]) { +// CHECK-NEXT: %[[V0:.*]] = load ptr addrspace(1), ptr %[[ARG]], align 8 +// CHECK-NEXT: store ptr addrspace(1) %[[V0]], ptr %[[ARG]], align 8 +// CHECK-NEXT: ret void +// CHECK-NEXT: } +llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) { + %1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#llvm.address_space<1>> + llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr + llvm.return +} diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index af5f1a3a699a0..fe1e9166a3099 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -386,14 +386,14 @@ TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const { bool TestConstMemorySpaceAttr::isValidLoad( Type type, mlir::ptr::AtomicOrdering ordering, - std::optional alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { return true; } bool TestConstMemorySpaceAttr::isValidStore( Type type, mlir::ptr::AtomicOrdering ordering, - std::optional alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { if (emitError) emitError() << "memory space is read-only"; @@ -402,7 +402,7 @@ bool TestConstMemorySpaceAttr::isValidStore( bool TestConstMemorySpaceAttr::isValidAtomicOp( mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, - std::optional alignment, + std::optional alignment, const ::mlir::DataLayout *dataLayout, function_ref emitError) const { if (emitError) emitError() << "memory space is read-only"; @@ -412,6 +412,7 @@ bool TestConstMemorySpaceAttr::isValidAtomicOp( bool TestConstMemorySpaceAttr::isValidAtomicXchg( Type type, mlir::ptr::AtomicOrdering successOrdering, mlir::ptr::AtomicOrdering failureOrdering, std::optional alignment, + const ::mlir::DataLayout *dataLayout, function_ref emitError) const { if (emitError) emitError() << "memory space is read-only";