Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,11 +934,13 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
if (result.use_empty())
continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
if (result.use_empty() || !it.second)
if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
Expand Down Expand Up @@ -1843,16 +1845,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());

llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
// Replace the old `WarpOp` with the new one that has additional yield
// values and types.
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
Expand All @@ -1870,8 +1872,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
static_cast<bool>(ifOp.thenBlock()),
rewriter, ifOp.getLoc(), newIfOpDistResTypes,
newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
Expand All @@ -1888,7 +1890,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
newWarpOp.getResult(warpResRangeStart));
newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
Expand Down Expand Up @@ -1936,17 +1938,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `IfOp`.
for (auto [origIdx, newIdx] : origToNewYieldIdx)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
// at this point.
rewriter.eraseOp(ifOp);
rewriter.eraseOp(warpOp);
return success();
}

Expand Down Expand Up @@ -2065,19 +2058,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
// types. We also create a mapping between the non-`ForOp` yielded value
// index and the corresponding new `WarpOp` yield value index (needed to
// update users later).
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
// types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);

// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
Expand All @@ -2086,7 +2076,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));

// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -2110,7 +2100,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
innerWarpInput.push_back(newWarpOp.getResult(i));
innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
Expand Down Expand Up @@ -2146,20 +2136,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());

// Update the users of original `WarpOp` results that were coming from the
// Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `ForOp`.
for (auto [origIdx, newIdx] : nonForResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
// at this point.
rewriter.eraseOp(forOp);
rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) {
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
// CHECK-PROP: return
// CHECK-PROP: }

// -----
func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
%r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
(vector<1xf32>, vector<2xf32>, vector<1xf32>) {
%2 = "some_def"() : () -> (vector<32xf32>)
%3 = "some_def"() : () -> (vector<64xf32>)
gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
}
%r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
return %r0 : vector<1xf32>
}

// CHECK-PROP: func @dedup_unused_result
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32>
// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>