diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp index 0961ddfb92040..8c60ced4ed38e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp @@ -842,6 +842,70 @@ struct SgToWiStoreScatter : public OpConversionPattern { } }; +/// Distribute a vector::StepOp to workitem-level. +/// The layout must have exactly 1 effective lane dimension. +/// We completely resolve the vector::StepOp by computing the lane_data-sized +/// subranges. +struct SgToWiVectorStep : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::StepOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getTemporaryLayout(op->getResult(0)); + if (!resultLayout || !resultLayout.isForSubgroup()) + return rewriter.notifyMatchFailure( + op, "the result vector of the step op lacks subgroup layout"); + + auto loc = op.getLoc(); + auto stepResultVecTy = op.getResult().getType(); + auto wiShapeOrFailure = + xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy); + if (failed(wiShapeOrFailure)) + return rewriter.notifyMatchFailure( + op, "unable to compute workitem vector type from the layout"); + VectorType newVecTy = wiShapeOrFailure.value(); + + Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(), + /*upperBound=*/mlir::IntegerAttr()); + auto laneDataBlockCoords = resultLayout.computeDistributedCoords( + rewriter, loc, laneId, stepResultVecTy.getShape()); + if (failed(laneDataBlockCoords)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane data block coordinates"); + + auto laneDataBlockCoordsVec = laneDataBlockCoords.value(); + auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0]; + assert(static_cast(laneDataBlockCoordsVec.size()) == + newVecTy.getNumElements() / laneDataBlockLength); + SmallVector stepVals; + // For each lane_data block, reconstruct its sub-range + // from the range of SG-level vector.step.Example: vector.step + // {slice, dims=[0,2]>} : + // vector<16xindex> + // Each logical lane holds 4 elements as 2 blocks of 2 elements each. + // The blocks are round-robin distributed, so logical lane id 0 + // holds values [0,1, 8,9]. + for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) { + auto laneDataBlockStartCoord = laneDataBlockCoords[0]; + stepVals.push_back(laneDataBlockStartCoord); + for (int i = 1; i < laneDataBlockLength; ++i) { + auto offset = arith::ConstantIndexOp::create(rewriter, loc, i); + stepVals.push_back(arith::AddIOp::create( + rewriter, loc, laneDataBlockStartCoord, offset)); + } + } + assert(static_cast(stepVals.size()) == newVecTy.getNumElements() && + "Expecting the number of step values to match the number of " + "elements in the vector"); + auto stepOpVal = + vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals); + rewriter.replaceOp(op, stepOpVal); + return success(); + } +}; + /// Distributes a subgroup-level vector.extract op to workitem-level. Only /// handles sub-vector extraction (result is VectorType, not scalar). struct SgToWiVectorExtract : public OpConversionPattern { @@ -876,6 +940,33 @@ struct SgToWiVectorExtract : public OpConversionPattern { } }; +/// This pattern distributes a subgroup-level ShapeCast op to workitem-level. +struct SgToWiVectorShapeCast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getTemporaryLayout(op->getOpResult(0)); + if (!resultLayout || !resultLayout.isForSubgroup()) + return rewriter.notifyMatchFailure( + op, "the result vector of the shape_cast op lacks subgroup layout"); + + auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout( + resultLayout, op.getResultVectorType()); + if (failed(resultDistTypeOrFailure)) + return rewriter.notifyMatchFailure( + op, "failed to get distributed vector type for result"); + + Value source = adaptor.getSource(); + auto newShapeCast = vector::ShapeCastOp::create( + rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source); + rewriter.replaceOp(op, newShapeCast); + return success(); + } +}; + /// Distributes a subgroup-level vector.extract_strided_slice op to /// workitem-level. If the result is distributed, the offsets and sizes are /// adjusted to match the distributed types. @@ -968,6 +1059,125 @@ struct SgToWiVectorExtractStridedSlice } }; +/// This pattern distributes a subgroup-level `vector.broadcast` op to +/// workitem-level. The pattern supports three cases: +/// +/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input +/// vector must have a slice layout of the result. If the distributed source +/// and target vector types are identical, this lowers to a no-op; otherwise, +/// it remains a broadcast but operates on distributed vectors. +/// +/// 2) Broadcast a same-rank vector with identical layouts for source and +/// target: The source vector must have unit dimensions, and lane_data must +/// be unit size for those unit dims. This always lowers to a no-op. +/// +/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast +/// from scalar to distributed result type. +/// +/// Example 1 (low-rank to high-rank broadcast): +/// ``` +/// %0 = "some_op"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout, +/// dims = [0]>} : () -> vector<16xf16> +/// %1 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout} +/// : vector<16xf16> to vector<16x16xf16> +/// ``` +/// is distributed to: +/// ``` +/// %0 = "some_op"() : () -> vector<1xf16> +/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16> +/// ``` +/// +/// Example 2 (same-rank broadcast, no-op): +/// ``` +/// %0 = "some_op"() {layout_result_0 = +/// #xegpu.layout} +/// : () -> vector<16x1xf16> +/// %1 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout} +/// : vector<16x1xf16> to vector<16x16xf16> +/// ``` +/// is distributed to (no-op, source already matches distributed result type): +/// ``` +/// %0 = "some_op"() : () -> vector<16x1xf16> +/// // broadcast is eliminated, %0 is used directly +/// ``` +/// +/// Example 3 (scalar to vector broadcast): +/// ``` +/// %0 = "some_op"() : () -> f16 +/// %1 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout} +/// : f16 to vector<16x16xf16> +/// ``` +/// is distributed to: +/// ``` +/// %0 = "some_op"() : f16 +/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16> +/// ``` +struct SgToWiBroadcast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getTemporaryLayout(cast(op.getResult())); + if (!resultLayout || !resultLayout.isForSubgroup()) + return rewriter.notifyMatchFailure( + op, "result does not have subgroup distribute layout"); + + VectorType destType = op.getResultVectorType(); + VectorType sourceType = dyn_cast(op.getSourceType()); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getTemporaryLayout(op->getOpOperand(0)); + + if (sourceType) { + int64_t rankDiff = destType.getRank() - sourceType.getRank(); + if (rankDiff > 0) { + // Case 1: Low-rank to high-rank broadcast. + if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout)) + op.emitWarning( + "broadcast source layout must be a slice of result layout"); + } else if (rankDiff == 0) { + // Case 2: Same-rank broadcast. + auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims(); + SmallVector broadcastUnitDims(broadcastUnitDimsSet.begin(), + broadcastUnitDimsSet.end()); + assert(sourceLayout.isEqualTo( + sourceLayout.setUnitDimData(broadcastUnitDims)) && + "The sg_data for unit dimensions should be set as 1"); + sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims); + } + } else { + // Case 3: Scalar to vector broadcast. + if (sourceLayout) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar must not have a layout attribute"); + } + + auto destDistType = + xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType); + if (failed(destDistType)) + return rewriter.notifyMatchFailure( + op, "failed to distribute the result vector type"); + + Value source = adaptor.getSource(); + // If the adapted source already matches the dest dist type, it's a no-op. + if (source.getType() == destDistType.value()) { + rewriter.replaceOp(op, source); + return success(); + } + + auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(), + destDistType.value(), source); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + /// Distributes a subgroup-level vector.insert_strided_slice op to /// workitem-level. If the dest is distributed, the offsets are adjusted to /// match the distributed types. @@ -1322,6 +1532,14 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( [=](vector::MultiDimReductionOp op) -> bool { return !isValidSubgroupMultiReductionOp(op); }); + target.addDynamicallyLegalOp( + [=](Operation *op) -> bool { + return !xegpu::getTemporaryLayout(dyn_cast(op->getResult(0))); + }); + target.addDynamicallyLegalOp( + [=](vector::BroadcastOp op) -> bool { + return !xegpu::getTemporaryLayout(op->getResult(0)); + }); target.addDynamicallyLegalOp( [=](vector::ExtractOp op) -> bool { if (!isa(op.getType())) @@ -1346,6 +1564,7 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction, SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert, SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice, - SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout>( + SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout, + SgToWiVectorStep, SgToWiVectorShapeCast, SgToWiBroadcast>( typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir index 417d1d121ee5e..9c4f469ea475a 100644 --- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir +++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir @@ -748,3 +748,165 @@ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layou gpu.return } } + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_step_slice +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[REM:.*]] = arith.remui %[[LANE_ID]], %[[C16]] : index +// CHECK: %[[REM2:.*]] = arith.remui %[[REM]], %[[C16]]{{.*}} : index +// CHECK: %[[VEC:.*]] = vector.from_elements %[[REM2]] : vector<1xindex> +gpu.func @vector_step_slice() { + %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1, 2]>} : vector<16xindex> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_step_slice_unit +// CHECK: %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex> +gpu.func @vector_step_slice_unit() { + %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1, 3]>} : vector<1xindex> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_step_slice_multi_dist +// CHECK: %[[LANE_ID:.*]] = gpu.lane_id +// CHECK: %[[MULI:.*]] = arith.muli %{{.*}}, %{{.*}} : index +// CHECK: %[[V0:.*]] = arith.remui %[[MULI]], %{{.*}} : index +// CHECK: %[[SUM1:.*]] = arith.addi %[[MULI]], %{{.*}} : index +// CHECK: %[[V2:.*]] = arith.remui %[[SUM1]], %{{.*}} : index +// CHECK: %[[V1:.*]] = arith.addi %[[V0]], %{{.*}} : index +// CHECK: %[[V3:.*]] = arith.addi %[[V2]], %{{.*}} : index +// CHECK: %[[VEC:.*]] = vector.from_elements %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<4xindex> +gpu.func @vector_step_slice_multi_dist() { + %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 2]>} : vector<16xindex> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing +// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32> +gpu.func @vector_shapecast_rank_increasing() { + %cst = "some_op"() + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_result_0 = #xegpu.layout + } + : vector<16xf32> to vector<1x16xf32> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing +// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1x1xf32> to vector<1xf32> +gpu.func @vector_shapecast_rank_reducing() { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout} + : () -> (vector<1x16xf32>) + %cast = vector.shape_cast %cst + { + layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]> + } + : vector<1x16xf32> to vector<16xf32> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing_without_slicing_layout +// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32> +gpu.func @vector_shapecast_rank_increasing_without_slicing_layout() { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout} + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_result_0 = #xegpu.layout + } + : vector<16xf32> to vector<1x16xf32> + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16xf16> to vector<1xf16> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<1xf16> to vector<16x1xf16> +gpu.func @vector_broadcast_1d_to_2d(%laneid: index) { + %0 = "some_op"() {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} : () -> vector<16xf16> + %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout} : vector<16xf16> to vector<16x16xf16> + "some_use"(%1) : (vector<16x16xf16>) -> () + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_3d +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x16xf16> to vector<16x1xf16> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<16x1xf16> to vector<1x16x1xf16> +gpu.func @vector_broadcast_2d_to_3d(%laneid: index) { + %0 = "some_op"() {layout_result_0 = #xegpu.layout} : () -> vector<16x16xf16> + %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout} : vector<16x16xf16> to vector<1x16x16xf16> + "some_use"(%1) : (vector<1x16x16xf16>) -> () + gpu.return +} +} + +// ----- +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_noop +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK-NOT: vector.broadcast +gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) { + %0 = "some_op"() {layout_result_0 = #xegpu.layout} : () -> vector<16x1xf16> + %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout} : vector<16x1xf16> to vector<16x16xf16> + "some_use"(%1) : (vector<16x16xf16>) -> () + gpu.return +} +} + +// ----- +// Scalar to vector broadcast (with layout) +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x1xf16> +gpu.func @vector_broadcast_scalar_to_vector(%laneid: index) { + %0 = "some_op"() : () -> f16 + %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout} : f16 to vector<16x16xf16> + "some_use"(%1) : (vector<16x16xf16>) -> () + gpu.return +} +} + +// ----- +// Scalar to vector broadcast (no layout - uniform, should remain unchanged) +gpu.module @xevm_module { +// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector_uniform +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x16xf16> +// CHECK: "some_use"(%[[BCAST]]) +gpu.func @vector_broadcast_scalar_to_vector_uniform(%laneid: index) { + %0 = "some_op"() : () -> f16 + %1 = vector.broadcast %0 : f16 to vector<16x16xf16> + "some_use"(%1) : (vector<16x16xf16>) -> () + gpu.return +} +}