Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 246 additions & 19 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
return targetType;
}

/// Given a warpOp that contains ops with regions, the corresponding op's
/// "inner" region and the distributionMapFn, get all values used by the op's
/// region that are defined within the warpOp, but outside the inner region.
/// Return the set of values, their types and their distributed types.
std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
SmallVector<Type>>
getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
DistributionMapFn distributionMapFn) {
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> escapingValueTypes;
SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
if (innerRegion.empty())
return {std::move(escapingValues), std::move(escapingValueTypes),
std::move(escapingValueDistTypes)};
mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
escapingValueTypes.push_back(operand->get().getType());
escapingValueDistTypes.push_back(distType);
}
});
return {std::move(escapingValues), std::move(escapingValueTypes),
std::move(escapingValueDistTypes)};
}

/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
/// will not be distributed (it should be less than the warp size).
Expand Down Expand Up @@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};

/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.if is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.if after the
/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
/// the "inner" WarpExecuteOnLane0Op. Example:
/// ```
/// gpu.warp_execute_on_lane_0(%laneid)[32] {
/// %payload = ... : vector<32xindex>
/// scf.if %pred {
/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
/// vector<32xindex>
/// }
/// gpu.yield
/// }
/// ```
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
/// %payload = ... : vector<32xindex>
/// gpu.yield %payload : vector<32xindex>
/// }
/// scf.if %pred {
/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
/// ^bb0(%arg1: vector<32xindex>):
/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
/// }
/// }
/// ```
struct WarpOpScfIfOp : public WarpDistributionPattern {
WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
gpu::YieldOp warpOpYield = warpOp.getTerminator();
// Only pick up `IfOp` if it is the last op in the region.
Operation *lastNode = warpOpYield->getPrevNode();
auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
if (!ifOp)
return failure();

// The current `WarpOp` can yield two types of values:
// 1. Not results of `IfOp`:
// Preserve them in the new `WarpOp`.
// Collect their yield index to remap the usages.
// 2. Results of `IfOp`:
// They are not part of the new `WarpOp` results.
// Map current warp's yield operand index to `IfOp` result idx.
SmallVector<Value> nonIfYieldValues;
SmallVector<unsigned> nonIfYieldIndices;
llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
nonIfYieldValues.push_back(yieldOperand.get());
nonIfYieldIndices.push_back(yieldOperandIdx);
continue;
}
OpResult ifResult = cast<OpResult>(yieldOperand.get());
const unsigned ifResultIdx = ifResult.getResultNumber();
ifResultMapping[yieldOperandIdx] = ifResultIdx;
// If this `ifOp` result is vector type and it is yielded by the
// `WarpOp`, we keep track the distributed type for this result.
if (!isa<VectorType>(ifResult.getType()))
continue;
VectorType distType =
cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
ifResultDistTypes[ifResultIdx] = distType;
}

// Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
// them
auto [escapingValuesThen, escapingValueInputTypesThen,
escapingValueDistTypesThen] =
getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
distributionMapFn);
auto [escapingValuesElse, escapingValueInputTypesElse,
escapingValueDistTypesElse] =
getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
distributionMapFn);
if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
llvm::is_contained(escapingValueDistTypesElse, Type{}))
return failure();

// The new `WarpOp` groups yields values in following order:
// 1. Branch condition
// 2. Escaping values then branch
// 3. Escaping values else branch
// 4. All non-`ifOp` yielded values.
SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
newWarpOpYieldValues.append(escapingValuesThen.begin(),
escapingValuesThen.end());
newWarpOpYieldValues.append(escapingValuesElse.begin(),
escapingValuesElse.end());
SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
escapingValueDistTypesThen.end());
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());

llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
Type distType = cast<Value>(res).getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(cast<Value>(res));
// Fallback to affine map if the dist result was not previously recorded
distType = ifResultDistTypes.count(i)
? ifResultDistTypes[i]
: getDistributedType(vecType, map, warpOp.getWarpSize());
}
newIfOpDistResTypes.push_back(distType);
}
// Create a new `IfOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
llvm::SmallSetVector<Value, 32> &escapingValues,
SmallVector<Type> &escapingValueInputTypes,
size_t warpResRangeStart) {
OpBuilder::InsertionGuard g(rewriter);
if (!newIfBranch)
return;
rewriter.setInsertionPointToStart(newIfBranch);
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
SmallVector<Value> innerWarpInputVals;
SmallVector<Type> innerWarpInputTypes;
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
newWarpOp.getResult(warpResRangeStart));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
}
auto innerWarp = WarpExecuteOnLane0Op::create(
rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
innerWarpInputVals, innerWarpInputTypes);

innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
innerWarp.getWarpRegion().addArguments(
innerWarpInputTypes,
SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));

SmallVector<Value> yieldOperands;
for (Value operand : oldIfBranch->getTerminator()->getOperands())
yieldOperands.push_back(operand);
rewriter.eraseOp(oldIfBranch->getTerminator());

rewriter.setInsertionPointToEnd(innerWarp.getBody());
gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());

// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are arguments of the inner `WarpOp`.
innerWarp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = escapeValToBlockArgIndex.find(operand.get());
if (it == escapeValToBlockArgIndex.end())
continue;
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
mlir::vector::moveScalarUniformCode(innerWarp);
};
encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
&newIfOp.getThenRegion().front(), escapingValuesThen,
escapingValueInputTypesThen, 1);
if (!ifOp.getElseRegion().empty())
encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
&newIfOp.getElseRegion().front(),
escapingValuesElse, escapingValueInputTypesElse,
1 + escapingValuesThen.size());
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `IfOp`.
for (auto [origIdx, newIdx] : origToNewYieldIdx)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
// at this point.
rewriter.eraseOp(ifOp);
rewriter.eraseOp(warpOp);
return success();
}

private:
DistributionMapFn distributionMapFn;
};

/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
Expand Down Expand Up @@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
return failure();
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
// Those Values need to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> escapingValueInputTypes;
SmallVector<Type> escapingValueDistTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
escapingValueInputTypes.push_back(operand->get().getType());
escapingValueDistTypes.push_back(distType);
}
});

auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
distributionMapFn);
if (llvm::is_contained(escapingValueDistTypes, Type{}))
return failure();
// `WarpOp` can yield two types of values:
Expand Down Expand Up @@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
benefit);
}

void mlir::vector::populateDistributeReduction(
Expand Down
69 changes: 69 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
// CHECK-PROP-NOT: vector.broadcast
// CHECK-PROP: vector.step : vector<64xindex>

// -----

func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) {
%laneid = gpu.lane_id
%c0 = arith.constant 0 : index

gpu.warp_execute_on_lane_0(%laneid)[32] {
%seq = vector.step : vector<32xindex>
scf.if %pred {
vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
}
gpu.yield
}
return
}

// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
// CHECK-PROP: scf.if %[[ARG1]] {
// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>):
// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>

// -----

func.func @warp_scf_if_distribute(%pred : i1) {
%laneid = gpu.lane_id
%c0 = arith.constant 0 : index

%0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
%seq1 = vector.step : vector<32xindex>
%seq2 = arith.constant dense<2> : vector<32xindex>
%0 = scf.if %pred -> (vector<32xf32>) {
%1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
scf.yield %1 : vector<32xf32>
} else {
%2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
scf.yield %2 : vector<32xf32>
}
gpu.yield %0 : vector<32xf32>
}
"some_use"(%0) : (vector<1xf32>) -> ()

return
}

// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
// CHECK-PROP-SAME: %[[ARG0:.+]]: i1
// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id
// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>):
// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
// CHECK-PROP: }
// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32>
// CHECK-PROP: } else {
// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
// CHECK-PROP: }
// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32>
// CHECK-PROP: }
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
// CHECK-PROP: return
// CHECK-PROP: }
Loading
Loading