diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 2a0bd67c1cb0d..bd80b2b582d11 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -166,6 +166,11 @@ class Value { replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl &exceptions) const; + /// Replace all uses of 'this' value with 'newValue', updating anything in the + /// IR that uses 'this' to use the other value instead except if the user is + /// 'exceptedUser'. + void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const; + /// Replace all uses of 'this' value with 'newValue' if the given callback /// returns true. void replaceUsesWithIf(Value newValue, diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp index 8653bcf2ad638..1a785a03df763 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp @@ -72,7 +72,7 @@ void mlir::normalizeAffineParallel(AffineParallelOp op) { applyOperands.push_back(iv); applyOperands.append(symbolOperands.begin(), symbolOperands.end()); auto apply = builder.create(op.getLoc(), map, applyOperands); - iv.replaceAllUsesExcept(apply, SmallPtrSet{apply}); + iv.replaceAllUsesExcept(apply, apply); } SmallVector newSteps(op.getNumDims(), 1); @@ -181,8 +181,7 @@ static void normalizeAffineFor(AffineForOp op) { AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols(), newIVExpr); Operation *newIV = opBuilder.create(loc, ivMap, lbOperands); - op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), - SmallPtrSet{newIV}); + op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 79903c06883db..4c06d3dc35042 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -191,8 +191,7 @@ static LinalgOp fuse(OpBuilder &builder, LinalgOp producer, AffineApplyOp applyOp = builder.create( indexOp.getLoc(), index + offset, ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset}); - indexOp.getResult().replaceAllUsesExcept( - applyOp, SmallPtrSet{applyOp}); + indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 0479ab6543110..bdc1d7097ccd9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -155,8 +155,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, AffineApplyOp applyOp = b.create( indexOp.getLoc(), index + iv, ValueRange{indexOp.getResult(), ivs[rangeIndex->second]}); - indexOp.getResult().replaceAllUsesExcept( - applyOp.getResult(), SmallPtrSet{applyOp}); + indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); } } diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index cdb4afa929cf2..8282c0771f302 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -121,8 +121,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { Value inner_index = std::get<0>(ivs); AddIOp newIndex = b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); - inner_index.replaceAllUsesExcept( - newIndex, SmallPtrSet{newIndex.getOperation()}); + inner_index.replaceAllUsesExcept(newIndex, newIndex); } op.erase(); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index e28ab9ba470d4..a4baa93110019 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -63,12 +63,23 @@ void Value::replaceAllUsesWith(Value newValue) const { /// listed in 'exceptions' . void Value::replaceAllUsesExcept( Value newValue, const SmallPtrSetImpl &exceptions) const { - for (auto &use : llvm::make_early_inc_range(getUses())) { + for (OpOperand &use : llvm::make_early_inc_range(getUses())) { if (exceptions.count(use.getOwner()) == 0) use.set(newValue); } } +/// Replace all uses of 'this' value with 'newValue', updating anything in the +/// IR that uses 'this' to use the other value instead except if the user is +/// 'exceptedUser'. +void Value::replaceAllUsesExcept(Value newValue, + Operation *exceptedUser) const { + for (OpOperand &use : llvm::make_early_inc_range(getUses())) { + if (use.getOwner() != exceptedUser) + use.set(newValue); + } +} + /// Replace all uses of 'this' value with 'newValue' if the given callback /// returns true. void Value::replaceUsesWithIf(Value newValue,