From 97c9c9507c50ddf954b06c9bd0b3881e927cc63f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 18 Nov 2025 22:22:26 +0000 Subject: [PATCH 1/2] Fix dropSgLayoutAndData & dropInstData in SliceAttr --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 8 +++ .../Transforms/XeGPUWgToSgDistribute.cpp | 65 +++++++------------ 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 3f27d690f949b..c464c156e1fad 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -635,6 +635,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); parent = parent.dropSgLayoutAndData(); + if (!parent) + return nullptr; + if(!parent.getInstData() && !parent.getLaneLayout()) + return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); } @@ -642,6 +646,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); parent = parent.dropInstData(); + if (!parent) + return nullptr; + if (!parent.getSgLayout() && !parent.getLaneLayout()) + return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 33d4b0457e5d3..8d4e1d11e2873 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -489,10 +489,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -738,12 +736,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern { Location loc = op.getLoc(); auto eltType = vecType.getElementType(); - auto setLayoutIfNeeded = [&](Value val) { - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(llvm::dyn_cast(val), - layout.dropSgLayoutAndData()); - } + auto setLayout = [&](Value val) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast(val), + layout.dropSgLayoutAndData()); }; if (vecAttr.isSplat()) { @@ -751,14 +746,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern { Attribute singleVal = vecAttr.getSplatValue(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - setLayoutIfNeeded(cstOp->getResult(0)); + setLayout(cstOp->getResult(0)); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); - setLayoutIfNeeded(newConstOp->getResult(0)); + setLayout(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { @@ -860,9 +855,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern { rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); - setLayoutIfNeeded(baseConstVec); - setLayoutIfNeeded(bcastOffset); - setLayoutIfNeeded(finalConst); + setLayout(baseConstVec); + setLayout(bcastOffset); + setLayout(finalConst); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); @@ -969,14 +964,11 @@ struct WgToSgStoreScatterOpWithOffset op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout.dropSgLayoutAndData()); // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); } } rewriter.eraseOp(op); @@ -1069,15 +1061,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern { vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(steps->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), - layout.dropSgLayoutAndData()); - } + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); newOps.push_back(finalSteps); } @@ -1145,10 +1134,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } @@ -1209,10 +1196,8 @@ struct WgToSgMultiDimReductionOp auto newOp = vector::MultiDimReductionOp::create( rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } From b40b44cfbbf8b884d00a035066398240df3817a3 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 18 Nov 2025 23:07:30 +0000 Subject: [PATCH 2/2] Fix check --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index c464c156e1fad..93c5187b00756 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -637,8 +637,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { parent = parent.dropSgLayoutAndData(); if (!parent) return nullptr; - if(!parent.getInstData() && !parent.getLaneLayout()) - return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); } @@ -648,8 +646,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { parent = parent.dropInstData(); if (!parent) return nullptr; - if (!parent.getSgLayout() && !parent.getLaneLayout()) - return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); }