Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
"vector::VectorDialect"];
let options = [Option<
"enableSGReductions", "enable-sg-reductions", "bool",
/*default=*/"true",
"Enable subgroup reductions using subgroup shuffles.">];
}

def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Expand Down
67 changes: 43 additions & 24 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,23 +875,29 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
storeScatterOp,
"Some vector operands have no layouts, using defaults instead.");
}
VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
VectorType expectedPayloadTy = VectorType::get(
{distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
// Distributed store payload type according to the lane layout.
VectorType distPayloadTyByWarpOp = distStoreVecByWarpOpOrFailure.value();
// Expected distributed payload type is always 1D.
VectorType expectedPayloadTy =
VectorType::get({distPayloadTyByWarpOp.getNumElements()},
distPayloadTyByWarpOp.getElementType());

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
SmallVector<Type> operandTypesToYield = {
expectedPayloadTy, operands[1].getType(),
distPayloadTyByWarpOp, operands[1].getType(),
distOffsetsByWarpOpOrFailure.value(),
distMaskByWarpOpOrFailure.value()};

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });

// The payload operand may need type adjustment due to mismatch between warp
// distributed type and expected SIMT type.
rewriter.setInsertionPointAfter(newWarpOp);
newStoreScatterOpOperands[0] = resolveDistributedTy(
newStoreScatterOpOperands[0], expectedPayloadTy, rewriter);
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
Expand Down Expand Up @@ -976,8 +982,11 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
distMaskByWarpOpOrFailure.value()};

const unsigned operandIdx = producedByLastLoad->getOperandNumber();
VectorType loadVecTy =
VectorType distResultTy =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
// Distributed load op will always be 1D.
VectorType loadVecTy = VectorType::get({distResultTy.getNumElements()},
distResultTy.getElementType());

gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
Expand All @@ -991,13 +1000,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
loadGatherOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
// Resolve the output type and replace all uses.
rewriter.replaceAllUsesWith(
distributedVal,
resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
return success();
}
};

/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
/// VectorReductionOps.
/// VectorReductionOps. We also insert layouts for the newly created ops.
static Value lowerToVectorReductions(TypedValue<VectorType> src,
TypedValue<VectorType> acc,
vector::CombiningKind kind,
Expand All @@ -1014,6 +1026,9 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
Value reductionResult = arith::ConstantOp::create(
rewriter, loc, acc.getType(),
DenseElementsAttr::get(acc.getType(), zeroAttr));
// Reduction result should have the same layout as the accumulator.
xegpu::setDistributeLayoutAttr(cast<OpResult>(reductionResult),
xegpu::getDistributeLayoutAttr(acc));
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
for (int i = 0; i < nSlices; ++i) {
Expand All @@ -1029,13 +1044,23 @@ static Value lowerToVectorReductions(TypedValue<VectorType> src,
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
sliceSizes, {1, 1});
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
Value slice = vector::ShapeCastOp::create(
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get({nSliceElements}, sourceType.getElementType()),
extractOp.getResult());
// Shape cast is currently handled in xegpu side. So layouts must be
// retained during lowering. Shape cast output has the same layout as the
// accumulator. Shape cast source has the same layout as the original
// reduction source.
// TODO: other ops generated here may also need layout attributes.
xegpu::setDistributeLayoutAttr(slice->getOpOperand(0),
xegpu::getDistributeLayoutAttr(src));
xegpu::setDistributeLayoutAttr(slice->getOpResult(0),
xegpu::getDistributeLayoutAttr(acc));
// Extract and reduction results in scalars, so no result layout is needed.
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
Value reduction =
vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
Value reduction = vector::ReductionOp::create(
rewriter, loc, kind, slice.getResult(), accExtract);
reductionResult =
vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
}
Expand Down Expand Up @@ -1107,7 +1132,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
return failure();
auto reductionOp =
cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
unsigned operandNumber = yieldOperand->getOperandNumber();
unsigned operandIdx = yieldOperand->getOperandNumber();
VectorType sourceType = reductionOp.getSourceVectorType();
// Only 2D vectors are supported.
if (sourceType.getRank() != 2)
Expand All @@ -1121,7 +1146,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
warpOp, "Only 1 reduction dimension is supported.");
int64_t reductionDim = reductionDims[0];
VectorType distributedResultType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
cast<VectorType>(warpOp.getResult(operandIdx).getType());
VectorType resultType = cast<VectorType>(reductionOp.getType());
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(reductionOp.getSource());
Expand Down Expand Up @@ -1184,7 +1209,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
// Replace the warp op result with the final result.
rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), result);
return success();
}
// For non-lane-local case, we simply rewrite the MultiReductionOp in terms
Expand Down Expand Up @@ -1217,7 +1242,7 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
auto resultDistTy =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
xegpu::getDistributeLayoutAttr(shapeCastOp->getOpOperand(0));
xegpu::DistributeLayoutAttr resultLayout =
xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
if (!sourceLayout || !resultLayout)
Expand Down Expand Up @@ -1403,11 +1428,6 @@ namespace {
struct XeGPUSubgroupDistributePass final
: public xegpu::impl::XeGPUSubgroupDistributeBase<
XeGPUSubgroupDistributePass> {
XeGPUSubgroupDistributePass() = default;
XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
default;
XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
: XeGPUSubgroupDistributeBase(options) {}
void runOnOperation() override;
};
} // namespace
Expand Down Expand Up @@ -1515,10 +1535,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return laneVal;
};

if (enableSGReductions)
vector::populateDistributeReduction(
patterns, warpReduction,
/*pattern benefit=*/regularPatternBenefit);
vector::populateDistributeReduction(
patterns, warpReduction,
/*pattern benefit=*/regularPatternBenefit);

vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn,
Expand Down
Loading
Loading