3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Func/IR/FuncOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ LogicalResult ConstantOp::verify() {
return success();
}

OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1286,12 +1286,12 @@ LogicalResult SubgroupMmaComputeOp::verify() {
return success();
}

LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
return memref::foldMemRefCast(*this);
}

LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<::mlir::OpFoldResult> &results) {
return memref::foldMemRefCast(*this);
}
Expand Down
123 changes: 67 additions & 56 deletions mlir/lib/Dialect/Index/IR/IndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,40 @@ static OpFoldResult foldBinaryOpChecked(
// AddOp
//===----------------------------------------------------------------------===//

OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
}

//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//

OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
}

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
}

//===----------------------------------------------------------------------===//
// DivSOp
//===----------------------------------------------------------------------===//

OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return std::nullopt;
Expand All @@ -156,9 +160,10 @@ OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
// DivUOp
//===----------------------------------------------------------------------===//

OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return std::nullopt;
Expand Down Expand Up @@ -193,18 +198,19 @@ static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
return (n + x).sdiv(m) + 1;
}

OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateCeilDivS);
OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
}

//===----------------------------------------------------------------------===//
// CeilDivUOp
//===----------------------------------------------------------------------===//

OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
return foldBinaryOpChecked(
operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &n, const APInt &m) -> Optional<APInt> {
// Don't fold division by zero.
if (m.isZero())
return std::nullopt;
Expand Down Expand Up @@ -242,56 +248,58 @@ static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
return -1 - (x - n).sdiv(m);
}

OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateFloorDivS);
OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
}

//===----------------------------------------------------------------------===//
// RemSOp
//===----------------------------------------------------------------------===//

OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.srem(rhs);
});
OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
}

//===----------------------------------------------------------------------===//
// RemUOp
//===----------------------------------------------------------------------===//

OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.urem(rhs);
});
OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
}

//===----------------------------------------------------------------------===//
// MaxSOp
//===----------------------------------------------------------------------===//

OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.sgt(rhs) ? lhs : rhs;
});
OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) {
return lhs.sgt(rhs) ? lhs : rhs;
});
}

//===----------------------------------------------------------------------===//
// MaxUOp
//===----------------------------------------------------------------------===//

OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.ugt(rhs) ? lhs : rhs;
});
OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) {
return lhs.ugt(rhs) ? lhs : rhs;
});
}

//===----------------------------------------------------------------------===//
// MinSOp
//===----------------------------------------------------------------------===//

OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
return lhs.slt(rhs) ? lhs : rhs;
});
}
Expand All @@ -300,8 +308,8 @@ OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
// MinUOp
//===----------------------------------------------------------------------===//

OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
return lhs.ult(rhs) ? lhs : rhs;
});
}
Expand All @@ -310,9 +318,10 @@ OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
// ShlOp
//===----------------------------------------------------------------------===//

OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// We cannot fold if the RHS is greater than or equal to 32 because
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
// already treated as unsigned.
Expand All @@ -326,9 +335,10 @@ OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
// ShrSOp
//===----------------------------------------------------------------------===//

OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold if RHS is greater than or equal to 32.
if (rhs.uge(32))
return {};
Expand All @@ -340,9 +350,10 @@ OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
// ShrUOp
//===----------------------------------------------------------------------===//

OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold if RHS is greater than or equal to 32.
if (rhs.uge(32))
return {};
Expand All @@ -354,27 +365,30 @@ OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
// AndOp
//===----------------------------------------------------------------------===//

OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}

//===----------------------------------------------------------------------===//
// OrOp
//===----------------------------------------------------------------------===//

OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}

//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//

OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
adaptor.getOperands(),
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -425,10 +439,9 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
llvm_unreachable("unhandled IndexCmpPredicate predicate");
}

OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "compare expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!lhs || !rhs)
return {};

Expand All @@ -453,9 +466,7 @@ void ConstantOp::getAsmResultNames(
setNameFn(getResult(), specialName.str());
}

OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
build(b, state, b.getIndexType(), b.getIndexAttr(value));
Expand All @@ -465,7 +476,7 @@ void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
// BoolConstantOp
//===----------------------------------------------------------------------===//

OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ static Type getInsertExtractValueElementType(Type llvmType,
return llvmType;
}

OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
OpFoldResult result = {};
while (insertValueOp) {
Expand Down Expand Up @@ -2275,7 +2275,7 @@ LogicalResult LLVM::ConstantOp::verify() {
}

// Constant op constant-folds to its value.
OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// Utility functions for parsing atomic ops
Expand Down Expand Up @@ -2513,7 +2513,7 @@ LogicalResult FenceOp::verify() {
// Folder for LLVM::BitcastOp
//===----------------------------------------------------------------------===//

OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
// bitcast(x : T0, T0) -> x
if (getArg().getType() == getType())
return getArg();
Expand All @@ -2528,7 +2528,7 @@ OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
// Folder for LLVM::AddrSpaceCastOp
//===----------------------------------------------------------------------===//

OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
// addrcast(x : T0, T0) -> x
if (getArg().getType() == getType())
return getArg();
Expand All @@ -2543,9 +2543,9 @@ OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
// Folder for LLVM::GEPOp
//===----------------------------------------------------------------------===//

OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
operands.drop_front());
adaptor.getDynamicIndices());

// gep %x:T, 0 -> %x
if (getBase().getType() == getType() && indices.size() == 1)
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,7 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<EraseIdentityGenericOp>(context);
}

LogicalResult GenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}

Expand Down
42 changes: 22 additions & 20 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return false;
}

OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}

Expand Down Expand Up @@ -883,7 +883,7 @@ void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldCopyOfCast, FoldSelfCopy>(context);
}

LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// copy(memrefcast) -> copy
bool folded = false;
Expand All @@ -902,7 +902,7 @@ LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
// DeallocOp
//===----------------------------------------------------------------------===//

LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dealloc(memrefcast) -> dealloc
return foldMemRefCast(*this);
Expand Down Expand Up @@ -1056,9 +1056,9 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
return *unusedDims;
}

OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
if (!index)
return {};

Expand Down Expand Up @@ -1322,7 +1322,7 @@ LogicalResult DmaStartOp::verify() {
return success();
}

LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
return foldMemRefCast(*this);
Expand All @@ -1332,7 +1332,7 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
// DmaWaitOp
// ---------------------------------------------------------------------------

LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
return foldMemRefCast(*this);
Expand Down Expand Up @@ -1433,7 +1433,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
}

LogicalResult
ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
OpBuilder builder(*this);

Expand Down Expand Up @@ -1677,7 +1677,7 @@ LogicalResult LoadOp::verify() {
return success();
}

OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
/// load(memrefcast) -> load
if (succeeded(foldMemRefCast(*this)))
return getResult();
Expand Down Expand Up @@ -1747,7 +1747,7 @@ LogicalResult PrefetchOp::verify() {
return success();
}

LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// prefetch(memrefcast) -> prefetch
return foldMemRefCast(*this);
Expand All @@ -1757,7 +1757,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
// RankOp
//===----------------------------------------------------------------------===//

OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
auto shapedType = type.dyn_cast<ShapedType>();
Expand Down Expand Up @@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
return success();
}

OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
Value src = getSource();
auto getPrevSrc = [&]() -> Value {
// reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
Expand Down Expand Up @@ -2465,12 +2465,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseShapeOpMemRefCastFolder>(context);
}

OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}

OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2522,7 +2524,7 @@ LogicalResult StoreOp::verify() {
return success();
}

LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
LogicalResult StoreOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
return foldMemRefCast(*this, getValueToStore());
Expand Down Expand Up @@ -3101,7 +3103,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
}

OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto resultShapedType = getResult().getType().cast<ShapedType>();
auto sourceShapedType = getSource().getType().cast<ShapedType>();

Expand Down Expand Up @@ -3217,7 +3219,7 @@ LogicalResult TransposeOp::verify() {
return success();
}

OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
OpFoldResult TransposeOp::fold(FoldAdaptor) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
Expand Down Expand Up @@ -3393,7 +3395,7 @@ LogicalResult AtomicRMWOp::verify() {
return success();
}

OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
/// atomicrmw(memrefcast) -> atomicrmw
if (succeeded(foldMemRefCast(*this, getValue())))
return getResult();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Quant/IR/QuantOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void QuantizationDialect::initialize() {
addBytecodeInterface(this);
}

OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
// value of x if the casts invert each other.
auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
}

LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
LogicalResult IfOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// if (!c) then A() else B() -> if c then B() else A()
if (getElseRegion().empty())
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ LogicalResult ConvertOp::verify() {
return emitError("unexpected type in convert");
}

OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
Type dstType = getType();
// Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
// convert for codegen to remove. This is because we use trivial
Expand Down Expand Up @@ -531,7 +531,7 @@ static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
}

OpFoldResult GetStorageSpecifierOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
StorageSpecifierKind kind = getSpecifierKind();
std::optional<APInt> dim = getDim();
for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ void transform::MergeHandlesOp::getEffects(
// manipulation.
}

OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
if (getDeduplicate() || getHandles().size() != 1)
return {};

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/BuiltinDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ LogicalResult ModuleOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult
UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
UnrealizedConversionCastOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &foldResults) {
OperandRange operands = getInputs();
ResultRange results = getOutputs();
Expand Down
15 changes: 7 additions & 8 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,32 +1099,31 @@ void TestOpWithRegionPattern::getCanonicalizationPatterns(
results.add<TestRemoveOpWithInnerOps>(context);
}

OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
return getOperand();
}

OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) {
return getValue();
}

LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->getOperands()) {
results.push_back(input);
}
return success();
}

OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
if (operands.front()) {
(*this)->setAttr("attr", operands.front());
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
if (adaptor.getOp()) {
(*this)->setAttr("attr", adaptor.getOp());
return getResult();
}
return {};
}

OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
return getOperand();
}

Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def Test_Dialect : Dialect {
let hasNonDefaultDestructor = 1;
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
let isExtensible = 1;
let dependentDialects = ["::mlir::DLTIDialect"];

Expand Down
8 changes: 2 additions & 6 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
let results = (outs Variadic<I1>);
let hasFolder = 1;
let extraClassDefinition = [{
::mlir::LogicalResult $cppClass::fold(ArrayRef<Attribute> operands,
::mlir::LogicalResult $cppClass::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
return success();
}
Expand All @@ -1315,11 +1315,7 @@ def TestOpFoldWithFoldAdaptor
$op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
}];

let hasFolder = 0;

let extraClassDeclaration = [{
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
}];
let hasFolder = 1;
}

// An op that always fold itself.
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/Test/TestTraits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ using namespace test;
//===----------------------------------------------------------------------===//

OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
ArrayRef<Attribute> operands) {
FoldAdaptor adaptor) {
// This failure should cause the trait fold to run instead.
return {};
}

OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
ArrayRef<Attribute> operands) {
FoldAdaptor adaptor) {
auto argumentOp = getOperand();
// The success case should cause the trait fold to be supressed.
return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ ArrayAttr {0}::getIndexingMaps() {{
// Parameters:
// {0}: Class name
const char structuredOpFoldersFormat[] = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
LogicalResult {0}::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {{
return memref::foldMemRefCast(*this);
}
Expand Down