From eaaca7f54a9333b1841283b4483cb9c8f91f9f6b Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 16:42:52 +0000 Subject: [PATCH 01/26] save --- .../Vector/Transforms/VectorDistribute.cpp | 242 +++++++++++++++--- 1 file changed, 213 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index be0d28a91cba7..2d9fcaee37282 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,13 +15,19 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include #include using namespace mlir; @@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { } }; +static VectorType +tryFindDistributedType(TypedValue source, + WarpExecuteOnLane0Op warpOp, + const DistributionMapFn &distributionMapFn) { + VectorType distributedType = source.getType(); + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == source; + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(source); + VectorType computed = + getDistributedType(source.getType(), map, warpOp.getWarpSize()); + if (!computed) + return source.getType(); + distributedType = computed; + } + return distributedType; +} + struct WarpOpBroadcast : public WarpDistributionPattern { - using Base::Base; + WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -953,18 +991,23 @@ struct WarpOpBroadcast : public WarpDistributionPattern { auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); Value broadcastSrc = broadcastOp.getSource(); - Type broadcastSrcType = broadcastSrc.getType(); + Type srcDistributedType = broadcastSrc.getType(); + + if (isa(srcDistributedType)) + srcDistributedType = + tryFindDistributedType(cast>(broadcastSrc), + warpOp, distributionMapFn); // Check that the broadcast actually spans a set of values uniformly across // all threads. In other words, check that each thread can reconstruct // their own broadcast. // For that we simply check that the broadcast we want to build makes sense. - if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != + if (vector::isBroadcastableTo(srcDistributedType, destVecType) != vector::BroadcastableToResult::Success) return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = vector::BroadcastOp::create( rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); @@ -972,49 +1015,83 @@ struct WarpOpBroadcast : public WarpDistributionPattern { broadcasted); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public WarpDistributionPattern { - using Base::Base; + + WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); - auto oldCastOp = operand->get().getDefiningOp(); unsigned int operandNumber = operand->getOperandNumber(); - auto castDistributedType = + VectorType sourceType = oldCastOp.getSourceVectorType(); + VectorType distributedResultType = cast(warpOp->getResultTypes()[operandNumber]); - VectorType castOriginalType = oldCastOp.getSourceVectorType(); - VectorType castResultType = castDistributedType; - - // We expect the distributed type to have a smaller rank than the original - // type. Prepend with size-one dimensions to make them the same. - unsigned castDistributedRank = castDistributedType.getRank(); - unsigned castOriginalRank = castOriginalType.getRank(); - if (castDistributedRank < castOriginalRank) { - SmallVector shape(castOriginalRank - castDistributedRank, 1); - llvm::append_range(shape, castDistributedType.getShape()); - castDistributedType = - VectorType::get(shape, castDistributedType.getElementType()); + VectorType distributedSourceType = sourceType; + bool isResultDistributed = distributedResultType.getNumElements() < + oldCastOp.getResultVectorType().getNumElements(); + + // If the result is not distributed, source distribted type is the same + // as the source type. If the result is distributed, we need to compute the + // distributed source type according to following rules: + // 1. If the source type is yielded from the warp op, we can use the + // matching warp result type as the distributed source type. + // 2. If the source type is not yielded from the warp op, we need + // to compute the distributed source type based on the distribution map + // and the warp size. + if (isResultDistributed) { + // Check if the source is yielded from the warp op. + gpu::YieldOp yieldOp = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + auto *it = + llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { + return operand.get() == oldCastOp.getSource(); + }); + + if (it != yieldOp->getOpOperands().end()) { + // If the source is yielded from the warp op, we can use the matching + // warp result type as the distributed source type. + distributedSourceType = + cast(warpOp->getResultTypes()[it->getOperandNumber()]); + } else { + // If the source is not yielded from the warp op, we need to compute + // the distributed source type based on the distribution map and the + // warp size. + AffineMap map = distributionMapFn(oldCastOp.getSource()); + distributedSourceType = + getDistributedType(sourceType, map, warpOp.getWarpSize()); + if (!distributedSourceType) + return rewriter.notifyMatchFailure( + oldCastOp, + "cannot compute distributed source type for shape cast"); + } } SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, + rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value newCast = vector::ShapeCastOp::create( - rewriter, oldCastOp.getLoc(), castResultType, + rewriter, oldCastOp.getLoc(), distributedResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// Sink out vector.create_mask op feeding into a warp op yield. @@ -1996,6 +2073,114 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { + VectorMultiDimReductionDistribution(MLIRContext *context, + PatternBenefit benefit = 1) + : WarpDistributionPattern(context, benefit) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred); + if (!yieldOperand) + return failure(); + auto reductionOp = + cast(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + VectorType sourceType = reductionOp.getSourceVectorType(); + VectorType distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. + if (sourceType.getRank() != 2) + return rewriter.notifyMatchFailure(warpOp, + "Only 2D reductions are supported."); + ArrayRef reductionDims = reductionOp.getReductionDims(); + // Only 1 reduction dimension supported. + if (reductionDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Only 1 reduction dimension is supported."); + + // Col reduction. + if (reductionDims[0] == 0) { + // Yield the source vector and the accumulator. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + SmallVector shape(sourceType.getShape()); + shape[1] = shape[1] / warpOp.getWarpSize(); + auto sourceDistributedType = VectorType::get(shape, elementType); + SmallVector newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, + {sourceDistributedType, distributedResultType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + // Create new reduction op. + // auto newOp = vector::MultiDimReductionOp::create( + // rewriter, reductionOp.getLoc(), distributedResultType, + // reductionOp.getKind(), + // /** source = **/ newWarpOp.getResult(newRetIndices[0]), + // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]), + // reductionDims); + // Create a constant zero value for storing the reduction result. + // rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + int nCols = sourceDistributedType.getShape()[1]; + Value source = newWarpOp.getResult(newRetIndices[0]); + Value acc = newWarpOp.getResult(newRetIndices[1]); + for (int i = 0; i < nCols; ++i) { + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), source, {0, i}, + {sourceDistributedType.getShape()[0], 1}, {1, 1}); + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({sourceDistributedType.getShape()[0]}, elementType), + col); + Value accCol = + vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); + // Insert the reduced column into the result. + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + colReduce, result, i); + } + // Replace the warp op result with the new reduction op. + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + return success(); + } + // Row reduction. + // Create a constant zero value for storing the reduction result. + rewriter.setInsertionPointAfter(reductionOp); + auto zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Value result = arith::ConstantOp::create( + // rewriter, reductionOp.getLoc(), + // rewriter.getIntegerAttr(reductionOp.getType(), 0)); + int nRows = sourceType.getShape()[0]; + // For each row, do a vector reduction. + for (int i = 0; i < nRows; ++i) { + Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getSource(), i); + Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value rowReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + rowReduce, result, i); + } + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + + return success(); + } +}; + } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -2016,16 +2201,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns - .add( - patterns.getContext(), benefit); + patterns.add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add(patterns.getContext(), distributionMapFn, - benefit); + patterns.add( + patterns.getContext(), distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( From 56c3441e9443660788e51064f8206c5e4ac9fbaf Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:07:10 +0000 Subject: [PATCH 02/26] save --- .../Vector/Transforms/VectorDistribute.cpp | 110 +++++------------ .../Vector/vector-warp-distribute.mlir | 111 ++++++++++++++++++ 2 files changed, 143 insertions(+), 78 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2d9fcaee37282..6410a895fc9ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -945,40 +945,8 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { } }; -static VectorType -tryFindDistributedType(TypedValue source, - WarpExecuteOnLane0Op warpOp, - const DistributionMapFn &distributionMapFn) { - VectorType distributedType = source.getType(); - // Check if the source is yielded from the warp op. - gpu::YieldOp yieldOp = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { - return operand.get() == source; - }); - - if (it != yieldOp->getOpOperands().end()) { - // If the source is yielded from the warp op, we can use the matching - // warp result type as the distributed source type. - distributedType = - cast(warpOp->getResultTypes()[it->getOperandNumber()]); - } else { - // If the source is not yielded from the warp op, we need to compute - // the distributed source type based on the distribution map and the - // warp size. - AffineMap map = distributionMapFn(source); - VectorType computed = - getDistributedType(source.getType(), map, warpOp.getWarpSize()); - if (!computed) - return source.getType(); - distributedType = computed; - } - return distributedType; -} - struct WarpOpBroadcast : public WarpDistributionPattern { - WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) - : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -991,23 +959,18 @@ struct WarpOpBroadcast : public WarpDistributionPattern { auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); Value broadcastSrc = broadcastOp.getSource(); - Type srcDistributedType = broadcastSrc.getType(); - - if (isa(srcDistributedType)) - srcDistributedType = - tryFindDistributedType(cast>(broadcastSrc), - warpOp, distributionMapFn); + Type broadcastSrcType = broadcastSrc.getType(); // Check that the broadcast actually spans a set of values uniformly across // all threads. In other words, check that each thread can reconstruct // their own broadcast. // For that we simply check that the broadcast we want to build makes sense. - if (vector::isBroadcastableTo(srcDistributedType, destVecType) != + if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != vector::BroadcastableToResult::Success) return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = vector::BroadcastOp::create( rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0])); @@ -1015,9 +978,6 @@ struct WarpOpBroadcast : public WarpDistributionPattern { broadcasted); return success(); } - -private: - DistributionMapFn distributionMapFn; }; /// Pattern to move shape cast out of the warp op. shape cast is basically a @@ -2100,37 +2060,37 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { return rewriter.notifyMatchFailure( warpOp, "Only 1 reduction dimension is supported."); + // Create a constant vector to store the result of the reduction per lane. + TypedAttr zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Col reduction. if (reductionDims[0] == 0) { - // Yield the source vector and the accumulator. + // Source vector must be distributable to lanes in the col dimension. if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Source vector dimension must be divisible by warp size."); + // Compute source distributed type. SmallVector shape(sourceType.getShape()); shape[1] = shape[1] / warpOp.getWarpSize(); auto sourceDistributedType = VectorType::get(shape, elementType); + + // Yield the source and acc vectors from the WarpOp. SmallVector newRetIndices; auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, {sourceDistributedType, distributedResultType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - // Create new reduction op. - // auto newOp = vector::MultiDimReductionOp::create( - // rewriter, reductionOp.getLoc(), distributedResultType, - // reductionOp.getKind(), - // /** source = **/ newWarpOp.getResult(newRetIndices[0]), - // /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]), - // reductionDims); - // Create a constant zero value for storing the reduction result. - // rewriter.setInsertionPointAfter(reductionOp); - auto zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); + int nCols = sourceDistributedType.getShape()[1]; Value source = newWarpOp.getResult(newRetIndices[0]); Value acc = newWarpOp.getResult(newRetIndices[1]); + // For each column owned by a lane, extract the column (of size nRows x + // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the + // result back to the result vector. for (int i = 0; i < nCols; ++i) { Value col = vector::ExtractStridedSliceOp::create( rewriter, reductionOp.getLoc(), source, {0, i}, @@ -2143,7 +2103,6 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); Value colReduce = vector::ReductionOp::create( rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); - // Insert the reduced column into the result. result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), colReduce, result, i); } @@ -2151,19 +2110,13 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); return success(); } - // Row reduction. - // Create a constant zero value for storing the reduction result. + // For row reductions, we simply rewrite the MultiReductionOp in terms of + // multiple ReductionOps. Actual distribution is done by the WarpOpReduction + // pattern. rewriter.setInsertionPointAfter(reductionOp); - auto zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Value result = arith::ConstantOp::create( - // rewriter, reductionOp.getLoc(), - // rewriter.getIntegerAttr(reductionOp.getType(), 0)); int nRows = sourceType.getShape()[0]; - // For each row, do a vector reduction. + // For each row of the source, extract the row vector, do a reduction and, + // insert the result back to the result. for (int i = 0; i < nRows; ++i) { Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), reductionOp.getSource(), i); @@ -2201,15 +2154,16 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns.add( - patterns.getContext(), benefit); + patterns + .add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add( - patterns.getContext(), distributionMapFn, benefit); + patterns.add(patterns.getContext(), + distributionMapFn, benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 4d2c964a6df3c..bf70fbbd27244 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { return %r : f32 } +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce +// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<32x64xf32>) + %acc = "some_def"() : () -> (vector<64xf32>) + %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> + gpu.yield %1 : vector<64xf32> + } + return %r : vector<2xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce +// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { + %zero = arith.constant dense<0.0> : vector<2xf32> + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<2x32xf32>) + %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> + gpu.yield %1 : vector<2xf32> + } + return %r : vector<2xf32> +} + // ----- // CHECK-PROP-LABEL: func @warp_duplicate_yield( @@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) // CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> // CHECK-PROP: return %[[CAST]] : vector<4xf32> +// ----- +func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32> + %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32> + gpu.yield %3 : vector<32x64xf32> + } + return %r : vector<32x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<32x2xf32> + +// ----- +func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) { + %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32> + %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32> + gpu.yield %3 : vector<8x4x2xf32> + } + return %r : vector<8x4x2xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result +// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32> +// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32> +// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32> + // ----- func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> { From 01880b561e94c6cb752e6eddb16957e00dbdc97f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:26:49 +0000 Subject: [PATCH 03/26] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 6410a895fc9ae..8dc1418e09006 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2033,6 +2033,12 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; +// This patterns distribute the `vector.multi_reduction` operation across +// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes +// that source vector is distributed in column dimension (i.e. Each lane owns +// complete column(s) of the source vector. +// TODO: Add support for the case where source rows are distributed accross +// lanes. Requires DistributionMapFn to express the data distribution. struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { VectorMultiDimReductionDistribution(MLIRContext *context, PatternBenefit benefit = 1) From 53da9928117634d6eb929f81cbfa59ed4c06d884 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:28:13 +0000 Subject: [PATCH 04/26] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 8dc1418e09006..c88c001f34843 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2036,9 +2036,9 @@ struct WarpOpReduction : public WarpDistributionPattern { // This patterns distribute the `vector.multi_reduction` operation across // lanes in a warp. Currently only 2D to 1D reductions are supported and assumes // that source vector is distributed in column dimension (i.e. Each lane owns -// complete column(s) of the source vector. -// TODO: Add support for the case where source rows are distributed accross -// lanes. Requires DistributionMapFn to express the data distribution. +// complete column(s) of the source vector). +// TODO: Add support for the case where source rows are distributed across +// lanes. Requires `DistributionMapFn` to express the data distribution. struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { VectorMultiDimReductionDistribution(MLIRContext *context, PatternBenefit benefit = 1) From affd4aadb2e0f3f7cd19b0805b34067c1fa65371 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 19 Aug 2025 23:48:08 +0000 Subject: [PATCH 05/26] save --- .../Vector/vector-warp-distribute.mlir | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bf70fbbd27244..bf0191655d654 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -879,44 +879,44 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { // ----- // CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce -// CHECK-PROP: %[[C16:.*]] = arith.constant 16 : i32 -// CHECK-PROP: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK-PROP: %[[C4:.*]] = arith.constant 4 : i32 -// CHECK-PROP: %[[C2:.*]] = arith.constant 2 : i32 -// CHECK-PROP: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK-PROP: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-PROP: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { -// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> -// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 -// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 -// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 -// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 -// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 -// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 // -// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 -// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 -// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 -// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 -// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 -// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { %zero = arith.constant dense<0.0> : vector<2xf32> %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { From df59c20f5d8020ab9ba78f1c360334c738a60404 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 00:01:32 +0000 Subject: [PATCH 06/26] save --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index c88c001f34843..b0b52919c69ce 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,19 +15,13 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" -#include #include using namespace mlir; From 55797318492b6a38801aa27bf9ec97d26523322e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 23:31:35 +0000 Subject: [PATCH 07/26] save --- .../Vector/Transforms/VectorDistribute.cpp | 52 +++++++++++++------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index b0b52919c69ce..ab0f1b55d04da 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2033,9 +2033,8 @@ struct WarpOpReduction : public WarpDistributionPattern { // complete column(s) of the source vector). // TODO: Add support for the case where source rows are distributed across // lanes. Requires `DistributionMapFn` to express the data distribution. -struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { - VectorMultiDimReductionDistribution(MLIRContext *context, - PatternBenefit benefit = 1) +struct WarpOpMultiReduction : public WarpDistributionPattern { + WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1) : WarpDistributionPattern(context, benefit) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -2047,18 +2046,46 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { cast(yieldOperand->get().getDefiningOp()); unsigned operandNumber = yieldOperand->getOperandNumber(); VectorType sourceType = reductionOp.getSourceVectorType(); - VectorType distributedResultType = - cast(warpOp.getResult(operandNumber).getType()); - Type elementType = distributedResultType.getElementType(); + // Only 2D vectors are supported. if (sourceType.getRank() != 2) return rewriter.notifyMatchFailure(warpOp, "Only 2D reductions are supported."); ArrayRef reductionDims = reductionOp.getReductionDims(); - // Only 1 reduction dimension supported. + // Only 1 reduction dimension supported. This also ensures that result is + // also vector type. if (reductionDims.size() != 1) return rewriter.notifyMatchFailure( warpOp, "Only 1 reduction dimension is supported."); + int64_t reductionDim = reductionDims[0]; + auto resultType = cast(reductionOp.getType()); + auto distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + + // Currently we make the following assumptions. + // 1. The source vector is distributed in the column dimension. Each lane + // owns complete column(s) of the source vector. + // 2. If the reduction dim == 0, its a lane-local col reduction. In this + // case each lane owns its portion of the result (i.e. result is also + // distributed). + // 3. If reduction dim == 1, its a row reduction that require cross lanes + // shuffles. In this case result is not distributed and broadcasted instead. + // TODO: These assumptions are fairly restrictive. For example, source + // vector can have row distributed layout. Improve support for such cases. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + bool isResultDistributed = + distributedResultType.getNumElements() < resultType.getNumElements(); + if (reductionDim == 0 && !isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be distributed in a col reduction."); + if (reductionDim == 1 && isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be broadcasted in a row reduction."); // Create a constant vector to store the result of the reduction per lane. TypedAttr zeroAttr = @@ -2066,14 +2093,9 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern { Value result = arith::ConstantOp::create( rewriter, reductionOp->getLoc(), distributedResultType, DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Col reduction. - if (reductionDims[0] == 0) { - // Source vector must be distributable to lanes in the col dimension. - if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) - return rewriter.notifyMatchFailure( - warpOp, "Source vector dimension must be divisible by warp size."); - // Compute source distributed type. + if (reductionDim == 0) { + // Compute source distributed type assuming each lane owns cols. SmallVector shape(sourceType.getShape()); shape[1] = shape[1] / warpOp.getWarpSize(); auto sourceDistributedType = VectorType::get(shape, elementType); @@ -2158,7 +2180,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( .add( + WarpOpInsertStridedSlice, WarpOpMultiReduction>( patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); From 07c0364d64109faf740023107ab68ec0f242d9ca Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 20 Aug 2025 23:36:26 +0000 Subject: [PATCH 08/26] save --- .../Vector/Transforms/VectorDistribute.cpp | 3 +- .../Vector/vector-warp-distribute.mlir | 32 ++++++++++--------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index ab0f1b55d04da..aecb6a11a7b36 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2034,8 +2034,7 @@ struct WarpOpReduction : public WarpDistributionPattern { // TODO: Add support for the case where source rows are distributed across // lanes. Requires `DistributionMapFn` to express the data distribution. struct WarpOpMultiReduction : public WarpDistributionPattern { - WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1) - : WarpDistributionPattern(context, benefit) {} + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bf0191655d654..95b8a48404f20 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -852,21 +852,23 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { // ----- // CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce -// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { -// CHECK-PROP: %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> -// CHECK-PROP: %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> -// CHECK-PROP: gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP: %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP: %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> -// CHECK-PROP: %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 -// CHECK-PROP: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP: %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP: %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> -// CHECK-PROP: %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> +// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP : } +// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP : return %[[R]] : vector<2xf32> func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { %0 = "some_def"() : () -> (vector<32x64xf32>) From 4ed74d89b7808c17bed035a906acf15d2a96c51f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 28 Aug 2025 22:21:49 +0000 Subject: [PATCH 09/26] save work --- .../Vector/Transforms/VectorDistribute.cpp | 58 ++++++++++++++++--- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2cd743d1ee8e8..4a6ea07c1c236 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1031,7 +1031,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { return failure(); auto oldCastOp = operand->get().getDefiningOp(); - unsigned int operandNumber = operand->getOperandNumber(); + unsigned operandNumber = operand->getOperandNumber(); VectorType sourceType = oldCastOp.getSourceVectorType(); VectorType distributedResultType = cast(warpOp->getResultTypes()[operandNumber]); @@ -2069,12 +2069,56 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; -// This patterns distribute the `vector.multi_reduction` operation across -// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes -// that source vector is distributed in column dimension (i.e. Each lane owns -// complete column(s) of the source vector). -// TODO: Add support for the case where source rows are distributed across -// lanes. Requires `DistributionMapFn` to express the data distribution. +/// This patterns distribute the `vector.multi_reduction` operation across +/// lanes in a warp. Currently only 2D to 1D reductions are supported and +/// assumes that source vector is distributed in column dimension (i.e. Each +/// lane owns complete column(s) of the source vector). +/// TODO: Add support for the case where source rows are distributed across +/// lanes. Requires `DistributionMapFn` to express the data distribution. +/// Example 1 (Col reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// %1 = vector.multi_reduction , %0, %acc [0] : vector<16x32xf32> to +/// vector<32xf32> gpu.yield %1 : vector<32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>, +/// vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32> +/// } +/// %c = arith.constant dense<0.0> : vector<1xf32> +/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32> +/// %2 = vector.reduction , %1, %r#1 : vector<16xf32> to f32 +/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> +/// ``` +/// Example 2 (Row reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = vector.multi_reduction , %0, %acc [1] : vector<2x32xf32> to +/// vector<2xf32> +/// gpu.yield %1 : vector<2xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = arith.constant dense<0.0> : vector<2xf32> +/// %2 = vector.extract %0[0] : vector<32xf32> from > +/// %3 = ("warp.reduction %2") : f32 +/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32> +/// ... repeat for row 1 +/// gpu.yield %1 : vector<2xf32> +/// } struct WarpOpMultiReduction : public WarpDistributionPattern { using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, From 116e4bceb48ebe85e35fd61dd9e52867897b0a39 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 2 Sep 2025 20:49:19 +0000 Subject: [PATCH 10/26] save work --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 4a6ea07c1c236..dddfcaf4f273d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" @@ -1039,7 +1040,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { bool isResultDistributed = distributedResultType.getNumElements() < oldCastOp.getResultVectorType().getNumElements(); - // If the result is not distributed, source distribted type is the same + // If the result is not distributed, source distributed type is the same // as the source type. If the result is distributed, we need to compute the // distributed source type according to following rules: // 1. If the source type is yielded from the warp op, we can use the @@ -1051,7 +1052,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern { // Check if the source is yielded from the warp op. gpu::YieldOp yieldOp = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - auto *it = + OpOperand *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { return operand.get() == oldCastOp.getSource(); }); @@ -2155,7 +2156,9 @@ struct WarpOpMultiReduction : public WarpDistributionPattern { // case each lane owns its portion of the result (i.e. result is also // distributed). // 3. If reduction dim == 1, its a row reduction that require cross lanes - // shuffles. In this case result is not distributed and broadcasted instead. + // shuffles. In this case, the reduction result is not distributed across + // lanes. Instead each lane owns a complete copy of the result + // (broadcasted). // TODO: These assumptions are fairly restrictive. For example, source // vector can have row distributed layout. Improve support for such cases. if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) From a78aec590ed062d9770cd747d7bc61114fc6f23e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 3 Sep 2025 23:07:45 +0000 Subject: [PATCH 11/26] move work --- .../Vector/Transforms/VectorDistribute.cpp | 174 ----------------- .../Transforms/XeGPUSubgroupDistribute.cpp | 183 +++++++++++++++++- .../Vector/vector-warp-distribute.mlir | 113 ----------- .../Dialect/XeGPU/subgroup-distribute.mlir | 79 ++++++++ 4 files changed, 258 insertions(+), 291 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index dddfcaf4f273d..aacbb4c23af3e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -2070,180 +2070,6 @@ struct WarpOpReduction : public WarpDistributionPattern { DistributedReductionFn distributedReductionFn; }; -/// This patterns distribute the `vector.multi_reduction` operation across -/// lanes in a warp. Currently only 2D to 1D reductions are supported and -/// assumes that source vector is distributed in column dimension (i.e. Each -/// lane owns complete column(s) of the source vector). -/// TODO: Add support for the case where source rows are distributed across -/// lanes. Requires `DistributionMapFn` to express the data distribution. -/// Example 1 (Col reduction): -/// ``` -/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { -/// %0 = "some_def"() : () -> (vector<16x32xf32>) -/// %acc = "some_def"() : () -> (vector<32xf32>) -/// %1 = vector.multi_reduction , %0, %acc [0] : vector<16x32xf32> to -/// vector<32xf32> gpu.yield %1 : vector<32xf32> -/// } -/// ``` -/// is lowered to: -/// ``` -/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>, -/// vector<1xf32>) { -/// %0 = "some_def"() : () -> (vector<16x32xf32>) -/// %acc = "some_def"() : () -> (vector<32xf32>) -/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32> -/// } -/// %c = arith.constant dense<0.0> : vector<1xf32> -/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32> -/// %2 = vector.reduction , %1, %r#1 : vector<16xf32> to f32 -/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> -/// ``` -/// Example 2 (Row reduction): -/// ``` -/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { -/// %0 = "some_def"() : () -> (vector<2x32xf32>) -/// %acc = "some_def"() : () -> (vector<2xf32>) -/// %1 = vector.multi_reduction , %0, %acc [1] : vector<2x32xf32> to -/// vector<2xf32> -/// gpu.yield %1 : vector<2xf32> -/// } -/// ``` -/// is lowered to: -/// ``` -/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { -/// %0 = "some_def"() : () -> (vector<2x32xf32>) -/// %acc = "some_def"() : () -> (vector<2xf32>) -/// %1 = arith.constant dense<0.0> : vector<2xf32> -/// %2 = vector.extract %0[0] : vector<32xf32> from > -/// %3 = ("warp.reduction %2") : f32 -/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32> -/// ... repeat for row 1 -/// gpu.yield %1 : vector<2xf32> -/// } -struct WarpOpMultiReduction : public WarpDistributionPattern { - using Base::Base; - LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, - PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = - getWarpResult(warpOp, llvm::IsaPred); - if (!yieldOperand) - return failure(); - auto reductionOp = - cast(yieldOperand->get().getDefiningOp()); - unsigned operandNumber = yieldOperand->getOperandNumber(); - VectorType sourceType = reductionOp.getSourceVectorType(); - - // Only 2D vectors are supported. - if (sourceType.getRank() != 2) - return rewriter.notifyMatchFailure(warpOp, - "Only 2D reductions are supported."); - ArrayRef reductionDims = reductionOp.getReductionDims(); - // Only 1 reduction dimension supported. This also ensures that result is - // also vector type. - if (reductionDims.size() != 1) - return rewriter.notifyMatchFailure( - warpOp, "Only 1 reduction dimension is supported."); - int64_t reductionDim = reductionDims[0]; - auto resultType = cast(reductionOp.getType()); - auto distributedResultType = - cast(warpOp.getResult(operandNumber).getType()); - Type elementType = distributedResultType.getElementType(); - - // Currently we make the following assumptions. - // 1. The source vector is distributed in the column dimension. Each lane - // owns complete column(s) of the source vector. - // 2. If the reduction dim == 0, its a lane-local col reduction. In this - // case each lane owns its portion of the result (i.e. result is also - // distributed). - // 3. If reduction dim == 1, its a row reduction that require cross lanes - // shuffles. In this case, the reduction result is not distributed across - // lanes. Instead each lane owns a complete copy of the result - // (broadcasted). - // TODO: These assumptions are fairly restrictive. For example, source - // vector can have row distributed layout. Improve support for such cases. - if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) - return rewriter.notifyMatchFailure( - warpOp, "Source vector dimension must be divisible by warp size."); - bool isResultDistributed = - distributedResultType.getNumElements() < resultType.getNumElements(); - if (reductionDim == 0 && !isResultDistributed) - return rewriter.notifyMatchFailure( - warpOp, - "Expecting result vector to be distributed in a col reduction."); - if (reductionDim == 1 && isResultDistributed) - return rewriter.notifyMatchFailure( - warpOp, - "Expecting result vector to be broadcasted in a row reduction."); - - // Create a constant vector to store the result of the reduction per lane. - TypedAttr zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Col reduction. - if (reductionDim == 0) { - // Compute source distributed type assuming each lane owns cols. - SmallVector shape(sourceType.getShape()); - shape[1] = shape[1] / warpOp.getWarpSize(); - auto sourceDistributedType = VectorType::get(shape, elementType); - - // Yield the source and acc vectors from the WarpOp. - SmallVector newRetIndices; - auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, - {sourceDistributedType, distributedResultType}, newRetIndices); - rewriter.setInsertionPointAfter(newWarpOp); - - int nCols = sourceDistributedType.getShape()[1]; - Value source = newWarpOp.getResult(newRetIndices[0]); - Value acc = newWarpOp.getResult(newRetIndices[1]); - // For each column owned by a lane, extract the column (of size nRows x - // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the - // result back to the result vector. - for (int i = 0; i < nCols; ++i) { - Value col = vector::ExtractStridedSliceOp::create( - rewriter, reductionOp.getLoc(), source, {0, i}, - {sourceDistributedType.getShape()[0], 1}, {1, 1}); - col = vector::ShapeCastOp::create( - rewriter, reductionOp.getLoc(), - VectorType::get({sourceDistributedType.getShape()[0]}, elementType), - col); - Value accCol = - vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); - Value colReduce = vector::ReductionOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); - result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), - colReduce, result, i); - } - // Replace the warp op result with the new reduction op. - rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); - return success(); - } - // For row reductions, we simply rewrite the MultiReductionOp in terms of - // multiple ReductionOps. Actual distribution is done by the WarpOpReduction - // pattern. - rewriter.setInsertionPointAfter(reductionOp); - int nRows = sourceType.getShape()[0]; - // For each row of the source, extract the row vector, do a reduction and, - // insert the result back to the result. - for (int i = 0; i < nRows; ++i) { - Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), - reductionOp.getSource(), i); - Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), - reductionOp.getAcc(), i); - Value rowReduce = vector::ReductionOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); - result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), - rowReduce, result, i); - } - // Replace the warp op result with the final result. - rewriter.replaceAllUsesWith(reductionOp.getResult(), result); - - return success(); - } -}; - } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index dddb5eaece2cb..050fa0cd1d342 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -807,6 +807,180 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +/// This patterns distribute the `vector.multi_reduction` operation across +/// lanes in a warp. Currently only 2D to 1D reductions are supported and +/// assumes that source vector is distributed in column dimension (i.e. Each +/// lane owns complete column(s) of the source vector). +/// TODO: Add support for the case where source rows are distributed across +/// lanes. Requires `DistributionMapFn` to express the data distribution. +/// Example 1 (Col reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// %1 = vector.multi_reduction , %0, %acc [0] : vector<16x32xf32> to +/// vector<32xf32> gpu.yield %1 : vector<32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>, +/// vector<1xf32>) { +/// %0 = "some_def"() : () -> (vector<16x32xf32>) +/// %acc = "some_def"() : () -> (vector<32xf32>) +/// gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32> +/// } +/// %c = arith.constant dense<0.0> : vector<1xf32> +/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32> +/// %2 = vector.reduction , %1, %r#1 : vector<16xf32> to f32 +/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> +/// ``` +/// Example 2 (Row reduction): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = vector.multi_reduction , %0, %acc [1] : vector<2x32xf32> to +/// vector<2xf32> +/// gpu.yield %1 : vector<2xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { +/// %0 = "some_def"() : () -> (vector<2x32xf32>) +/// %acc = "some_def"() : () -> (vector<2xf32>) +/// %1 = arith.constant dense<0.0> : vector<2xf32> +/// %2 = vector.extract %0[0] : vector<32xf32> from > +/// %3 = ("warp.reduction %2") : f32 +/// %4 = vector.insert %3, %1[0] : f32 into vector<2xf32> +/// ... repeat for row 1 +/// gpu.yield %1 : vector<2xf32> +/// } +struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { + using Base::Base; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred); + if (!yieldOperand) + return failure(); + auto reductionOp = + cast(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + VectorType sourceType = reductionOp.getSourceVectorType(); + + // Only 2D vectors are supported. + if (sourceType.getRank() != 2) + return rewriter.notifyMatchFailure(warpOp, + "Only 2D reductions are supported."); + ArrayRef reductionDims = reductionOp.getReductionDims(); + // Only 1 reduction dimension supported. This also ensures that result is + // also vector type. + if (reductionDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Only 1 reduction dimension is supported."); + int64_t reductionDim = reductionDims[0]; + auto resultType = cast(reductionOp.getType()); + auto distributedResultType = + cast(warpOp.getResult(operandNumber).getType()); + Type elementType = distributedResultType.getElementType(); + + // Currently we make the following assumptions. + // 1. The source vector is distributed in the column dimension. Each lane + // owns complete column(s) of the source vector. + // 2. If the reduction dim == 0, its a lane-local col reduction. In this + // case each lane owns its portion of the result (i.e. result is also + // distributed). + // 3. If reduction dim == 1, its a row reduction that require cross lanes + // shuffles. In this case, the reduction result is not distributed across + // lanes. Instead each lane owns a complete copy of the result + // (broadcasted). + // TODO: These assumptions are fairly restrictive. For example, source + // vector can have row distributed layout. Improve support for such cases. + if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + return rewriter.notifyMatchFailure( + warpOp, "Source vector dimension must be divisible by warp size."); + bool isResultDistributed = + distributedResultType.getNumElements() < resultType.getNumElements(); + if (reductionDim == 0 && !isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be distributed in a col reduction."); + if (reductionDim == 1 && isResultDistributed) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting result vector to be broadcasted in a row reduction."); + + // Create a constant vector to store the result of the reduction per lane. + TypedAttr zeroAttr = + rewriter.getZeroAttr(distributedResultType.getElementType()); + Value result = arith::ConstantOp::create( + rewriter, reductionOp->getLoc(), distributedResultType, + DenseElementsAttr::get(distributedResultType, zeroAttr)); + // Col reduction. + if (reductionDim == 0) { + // Compute source distributed type assuming each lane owns cols. + SmallVector shape(sourceType.getShape()); + shape[1] = shape[1] / warpOp.getWarpSize(); + auto sourceDistributedType = VectorType::get(shape, elementType); + + // Yield the source and acc vectors from the WarpOp. + SmallVector newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, + {sourceDistributedType, distributedResultType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + int nCols = sourceDistributedType.getShape()[1]; + Value source = newWarpOp.getResult(newRetIndices[0]); + Value acc = newWarpOp.getResult(newRetIndices[1]); + // For each column owned by a lane, extract the column (of size nRows x + // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the + // result back to the result vector. + for (int i = 0; i < nCols; ++i) { + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), source, {0, i}, + {sourceDistributedType.getShape()[0], 1}, {1, 1}); + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({sourceDistributedType.getShape()[0]}, elementType), + col); + Value accCol = + vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + colReduce, result, i); + } + // Replace the warp op result with the new reduction op. + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + return success(); + } + // For row reductions, we simply rewrite the MultiReductionOp in terms of + // multiple ReductionOps. Actual distribution is done by the WarpOpReduction + // pattern. + rewriter.setInsertionPointAfter(reductionOp); + int nRows = sourceType.getShape()[0]; + // For each row of the source, extract the row vector, do a reduction and, + // insert the result back to the result. + for (int i = 0; i < nRows; ++i) { + Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getSource(), i); + Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value rowReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), + rowReduce, result, i); + } + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + + return success(); + } +}; + } // namespace namespace { @@ -819,10 +993,11 @@ struct XeGPUSubgroupDistributePass final void xegpu::populateXeGPUSubgroupDistributePatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } void XeGPUSubgroupDistributePass::runOnOperation() { diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 755222494d223..8750582ef1e1f 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -850,85 +850,6 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) { return %r : f32 } -// ----- -// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce -// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { -// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> -// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> -// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> -// CHECK-PROP : } -// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 -// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> -// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 -// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 -// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> -// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 -// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> -// CHECK-PROP : return %[[R]] : vector<2xf32> -func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (vector<32x64xf32>) - %acc = "some_def"() : () -> (vector<64xf32>) - %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> - gpu.yield %1 : vector<64xf32> - } - return %r : vector<2xf32> -} - -// ----- -// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce -// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32 -// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32 -// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32 -// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { -// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> -// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 -// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 -// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 -// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 -// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 -// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 -// -// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 -// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 -// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 -// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 -// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 -// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> -func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { - %zero = arith.constant dense<0.0> : vector<2xf32> - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (vector<2x32xf32>) - %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> - gpu.yield %1 : vector<2xf32> - } - return %r : vector<2xf32> -} - // ----- // CHECK-PROP-LABEL: func @warp_duplicate_yield( @@ -1646,40 +1567,6 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) // CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> // CHECK-PROP: return %[[CAST]] : vector<4xf32> -// ----- -func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) { - %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32> - %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32> - gpu.yield %3 : vector<32x64xf32> - } - return %r : vector<32x2xf32> -} - -// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d -// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32> -// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32> -// CHECK-PROP: return %[[CAST]] : vector<32x2xf32> - -// ----- -func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) { - %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32> - %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32> - gpu.yield %3 : vector<8x4x2xf32> - } - return %r : vector<8x4x2xf32> -} - -// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result -// CHECK-PROP: %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32> -// CHECK-PROP: %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32> -// CHECK-PROP: return %[[CAST]] : vector<8x4x2xf32> - // ----- func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 54ef56e013abb..cfb5428c92400 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -319,3 +319,82 @@ gpu.module @test { gpu.return } } + +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce +// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { +// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> +// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> +// CHECK-PROP : } +// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 +// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 +// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> +// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> +// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> +// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 +// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> +// CHECK-PROP : return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<32x64xf32>) + %acc = "some_def"() : () -> (vector<64xf32>) + %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> + gpu.yield %1 : vector<64xf32> + } + return %r : vector<2xf32> +} + +// ----- +// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce +// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { +// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> +// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 +// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 +// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 +// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 +// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 +// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 +// +// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> +// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 +// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 +// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 +// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 +// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 +// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 +// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 +// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 +// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 +// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 +// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 +// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +// CHECK-PROP: return %[[R]] : vector<2xf32> +func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { + %zero = arith.constant dense<0.0> : vector<2xf32> + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<2x32xf32>) + %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> + gpu.yield %1 : vector<2xf32> + } + return %r : vector<2xf32> +} From 2ba43fc7827a64550a038453271ab972f2e94b01 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 3 Sep 2025 23:08:26 +0000 Subject: [PATCH 12/26] save work --- .../Vector/Transforms/VectorDistribute.cpp | 80 ++++++------------- 1 file changed, 24 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index aacbb4c23af3e..c84eb2c9f8857 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" @@ -1021,75 +1020,44 @@ struct WarpOpBroadcast : public WarpDistributionPattern { /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public WarpDistributionPattern { - - WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) - : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); + auto oldCastOp = operand->get().getDefiningOp(); - unsigned operandNumber = operand->getOperandNumber(); - VectorType sourceType = oldCastOp.getSourceVectorType(); - VectorType distributedResultType = + unsigned int operandNumber = operand->getOperandNumber(); + auto castDistributedType = cast(warpOp->getResultTypes()[operandNumber]); - VectorType distributedSourceType = sourceType; - bool isResultDistributed = distributedResultType.getNumElements() < - oldCastOp.getResultVectorType().getNumElements(); - - // If the result is not distributed, source distributed type is the same - // as the source type. If the result is distributed, we need to compute the - // distributed source type according to following rules: - // 1. If the source type is yielded from the warp op, we can use the - // matching warp result type as the distributed source type. - // 2. If the source type is not yielded from the warp op, we need - // to compute the distributed source type based on the distribution map - // and the warp size. - if (isResultDistributed) { - // Check if the source is yielded from the warp op. - gpu::YieldOp yieldOp = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - OpOperand *it = - llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) { - return operand.get() == oldCastOp.getSource(); - }); - - if (it != yieldOp->getOpOperands().end()) { - // If the source is yielded from the warp op, we can use the matching - // warp result type as the distributed source type. - distributedSourceType = - cast(warpOp->getResultTypes()[it->getOperandNumber()]); - } else { - // If the source is not yielded from the warp op, we need to compute - // the distributed source type based on the distribution map and the - // warp size. - AffineMap map = distributionMapFn(oldCastOp.getSource()); - distributedSourceType = - getDistributedType(sourceType, map, warpOp.getWarpSize()); - if (!distributedSourceType) - return rewriter.notifyMatchFailure( - oldCastOp, - "cannot compute distributed source type for shape cast"); - } + VectorType castOriginalType = oldCastOp.getSourceVectorType(); + VectorType castResultType = castDistributedType; + + // We expect the distributed type to have a smaller rank than the original + // type. Prepend with size-one dimensions to make them the same. + unsigned castDistributedRank = castDistributedType.getRank(); + unsigned castOriginalRank = castOriginalType.getRank(); + if (castDistributedRank < castOriginalRank) { + SmallVector shape(castOriginalRank - castDistributedRank, 1); + llvm::append_range(shape, castDistributedType.getShape()); + castDistributedType = + VectorType::get(shape, castDistributedType.getElementType()); } SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType}, + rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value newCast = vector::ShapeCastOp::create( - rewriter, oldCastOp.getLoc(), distributedResultType, + rewriter, oldCastOp.getLoc(), castResultType, newWarpOp->getResult(newRetIndices[0])); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); return success(); } - -private: - DistributionMapFn distributionMapFn; }; /// Sink out vector.create_mask op feeding into a warp op yield. @@ -2091,15 +2059,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); patterns - .add( + .add( patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); - patterns.add(patterns.getContext(), - distributionMapFn, benefit); + patterns.add(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populateDistributeReduction( From 3dea80c8bbaf229235771f2e9c66d36ef075b185 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 4 Sep 2025 17:43:03 +0000 Subject: [PATCH 13/26] save test --- mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index cfb5428c92400..8e2e96dfc05cc 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -339,14 +339,16 @@ gpu.module @test { // CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 // CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> // CHECK-PROP : return %[[R]] : vector<2xf32> -func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { +gpu.module @test { +gpu.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { %0 = "some_def"() : () -> (vector<32x64xf32>) %acc = "some_def"() : () -> (vector<64xf32>) %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> gpu.yield %1 : vector<64xf32> } - return %r : vector<2xf32> + gpu.return %r : vector<2xf32> +} } // ----- @@ -389,12 +391,14 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { // CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 // CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> // CHECK-PROP: return %[[R]] : vector<2xf32> -func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { +gpu.module @test { +gpu.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { %zero = arith.constant dense<0.0> : vector<2xf32> %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { %0 = "some_def"() : () -> (vector<2x32xf32>) %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> gpu.yield %1 : vector<2xf32> } - return %r : vector<2xf32> + gpu.return %r : vector<2xf32> +} } From 3c06f28573cb36e45f152f81373a4e26f385f5b5 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 5 Sep 2025 22:11:31 +0000 Subject: [PATCH 14/26] save work --- .../Transforms/XeGPUSubgroupDistribute.cpp | 243 +++++++++++++----- .../Dialect/XeGPU/subgroup-distribute.mlir | 83 ------ 2 files changed, 179 insertions(+), 147 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 050fa0cd1d342..5af45d1324b3b 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { namespace xegpu { @@ -72,27 +73,43 @@ namespace { /// | 32x16 | [2, 8] | 16x2 | /// | 2x32x16 | [1, 16] | 2x32x1 | static FailureOr -getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout, +getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, VectorType originalType) { if (!layout) return failure(); + assert((isa(layout) || isa(layout)) && + "Expecting a valid layout."); + SmallVector effectiveLaneLayout; + // If the layout is a slice, we need to get effective lane layout by removing + // sliced dims. + if (auto sliceAttr = dyn_cast(layout)) { + ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); + llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); + for (auto [i, dim] : + llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { + if (!lookUp.contains(i)) + effectiveLaneLayout.push_back(dim); + } + } else { + effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); + } - auto laneLayout = layout.getLaneLayout().asArrayRef(); - assert(originalType.getShape().size() >= laneLayout.size() && + assert(originalType.getShape().size() >= effectiveLaneLayout.size() && "Rank of the original vector type should be greater or equal to the " "size of the lane layout to distribute the vector type."); SmallVector distributedShape(originalType.getShape()); // Only distribute the last `laneLayout.size()` dimensions. The remaining // dimensions are not distributed. - unsigned distributionStart = originalType.getRank() - laneLayout.size(); + unsigned distributionStart = + originalType.getRank() - effectiveLaneLayout.size(); for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { if (i < distributionStart) continue; // Check if the dimension can be distributed evenly. - if (dim % laneLayout[i - distributionStart] != 0) + if (dim % effectiveLaneLayout[i - distributionStart] != 0) return failure(); - distributedShape[i] = dim / laneLayout[i - distributionStart]; + distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart]; } return VectorType::get(distributedShape, originalType.getElementType()); } @@ -858,7 +875,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { /// gpu.yield %1 : vector<2xf32> /// } struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { - using Base::Base; + using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = @@ -869,83 +886,108 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { cast(yieldOperand->get().getDefiningOp()); unsigned operandNumber = yieldOperand->getOperandNumber(); VectorType sourceType = reductionOp.getSourceVectorType(); - // Only 2D vectors are supported. if (sourceType.getRank() != 2) return rewriter.notifyMatchFailure(warpOp, "Only 2D reductions are supported."); ArrayRef reductionDims = reductionOp.getReductionDims(); - // Only 1 reduction dimension supported. This also ensures that result is - // also vector type. + // Only 1 reduction dimension supported. This also ensures that the result + // is vector type. if (reductionDims.size() != 1) return rewriter.notifyMatchFailure( warpOp, "Only 1 reduction dimension is supported."); int64_t reductionDim = reductionDims[0]; - auto resultType = cast(reductionOp.getType()); - auto distributedResultType = + VectorType distributedResultType = cast(warpOp.getResult(operandNumber).getType()); + VectorType resultType = cast(reductionOp.getType()); Type elementType = distributedResultType.getElementType(); + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(reductionOp.getSource()); - // Currently we make the following assumptions. - // 1. The source vector is distributed in the column dimension. Each lane - // owns complete column(s) of the source vector. - // 2. If the reduction dim == 0, its a lane-local col reduction. In this - // case each lane owns its portion of the result (i.e. result is also - // distributed). - // 3. If reduction dim == 1, its a row reduction that require cross lanes - // shuffles. In this case, the reduction result is not distributed across - // lanes. Instead each lane owns a complete copy of the result - // (broadcasted). - // TODO: These assumptions are fairly restrictive. For example, source - // vector can have row distributed layout. Improve support for such cases. - if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0) + FailureOr sourceDistTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType); + if (failed(sourceDistTypeOrFailure)) return rewriter.notifyMatchFailure( - warpOp, "Source vector dimension must be divisible by warp size."); - bool isResultDistributed = + warpOp, "Failed to distribute the source vector type."); + VectorType sourceDistType = sourceDistTypeOrFailure.value(); + // Only single dimension distribution is supported. + bool dim0Distributed = + sourceDistType.getShape()[0] != sourceType.getShape()[0]; + bool dim1Distributed = + sourceDistType.getShape()[1] != sourceType.getShape()[1]; + if (dim0Distributed && dim1Distributed) + return rewriter.notifyMatchFailure( + warpOp, "Expecting source to be distributed in a single dimension."); + int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1); + if (sourceDistDim == -1) + return rewriter.notifyMatchFailure( + warpOp, "Expecting a distributed source vector."); + bool resultDistributed = distributedResultType.getNumElements() < resultType.getNumElements(); - if (reductionDim == 0 && !isResultDistributed) + // If the lane owns all the data required for reduction (i.e. reduction is + // fully parallel accross lanes), then each lane owns part of the result + // (i.e. result is distributed). If the reduction require cross-lane + // shuffling, then the result is shared among all lanes (broadcasted). + // Therefore we expect following cases: + // + // | Source vector | Reduction dim | Result vector | + // |----------------------|----------------|----------------| + // | dim-0 distributed | 0 | broadcasted | + // | dim-0 distributed | 1 | distributed | + // | dim-1 distributed | 0 | distributed | + // | dim-1 distributed | 1 | broadcasted | + + bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) || + (sourceDistDim == 1 && reductionDim == 0); + if (isReductionLaneLocal && !resultDistributed) return rewriter.notifyMatchFailure( - warpOp, - "Expecting result vector to be distributed in a col reduction."); - if (reductionDim == 1 && isResultDistributed) + warpOp, "Expecting a distributed result for lane-local reduction."); + + if (!isReductionLaneLocal && resultDistributed) return rewriter.notifyMatchFailure( warpOp, - "Expecting result vector to be broadcasted in a row reduction."); + "Expecting a broadcasted result for non-lane-local reduction."); // Create a constant vector to store the result of the reduction per lane. + rewriter.setInsertionPoint(warpOp); TypedAttr zeroAttr = rewriter.getZeroAttr(distributedResultType.getElementType()); Value result = arith::ConstantOp::create( rewriter, reductionOp->getLoc(), distributedResultType, DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Col reduction. - if (reductionDim == 0) { - // Compute source distributed type assuming each lane owns cols. - SmallVector shape(sourceType.getShape()); - shape[1] = shape[1] / warpOp.getWarpSize(); - auto sourceDistributedType = VectorType::get(shape, elementType); + // Handle lane-local reduction case. In this case we fully distribute the + // reduction. + if (isReductionLaneLocal) { // Yield the source and acc vectors from the WarpOp. SmallVector newRetIndices; auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, - {sourceDistributedType, distributedResultType}, newRetIndices); + {sourceDistType, distributedResultType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - int nCols = sourceDistributedType.getShape()[1]; + int nSlices = sourceDistType.getShape()[sourceDistDim]; Value source = newWarpOp.getResult(newRetIndices[0]); Value acc = newWarpOp.getResult(newRetIndices[1]); - // For each column owned by a lane, extract the column (of size nRows x - // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the - // result back to the result vector. - for (int i = 0; i < nCols; ++i) { + // For each slice owned by a lane, extract the slice, shape cast to 1D, do + // a vector.reduction and, insert the result back to the result vector. + for (int i = 0; i < nSlices; ++i) { + SmallVector sliceOffsets, sliceSizes; + if (sourceDistDim == 0) { + sliceOffsets = {i, 0}; + sliceSizes = {1, sourceDistType.getShape()[1]}; + } else { + sliceOffsets = {0, i}; + sliceSizes = {sourceDistType.getShape()[0], 1}; + } Value col = vector::ExtractStridedSliceOp::create( - rewriter, reductionOp.getLoc(), source, {0, i}, - {sourceDistributedType.getShape()[0], 1}, {1, 1}); + rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes, + {1, 1}); + int64_t col1DSize = + sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1]; col = vector::ShapeCastOp::create( rewriter, reductionOp.getLoc(), - VectorType::get({sourceDistributedType.getShape()[0]}, elementType), - col); + VectorType::get({col1DSize}, elementType), col); Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); Value colReduce = vector::ReductionOp::create( @@ -957,26 +999,79 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); return success(); } - // For row reductions, we simply rewrite the MultiReductionOp in terms of - // multiple ReductionOps. Actual distribution is done by the WarpOpReduction - // pattern. + // For non-lane-local case, we simply rewrite the MultiReductionOp in terms + // of multiple ReductionOps. Actual distribution is done by the + // WarpOpReduction pattern. rewriter.setInsertionPointAfter(reductionOp); - int nRows = sourceType.getShape()[0]; - // For each row of the source, extract the row vector, do a reduction and, - // insert the result back to the result. - for (int i = 0; i < nRows; ++i) { - Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), - reductionOp.getSource(), i); - Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), - reductionOp.getAcc(), i); - Value rowReduce = vector::ReductionOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc); + int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0]; + // For each slice of the source, extract the slice vector, do a reduction + // and, insert the result back to the result. + for (int i = 0; i < nSlices; ++i) { + SmallVector sliceOffsets, sliceSizes; + if (sourceDistDim == 1) { + sliceOffsets = {i, 0}; + sliceSizes = {1, sourceType.getShape()[1]}; + } else { + sliceOffsets = {0, i}; + sliceSizes = {sourceType.getShape()[0], 1}; + } + Value col = vector::ExtractStridedSliceOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets, + sliceSizes, {1, 1}); + int64_t col1DSize = sourceType.getShape()[sourceDistDim]; + col = vector::ShapeCastOp::create( + rewriter, reductionOp.getLoc(), + VectorType::get({col1DSize}, elementType), col); + Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), + reductionOp.getAcc(), i); + Value colReduce = vector::ReductionOp::create( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), - rowReduce, result, i); + colReduce, result, i); } // Replace the warp op result with the final result. rewriter.replaceAllUsesWith(reductionOp.getResult(), result); + return success(); + } +}; +struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred); + if (!yieldOperand) + return failure(); + auto shapeCastOp = + cast(yieldOperand->get().getDefiningOp()); + unsigned operandNumber = yieldOperand->getOperandNumber(); + auto resultDistTy = + cast(warpOp.getResult(operandNumber).getType()); + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(shapeCastOp.getSource()); + if (!sourceLayout) + return rewriter.notifyMatchFailure( + warpOp, "the source of shape_cast op lacks distribution layout"); + FailureOr sourceDistTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(sourceLayout, + shapeCastOp.getSourceVectorType()); + if (failed(sourceDistTypeOrFailure)) + return rewriter.notifyMatchFailure( + warpOp, "failed to get distributed vector type for source"); + VectorType sourceDistType = sourceDistTypeOrFailure.value(); + // Create a new warp op that yields the source of the shape_cast op. + SmallVector newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value source = newWarpOp.getResult(newRetIndices[0]); + // Create a new shape_cast op outside the warp op. + Value newShapeCast = vector::ShapeCastOp::create( + rewriter, shapeCastOp.getLoc(), resultDistTy, source); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), + newShapeCast); return success(); } }; @@ -998,6 +1093,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution, GpuBarrierDistribution, VectorMultiReductionDistribution>( patterns.getContext()); + patterns.add(patterns.getContext(), + /*benefit=*/2); } void XeGPUSubgroupDistributePass::runOnOperation() { @@ -1012,8 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa(operand.get().getType())) continue; - auto layout = - xegpu::getDistributeLayoutAttrOfType(operand); + auto layout = xegpu::getDistributeLayoutAttr(operand.get()); if (!layout) { op->emitError("Could not find layout attribute for operand ") << operand.getOperandNumber() << " of operation " << op->getName(); @@ -1074,6 +1170,25 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // TODO: shuffleFn is not used. auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, int64_t warpSz) { return Value(); }; + + auto warpReduction = [](Location loc, OpBuilder &builder, Value input, + vector::CombiningKind kind, uint32_t size) { + // First reduce on a single thread to get per lane reduction value. + Value laneVal = builder.create(loc, kind, input); + // Parallel reduction using butterfly shuffles. + for (uint64_t i = 1; i < size; i <<= 1) { + Value shuffled = + builder + .create(loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); + laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); + } + return laneVal; + }; + + vector::populateDistributeReduction(patterns, warpReduction); vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 8e2e96dfc05cc..54ef56e013abb 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -319,86 +319,3 @@ gpu.module @test { gpu.return } } - -// ----- -// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce -// CHECK-PROP : %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) { -// CHECK-PROP : %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32> -// CHECK-PROP : %[[ACC:.*]] = "some_def"() : () -> vector<64xf32> -// CHECK-PROP : gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32> -// CHECK-PROP : } -// CHECK-PROP : %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 -// CHECK-PROP-SAME : {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP : %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP : %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32> -// CHECK-PROP : %[[REDUCE0:.*]] = vector.reduction , %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32 -// CHECK-PROP : %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 -// CHECK-PROP-SAME : {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32> -// CHECK-PROP : %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32> -// CHECK-PROP : %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32> -// CHECK-PROP : %[[REDUCE1:.*]] = vector.reduction , %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32 -// CHECK-PROP : %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32> -// CHECK-PROP : return %[[R]] : vector<2xf32> -gpu.module @test { -gpu.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> { - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (vector<32x64xf32>) - %acc = "some_def"() : () -> (vector<64xf32>) - %1 = vector.multi_reduction , %0, %acc [0] : vector<32x64xf32> to vector<64xf32> - gpu.yield %1 : vector<64xf32> - } - gpu.return %r : vector<2xf32> -} -} - -// ----- -// CHECK-PROP-LABEL: func.func @vector_multi_reduction_row_reduce -// CHECK-PROP-DAG: %[[C16:.*]] = arith.constant 16 : i32 -// CHECK-PROP-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK-PROP-DAG: %[[C4:.*]] = arith.constant 4 : i32 -// CHECK-PROP-DAG: %[[C2:.*]] = arith.constant 2 : i32 -// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-PROP-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) { -// CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32> -// CHECK-PROP: gpu.yield %[[SRC]] : vector<2x32xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR:.*]], %{{.*}} = gpu.shuffle xor %[[T1]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32 -// CHECK-PROP: %[[SR0:.*]], %{{.*}} = gpu.shuffle xor %[[T2]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32 -// CHECK-PROP: %[[SR2:.*]], %{{.*}} = gpu.shuffle xor %[[T3]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32 -// CHECK-PROP: %[[SR4:.*]], %{{.*}} = gpu.shuffle xor %[[T4]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32 -// CHECK-PROP: %[[SR6:.*]], %{{.*}} = gpu.shuffle xor %[[T5]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32 -// CHECK-PROP: %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32 -// -// CHECK-PROP: %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32> -// CHECK-PROP: %[[SR8:.*]], %{{.*}} = gpu.shuffle xor %[[T8]], %[[C1]], %[[C32]] : f32 -// CHECK-PROP: %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32 -// CHECK-PROP: %[[SR10:.*]], %{{.*}} = gpu.shuffle xor %[[T9]], %[[C2]], %[[C32]] : f32 -// CHECK-PROP: %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32 -// CHECK-PROP: %[[SR12:.*]], %{{.*}} = gpu.shuffle xor %[[T10]], %[[C4]], %[[C32]] : f32 -// CHECK-PROP: %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32 -// CHECK-PROP: %[[SR14:.*]], %{{.*}} = gpu.shuffle xor %[[T11]], %[[C8]], %[[C32]] : f32 -// CHECK-PROP: %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32 -// CHECK-PROP: %[[SR16:.*]], %{{.*}} = gpu.shuffle xor %[[T12]], %[[C16]], %[[C32]] : f32 -// CHECK-PROP: %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32 -// CHECK-PROP: %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32 -// CHECK-PROP: %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> -// CHECK-PROP: return %[[R]] : vector<2xf32> -gpu.module @test { -gpu.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> { - %zero = arith.constant dense<0.0> : vector<2xf32> - %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { - %0 = "some_def"() : () -> (vector<2x32xf32>) - %1 = vector.multi_reduction , %0, %zero [1] : vector<2x32xf32> to vector<2xf32> - gpu.yield %1 : vector<2xf32> - } - gpu.return %r : vector<2xf32> -} -} From 8728eee0cab9a6940d8888c0a3c0cb1d36d154a7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 8 Sep 2025 21:42:56 +0000 Subject: [PATCH 15/26] save work --- .../mlir/Dialect/XeGPU/Transforms/Passes.td | 4 + .../Transforms/XeGPUSubgroupDistribute.cpp | 191 +++++++++--------- .../Dialect/XeGPU/subgroup-distribute.mlir | 113 +++++++++++ 3 files changed, 215 insertions(+), 93 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index ddf6b4ac85a90..59dca9f0d852a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -27,6 +27,10 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> { }]; let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"]; + let options = [Option< + "enableSGReductions", "enable-sg-reductions", "bool", + /*default=*/"true", + "Enable subgroup reductions using subgroup shuffles.">]; } def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 5af45d1324b3b..a436a3d35dba6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -58,6 +58,24 @@ namespace { //===----------------------------------------------------------------------===// // SIMT Distribution Patterns //===----------------------------------------------------------------------===// +static SmallVector +computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) { + SmallVector effectiveLaneLayout; + // If the layout is a slice, we need to get effective lane layout by removing + // sliced dims. + if (auto sliceAttr = dyn_cast(layout)) { + ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); + llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); + for (auto [i, dim] : + llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { + if (!lookUp.contains(i)) + effectiveLaneLayout.push_back(dim); + } + } else { + effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); + } + return effectiveLaneLayout; +} /// Helper function to get distributed vector type for a source vector type /// according to the lane_layout. We simply divide each dimension of tensor @@ -79,20 +97,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, return failure(); assert((isa(layout) || isa(layout)) && "Expecting a valid layout."); - SmallVector effectiveLaneLayout; - // If the layout is a slice, we need to get effective lane layout by removing - // sliced dims. - if (auto sliceAttr = dyn_cast(layout)) { - ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); - llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); - for (auto [i, dim] : - llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { - if (!lookUp.contains(i)) - effectiveLaneLayout.push_back(dim); - } - } else { - effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); - } + SmallVector effectiveLaneLayout = computeEffectiveLaneLayout(layout); assert(originalType.getShape().size() >= effectiveLaneLayout.size() && "Rank of the original vector type should be greater or equal to the " @@ -824,13 +829,64 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { } }; +/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D +/// VectorReductionOps. +static Value lowerToVectorReductions(TypedValue src, + TypedValue acc, + vector::CombiningKind kind, + int64_t reductionDim, Location loc, + PatternRewriter &rewriter) { + // Expecting a 2D source vector. + assert(src.getType().getRank() == 2 && "expected a 2D source vector"); + VectorType sourceType = src.getType(); + int64_t sourceH = sourceType.getShape()[0]; + int64_t sourceW = sourceType.getShape()[1]; + int nSlices = (reductionDim == 0) ? sourceW : sourceH; + // Create a constant vector to hold the result of the reduction. + TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType()); + Value reductionResult = arith::ConstantOp::create( + rewriter, loc, acc.getType(), + DenseElementsAttr::get(acc.getType(), zeroAttr)); + // For each slice of the source, extract the slice vector, do a reduction + // and, insert the reduced value back to the result vector. + for (int i = 0; i < nSlices; ++i) { + SmallVector sliceOffsets, sliceSizes; + if (reductionDim == 1) { + sliceOffsets = {i, 0}; + sliceSizes = {1, sourceW}; + } else { + sliceOffsets = {0, i}; + sliceSizes = {sourceH, 1}; + } + vector::ExtractStridedSliceOp extractOp = + vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets, + sliceSizes, {1, 1}); + int64_t nSliceElements = extractOp.getResult().getType().getNumElements(); + Value slice = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get({nSliceElements}, sourceType.getElementType()), + extractOp.getResult()); + Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i); + Value reduction = + vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract); + reductionResult = + vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i); + } + return reductionResult; +} + /// This patterns distribute the `vector.multi_reduction` operation across -/// lanes in a warp. Currently only 2D to 1D reductions are supported and -/// assumes that source vector is distributed in column dimension (i.e. Each -/// lane owns complete column(s) of the source vector). -/// TODO: Add support for the case where source rows are distributed across -/// lanes. Requires `DistributionMapFn` to express the data distribution. -/// Example 1 (Col reduction): +/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given +/// layouts for the source and accumulator vectors, +/// * If the reduction dimension is distributed across lanes, the reduction is +/// non-lane-local and the reduction is done using warp shuffles. Here we +/// simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in +/// the warp op body. +/// * If the reduction dimension is not distributed across lanes, the reduction +/// is lane-local. In this case, we yield the source and accumulator vectors +/// from the warp op and perform the lane-local reduction outside the warp op +/// using a sequence of ReductionOps. +/// Example 1 (Reduction is lane-local): /// ``` /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { /// %0 = "some_def"() : () -> (vector<16x32xf32>) @@ -852,7 +908,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { /// %2 = vector.reduction , %1, %r#1 : vector<16xf32> to f32 /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32> /// ``` -/// Example 2 (Row reduction): +/// Example 2 (Reduction is non-lane-local): /// ``` /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) { /// %0 = "some_def"() : () -> (vector<2x32xf32>) @@ -900,7 +956,6 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { VectorType distributedResultType = cast(warpOp.getResult(operandNumber).getType()); VectorType resultType = cast(reductionOp.getType()); - Type elementType = distributedResultType.getElementType(); xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(reductionOp.getSource()); @@ -948,16 +1003,8 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { warpOp, "Expecting a broadcasted result for non-lane-local reduction."); - // Create a constant vector to store the result of the reduction per lane. - rewriter.setInsertionPoint(warpOp); - TypedAttr zeroAttr = - rewriter.getZeroAttr(distributedResultType.getElementType()); - Value result = arith::ConstantOp::create( - rewriter, reductionOp->getLoc(), distributedResultType, - DenseElementsAttr::get(distributedResultType, zeroAttr)); - // Handle lane-local reduction case. In this case we fully distribute the - // reduction. + // reduction result. if (isReductionLaneLocal) { // Yield the source and acc vectors from the WarpOp. SmallVector newRetIndices; @@ -965,70 +1012,22 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()}, {sourceDistType, distributedResultType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - - int nSlices = sourceDistType.getShape()[sourceDistDim]; - Value source = newWarpOp.getResult(newRetIndices[0]); - Value acc = newWarpOp.getResult(newRetIndices[1]); - // For each slice owned by a lane, extract the slice, shape cast to 1D, do - // a vector.reduction and, insert the result back to the result vector. - for (int i = 0; i < nSlices; ++i) { - SmallVector sliceOffsets, sliceSizes; - if (sourceDistDim == 0) { - sliceOffsets = {i, 0}; - sliceSizes = {1, sourceDistType.getShape()[1]}; - } else { - sliceOffsets = {0, i}; - sliceSizes = {sourceDistType.getShape()[0], 1}; - } - Value col = vector::ExtractStridedSliceOp::create( - rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes, - {1, 1}); - int64_t col1DSize = - sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1]; - col = vector::ShapeCastOp::create( - rewriter, reductionOp.getLoc(), - VectorType::get({col1DSize}, elementType), col); - Value accCol = - vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i); - Value colReduce = vector::ReductionOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); - result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), - colReduce, result, i); - } - // Replace the warp op result with the new reduction op. - rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result); + Value result = lowerToVectorReductions( + cast>(newWarpOp->getResult(newRetIndices[0])), + cast>(newWarpOp->getResult(newRetIndices[1])), + reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter); + // Replace the warp op result with the final result. + rewriter.replaceAllUsesWith(reductionOp.getResult(), result); return success(); } // For non-lane-local case, we simply rewrite the MultiReductionOp in terms // of multiple ReductionOps. Actual distribution is done by the // WarpOpReduction pattern. rewriter.setInsertionPointAfter(reductionOp); - int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0]; - // For each slice of the source, extract the slice vector, do a reduction - // and, insert the result back to the result. - for (int i = 0; i < nSlices; ++i) { - SmallVector sliceOffsets, sliceSizes; - if (sourceDistDim == 1) { - sliceOffsets = {i, 0}; - sliceSizes = {1, sourceType.getShape()[1]}; - } else { - sliceOffsets = {0, i}; - sliceSizes = {sourceType.getShape()[0], 1}; - } - Value col = vector::ExtractStridedSliceOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets, - sliceSizes, {1, 1}); - int64_t col1DSize = sourceType.getShape()[sourceDistDim]; - col = vector::ShapeCastOp::create( - rewriter, reductionOp.getLoc(), - VectorType::get({col1DSize}, elementType), col); - Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(), - reductionOp.getAcc(), i); - Value colReduce = vector::ReductionOp::create( - rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol); - result = vector::InsertOp::create(rewriter, reductionOp.getLoc(), - colReduce, result, i); - } + Value result = lowerToVectorReductions( + cast>(reductionOp.getSource()), + cast>(reductionOp.getAcc()), + reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter); // Replace the warp op result with the final result. rewriter.replaceAllUsesWith(reductionOp.getResult(), result); return success(); @@ -1082,6 +1081,11 @@ namespace { struct XeGPUSubgroupDistributePass final : public xegpu::impl::XeGPUSubgroupDistributeBase< XeGPUSubgroupDistributePass> { + XeGPUSubgroupDistributePass() = default; + XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) = + default; + XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options) + : XeGPUSubgroupDistributeBase(options) {} void runOnOperation() override; }; } // namespace @@ -1150,8 +1154,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (vecRank == 0) return AffineMap::get(val.getContext()); // Get the layout of the vector type. - // TODO: support more layout types - auto layout = xegpu::getDistributeLayoutAttrOfType(val); + xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val); // If no layout is specified, assume the inner most dimension is distributed // for now. if (!layout) @@ -1159,7 +1162,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { vecRank, {static_cast(vecRank - 1)}, val.getContext()); SmallVector distributedDims; // Get the distributed dimensions based on the layout. - ArrayRef laneLayout = layout.getLaneLayout().asArrayRef(); + SmallVector laneLayout = computeEffectiveLaneLayout(layout); for (unsigned i = 0; i < laneLayout.size(); ++i) { if (laneLayout[i] > 1) distributedDims.push_back(i); @@ -1188,7 +1191,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() { return laneVal; }; - vector::populateDistributeReduction(patterns, warpReduction); + if (enableSGReductions) + vector::populateDistributeReduction(patterns, warpReduction); + vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 54ef56e013abb..77a475cef126c 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xegpu-subgroup-distribute="enable-sg-reductions=false" -allow-unregistered-dialect \ +// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION + // CHECK-LABEL: gpu.func @store_nd_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32> @@ -319,3 +322,113 @@ gpu.module @test { gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> +// CHECK-SAME: (!xegpu.tensor_desc<1x32xf32, #xegpu.layout>, vector<16x2xf32>) { +// CHECK: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout} : () -> vector<16x32xf32> +// CHECK-NEXT: gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<1x32xf32, #xegpu.layout>, vector<16x32xf32> +// CHECK-NEXT: } +// CHECK: %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32> +// CHECK-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32> +// CHECK-NEXT: %[[RED0:.*]] = vector.reduction , %[[CAST0]], %{{.*}} : vector<16xf32> into f32 +// CHECK: %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32> +// CHECK-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32> +// CHECK-NEXT: %[[RED1:.*]] = vector.reduction , %[[CAST1]], %{{.*}} : vector<16xf32> into f32 +// CHECK-NEXT: vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32> +gpu.module @test { +gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction() { + %0 = "some_def"() : () -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout> + %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<16x32xf32>) + %acc = arith.constant {layout_result_0 = #xegpu.layout} dense<0.0> : vector<32xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.layout} [0] + : vector<16x32xf32> to vector<32xf32> + %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} + : vector<32xf32> to vector<1x32xf32> + xegpu.store_nd %3, %0 : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #xegpu.layout> + gpu.return +} +} + +// ----- +// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction +// CHECK-REDUCTION: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<2x16xf32, +// CHECK-REDUCTION-SAME: #xegpu.layout>, f32, f32) { +// CHECK-REDUCTION: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout} : () -> vector<2x16xf32> +// CHECK-REDUCTION-NEXT: %[[ROW0:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32> +// CHECK-REDUCTION-NEXT: %[[R0:.*]] = vector.reduction , %[[ROW0]], %{{.*}} : vector<16xf32> into f32 +// CHECK-REDUCTION-NEXT: %[[ROW1:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32> +// CHECK-REDUCTION-NEXT: %[[R1:.*]] = vector.reduction , %[[ROW1]], %{{.*}} : vector<16xf32> into f32 +// CHECK-REDUCTION-NEXT: gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<2x16xf32, #xegpu.layout>, f32, f32 +// CHECK-REDUCTION-NEXT: } +// CHECK-REDUCTION-NEXT: vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32> +gpu.module @test { +gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction() { + %0 = "some_def"() : () -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout> + %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<2x16xf32>) + %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<0.0> : vector<2xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} + [1] : vector<2x16xf32> to vector<2xf32> + %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} + : vector<2xf32> to vector<2x1xf32> + %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout} : vector<2x1xf32> to vector<2x16xf32> + xegpu.store_nd %4, %0 : vector<2x16xf32>, !xegpu.tensor_desc<2x16xf32, #xegpu.layout> + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%0)[16] -> +// CHECK-SAME: (!xegpu.tensor_desc<32x1xf32, #xegpu.layout>, vector<2x16xf32>) { +// CHECK: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout} : () -> vector<32x16xf32> +// CHECK-NEXT: gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<32x1xf32, #xegpu.layout>, vector<32x16xf32> +// CHECK-NEXT: } +// CHECK: %[[ROW0:.*]] = vector.extract %[[W]]#1[0] : vector<16xf32> from vector<2x16xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.reduction , %[[ROW0]], %{{.*}} : vector<16xf32> into f32 +// CHECK: %[[ROW1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.reduction , %[[ROW1]], %{{.*}} : vector<16xf32> into f32 +// CHECK-NEXT: vector.from_elements %[[R0]], %[[R1]] : vector<2xf32> +gpu.module @test { +gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction() { + %0 = "some_def"() : () -> !xegpu.tensor_desc<32x1xf32, #xegpu.layout> + %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<32x16xf32>) + %acc = arith.constant {layout_result_0 = #xegpu.layout} dense<0.0> : vector<32xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.layout} [1] + : vector<32x16xf32> to vector<32xf32> + %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} + : vector<32xf32> to vector<32x1xf32> + xegpu.store_nd %3, %0 : vector<32x1xf32>, !xegpu.tensor_desc<32x1xf32, #xegpu.layout> + gpu.return +} +} + +// ----- +// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction +// CHECK-REDUCTION: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<16x2xf32, +// CHECK-REDUCTION-SAME: #xegpu.layout>, f32, f32) { +// CHECK-REDUCTION: %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout} : () -> vector<16x2xf32> +// CHECK-REDUCTION-NEXT: %[[COL0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32> +// CHECK-REDUCTION-NEXT: %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32> +// CHECK-REDUCTION-NEXT: %[[R0:.*]] = vector.reduction , %[[CAST0]], %{{.*}} : vector<16xf32> into f32 +// CHECK-REDUCTION-NEXT: %[[COL1:.*]] = vector.extract_strided_slice %5 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32> +// CHECK-REDUCTION-NEXT: %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32> +// CHECK-REDUCTION-NEXT: %[[R1:.*]] = vector.reduction , %[[CAST1]], %cst : vector<16xf32> into f32 +// CHECK-REDUCTION-NEXT: gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<16x2xf32, #xegpu.layout>, f32, f32 +// CHECK-REDUCTION-NEXT: } +// CHECK-REDUCTION-NEXT: vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32> +gpu.module @test { +gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction() { + %0 = "some_def"() : () -> !xegpu.tensor_desc<16x2xf32, #xegpu.layout> + %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<16x2xf32>) + %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<0.0> : vector<2xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} + [0] : vector<16x2xf32> to vector<2xf32> + %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} + : vector<2xf32> to vector<1x2xf32> + %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout} : vector<1x2xf32> to vector<16x2xf32> + xegpu.store_nd %4, %0 : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.layout> + gpu.return +} +} From 9b72ac0300caf075b36dfbb38164e196766addd3 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 8 Sep 2025 22:17:28 +0000 Subject: [PATCH 16/26] save work --- .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index a436a3d35dba6..4faaa487f47a7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -58,6 +58,14 @@ namespace { //===----------------------------------------------------------------------===// // SIMT Distribution Patterns //===----------------------------------------------------------------------===// + +/// In certain cases, we may need to favor XeGPU specific distribution patterns +/// over generic vector distribution patterns. In such cases, we can assign +/// priorities to patterns. +enum class PatternPriority : int { Regular = 1, High = 2 }; + +/// Helper function to compute the effective lane layout from a +/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. static SmallVector computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) { SmallVector effectiveLaneLayout; @@ -1034,6 +1042,8 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { } }; +/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing +/// `gpu.warp_execute_on_lane_0` region. struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, @@ -1098,7 +1108,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( GpuBarrierDistribution, VectorMultiReductionDistribution>( patterns.getContext()); patterns.add(patterns.getContext(), - /*benefit=*/2); + /*benefit=*/PatternPriority::High); } void XeGPUSubgroupDistributePass::runOnOperation() { From 18547136864d07cc4e85abc9bc421c1bfd79290e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 8 Sep 2025 22:44:34 +0000 Subject: [PATCH 17/26] save work --- .../Transforms/XeGPUSubgroupDistribute.cpp | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 0e52f8ca744ab..bf4b75f6ae2b0 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -33,6 +33,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/LogicalResult.h" @@ -62,7 +63,8 @@ namespace { /// In certain cases, we may need to favor XeGPU specific distribution patterns /// over generic vector distribution patterns. In such cases, we can assign /// priorities to patterns. -enum class PatternPriority : int { Regular = 1, High = 2 }; +static constexpr unsigned regularPatternBenefit = 1; +static constexpr unsigned highPatternBenefit = 2; /// Helper function to compute the effective lane layout from a /// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. @@ -1300,9 +1302,12 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( .add(patterns.getContext()); - patterns.add(patterns.getContext(), - /*benefit=*/PatternPriority::High); + LoadDistribution, StoreDistribution>( + patterns.getContext(), + /*pattern benefit=*/regularPatternBenefit); + patterns.add( + patterns.getContext(), + /*pattern benefit=*/highPatternBenefit); } void XeGPUSubgroupDistributePass::runOnOperation() { @@ -1396,10 +1401,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { }; if (enableSGReductions) - vector::populateDistributeReduction(patterns, warpReduction); + vector::populateDistributeReduction( + patterns, warpReduction, + /*pattern benefit=*/regularPatternBenefit); vector::populatePropagateWarpVectorDistributionPatterns( - patterns, distributionFn, shuffleFn); + patterns, distributionFn, shuffleFn, + /*pattern benefit=*/regularPatternBenefit); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); return; From 797aa3e3d5b63b2c250a0eaac2eb0f561aaecc2e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Mon, 8 Sep 2025 22:58:52 +0000 Subject: [PATCH 18/26] save work --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index bf4b75f6ae2b0..6f06c751549d3 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -33,9 +33,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" namespace mlir { namespace xegpu { From 232808e7ad63d256240ad3c33dda4d40f489e5f8 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 9 Sep 2025 18:32:49 +0000 Subject: [PATCH 19/26] save work --- .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 9 ++++++ .../Transforms/XeGPUSubgroupDistribute.cpp | 29 ++++--------------- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 21 ++++++++++++++ 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 04cfd58d846a7..507d748c5f4f6 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -162,6 +162,15 @@ SmallVector addElementwise(OpBuilder &builder, Location loc, SmallVector addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef lhs, ArrayRef rhs); + +/// Helper function to compute the effective lane layout from a +/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. For +/// LayoutAttr, this will simply return the lane layout. For SliceAttr, it will +/// compute the effective lane layout by removing the sliced dimensions from the +/// parent lane layout. +SmallVector +computeEffectiveLaneLayout(xegpu::DistributeLayoutAttr layout); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 6f06c751549d3..0f6368583c763 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -64,27 +64,6 @@ namespace { static constexpr unsigned regularPatternBenefit = 1; static constexpr unsigned highPatternBenefit = 2; -/// Helper function to compute the effective lane layout from a -/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. -static SmallVector -computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) { - SmallVector effectiveLaneLayout; - // If the layout is a slice, we need to get effective lane layout by removing - // sliced dims. - if (auto sliceAttr = dyn_cast(layout)) { - ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); - llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); - for (auto [i, dim] : - llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { - if (!lookUp.contains(i)) - effectiveLaneLayout.push_back(dim); - } - } else { - effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); - } - return effectiveLaneLayout; -} - /// Helper function to get distributed vector type for a source vector type /// according to the lane_layout. We simply divide each dimension of tensor /// descriptor shape by corresponding lane_layout dimension. If @@ -105,9 +84,11 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, return failure(); assert((isa(layout) || isa(layout)) && "Expecting a valid layout."); - SmallVector effectiveLaneLayout = computeEffectiveLaneLayout(layout); + SmallVector effectiveLaneLayout = + xegpu::computeEffectiveLaneLayout(layout); - assert(originalType.getShape().size() >= effectiveLaneLayout.size() && + assert(static_cast(originalType.getRank()) >= + effectiveLaneLayout.size() && "Rank of the original vector type should be greater or equal to the " "size of the lane layout to distribute the vector type."); SmallVector distributedShape(originalType.getShape()); @@ -1369,7 +1350,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() { vecRank, {static_cast(vecRank - 1)}, val.getContext()); SmallVector distributedDims; // Get the distributed dimensions based on the layout. - SmallVector laneLayout = computeEffectiveLaneLayout(layout); + SmallVector laneLayout = xegpu::computeEffectiveLaneLayout(layout); for (unsigned i = 0; i < laneLayout.size(); ++i) { if (laneLayout[i] > 1) distributedDims.push_back(i); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index b72d5648b29f9..aa55ed27d6ecf 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -484,3 +484,24 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc, results.append(addElementwise(builder, loc, a, b)); return results; } + +SmallVector +xegpu::computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) { + if (!layout) + return {}; + SmallVector effectiveLaneLayout; + // If the layout is a slice, we need to get effective lane layout by removing + // sliced dims. + if (auto sliceAttr = dyn_cast(layout)) { + ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); + llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); + for (auto [i, dim] : + llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { + if (!lookUp.contains(i)) + effectiveLaneLayout.push_back(dim); + } + } else { + effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); + } + return effectiveLaneLayout; +} From be1c00cc486c3b2fe69c13b5477df5be8bd1c70e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 17:51:57 +0000 Subject: [PATCH 20/26] add transpose function --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index cfe3e800484ce..24756318e4339 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -231,7 +231,51 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { multiple blocks according to round-robin distribution rules.}], "FailureOr>>", "getOffsets", - (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)> + (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)>, + InterfaceMethod": $perm), + /*methodBody=*/[{ + if (!other) + return false; + if ($_self.getRank() != other.getRank() || perm.size() != static_cast($_self.getRank())) + return false; + // check if the permutation is valid + int64_t rank = $_self.getRank(); + SmallVector seen(rank, false); + for (const auto &ta : llvm::enumerate(perm)) { + if (ta.value() < 0 || ta.value() >= rank) + return false; + if (seen[ta.value()]) + return false; + seen[ta.value()] = true; + } + auto checkTranspose = [](ArrayRef dst, ArrayRef src, ArrayRef perm) { + for (const auto &ta : llvm::enumerate(perm)) { + if (src[ta.index()] != dst[ta.value()]) + return false; + } + return true; + }; + // check sgLayout + if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm)) + return false; + // check sgData + if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm)) + return false; + // check instData + if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm)) + return false; + // check laneLayout + if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm)) + return false; + // check laneData + if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) + return false; + return true; + }]> ]; } From 2ebe31ec0df799bcc5d3ea1cb998acf866a2f5fa Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 19:33:00 +0000 Subject: [PATCH 21/26] fix test --- mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 6d8d70972de42..b4086df3cfd11 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -341,8 +341,8 @@ gpu.module @test { gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction() { %0 = "some_def"() : () -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout> %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<16x32xf32>) - %acc = arith.constant {layout_result_0 = #xegpu.layout} dense<0.0> : vector<32xf32> - %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.layout} [0] + %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} dense<0.0> : vector<32xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] : vector<16x32xf32> to vector<32xf32> %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} : vector<32xf32> to vector<1x32xf32> @@ -394,10 +394,10 @@ gpu.module @test { gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction() { %0 = "some_def"() : () -> !xegpu.tensor_desc<32x1xf32, #xegpu.layout> %src = "some_def"() {layout_result_0 = #xegpu.layout} : () -> (vector<32x16xf32>) - %acc = arith.constant {layout_result_0 = #xegpu.layout} dense<0.0> : vector<32xf32> - %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.layout} [1] + %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<0.0> : vector<32xf32> + %1 = vector.multi_reduction , %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] : vector<32x16xf32> to vector<32xf32> - %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} + %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout} : vector<32xf32> to vector<32x1xf32> xegpu.store_nd %3, %0 : vector<32x1xf32>, !xegpu.tensor_desc<32x1xf32, #xegpu.layout> gpu.return From 916c75f12298f76b2f8c6e2b5645125e75d34a73 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 23:15:18 +0000 Subject: [PATCH 22/26] add slice attribute utils --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 12 ++++++++++- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 24756318e4339..aa3e3c5cddc05 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -275,7 +275,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) return false; return true; - }]> + }]>, + InterfaceMethod ]; } @@ -477,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } + }]; let assemblyFormat = "`<` struct(params) `>`"; @@ -638,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { FailureOr>> getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef shape); + /// Check if this is slice of some other layout. + bool isSliceOf(const xegpu::DistributeLayoutAttr &other); + }]; let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`"; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7f3be7f91c56b..a3783d5e05df6 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, shape); } +bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { + auto flattenedThis = flatten(); + // If other is a LayoutAttr, just compare directly with parent of + // flattenedThis. + if (auto otherLayout = dyn_cast(other)) + return flattenedThis.getParent() == otherLayout; + // If other is a SliceAttr, flatten it first before comparing. + auto otherFlattened = dyn_cast(other).flatten(); + // Both must have common parent LayoutAttr. + if (flattenedThis.getParent() != otherFlattened.getParent()) + return false; + // otherFlattened's sliced dims must be a subset of flattenedThis's sliced + // dims. + llvm::SmallDenseSet thisDims( + flattenedThis.getDims().asArrayRef().begin(), + flattenedThis.getDims().asArrayRef().end()); + return llvm::all_of(otherFlattened.getDims().asArrayRef(), + [&](int64_t dim) { return thisDims.contains(dim); }); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// From 77e8a9477dbd76bf95e5d142a0a6e6a4596ab3d2 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 10 Sep 2025 23:54:57 +0000 Subject: [PATCH 23/26] fix name --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index a3783d5e05df6..cc133b110c95a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -417,16 +417,16 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { if (auto otherLayout = dyn_cast(other)) return flattenedThis.getParent() == otherLayout; // If other is a SliceAttr, flatten it first before comparing. - auto otherFlattened = dyn_cast(other).flatten(); + auto flattenedOther = dyn_cast(other).flatten(); // Both must have common parent LayoutAttr. - if (flattenedThis.getParent() != otherFlattened.getParent()) + if (flattenedThis.getParent() != flattenedOther.getParent()) return false; // otherFlattened's sliced dims must be a subset of flattenedThis's sliced // dims. llvm::SmallDenseSet thisDims( flattenedThis.getDims().asArrayRef().begin(), flattenedThis.getDims().asArrayRef().end()); - return llvm::all_of(otherFlattened.getDims().asArrayRef(), + return llvm::all_of(flattenedOther.getDims().asArrayRef(), [&](int64_t dim) { return thisDims.contains(dim); }); } From 6e2f42071f1a161e2f7f2d9f976becb842fa44af Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 11 Sep 2025 19:17:31 +0000 Subject: [PATCH 24/26] fix func naming --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 99 ++++++------------- .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 8 -- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 26 ++--- .../XeGPU/Transforms/XeGPUBlocking.cpp | 20 ++-- .../Transforms/XeGPUSubgroupDistribute.cpp | 30 ++++-- .../Transforms/XeGPUWgToSgDistribute.cpp | 8 +- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 21 ---- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 3 +- 8 files changed, 81 insertions(+), 134 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index aa3e3c5cddc05..1f1d367118365 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -194,26 +194,29 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { InterfaceMethod<"Get the num of effective subgroups", "int64_t", "getNumSubgroups", (ins), [{ - std::optional> sgLayout = llvm::cast(tablegen_opaque_val).getSgLayoutAsInt(); + std::optional> sgLayout = llvm::cast(tablegen_opaque_val).getEffectiveSgLayoutAsInt(); if (sgLayout.has_value()) return computeProduct(*sgLayout); return 0; }], [{}]>, - InterfaceMethod<"Get the SgLayout field of the attribute as integer array", + InterfaceMethod<"Get the order of the layout attribute", + "DenseI32ArrayAttr", + "getOrder">, + InterfaceMethod<"Get the effective SgLayout of the layout attribute as integer array", "SmallVector", - "getSgLayoutAsInt">, - InterfaceMethod<"Get the SgData field of the attribute as integer array", + "getEffectiveSgLayoutAsInt">, + InterfaceMethod<"Get the effective SgData of the layout attribute as integer array", "SmallVector", - "getSgDataAsInt">, - InterfaceMethod<"Get the InstData field of the attribute as integer array", + "getEffectiveSgDataAsInt">, + InterfaceMethod<"Get the effective InstData of the layout attribute as integer array", "SmallVector", - "getInstDataAsInt">, - InterfaceMethod<"Get the LaneLayout field of the attribute as integer array", + "getEffectiveInstDataAsInt">, + InterfaceMethod<"Get the effective LaneLayout of the layout attribute as integer array", "SmallVector", - "getLaneLayoutAsInt">, - InterfaceMethod<"Get the LaneData field of the attribute as integer array", + "getEffectiveLaneLayoutAsInt">, + InterfaceMethod<"Get the effective LaneData of the layout attribute as integer array", "SmallVector", - "getLaneDataAsInt">, + "getEffectiveLaneDataAsInt">, InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData", "xegpu::DistributeLayoutAttr", "dropSgLayoutAndData">, @@ -232,50 +235,6 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { "FailureOr>>", "getOffsets", (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef":$shape)>, - InterfaceMethod": $perm), - /*methodBody=*/[{ - if (!other) - return false; - if ($_self.getRank() != other.getRank() || perm.size() != static_cast($_self.getRank())) - return false; - // check if the permutation is valid - int64_t rank = $_self.getRank(); - SmallVector seen(rank, false); - for (const auto &ta : llvm::enumerate(perm)) { - if (ta.value() < 0 || ta.value() >= rank) - return false; - if (seen[ta.value()]) - return false; - seen[ta.value()] = true; - } - auto checkTranspose = [](ArrayRef dst, ArrayRef src, ArrayRef perm) { - for (const auto &ta : llvm::enumerate(perm)) { - if (src[ta.index()] != dst[ta.value()]) - return false; - } - return true; - }; - // check sgLayout - if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm)) - return false; - // check sgData - if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm)) - return false; - // check instData - if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm)) - return false; - // check laneLayout - if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm)) - return false; - // check laneData - if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm)) - return false; - return true; - }]>, InterfaceMethod { getLaneLayout(), getLaneData(), getOrder()); } - SmallVector getSgLayoutAsInt() const { + SmallVector getEffectiveSgLayoutAsInt() const { if (DenseI32ArrayAttr layout = getSgLayout()) return llvm::to_vector_of(layout.asArrayRef()); return {}; } - SmallVector getSgDataAsInt() const { + SmallVector getEffectiveSgDataAsInt() const { if (DenseI32ArrayAttr data = getSgData()) return llvm::to_vector_of(data.asArrayRef()); return {}; } - SmallVector getInstDataAsInt() const { + SmallVector getEffectiveInstDataAsInt() const { if (DenseI32ArrayAttr inst = getInstData()) return llvm::to_vector_of(inst.asArrayRef()); return {}; } - SmallVector getLaneLayoutAsInt() const { + SmallVector getEffectiveLaneLayoutAsInt() const { if (DenseI32ArrayAttr layout = getLaneLayout()) return llvm::to_vector_of(layout.asArrayRef()); return {}; } - SmallVector getLaneDataAsInt() const { + SmallVector getEffectiveLaneDataAsInt() const { if (DenseI32ArrayAttr data = getLaneData()) return llvm::to_vector_of(data.asArrayRef()); return {}; @@ -550,10 +509,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the SgLayout of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - SmallVector getSgLayoutAsInt() const { + SmallVector getEffectiveSgLayoutAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - auto layout = parent.getSgLayoutAsInt(); + auto layout = parent.getEffectiveSgLayoutAsInt(); if (layout.size()) { ArrayRef dims = attr.getDims().asArrayRef(); return XeGPUDialect::slice(ArrayRef(layout), dims); @@ -563,10 +522,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the SgData of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - SmallVector getSgDataAsInt() const { + SmallVector getEffectiveSgDataAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - auto data = parent.getSgDataAsInt(); + auto data = parent.getEffectiveSgDataAsInt(); if (data.size()) { ArrayRef dims = attr.getDims().asArrayRef(); return XeGPUDialect::slice(ArrayRef(data), dims); @@ -576,10 +535,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the InstData of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - SmallVector getInstDataAsInt() const { + SmallVector getEffectiveInstDataAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - auto inst = parent.getInstDataAsInt(); + auto inst = parent.getEffectiveInstDataAsInt(); if (inst.size()) { ArrayRef dims = attr.getDims().asArrayRef(); return XeGPUDialect::slice(llvm::ArrayRef(inst), dims); @@ -589,10 +548,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the LaneLayout of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - SmallVector getLaneLayoutAsInt() const { + SmallVector getEffectiveLaneLayoutAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - auto layout = parent.getLaneLayoutAsInt(); + auto layout = parent.getEffectiveLaneLayoutAsInt(); if (layout.size()) { ArrayRef dims = attr.getDims().asArrayRef(); return XeGPUDialect::slice(llvm::ArrayRef(layout), dims); @@ -602,10 +561,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Returns the LaneData of the attribute, computed by applying /// the slice dimensions to the underlying LayoutAttr. - SmallVector getLaneDataAsInt() const { + SmallVector getEffectiveLaneDataAsInt() const { SliceAttr attr = flatten(); auto parent = dyn_cast(attr.getParent()); - auto data = parent.getLaneDataAsInt(); + auto data = parent.getEffectiveLaneDataAsInt(); if (data.size()) { ArrayRef dims = attr.getDims().asArrayRef(); return XeGPUDialect::slice(llvm::ArrayRef(data), dims); diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 507d748c5f4f6..ebdef7315b0a3 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -163,14 +163,6 @@ SmallVector addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef lhs, ArrayRef rhs); -/// Helper function to compute the effective lane layout from a -/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. For -/// LayoutAttr, this will simply return the lane layout. For SliceAttr, it will -/// compute the effective lane layout by removing the sliced dimensions from the -/// parent lane layout. -SmallVector -computeEffectiveLaneLayout(xegpu::DistributeLayoutAttr layout); - } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index cc133b110c95a..6094b0fb42f08 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -134,22 +134,23 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef shape, }; // check the sgLayout and sgData - auto maybeSgShape = - tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt()); + auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(), + attr.getEffectiveSgDataAsInt()); if (!maybeSgShape) return false; auto sgShape = maybeSgShape.value(); // check InstData, it neither have layout nor need round-robin auto maybeInstShape = - tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false); + tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false); if (!maybeInstShape) return false; auto instShape = maybeInstShape.value(); // check LaneLayout and LaneData - auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(), - attr.getLaneDataAsInt(), false); + auto maybeLaneShape = + tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(), + attr.getEffectiveLaneDataAsInt(), false); return maybeLaneShape.has_value(); } @@ -283,9 +284,10 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, if (!hasDefaultOrder()) return mlir::emitError(loc, "order attribute is currently not supported."); - auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value { - return builder.createOrFold(loc, d); - }); + auto dims = + llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { + return builder.createOrFold(loc, d); + }); return affine::delinearizeIndex(builder, loc, linearId, dims); } @@ -299,8 +301,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, if (!isForWorkgroup()) return failure(); - SmallVector sgLayout = getSgLayoutAsInt(); - SmallVector sgShape = getSgDataAsInt(); + SmallVector sgLayout = getEffectiveSgLayoutAsInt(); + SmallVector sgShape = getEffectiveSgDataAsInt(); if (sgShape.empty()) { if (auto derivedShape = computeShapeRatio(shape, sgLayout)) sgShape = derivedShape.value(); @@ -386,8 +388,8 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, if (!isForWorkgroup()) return failure(); - SmallVector sgLayout = getSgLayoutAsInt(); - SmallVector sgShape = getSgDataAsInt(); + SmallVector sgLayout = getEffectiveSgLayoutAsInt(); + SmallVector sgShape = getEffectiveSgDataAsInt(); if (sgShape.empty()) { if (auto derivedShape = computeShapeRatio(shape, sgLayout)) sgShape = derivedShape.value(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 5d5ff69e06886..7efa4b9fbd934 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -85,16 +85,16 @@ struct ConvertLayoutOpPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { - xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr(); - xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr(); - if (input_layout.getInstDataAsInt().empty() || - target_layout.getInstDataAsInt().empty()) + xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr(); + xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr(); + if (inputLayout.getEffectiveInstDataAsInt().empty() || + targetLayout.getEffectiveInstDataAsInt().empty()) return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp."); - input_layout = input_layout.dropInstData(); - target_layout = target_layout.dropInstData(); + inputLayout = inputLayout.dropInstData(); + targetLayout = targetLayout.dropInstData(); auto newOp = rewriter.createOrFold( - op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout); + op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout); rewriter.replaceOp(op, newOp); return success(); } @@ -145,8 +145,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { - if (!layout.getInstDataAsInt().empty()) - return layout.getInstDataAsInt(); + if (!layout.getEffectiveInstDataAsInt().empty()) + return layout.getEffectiveInstDataAsInt(); if (auto type = dyn_cast(value.getType())) return llvm::to_vector(type.getShape()); @@ -226,7 +226,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { Type valTy = value.getType(); if (auto tdescTy = dyn_cast(valTy)) { xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); - return layout && !layout.getInstDataAsInt().empty(); + return layout && !layout.getEffectiveInstDataAsInt().empty(); } auto shapedType = dyn_cast(valTy); return shapedType && !llvm::equal(tileShape, shapedType.getShape()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 0f6368583c763..21c1583bf2633 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -85,8 +85,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, assert((isa(layout) || isa(layout)) && "Expecting a valid layout."); SmallVector effectiveLaneLayout = - xegpu::computeEffectiveLaneLayout(layout); - + layout.getEffectiveLaneLayoutAsInt(); assert(static_cast(originalType.getRank()) >= effectiveLaneLayout.size() && "Rank of the original vector type should be greater or equal to the " @@ -1234,9 +1233,26 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { cast(warpOp.getResult(operandNumber).getType()); xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(shapeCastOp.getSource()); - if (!sourceLayout) + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getDistributeLayoutAttr(shapeCastOp.getResult()); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + warpOp, + "the source or result of shape_cast op lacks distribution layout"); + + // For rank reducing or increasing shape_cast ops, the lower rank layout + // must be a slice of higher rank layout. + int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank(); + int64_t resultRank = shapeCastOp.getResultVectorType().getRank(); + if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout)) + return rewriter.notifyMatchFailure( + warpOp, "shape_cast is rank reducing but source layout is not a " + "slice of result layout"); + if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout)) return rewriter.notifyMatchFailure( - warpOp, "the source of shape_cast op lacks distribution layout"); + warpOp, "shape_cast is rank increasing but result layout is not a " + "slice of source layout"); + FailureOr sourceDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(sourceLayout, shapeCastOp.getSourceVectorType()); @@ -1349,10 +1365,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { return AffineMap::getMultiDimMapWithTargets( vecRank, {static_cast(vecRank - 1)}, val.getContext()); SmallVector distributedDims; - // Get the distributed dimensions based on the layout. - SmallVector laneLayout = xegpu::computeEffectiveLaneLayout(layout); - for (unsigned i = 0; i < laneLayout.size(); ++i) { - if (laneLayout[i] > 1) + for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) { + if (v > 1) distributedDims.push_back(i); } return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 5d0f1d18402f2..3f48400fedf5e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -52,9 +52,9 @@ getSgShapeAndCount(ArrayRef shape, int count = 1; SmallVector sgShape(shape); if (layout && layout.isForWorkgroup()) { - SmallVector sgLayout = layout.getSgLayoutAsInt(); - if (!layout.getSgDataAsInt().empty()) - sgShape = layout.getSgDataAsInt(); + SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); + if (!layout.getEffectiveSgDataAsInt().empty()) + sgShape = layout.getEffectiveSgDataAsInt(); else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) sgShape = *maybeDerivedSgData; SmallVector distUnit = computeElementwiseMul(sgLayout, sgShape); @@ -488,7 +488,7 @@ struct WgToSgVectorBroadcastOp VectorType::get(sgShape, resultType.getElementType()); // Check if the output layout is distributable - SmallVector sgLayout = layout.getSgLayoutAsInt(); + SmallVector sgLayout = layout.getEffectiveSgLayoutAsInt(); if (sgLayout.empty()) return failure(); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index aa55ed27d6ecf..b72d5648b29f9 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -484,24 +484,3 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc, results.append(addElementwise(builder, loc, a, b)); return results; } - -SmallVector -xegpu::computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) { - if (!layout) - return {}; - SmallVector effectiveLaneLayout; - // If the layout is a slice, we need to get effective lane layout by removing - // sliced dims. - if (auto sliceAttr = dyn_cast(layout)) { - ArrayRef slicedDims = sliceAttr.flatten().getDims().asArrayRef(); - llvm::DenseSet lookUp(slicedDims.begin(), slicedDims.end()); - for (auto [i, dim] : - llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) { - if (!lookUp.contains(i)) - effectiveLaneLayout.push_back(dim); - } - } else { - effectiveLaneLayout = cast(layout).getLaneLayoutAsInt(); - } - return effectiveLaneLayout; -} diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 200323c7a4e51..e1ba45c60ac36 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -170,7 +170,8 @@ class TestStepOpPattern : public OpConversionPattern { if (!sliceAttr || sliceAttr.getRank() != 1) return failure(); - std::optional> sgShape = sliceAttr.getSgDataAsInt(); + std::optional> sgShape = + sliceAttr.getEffectiveSgDataAsInt(); if (!sgShape) return failure(); From 1c4f06f60cb0964db6341e8a72b8b71d7a4106a8 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 11 Sep 2025 19:19:27 +0000 Subject: [PATCH 25/26] fix func naming --- mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index ebdef7315b0a3..04cfd58d846a7 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -162,7 +162,6 @@ SmallVector addElementwise(OpBuilder &builder, Location loc, SmallVector addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef lhs, ArrayRef rhs); - } // namespace xegpu } // namespace mlir From 8febca8260570560920e7e14d96f043487d6978e Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 12 Sep 2025 16:15:38 +0000 Subject: [PATCH 26/26] remove header --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 6094b0fb42f08..94c5509fd7c29 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h"