diff --git a/mlir/include/mlir/Dialect/UB/IR/UBMatchers.h b/mlir/include/mlir/Dialect/UB/IR/UBMatchers.h new file mode 100644 index 0000000000000..9dc29ac37691c --- /dev/null +++ b/mlir/include/mlir/Dialect/UB/IR/UBMatchers.h @@ -0,0 +1,52 @@ +//===- UBMatchers.h - UB Dialect matchers -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides matchers for the UB dialect, in particular for matching +// poison values and attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UB_IR_UBMATCHERS_H +#define MLIR_DIALECT_UB_IR_UBMATCHERS_H + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Matchers.h" + +namespace mlir::ub { +namespace detail { + +/// Matches a poison attribute (any attribute implementing PoisonAttrInterface). +/// Supports matching against both Attribute and Operation* (via constant +/// folding). +struct poison_attr_matcher { + bool match(Attribute attr) { return isa(attr); } + + bool match(Operation *op) { + Attribute attr; + if (!::mlir::detail::constant_op_binder(&attr).match(op)) + return false; + return match(attr); + } +}; + +} // namespace detail + +/// Matches a poison constant (any attribute implementing PoisonAttrInterface). +/// Works with `matchPattern` on Value, Operation*, and Attribute. +/// +/// Examples: +/// matchPattern(value, ub::m_Poison()) // Matches ub.poison op via Value. +/// matchPattern(op, ub::m_Poison()) // Matches ub.poison op directly. +/// matchPattern(attr, ub::m_Poison()) // Matches PoisonAttr(Interface). +inline detail::poison_attr_matcher m_Poison() { + return detail::poison_attr_matcher(); +} + +} // namespace mlir::ub + +#endif // MLIR_DIALECT_UB_IR_UBMATCHERS_H diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index b99f77fdc8b30..73fa0df80bcba 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" -#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/UB/IR/UBMatchers.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" @@ -460,7 +460,7 @@ arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; })) { // If any operand is poison, propagate poison to both results. - if (isa(sumAttr)) { + if (matchPattern(sumAttr, ub::m_Poison())) { results.push_back(sumAttr); results.push_back(sumAttr); return success(); @@ -1919,7 +1919,7 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { return {}; /// Bitcast poison. - if (llvm::isa(operand)) + if (matchPattern(operand, ub::m_Poison())) return ub::PoisonAttr::get(getContext()); /// Bitcast integer or float to integer or float. @@ -2503,10 +2503,10 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { return falseVal; // If either operand is fully poisoned, return the other. - if (isa_and_nonnull(adaptor.getTrueValue())) + if (matchPattern(adaptor.getTrueValue(), ub::m_Poison())) return falseVal; - if (isa_and_nonnull(adaptor.getFalseValue())) + if (matchPattern(adaptor.getFalseValue(), ub::m_Poison())) return trueVal; // select %x, true, false => %x diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ae65afac6b54c..d84408c024e25 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -27,7 +27,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Utils/Utils.h" -#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/UB/IR/UBMatchers.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -2199,7 +2199,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, for (auto const &[untypedAttr, elementOrTensorType] : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { - if (isa(untypedAttr)) { + if (matchPattern(untypedAttr, ub::m_Poison())) { paddingValues.push_back(untypedAttr); continue; } @@ -2452,7 +2452,7 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, auto attr = dyn_cast(untypedAttr); Type elementType = getElementTypeOrSelf(elementOrTensorType); - if (isa(untypedAttr)) { + if (matchPattern(untypedAttr, ub::m_Poison())) { paddingValues.push_back(untypedAttr); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 52ab92f180575..ab2629b41a463 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -11,7 +11,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/UB/IR/UBMatchers.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" @@ -235,7 +235,7 @@ static Value padOperand(OpBuilder &builder, TilingInterface opToPad, paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(), complexTy, complexAttr); } - } else if (isa(paddingValueAttr)) { + } else if (matchPattern(paddingValueAttr, ub::m_Poison())) { paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(), getElementTypeOrSelf(v.getType())); } else if (auto typedAttr = dyn_cast(paddingValueAttr)) { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a3077d7313b93..7afcd7d3f88fb 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -20,7 +20,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/UB/IR/UBMatchers.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Utils/VerificationUtils.h" @@ -497,7 +497,7 @@ void VectorDialect::initialize() { Operation *VectorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (isa(value)) + if (matchPattern(value, ub::m_Poison())) return value.getDialect().materializeConstant(builder, value, type, loc); return arith::ConstantOp::materialize(builder, value, type, loc); @@ -2112,7 +2112,7 @@ static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, /// Fold a vector extract from is a poison source. static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) { - if (isa_and_nonnull(srcAttr)) + if (matchPattern(srcAttr, ub::m_Poison())) return srcAttr; return {}; @@ -2719,7 +2719,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef elements) { // Check for null or poison attributes before any processing. if (llvm::any_of(elements, [](Attribute attr) { - return !attr || isa(attr); + return !attr || matchPattern(attr, ub::m_Poison()); })) return {}; @@ -3168,7 +3168,7 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { } if (auto attr = llvm::dyn_cast(adaptor.getSource())) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); - if (llvm::dyn_cast(adaptor.getSource())) + if (matchPattern(adaptor.getSource(), ub::m_Poison())) return ub::PoisonAttr::get(getContext()); return {}; } @@ -3323,8 +3323,8 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { return {}; // Fold shuffle poison, poison -> poison. - bool isV1Poison = isa(v1Attr); - bool isV2Poison = isa(v2Attr); + bool isV1Poison = matchPattern(v1Attr, ub::m_Poison()); + bool isV2Poison = matchPattern(v2Attr, ub::m_Poison()); if (isV1Poison && isV2Poison) return ub::PoisonAttr::get(getContext()); @@ -4092,7 +4092,8 @@ class InsertStridedSliceConstantFolder final return failure(); // TODO: Support poison. - if (isa(vectorDestCst) || isa(sourceCst)) + if (matchPattern(vectorDestCst, ub::m_Poison()) || + matchPattern(sourceCst, ub::m_Poison())) return failure(); // TODO: Handle non-unit strides when they become available. @@ -6584,7 +6585,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return denseAttr.reshape(getType()); // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present(adaptor.getSource())) + if (matchPattern(adaptor.getSource(), ub::m_Poison())) return ub::PoisonAttr::get(getContext()); return {}; @@ -6940,7 +6941,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { return splat.reshape(getResultVectorType()); // Eliminate poison transpose ops. - if (llvm::dyn_cast_if_present(adaptor.getVector())) + if (matchPattern(adaptor.getVector(), ub::m_Poison())) return ub::PoisonAttr::get(getContext()); // Eliminate identity transposes, and more generally any transposes that diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir index 31f1b6de18c4e..589ea58c38470 100644 --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -52,3 +52,13 @@ func.func @test3(%a: f32) -> f32 { // CHECK-LABEL: test3 // CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched // CHECK: Pattern m_Attr("fastmath") matched and bound value to: fast + +func.func @test4(%a: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.constant 1.0 : f32 + %2 = arith.addf %a, %1 : f32 + return %2 : f32 +} + +// CHECK-LABEL: test4 +// CHECK: Pattern m_Poison() matched 1 times diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp index 60b962bbac091..fdd786f4c0d3a 100644 --- a/mlir/test/lib/IR/TestMatchers.cpp +++ b/mlir/test/lib/IR/TestMatchers.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBMatchers.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -132,7 +133,7 @@ static void test1(FunctionOpInterface f) { << countMatches(f, p17) << " times\n"; } -void test2(FunctionOpInterface f) { +static void test2(FunctionOpInterface f) { auto a = m_Val(f.getArgument(0)); FloatAttr floatAttr; auto p = @@ -148,7 +149,7 @@ void test2(FunctionOpInterface f) { llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; } -void test3(FunctionOpInterface f) { +static void test3(FunctionOpInterface f) { arith::FastMathFlagsAttr fastMathAttr; auto p = m_Op(m_Any(), m_Op(m_Any(), m_Op("test.name"))); @@ -163,6 +164,12 @@ void test3(FunctionOpInterface f) { << fastMathAttr.getValue() << "\n"; } +static void test4(FunctionOpInterface f) { + auto poison = ub::m_Poison(); + llvm::outs() << "Pattern m_Poison() matched " << countMatches(f, poison) + << " times\n"; +} + void TestMatchers::runOnOperation() { auto f = getOperation(); llvm::outs() << f.getName() << "\n"; @@ -172,6 +179,8 @@ void TestMatchers::runOnOperation() { test2(f); if (f.getName() == "test3") test3(f); + if (f.getName() == "test4") + test4(f); } namespace mlir {