diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp index ccac78eb6d9dc..981c250249e5f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp @@ -698,6 +698,173 @@ struct SgToWiLoadMatrix : public OpConversionPattern { } }; +/// Distributes a subgroup-level vector.transpose op to workitem-level. +struct SgToWiVectorTranspose : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getTemporaryLayout(op->getOpOperand(0)); + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getTemporaryLayout(op->getOpResult(0)); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + op, "the source or result vector of the transpose op lacks layout " + "attribute"); + ArrayRef perm = op.getPermutation(); + // Result layout must be a transpose of source layout. + if (!resultLayout.isTransposeOf(sourceLayout, perm, + xegpu::LayoutKind::Lane)) + return rewriter.notifyMatchFailure( + op, "the source or result vector layouts must be transposes of " + "each other"); + FailureOr distributedResultTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType()); + if (failed(distributedResultTypeOrFailure)) + return rewriter.notifyMatchFailure( + op, "Failed to distribute the result vector type in " + "vector::Transpose op"); + auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(), + adaptor.getVector(), perm); + rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(), + distributedResultTypeOrFailure.value())); + return success(); + } +}; + +/// Distributes a subgroup-level vector.bitcast op to workitem-level. +/// Bitcast only impacts the innermost dimension of the source/result vectors. +struct SgToWiVectorBitcast : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getTemporaryLayout(op->getOpResult(0)); + if (!resultLayout) + return rewriter.notifyMatchFailure( + op, "result vector of the bitcast op lacks layout attribute"); + FailureOr distributedResultTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType()); + if (failed(distributedResultTypeOrFailure)) + return rewriter.notifyMatchFailure( + op, "Failed to distribute the result vector type in " + "vector::BitCast op"); + auto newOp = vector::BitCastOp::create( + rewriter, op.getLoc(), distributedResultTypeOrFailure.value(), + adaptor.getSource()); + rewriter.replaceOp(op, newOp.getResult()); + return success(); + } +}; + +/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op +/// to workitem-level. Uses `computeDistributedCoords()` to obtain the +/// coordinates each workitem owns, then compares each coordinate against the +/// original mask bounds using `arith.cmpi slt`. The per-element boolean +/// results are assembled into the distributed mask vector. +/// +/// For multi-dimensional masks, the element is in-bounds when ALL dimensions +/// satisfy `coord[i] < bound[i]`. +/// +/// Example (1D): +/// layout = #xegpu.layout +/// %mask = vector.create_mask %m0 : vector<16xi1> +/// For lane k, computeDistributedCoords gives coord = [k], so: +/// %in_bounds = arith.cmpi slt, %coord, %m0 → i1 +/// %mask = vector.broadcast %in_bounds : i1 to vector<1xi1> +/// +/// Example (2D): +/// layout = #xegpu.layout +/// %mask = vector.create_mask %m0, %m1 : vector<8x4xi1> +/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords: +/// [[r0, c0], [r0, c1]] +/// For each coord: in_bounds = (r < m0) && (c < m1) +/// %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1> +template ::value>> +struct SgToWiCreateMask : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getTemporaryLayout(op->getOpResult(0)); + if (!layout || !layout.isForSubgroup()) + return rewriter.notifyMatchFailure( + op, "operation result does not have subgroup distribute layout"); + + VectorType origType = op.getType(); + FailureOr distTypeOrFailure = + getDistVecTypeBasedOnLaneLayout(layout, origType); + if (failed(distTypeOrFailure)) + return rewriter.notifyMatchFailure( + op, "unable to compute workitem vector type from the layout"); + + VectorType distType = distTypeOrFailure.value(); + Location loc = op.getLoc(); + + // Materialize the original mask bounds as Values. + SmallVector origBounds; + if constexpr (std::is_same_v) { + origBounds.append(op.getOperands().begin(), op.getOperands().end()); + } else { + auto dimSizes = op.getMaskDimSizesAttr().asArrayRef(); + for (auto dimSize : dimSizes) + origBounds.push_back( + arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult()); + } + + ArrayRef origShape = origType.getShape(); + + // Use computeDistributedCoords to get the coordinates each WI owns. + Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(), + /*upperBound=*/mlir::IntegerAttr()); + auto maybeCoordsVec = + layout.computeDistributedCoords(rewriter, loc, laneId, origShape); + if (failed(maybeCoordsVec)) + return rewriter.notifyMatchFailure( + op, "failed to compute distributed coordinates from layout"); + + SmallVector> coordsVec = maybeCoordsVec.value(); + int64_t numElements = distType.getNumElements(); + assert(static_cast(coordsVec.size()) == numElements && + "number of coordinate sets must match number of distributed " + "elements"); + + // For each element, compare all coordinates against bounds. + Value trueVal = + arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1); + SmallVector maskBits; + for (auto &coords : coordsVec) { + Value inBounds = trueVal; + for (size_t i = 0; i < coords.size(); ++i) { + Value cmp = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]); + inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp); + } + maskBits.push_back(inBounds); + } + + // Build the distributed mask vector. + Value result; + if (numElements == 1) { + result = + vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]); + } else { + result = + vector::FromElementsOp::create(rewriter, loc, distType, maskBits); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + /// This pattern distributes a subgroup-level StoreMatrix op to workitem-level. struct SgToWiStoreMatrix : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1543,14 +1710,12 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( [=](vector::MultiDimReductionOp op) -> bool { return !isValidSubgroupMultiReductionOp(op); }); - target.addDynamicallyLegalOp( - [=](Operation *op) -> bool { - return !xegpu::getTemporaryLayout(dyn_cast(op->getResult(0))); - }); - target.addDynamicallyLegalOp( - [=](vector::BroadcastOp op) -> bool { - return !xegpu::getTemporaryLayout(op->getResult(0)); - }); + target.addDynamicallyLegalOp([=](Operation *op) -> bool { + return !xegpu::getTemporaryLayout(op->getOpResult(0)); + }); target.addDynamicallyLegalOp( [=](vector::ExtractOp op) -> bool { if (!isa(op.getType())) @@ -1576,6 +1741,9 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality( SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert, SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice, SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout, - SgToWiVectorStep, SgToWiVectorShapeCast, SgToWiBroadcast>( - typeConverter, patterns.getContext()); + SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep, + SgToWiVectorShapeCast, SgToWiBroadcast, + SgToWiCreateMask, + SgToWiCreateMask>(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir index 303214e544031..f1c0a5d445059 100644 --- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir +++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir @@ -461,6 +461,114 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) gpu.return } +// CHECK-LABEL: gpu.func @vector_transpose +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x2xf32> to vector<1x2xf32> +// CHECK-NEXT: %[[T:.*]] = vector.transpose %[[CAST]], [1, 0] : vector<1x2xf32> to vector<2x1xf32> +// CHECK-NEXT: gpu.return +gpu.func @vector_transpose() { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout} + : () -> (vector<16x2xf32>) + %transpose = vector.transpose %cst, [1, 0] + { + layout_result_0 = #xegpu.layout + } + : vector<16x2xf32> to vector<2x16xf32> + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_bitcast +// CHECK: %[[SRC:.*]] = "some_op"() +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<4x32xi8> to vector<4x2xi8> +// CHECK-NEXT: %[[BC:.*]] = vector.bitcast %[[CAST]] : vector<4x2xi8> to vector<4x1xi16> +// CHECK-NEXT: gpu.return +gpu.func @vector_bitcast() { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout} + : () -> (vector<4x32xi8>) + %bitcast = vector.bitcast %cst + { + layout_result_0 = #xegpu.layout + } + : vector<4x32xi8> to vector<4x16xi16> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mask_1d +// CHECK-SAME: (%[[M0:.*]]: index) +// CHECK-DAG: %[[LANE:.*]] = gpu.lane_id +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[CMP:.*]] = arith.cmpi slt, %{{.*}}, %[[M0]] : index +// CHECK: %[[AND:.*]] = arith.andi %[[TRUE]], %[[CMP]] : i1 +// CHECK: %[[MASK:.*]] = vector.broadcast %[[AND]] : i1 to vector<1xi1> +// CHECK: gpu.return +gpu.func @create_mask_1d(%m0: index) { + %mask = vector.create_mask %m0 + {layout_result_0 = #xegpu.layout} + : vector<16xi1> + gpu.return +} + +// CHECK-LABEL: gpu.func @constant_mask_1d +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[LANE:.*]] = gpu.lane_id +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[CMP:.*]] = arith.cmpi slt, %{{.*}}, %[[C4]] : index +// CHECK: %[[AND:.*]] = arith.andi %[[TRUE]], %[[CMP]] : i1 +// CHECK: %[[MASK:.*]] = vector.broadcast %[[AND]] : i1 to vector<1xi1> +// CHECK: gpu.return +gpu.func @constant_mask_1d() { + %mask = vector.constant_mask [4] + {layout_result_0 = #xegpu.layout} + : vector<16xi1> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mask_2d +// CHECK-SAME: (%[[M0:.*]]: index, %[[M1:.*]]: index) +// CHECK-DAG: %[[LANE:.*]] = gpu.lane_id +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[CMP_R0:.*]] = arith.cmpi slt, %{{.*}}, %[[M0]] : index +// CHECK: %[[AND0:.*]] = arith.andi %[[TRUE]], %[[CMP_R0]] : i1 +// CHECK: %[[CMP_C0:.*]] = arith.cmpi slt, %{{.*}}, %[[M1]] : index +// CHECK: %[[BIT0:.*]] = arith.andi %[[AND0]], %[[CMP_C0]] : i1 +// CHECK: %[[CMP_R1:.*]] = arith.cmpi slt, %{{.*}}, %[[M0]] : index +// CHECK: %[[AND1:.*]] = arith.andi %[[TRUE]], %[[CMP_R1]] : i1 +// CHECK: %[[CMP_C1:.*]] = arith.cmpi slt, %{{.*}}, %[[M1]] : index +// CHECK: %[[BIT1:.*]] = arith.andi %[[AND1]], %[[CMP_C1]] : i1 +// CHECK: %[[MASK:.*]] = vector.from_elements %[[BIT0]], %[[BIT1]] : vector<1x2xi1> +// CHECK: gpu.return +gpu.func @create_mask_2d(%m0: index, %m1: index) { + %mask = vector.create_mask %m0, %m1 + {layout_result_0 = #xegpu.layout} + : vector<8x4xi1> + gpu.return +} + +// CHECK-LABEL: gpu.func @constant_mask_2d +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[LANE:.*]] = gpu.lane_id +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[CMP_R0:.*]] = arith.cmpi slt, %{{.*}}, %[[C2]] : index +// CHECK: %[[AND0:.*]] = arith.andi %[[TRUE]], %[[CMP_R0]] : i1 +// CHECK: %[[CMP_C0:.*]] = arith.cmpi slt, %{{.*}}, %[[C3]] : index +// CHECK: %[[BIT0:.*]] = arith.andi %[[AND0]], %[[CMP_C0]] : i1 +// CHECK: %[[CMP_R1:.*]] = arith.cmpi slt, %{{.*}}, %[[C2]] : index +// CHECK: %[[AND1:.*]] = arith.andi %[[TRUE]], %[[CMP_R1]] : i1 +// CHECK: %[[CMP_C1:.*]] = arith.cmpi slt, %{{.*}}, %[[C3]] : index +// CHECK: %[[BIT1:.*]] = arith.andi %[[AND1]], %[[CMP_C1]] : i1 +// CHECK: %[[MASK:.*]] = vector.from_elements %[[BIT0]], %[[BIT1]] : vector<1x2xi1> +// CHECK: gpu.return +gpu.func @constant_mask_2d() { + %mask = vector.constant_mask [2, 3] + {layout_result_0 = #xegpu.layout} + : vector<8x4xi1> + gpu.return +} + + // CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32> // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>