diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td index 3be84ae654f6a..20dd45272898d 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,7 +19,14 @@ def MemRef_Dialect : Dialect { manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = [ + // `arith` is a dependency because it is used to materialize constants, + // and in some canonicalization patterns. + "arith::ArithDialect", + // `ub` is a dependency because `AllocaOp::getDefaultValue` can produce a + // `ub.poison` value. + "ub::UBDialect" + ]; let hasConstantMaterializer = 1; } diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7aceea79..d358362f1984b 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df258c79..a1e3f10a871c1 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e0376ed..540423831937e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir index d300699f6f342..dd68675cc4441 100644 --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -18,7 +18,7 @@ func.func @basic() -> i32 { // CHECK-LABEL: func.func @basic_default func.func @basic_default() -> i32 { // CHECK-NOT: = memref.alloca - // CHECK: %[[RES:.*]] = arith.constant 0 : i32 + // CHECK: %[[RES:.*]] = ub.poison : i32 // CHECK-NOT: = memref.alloca %0 = arith.constant 5 : i32 %1 = memref.alloca() : memref