Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,17 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropSgLayoutAndData();
if (!parent)
return nullptr;
return SliceAttr::get(getContext(), parent, attr.getDims());
}

SliceAttr dropInstData() {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropInstData();
if (!parent)
return nullptr;
return SliceAttr::get(getContext(), parent, attr.getDims());
}

Expand Down
65 changes: 25 additions & 40 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());

newBroadcastOps.push_back(newBroadcast.getResult());
}
Expand Down Expand Up @@ -738,27 +736,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
Location loc = op.getLoc();
auto eltType = vecType.getElementType();

auto setLayoutIfNeeded = [&](Value val) {
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
layout.dropSgLayoutAndData());
}
auto setLayout = [&](Value val) {
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
layout.dropSgLayoutAndData());
};

if (vecAttr.isSplat()) {
// Splat: single value for all subgroups
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
setLayoutIfNeeded(cstOp->getResult(0));
setLayout(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
setLayoutIfNeeded(newConstOp->getResult(0));
setLayout(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
Expand Down Expand Up @@ -860,9 +855,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
setLayoutIfNeeded(bcastOffset);
setLayoutIfNeeded(finalConst);
setLayout(baseConstVec);
setLayout(bcastOffset);
setLayout(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
Expand Down Expand Up @@ -969,14 +964,11 @@ struct WgToSgStoreScatterOpWithOffset
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
layout.dropSgLayoutAndData());
// Update the layout attribute to drop sg_layout and sg_data.
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
for (OpOperand &operand : store->getOpOperands()) {
// Skip for operand one (memref)
if (operand.getOperandNumber() == 1)
continue;
xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
}
for (OpOperand &operand : store->getOpOperands()) {
// Skip for operand one (memref)
if (operand.getOperandNumber() == 1)
continue;
xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
}
}
rewriter.eraseOp(op);
Expand Down Expand Up @@ -1069,15 +1061,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
auto finalSteps =
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
xegpu::setDistributeLayoutAttr(steps->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
layout.dropSgLayoutAndData());
}
xegpu::setDistributeLayoutAttr(steps->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
layout.dropSgLayoutAndData());
newOps.push_back(finalSteps);
}

Expand Down Expand Up @@ -1145,10 +1134,8 @@ struct WgToSgVectorShapeCastOp
for (auto src : adaptor.getSource()) {
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
newResultType, src);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
layout.dropSgLayoutAndData());
newShapeCastOps.push_back(newShapeCast.getResult());
}

Expand Down Expand Up @@ -1209,10 +1196,8 @@ struct WgToSgMultiDimReductionOp
auto newOp = vector::MultiDimReductionOp::create(
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
adaptor.getAcc()[0], op.getReductionDims());
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
layout.dropSgLayoutAndData());
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
}

Expand Down