diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index b4d500212c770..2e0e650f2bb9c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1978,11 +1978,20 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are arguments of the inner `WarpOp`. innerWarp.walk([&](Operation *op) { + SmallVector> replacements; for (OpOperand &operand : op->getOpOperands()) { auto it = escapeValToBlockArgIndex.find(operand.get()); if (it == escapeValToBlockArgIndex.end()) continue; - operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + replacements.emplace_back( + operand.getOperandNumber(), + innerWarp.getBodyRegion().getArgument(it->second)); + } + if (!replacements.empty()) { + rewriter.modifyOpInPlace(op, [&]() { + for (auto [idx, newVal] : replacements) + op->setOperand(idx, newVal); + }); } }); mlir::vector::moveScalarUniformCode(innerWarp); @@ -2218,11 +2227,20 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { + SmallVector> replacements; for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); if (it == argIndexMapping.end()) continue; - operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + replacements.emplace_back( + operand.getOperandNumber(), + innerWarp.getBodyRegion().getArgument(it->second)); + } + if (!replacements.empty()) { + rewriter.modifyOpInPlace(op, [&]() { + for (auto [idx, newVal] : replacements) + op->setOperand(idx, newVal); + }); } });