diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 2cf521ea7ff37..5e39bfa0e5c2b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -122,6 +122,9 @@ /mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @banach-space @dcaballe @MaheshRavishankar @nicolasvasilache /mlir/**/*EmulateNarrowType* @dcaballe +# Polygeist dialect in MLIR +/mlir/**/*Polygeist* @wsmoses @ftynse @ivanradanov @chelini + # Presburger library in MLIR /mlir/**/*Presburger* @Groverkss @Superty diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index d2505877e2dd0..1750bcfe1da54 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -27,6 +27,7 @@ add_subdirectory(OpenACCMPCommon) add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) +add_subdirectory(Polygeist) add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/include/mlir/Dialect/Polygeist/CMakeLists.txt b/mlir/include/mlir/Dialect/Polygeist/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/include/mlir/Dialect/Polygeist/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt new file mode 100644 index 0000000000000..d21dc0f83acd1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS PolygeistOps.td) +add_mlir_dialect(PolygeistOps polygeist) +add_mlir_doc(PolygeistOps PolygeistOps Dialects/ -gen-dialect-doc) diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h b/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h new file mode 100644 index 0000000000000..edc92fc9e2c47 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h @@ -0,0 +1,15 @@ +#ifndef MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_ +#define MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +#include "mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Polygeist/IR/PolygeistOps.h.inc" + +#endif // MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_ diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td new file mode 100644 index 0000000000000..a324a60634ad5 --- /dev/null +++ b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td @@ -0,0 +1,15 @@ +#ifndef POLYGEIST_BASE +#define POLYGEIST_BASE + +include "mlir/IR/OpBase.td" + +def Polygeist_Dialect : Dialect { + let name = "polygeist"; + let cppNamespace = "::mlir::polygeist"; + let summary = "The Polygeist dialect."; + let description = [{ + The Polygeist dialect contains operations for raising low-level code to higher-level forms, and performing parallel and device transformations (including polyhedral). + }]; +} + +#endif // POLYGEIST_BASE diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td new file mode 100644 index 0000000000000..a872d0516705f --- /dev/null +++ b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td @@ -0,0 +1,40 @@ +#ifndef POLYGEIST_OPS +#define POLYGEIST_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/Polygeist/IR/PolygeistBase.td" + +def Memref2PointerOp + : Op { + let summary = "Extract an LLVM pointer from a MemRef"; + + let arguments = (ins AnyMemRef:$source); + let results = (outs LLVM_AnyPointer:$result); + + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; +} + +def Pointer2MemrefOp + : Op { + let summary = "Upgrade a pointer to a memref"; + + let arguments = (ins LLVM_AnyPointer:$source); + let results = (outs AnyMemRef:$result); + + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + ::mlir::Value getViewSource() { return getSource(); } + }]; +} + +#endif // POLYGEIST_OPS diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 66f68c369f81f..ee2838d8aaa70 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -27,6 +27,7 @@ add_subdirectory(OpenACCMPCommon) add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) +add_subdirectory(Polygeist) add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) diff --git a/mlir/lib/Dialect/Polygeist/CMakeLists.txt b/mlir/lib/Dialect/Polygeist/CMakeLists.txt new file mode 100644 index 0000000000000..f33061b2d87cf --- /dev/null +++ b/mlir/lib/Dialect/Polygeist/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt b/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt new file mode 100644 index 0000000000000..647c83a18e6e3 --- /dev/null +++ b/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRPolygeistDialect + PolygeistOps.cpp + PolygeistDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polygeist + + DEPENDS + MLIRPolygeistOpsIncGen + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + MLIRMemRefDialect + MLIRLLVMDialect + MLIRArithDialect + MLIRAffineDialect + MLIRSCFDialect + ) diff --git a/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp b/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp new file mode 100644 index 0000000000000..a50fe4e2aec01 --- /dev/null +++ b/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp @@ -0,0 +1,15 @@ +#include "mlir/Dialect/Polygeist/IR/Polygeist.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace mlir; +using namespace mlir::polygeist; + +#include "mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.cpp.inc" + +void PolygeistDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Polygeist/IR/PolygeistOps.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp new file mode 100644 index 0000000000000..137ae4b65aa5d --- /dev/null +++ b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp @@ -0,0 +1,379 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Polygeist/IR/Polygeist.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +using namespace mlir; +using namespace mlir::polygeist; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Polygeist/IR/PolygeistOps.cpp.inc" + +namespace { +/// Simplify pointer2memref(memref2pointer(x)) to cast(x) +class Memref2Pointer2MemrefCast final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Pointer2MemrefOp op, + PatternRewriter &rewriter) const override { + auto src = op.getSource().getDefiningOp(); + if (!src) + return failure(); + auto smt = cast(src.getSource().getType()); + auto omt = cast(op.getType()); + if (smt.getShape().size() != omt.getShape().size()) + return failure(); + for (unsigned i = 1; i < smt.getShape().size(); i++) { + if (smt.getShape()[i] != omt.getShape()[i]) + return failure(); + } + if (smt.getElementType() != omt.getElementType()) + return failure(); + if (smt.getMemorySpace() != omt.getMemorySpace()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + src.getSource()); + return success(); + } +}; + +/// Simplify memref2pointer(pointer2memref(x)) to cast(x) +class Memref2PointerBitCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::BitcastOp op, + PatternRewriter &rewriter) const override { + auto src = op.getOperand().getDefiningOp(); + if (!src) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + src.getSource()); + return success(); + } +}; + +/// Simplify load(pointer2memref(gep(...(x)))) to load(x, idx) +template +class LoadStorePointer2MemrefGEP final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + SmallVector newIndex(T op, Value finalIndex, + PatternRewriter &rewriter) const; + + void createNewOp(T op, Value baseMemref, SmallVector vals, + PatternRewriter &rewriter) const; + + Value getMemref(T op) const; + + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + if (op.getMemRefType().getRank() != 1) + return failure(); + + auto src = getMemref(op).template getDefiningOp(); + if (!src) + return failure(); + + Type elementType = op.getMemRefType().getElementType(); + unsigned elementSize = elementType.isIntOrFloat() + ? elementType.getIntOrFloatBitWidth() / 8 + : 0; + if (elementSize == 0) + return failure(); + + SmallVector> gepOps; + Value ptr = src.getSource(); + + while (auto gep = ptr.getDefiningOp()) { + if (gep.getIndices().size() != 1) + break; + + unsigned gepElemSize = 1; + auto elemTy = gep.getElemType(); + if (elemTy.isIntOrFloat()) { + gepElemSize = elemTy.getIntOrFloatBitWidth() / 8; + } else if (auto arrayTy = dyn_cast(elemTy)) { + auto baseTy = arrayTy.getElementType(); + if (baseTy.isIntOrFloat()) { + gepElemSize = + (baseTy.getIntOrFloatBitWidth() / 8) * arrayTy.getNumElements(); + } else { + break; + } + } else { + break; + } + + gepOps.emplace_back(gep, gepElemSize); + ptr = gep.getBase(); + } + + if (gepOps.empty()) + return failure(); + + Location loc = op.getLoc(); + auto baseMemref = Pointer2MemrefOp::create( + rewriter, loc, cast(src.getType()), ptr); + + Value finalIndex = nullptr; + for (auto [gep, gepElemSize] : llvm::reverse(gepOps)) { + PointerUnion rawIdx = gep.getIndices()[0]; + Value idx = dyn_cast_if_present(rawIdx); + if (!idx) + idx = arith::ConstantIndexOp::create( + rewriter, loc, cast(rawIdx).getValue().getSExtValue()); + + if (auto constIdx = idx.getDefiningOp()) { + if ((constIdx.value() * gepElemSize) % elementSize != 0) { + return failure(); + } + } + + if (!idx.getType().isIndex()) { + idx = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + idx); + } + + unsigned gcd = std::gcd(gepElemSize, elementSize); + unsigned scaledGep = gepElemSize / gcd; + unsigned scaledElement = elementSize / gcd; + + Value scaledIdx = + (scaledGep != 1) + ? arith::MulIOp::create( + rewriter, loc, idx, + arith::ConstantIndexOp::create(rewriter, loc, scaledGep)) + : idx; + + Value elemOffset = + (scaledElement != 1) + ? arith::DivSIOp::create( + rewriter, loc, scaledIdx, + arith::ConstantIndexOp::create(rewriter, loc, scaledElement)) + : scaledIdx; + + if (finalIndex) + finalIndex = + arith::AddIOp::create(rewriter, loc, finalIndex, elemOffset); + else + finalIndex = elemOffset; + } + + createNewOp(op, baseMemref, newIndex(op, finalIndex, rewriter), rewriter); + return success(); + } +}; + +template <> +Value LoadStorePointer2MemrefGEP::getMemref( + memref::LoadOp op) const { + return op.getMemref(); +} + +template <> +Value LoadStorePointer2MemrefGEP::getMemref( + memref::StoreOp op) const { + return op.getMemref(); +} + +template <> +Value LoadStorePointer2MemrefGEP::getMemref( + affine::AffineLoadOp op) const { + return op.getMemref(); +} + +template <> +Value LoadStorePointer2MemrefGEP::getMemref( + affine::AffineStoreOp op) const { + return op.getMemref(); +} + +template <> +SmallVector LoadStorePointer2MemrefGEP::newIndex( + memref::LoadOp op, Value finalIndex, PatternRewriter &rewriter) const { + auto operands = llvm::to_vector(op.getIndices()); + operands[0] = + arith::AddIOp::create(rewriter, op.getLoc(), operands[0], finalIndex); + return operands; +} + +template <> +SmallVector LoadStorePointer2MemrefGEP::newIndex( + affine::AffineLoadOp op, Value finalIndex, + PatternRewriter &rewriter) const { + auto apply = affine::AffineApplyOp::create( + rewriter, op.getLoc(), op.getAffineMap(), op.getMapOperands()); + + SmallVector operands; + for (auto op : apply->getResults()) + operands.push_back(op); + operands[0] = + arith::AddIOp::create(rewriter, op.getLoc(), operands[0], finalIndex); + return operands; +} + +template <> +SmallVector LoadStorePointer2MemrefGEP::newIndex( + memref::StoreOp op, Value finalIndex, PatternRewriter &rewriter) const { + auto operands = llvm::to_vector(op.getIndices()); + operands[0] = + arith::AddIOp::create(rewriter, op.getLoc(), operands[0], finalIndex); + return operands; +} + +template <> +SmallVector LoadStorePointer2MemrefGEP::newIndex( + affine::AffineStoreOp op, Value finalIndex, + PatternRewriter &rewriter) const { + auto apply = affine::AffineApplyOp::create( + rewriter, op.getLoc(), op.getAffineMap(), op.getMapOperands()); + + SmallVector operands; + for (auto op : apply->getResults()) + operands.push_back(op); + operands[0] = + arith::AddIOp::create(rewriter, op.getLoc(), operands[0], finalIndex); + return operands; +} + +template <> +void LoadStorePointer2MemrefGEP::createNewOp( + memref::LoadOp op, Value baseMemref, SmallVector idxs, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, baseMemref, idxs); +} + +template <> +void LoadStorePointer2MemrefGEP::createNewOp( + affine::AffineLoadOp op, Value baseMemref, SmallVector idxs, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, baseMemref, idxs); +} + +template <> +void LoadStorePointer2MemrefGEP::createNewOp( + memref::StoreOp op, Value baseMemref, SmallVector idxs, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, op.getValue(), baseMemref, + idxs); +} + +template <> +void LoadStorePointer2MemrefGEP::createNewOp( + affine::AffineStoreOp op, Value baseMemref, SmallVector idxs, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, op.getValue(), baseMemref, + idxs); +} + +/// Simplify cast(pointer2memref(x)) to pointer2memref(x) +class Pointer2MemrefCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CastOp op, + PatternRewriter &rewriter) const override { + auto src = op.getSource().getDefiningOp(); + if (!src) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + src.getSource()); + return success(); + } +}; + +/// Simplify memref2pointer(pointer2memref(x)) to cast(x) +class Pointer2Memref2PointerCast final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Memref2PointerOp op, + PatternRewriter &rewriter) const override { + auto src = op.getSource().getDefiningOp(); + if (!src) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + src.getSource()); + return success(); + } +}; + +} // namespace + +void Memref2PointerOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult Memref2PointerOp::fold(FoldAdaptor adaptor) { + /// Simplify memref2pointer(cast(x)) to memref2pointer(x) + if (auto mc = getSource().getDefiningOp()) { + getSourceMutable().assign(mc.getSource()); + return getResult(); + } + if (auto mc = getSource().getDefiningOp()) { + if (mc.getSource().getType() == getType()) { + return mc.getSource(); + } + } + return nullptr; +} + +void Pointer2MemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert, + LoadStorePointer2MemrefGEP, + LoadStorePointer2MemrefGEP, + LoadStorePointer2MemrefGEP>(context); +} + +OpFoldResult Pointer2MemrefOp::fold(FoldAdaptor adaptor) { + /// Simplify pointer2memref(cast(x)) to pointer2memref(x) + if (auto mc = getSource().getDefiningOp()) { + getSourceMutable().assign(mc.getOperand()); + return getResult(); + } + if (auto mc = getSource().getDefiningOp()) { + getSourceMutable().assign(mc.getOperand()); + return getResult(); + } + if (auto mc = getSource().getDefiningOp()) { + for (auto idx : mc.getDynamicIndices()) { + assert(idx); + if (!matchPattern(idx, m_Zero())) + return nullptr; + } + auto staticIndices = mc.getRawConstantIndices(); + for (auto pair : llvm::enumerate(staticIndices)) { + if (pair.value() != LLVM::GEPOp::kDynamicIndex) + if (pair.value() != 0) + return nullptr; + } + + getSourceMutable().assign(mc.getBase()); + return getResult(); + } + if (auto mc = getSource().getDefiningOp()) { + if (mc.getSource().getType() == getType()) { + return mc.getSource(); + } + } + return nullptr; +} diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp index 2f55296f424cd..456ea0610940d 100644 --- a/mlir/lib/RegisterAllDialects.cpp +++ b/mlir/lib/RegisterAllDialects.cpp @@ -66,6 +66,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Polygeist/IR/Polygeist.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -140,6 +141,7 @@ void mlir::registerAllDialects(DialectRegistry ®istry) { omp::OpenMPDialect, pdl::PDLDialect, pdl_interp::PDLInterpDialect, + polygeist::PolygeistDialect, ptr::PtrDialect, quant::QuantDialect, ROCDL::ROCDLDialect, diff --git a/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir b/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir new file mode 100644 index 0000000000000..9e86d78341879 --- /dev/null +++ b/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt --canonicalize -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @fold_memref2pointer_pointer2memref( +// CHECK-SAME: %[[PTR:.*]]: !llvm.ptr +func.func @fold_memref2pointer_pointer2memref(%ptr: !llvm.ptr) -> !llvm.ptr { + // CHECK-NOT: polygeist.pointer2memref + // CHECK-NOT: polygeist.memref2pointer + // CHECK: return %[[PTR]] + %memref = "polygeist.pointer2memref"(%ptr) : (!llvm.ptr) -> memref + %ptr2 = "polygeist.memref2pointer"(%memref) : (memref) -> !llvm.ptr + func.return %ptr2 : !llvm.ptr +} + +// ----- + +// CHECK-LABEL: func @fold_memref2pointer_cast( +// CHECK-SAME: %[[MEMREF:.*]]: memref<10xf64> +func.func @fold_memref2pointer_cast(%memref: memref<10xf64>) -> !llvm.ptr { + // CHECK: %[[RES:.*]] = "polygeist.memref2pointer"(%[[MEMREF]]) + // CHECK: return %[[RES]] + %cast = memref.cast %memref : memref<10xf64> to memref + %ptr = "polygeist.memref2pointer"(%cast) : (memref) -> !llvm.ptr + func.return %ptr : !llvm.ptr +} diff --git a/mlir/test/Dialect/Polygeist/canonicalize-pointer2memref.mlir b/mlir/test/Dialect/Polygeist/canonicalize-pointer2memref.mlir new file mode 100644 index 0000000000000..b9ad57199077b --- /dev/null +++ b/mlir/test/Dialect/Polygeist/canonicalize-pointer2memref.mlir @@ -0,0 +1,226 @@ +// RUN: mlir-opt --canonicalize -split-input-file %s | FileCheck %s + +llvm.mlir.global internal unnamed_addr constant @foo(dense<[1.5903078570611027E-10, -2.5050911383645487E-8, 2.7557314984630029E-6, -1.984126983447703E-4, 0.0083333333333293485, -0.16666666666666663, 0.000000e+00, 0.000000e+00, -1.1367817304626284E-11, 2.08758833785978E-9, -2.7557315542999557E-7, 2.4801587293618683E-5, -0.0013888888888880667, 0.041666666666666637, -5.000000e-01, 1.000000e+00]> : tensor<16xf64>) {addr_space = 1 : i32, alignment = 8 : i64, dso_local} : !llvm.array<16 x f64> + +// CHECK-LABEL: func @constant_load_idx( +// CHECK-SAME: %[[IDX:.*]]: i64 +func.func @constant_load_idx(%idx: i64) -> f64 { + // CHECK: %[[C9:.*]] = arith.constant 9 + %c8 = arith.constant 8 : index + + // CHECK: %[[BASE:.*]] = llvm.mlir.addressof + %ptr = llvm.mlir.addressof @foo : !llvm.ptr<1> + + // CHECK-NOT: llvm.getelementptr + %ptr_i8 = llvm.getelementptr inbounds %ptr[%idx] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f64 + %ptr_offset = llvm.getelementptr inbounds %ptr_i8[8] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[BASE]]) + %memref = "polygeist.pointer2memref"(%ptr_offset) : (!llvm.ptr<1>) -> memref + + // CHECK: %[[IDX_CAST:.*]] = arith.index_cast %[[IDX]] + // CHECK: %[[OFFSET:.*]] = arith.addi %[[IDX_CAST]], %[[C9]] : index + // CHECK: memref.load %[[MEMREF]][%[[OFFSET]]] : memref + %val = memref.load %memref[%c8] : memref + + func.return %val : f64 +} + +// ----- + +// CHECK-LABEL: func @dynamic_load_idx( +// CHECK-SAME: %[[IDX1:.*]]: i64 +// CHECK-SAME: %[[IDX2:.*]]: i32 +// CHECK-SAME: %[[IDX3:.*]]: i8 +// CHECK-SAME: %[[LOAD_IDX:.*]]: index +func.func @dynamic_load_idx(%idx1: i64, %idx2: i32, %idx3: i8, %load_idx: index) -> f16 { + // CHECK: %[[C2:.*]] = arith.constant 2 + %c16 = llvm.mlir.constant(16: i32) : i32 + + // CHECK: %[[BASE:.*]] = llvm.alloca + %ptr = llvm.alloca %c16 x f16 {alignment = 4 : i64} : (i32) -> !llvm.ptr + + // CHECK-NOT: llvm.getelementptr + %ptr_f32 = llvm.getelementptr %ptr[%idx1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %ptr_i16 = llvm.getelementptr %ptr_f32[%idx2] : (!llvm.ptr, i32) -> !llvm.ptr, i16 + %ptr_i8 = llvm.getelementptr inbounds %ptr_i16[%idx3] : (!llvm.ptr, i8) -> !llvm.ptr, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[BASE]]) + %memref = "polygeist.pointer2memref"(%ptr_i8) : (!llvm.ptr) -> memref + + // CHECK: %[[IDX_CAST1:.*]] = arith.index_cast %[[IDX1]] + // CHECK: %[[SCALED1:.*]] = arith.muli %[[IDX_CAST1]], %[[C2]] : index + + // CHECK: %[[IDX_CAST2:.*]] = arith.index_cast %[[IDX2]] + // CHECK: %[[OFFSET2:.*]] = arith.addi %[[SCALED1]], %[[IDX_CAST2]] : index + + // CHECK: %[[IDX_CAST3:.*]] = arith.index_cast %[[IDX3]] + // CHECK: %[[SCALED3:.*]] = arith.divsi %[[IDX_CAST3]], %[[C2]] : index + // CHECK: %[[OFFSET3:.*]] = arith.addi %[[OFFSET2]], %[[SCALED3]] : index + + // CHECK: %[[EIDX:.*]] = arith.addi %[[LOAD_IDX]], %[[OFFSET3]] : index + + // CHECK: memref.load %[[MEMREF]][%[[EIDX]]] : memref + %val = memref.load %memref[%load_idx] : memref + + func.return %val : f16 +} + +// ----- + +// CHECK-LABEL: func @reject_unaligned_gep_cst_idx( +// CHECK-SAME: %[[LOAD_IDX:.*]]: index +func.func @reject_unaligned_gep_cst_idx(%load_idx: index) -> f16 { + %c1 = arith.constant 1 : i32 + %c16 = llvm.mlir.constant(16: i32) : i32 + + %ptr = llvm.alloca %c16 x f16 {alignment = 4 : i64} : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr + %ptr_i8 = llvm.getelementptr %ptr[%c1] : (!llvm.ptr, i32) -> !llvm.ptr, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[GEP]]) + %memref = "polygeist.pointer2memref"(%ptr_i8) : (!llvm.ptr) -> memref + + // CHECK: memref.load %[[MEMREF]][%[[LOAD_IDX]]] : memref + %val = memref.load %memref[%load_idx] : memref + + func.return %val : f16 +} + +// ----- + +// CHECK-LABEL: func @reject_unaligned_gep_cst_scalar( +// CHECK-SAME: %[[LOAD_IDX:.*]]: index +func.func @reject_unaligned_gep_cst_scalar(%load_idx: index) -> f16 { + %c16 = llvm.mlir.constant(16: i32) : i32 + + %ptr = llvm.alloca %c16 x f16 {alignment = 4 : i64} : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr + %ptr_i8 = llvm.getelementptr %ptr[1] : (!llvm.ptr) -> !llvm.ptr, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[GEP]]) + %memref = "polygeist.pointer2memref"(%ptr_i8) : (!llvm.ptr) -> memref + + // CHECK: memref.load %[[MEMREF]][%[[LOAD_IDX]]] : memref + %val = memref.load %memref[%load_idx] : memref + + func.return %val : f16 +} + +// ----- + +// CHECK-LABEL: func @array_load_idx( +// CHECK-SAME: %[[IDX:.*]]: i64 +func.func @array_load_idx(%idx: i64) -> f64 { + %c0 = arith.constant 0 : index + %c16 = llvm.mlir.constant(16: i32) : i32 + + %ptr = llvm.alloca %c16 x !llvm.array<8 x i8> {alignment = 8 : i64} : (i32) -> !llvm.ptr + + // CHECK-NOT: llvm.getelementptr + %ptr_array = llvm.getelementptr inbounds %ptr[%idx] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<8 x i8> + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref" + %memref = "polygeist.pointer2memref"(%ptr_array) : (!llvm.ptr) -> memref + + // CHECK: %[[IDX_CAST:.*]] = arith.index_cast %[[IDX]] + // CHECK: memref.load %[[MEMREF]][%[[IDX_CAST]]] : memref + %val = memref.load %memref[%c0] : memref + + func.return %val : f64 +} + +// ----- + +llvm.mlir.global internal unnamed_addr constant @foo(dense<[1.5903078570611027E-10, -2.5050911383645487E-8, 2.7557314984630029E-6, -1.984126983447703E-4, 0.0083333333333293485, -0.16666666666666663, 0.000000e+00, 0.000000e+00, -1.1367817304626284E-11, 2.08758833785978E-9, -2.7557315542999557E-7, 2.4801587293618683E-5, -0.0013888888888880667, 0.041666666666666637, -5.000000e-01, 1.000000e+00]> : tensor<16xf64>) {addr_space = 1 : i32, alignment = 8 : i64, dso_local} : !llvm.array<16 x f64> + +// CHECK-LABEL: func @constant_store_idx( +// CHECK-SAME: %[[IDX:.*]]: i64, %[[VAL:.*]]: f64 +func.func @constant_store_idx(%idx: i64, %val: f64) { + // CHECK: %[[C9:.*]] = arith.constant 9 + %c8 = arith.constant 8 : index + + // CHECK: %[[BASE:.*]] = llvm.mlir.addressof + %ptr = llvm.mlir.addressof @foo : !llvm.ptr<1> + + // CHECK-NOT: llvm.getelementptr + %ptr_i8 = llvm.getelementptr inbounds %ptr[%idx] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f64 + %ptr_offset = llvm.getelementptr inbounds %ptr_i8[8] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[BASE]]) + %memref = "polygeist.pointer2memref"(%ptr_offset) : (!llvm.ptr<1>) -> memref + + // CHECK: %[[IDX_CAST:.*]] = arith.index_cast %[[IDX]] + // CHECK: %[[OFFSET:.*]] = arith.addi %[[IDX_CAST]], %[[C9]] : index + // CHECK: memref.store %[[VAL]], %[[MEMREF]][%[[OFFSET]]] : memref + memref.store %val, %memref[%c8] : memref + + func.return +} + +// ----- + +// CHECK-LABEL: func @dynamic_store_idx( +// CHECK-SAME: %[[IDX1:.*]]: i64 +// CHECK-SAME: %[[IDX2:.*]]: i32 +// CHECK-SAME: %[[IDX3:.*]]: i8 +// CHECK-SAME: %[[LOAD_IDX:.*]]: index +// CHECK-SAME: %[[VAL:.*]]: f16 +func.func @dynamic_store_idx(%idx1: i64, %idx2: i32, %idx3: i8, %load_idx: index, %val: f16) { + // CHECK: %[[C2:.*]] = arith.constant 2 + %c16 = llvm.mlir.constant(16: i32) : i32 + + // CHECK: %[[BASE:.*]] = llvm.alloca + %ptr = llvm.alloca %c16 x f16 {alignment = 4 : i64} : (i32) -> !llvm.ptr + + // CHECK-NOT: llvm.getelementptr + %ptr_f32 = llvm.getelementptr %ptr[%idx1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %ptr_i16 = llvm.getelementptr %ptr_f32[%idx2] : (!llvm.ptr, i32) -> !llvm.ptr, i16 + %ptr_i8 = llvm.getelementptr inbounds %ptr_i16[%idx3] : (!llvm.ptr, i8) -> !llvm.ptr, i8 + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref"(%[[BASE]]) + %memref = "polygeist.pointer2memref"(%ptr_i8) : (!llvm.ptr) -> memref + + // CHECK: %[[IDX_CAST1:.*]] = arith.index_cast %[[IDX1]] + // CHECK: %[[SCALED1:.*]] = arith.muli %[[IDX_CAST1]], %[[C2]] : index + + // CHECK: %[[IDX_CAST2:.*]] = arith.index_cast %[[IDX2]] + // CHECK: %[[OFFSET2:.*]] = arith.addi %[[SCALED1]], %[[IDX_CAST2]] : index + + // CHECK: %[[IDX_CAST3:.*]] = arith.index_cast %[[IDX3]] + // CHECK: %[[SCALED3:.*]] = arith.divsi %[[IDX_CAST3]], %[[C2]] : index + // CHECK: %[[OFFSET3:.*]] = arith.addi %[[OFFSET2]], %[[SCALED3]] : index + + // CHECK: %[[EIDX:.*]] = arith.addi %[[LOAD_IDX]], %[[OFFSET3]] : index + + // CHECK: memref.store %[[VAL]], %[[MEMREF]][%[[EIDX]]] : memref + memref.store %val, %memref[%load_idx] : memref + + func.return +} + +// ----- + +// CHECK-LABEL: func @array_store_idx( +// CHECK-SAME: %[[IDX:.*]]: i64 +// CHECK-SAME: %[[VAL:.*]]: f64 +func.func @array_store_idx(%idx: i64, %val: f64) { + %c0 = arith.constant 0 : index + %c16 = llvm.mlir.constant(16: i32) : i32 + + %ptr = llvm.alloca %c16 x !llvm.array<8 x i8> {alignment = 8 : i64} : (i32) -> !llvm.ptr + + // CHECK-NOT: llvm.getelementptr + %ptr_array = llvm.getelementptr inbounds %ptr[%idx] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<8 x i8> + + // CHECK: %[[MEMREF:.*]] = "polygeist.pointer2memref" + %memref = "polygeist.pointer2memref"(%ptr_array) : (!llvm.ptr) -> memref + + // CHECK: %[[IDX_CAST:.*]] = arith.index_cast %[[IDX]] + // CHECK: memref.store %[[VAL]], %[[MEMREF]][%[[IDX_CAST]]] : memref + memref.store %val, %memref[%c0] : memref + + func.return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index cf73b8d2c72da..d87bd208fed50 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6775,6 +6775,71 @@ gentbl_cc_library( deps = [":PDLInterpOpsTdFiles"], ) +td_library( + name = "PolygeistOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Polygeist/IR/PolygeistBase.td", + "include/mlir/Dialect/Polygeist/IR/PolygeistOps.td", + ], + includes = ["include"], + deps = [ + ":BuiltinDialectTdFiles", + ":LLVMOpsTdFiles", + ":OpBaseTdFiles", + ":SideEffectInterfacesTdFiles", + ":ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "PolygeistBaseIncGen", + tbl_outs = { + "include/mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=polygeist", + ], + "include/mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.cpp.inc": [ + "-gen-dialect-defs", + "-dialect=polygeist", + ], + }, + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Polygeist/IR/PolygeistBase.td", + deps = [":PolygeistOpsTdFiles"], +) + +gentbl_cc_library( + name = "PolygeistOpsIncGen", + tbl_outs = { + "include/mlir/Dialect/Polygeist/IR/PolygeistOps.h.inc": ["-gen-op-decls"], + "include/mlir/Dialect/Polygeist/IR/PolygeistOps.cpp.inc": ["-gen-op-defs"], + }, + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Polygeist/IR/PolygeistOps.td", + deps = [":PolygeistOpsTdFiles"], +) + +cc_library( + name = "PolygeistDialect", + srcs = glob(["lib/Dialect/Polygeist/IR/*.cpp"]), + hdrs = ["include/mlir/Dialect/Polygeist/IR/Polygeist.h"], + includes = ["include"], + deps = [ + ":AffineDialect", + ":ArithDialect", + ":IR", + ":LLVMDialect", + ":MemRefDialect", + ":PolygeistBaseIncGen", + ":PolygeistOpsIncGen", + ":SCFDialect", + ":SideEffectInterfaces", + ":Support", + ":ViewLikeInterface", + "//llvm:Support", + ], +) + td_library( name = "PtrTdFiles", srcs = [ @@ -9666,6 +9731,7 @@ cc_library( ":OpenMPDialect", ":PDLDialect", ":PDLInterpDialect", + ":PolygeistDialect", ":PtrDialect", ":QuantOps", ":ROCDLDialect",