diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 8b5e950733a22..5334470e2e3a0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1100,12 +1100,14 @@ struct WarpOpShapeCast : public WarpDistributionPattern { } }; -/// Sink out vector.create_mask op feeding into a warp op yield. +/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op +/// yield. /// ``` /// %0 = ... /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { /// ... /// %mask = vector.create_mask %0 : vector<32xi1> +/// // or %mask = vector.constant_mask[2] : vector<32xi1> /// gpu.yield %mask : vector<32xi1> /// } /// ``` @@ -1118,31 +1120,45 @@ struct WarpOpShapeCast : public WarpDistributionPattern { /// %cmp = arith.cmpi ult, %laneid, %0 /// %ub = arith.select %cmp, %c0, %c1 /// %1 = vector.create_mask %ub : vector<1xi1> +template ::value>> struct WarpOpCreateMask : public WarpDistributionPattern { using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = - getWarpResult(warpOp, llvm::IsaPred); + OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred)); if (!yieldOperand) return failure(); - auto mask = yieldOperand->get().getDefiningOp(); + Operation *mask = yieldOperand->get().getDefiningOp(); // Early exit if any values needed for calculating the new mask indices // are defined inside the warp op. - if (!llvm::all_of(mask->getOperands(), [&](Value value) { + if (mask->getOperands().size() && + !llvm::all_of(mask->getOperands(), [&](Value value) { return warpOp.isDefinedOutsideOfRegion(value); })) return failure(); - Location loc = mask.getLoc(); + Location loc = mask->getLoc(); unsigned operandIndex = yieldOperand->getOperandNumber(); auto distType = cast(warpOp.getResult(operandIndex).getType()); - VectorType seqType = mask.getVectorType(); + VectorType seqType = cast(mask->getResult(0).getType()); ArrayRef seqShape = seqType.getShape(); ArrayRef distShape = distType.getShape(); + SmallVector materializedOperands; + if constexpr (std::is_same_v) { + materializedOperands.append(mask->getOperands().begin(), + mask->getOperands().end()); + } else { + auto constantMaskOp = cast(mask); + auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef(); + for (auto dimSize : dimSizes) + materializedOperands.push_back( + arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult()); + } rewriter.setInsertionPointAfter(warpOp); @@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern { // mask sizes are always in the range [0, mask_vector_size[i]). Value maskDimIdx = affine::makeComposedAffineApply( rewriter, loc, s1 - s0 * distShape[i], - {delinearizedIds[i], mask.getOperand(i)}); + {delinearizedIds[i], materializedOperands[i]}); newOperands.push_back(maskDimIdx); } @@ -2282,12 +2298,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns - .add( - patterns.getContext(), benefit); + patterns.add, + WarpOpCreateMask, + WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 0cf6dd151e16c..135db02d543ef 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref // CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32> // CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32> // CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32> +// ----- + +func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) { + %1 = vector.constant_mask [1] : vector<32xi1> + gpu.yield %1 : vector<32xi1> + } + return %r : vector<1xi1> +} + +// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)> +// CHECK-PROP-LABEL: func @warp_propagate_constant_mask +// CHECK-PROP-SAME: %[[LANEID:.+]]: index +// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]] +// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1> // ----- @@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1: // CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]] // CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]] // CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1> +// ----- + +func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> { + %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) { + %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1> + gpu.yield %1 : vector<16x4x4xi1> + } + return %r : vector<1x2x4xi1> +} + +// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)> +// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)> +// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask +// CHECK-PROP-SAME: %[[LANEID:.+]]: index +// CHECK-PROP: %[[CST2:.+]] = arith.constant 2 : index +// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]] +// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]] +// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1> // -----