From 4b202defcaa9dac128041736afa8b0c9481c3bfb Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Fri, 5 Sep 2025 14:48:43 +0000 Subject: [PATCH 1/5] [MLIR][Vector] Add warp distribution for `scf.if` --- .../Vector/Transforms/VectorDistribute.cpp | 201 ++++++++++++++++++ .../Vector/vector-warp-distribute.mlir | 69 ++++++ 2 files changed, 270 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index c84eb2c9f8857..cf5928278aa64 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1713,6 +1713,205 @@ struct WarpOpInsert : public WarpDistributionPattern { } }; +struct WarpOpScfIfOp : public WarpDistributionPattern { + WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp warpOpYield = warpOp.getTerminator(); + // Only pick up `IfOp` if it is the last op in the region. + Operation *lastNode = warpOpYield->getPrevNode(); + auto ifOp = dyn_cast_or_null(lastNode); + if (!ifOp) + return failure(); + + // The current `WarpOp` can yield two types of values: + // 1. Not results of `IfOp`: + // Preserve them in the new `WarpOp`. + // Collect their yield index. + // 2. Results of `IfOp`: + // They are not part of the new `WarpOp` results. + // Map current warp's yield operand index to `IfOp` result idx. + SmallVector nonIfYieldValues; + SmallVector nonIfYieldIndices; + llvm::SmallDenseMap ifResultMapping; + llvm::SmallDenseMap ifResultDistTypes; + for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { + const unsigned yieldOperandIdx = yieldOperand.getOperandNumber(); + if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) { + nonIfYieldValues.push_back(yieldOperand.get()); + nonIfYieldIndices.push_back(yieldOperandIdx); + continue; + } + OpResult ifResult = cast(yieldOperand.get()); + const unsigned ifResultIdx = ifResult.getResultNumber(); + ifResultMapping[yieldOperandIdx] = ifResultIdx; + // If this `ifOp` result is vector type and it is yielded by the + // `WarpOp`, we keep track the distributed type for this result. + if (!isa(ifResult.getType())) + continue; + VectorType distType = + cast(warpOp.getResult(yieldOperandIdx).getType()); + ifResultDistTypes[ifResultIdx] = distType; + } + + // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns + // them + auto getEscapingValues = [&](Region &branch, + llvm::SmallSetVector &values, + SmallVector &inputTypes, + SmallVector &distTypes) { + if (branch.empty()) + return; + mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) { + Operation *parent = operand->get().getParentRegion()->getParentOp(); + if (warpOp->isAncestor(parent)) { + if (!values.insert(operand->get())) + return; + Type distType = operand->get().getType(); + if (auto vecType = dyn_cast(distType)) { + AffineMap map = distributionMapFn(operand->get()); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + inputTypes.push_back(operand->get().getType()); + distTypes.push_back(distType); + } + }); + }; + llvm::SmallSetVector escapingValuesThen; + SmallVector escapingValueInputTypesThen; // inner warp op block args + SmallVector escapingValueDistTypesThen; // new warp returns + getEscapingValues(ifOp.getThenRegion(), escapingValuesThen, + escapingValueInputTypesThen, escapingValueDistTypesThen); + llvm::SmallSetVector escapingValuesElse; + SmallVector escapingValueInputTypesElse; // inner warp op block args + SmallVector escapingValueDistTypesElse; // new warp returns + getEscapingValues(ifOp.getElseRegion(), escapingValuesElse, + escapingValueInputTypesElse, escapingValueDistTypesElse); + + if (llvm::is_contained(escapingValueDistTypesThen, Type{}) || + llvm::is_contained(escapingValueDistTypesElse, Type{})) + return failure(); + + // The new `WarpOp` groups yields values in following order: + // 1. Escaping values then branch + // 2. Escaping values else branch + // 3. All non-`ifOp` yielded values. + SmallVector newWarpOpYieldValues{escapingValuesThen.begin(), + escapingValuesThen.end()}; + newWarpOpYieldValues.append(escapingValuesElse.begin(), + escapingValuesElse.end()); + SmallVector newWarpOpDistTypes = escapingValueDistTypesThen; + newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), + escapingValueDistTypesElse.end()); + + llvm::SmallDenseMap origToNewYieldIdx; + for (auto [idx, val] : + llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { + origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); + newWarpOpYieldValues.push_back(val); + newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); + } + // Create the new `WarpOp` with the updated yield values and types. + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + + // `ifOp` returns the result of the inner warp op. + SmallVector newIfOpDistResTypes; + for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { + Type distType = cast(res).getType(); + if (auto vecType = dyn_cast(distType)) { + AffineMap map = distributionMapFn(cast(res)); + distType = ifResultDistTypes.count(i) + ? ifResultDistTypes[i] + : getDistributedType(vecType, map, warpOp.getWarpSize()); + } + newIfOpDistResTypes.push_back(distType); + } + // Create a new `IfOp` outside the new `WarpOp` region. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newWarpOp); + auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), + newIfOpDistResTypes, ifOp.getCondition(), + static_cast(ifOp.thenBlock()), + static_cast(ifOp.elseBlock())); + + auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch, + llvm::SmallSetVector &escapingValues, + SmallVector &escapingValueInputTypes) { + OpBuilder::InsertionGuard g(rewriter); + if (!newIfBranch) + return; + rewriter.setInsertionPointToStart(newIfBranch); + llvm::SmallDenseMap escapeValToBlockArgIndex; + SmallVector innerWarpInputVals; + SmallVector innerWarpInputTypes; + for (size_t i = 0; i < escapingValues.size(); ++i) { + innerWarpInputVals.push_back(newWarpOp.getResult(i)); + escapeValToBlockArgIndex[escapingValues[i]] = + innerWarpInputTypes.size(); + innerWarpInputTypes.push_back(escapingValueInputTypes[i]); + } + auto innerWarp = WarpExecuteOnLane0Op::create( + rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(), + newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals, + innerWarpInputTypes); + + innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent()); + innerWarp.getWarpRegion().addArguments( + innerWarpInputTypes, + SmallVector(innerWarpInputTypes.size(), ifOp.getLoc())); + + SmallVector yieldOperands; + for (Value operand : oldIfBranch->getTerminator()->getOperands()) + yieldOperands.push_back(operand); + rewriter.eraseOp(oldIfBranch->getTerminator()); + + rewriter.setInsertionPointToEnd(innerWarp.getBody()); + gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); + rewriter.setInsertionPointAfter(innerWarp); + scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults()); + + // Update any users of escaping values that were forwarded to the + // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. + innerWarp.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + auto it = escapeValToBlockArgIndex.find(operand.get()); + if (it == escapeValToBlockArgIndex.end()) + continue; + operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + } + }); + mlir::vector::moveScalarUniformCode(innerWarp); + }; + processBranch(&ifOp.getThenRegion().front(), + &newIfOp.getThenRegion().front(), escapingValuesThen, + escapingValueInputTypesThen); + if (!ifOp.getElseRegion().empty()) + processBranch(&ifOp.getElseRegion().front(), + &newIfOp.getElseRegion().front(), escapingValuesElse, + escapingValueInputTypesElse); + // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` + // result. + for (auto [origIdx, newIdx] : ifResultMapping) + rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + newIfOp.getResult(newIdx), newIfOp); + // Similarly, update any users of the `WarpOp` results that were not + // results of the `IfOp`. + for (auto [origIdx, newIdx] : origToNewYieldIdx) + rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), + newWarpOp.getResult(newIdx)); + // Remove the original `WarpOp` and `IfOp`, they should not have any uses + // at this point. + rewriter.eraseOp(ifOp); + rewriter.eraseOp(warpOp); + return success(); + } + +private: + DistributionMapFn distributionMapFn; +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't /// change the order of execution. This creates a new scf.for region after the @@ -2068,6 +2267,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( benefit); patterns.add(patterns.getContext(), distributionMapFn, benefit); + patterns.add(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 8750582ef1e1f..bb7639204022f 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre // CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size // CHECK-PROP-NOT: vector.broadcast // CHECK-PROP: vector.step : vector<64xindex> + +// ----- + +func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) { + %laneid = gpu.lane_id + %c0 = arith.constant 0 : index + + gpu.warp_execute_on_lane_0(%laneid)[32] { + %seq = vector.step : vector<32xindex> + scf.if %pred { + vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex> + } + gpu.yield + } + return +} + +// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute( +// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1 +// CHECK-PROP: scf.if %[[ARG1]] { +// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) { +// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>): +// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex> + +// ----- + +func.func @warp_scf_if_distribute(%pred : i1) { + %laneid = gpu.lane_id + %c0 = arith.constant 0 : index + + %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> { + %seq1 = vector.step : vector<32xindex> + %seq2 = arith.constant dense<2> : vector<32xindex> + %0 = scf.if %pred -> (vector<32xf32>) { + %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>) + scf.yield %1 : vector<32xf32> + } else { + %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>) + scf.yield %2 : vector<32xf32> + } + gpu.yield %0 : vector<32xf32> + } + "some_use"(%0) : (vector<1xf32>) -> () + + return +} + +// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute( +// CHECK-PROP-SAME: %[[ARG0:.+]]: i1 +// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex> +// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id +// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex> +// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) { +// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) { +// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>): +// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32> +// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32> +// CHECK-PROP: } else { +// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) { +// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32> +// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> () +// CHECK-PROP: return +// CHECK-PROP: } From b356d1119d1053a19e1145e0ff135750009f4cce Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Fri, 5 Sep 2025 16:55:45 +0000 Subject: [PATCH 2/5] Add xegpu tests --- .../Dialect/XeGPU/subgroup-distribute.mlir | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index a39aa90bbe3a8..b57903b2eb69b 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -338,6 +338,64 @@ gpu.module @test { } } +// ----- +// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}}) { +// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16> +// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1 +// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) { +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16> +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[DEFAULT]] : vector<8xf16> +// CHECK-NEXT: } +// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>) { + %pred = llvm.mlir.poison : i1 + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> + %loaded = scf.if %pred -> (vector<16x8xf16>) { + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_result_0 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + scf.yield %3 : vector<16x8xf16> + } else { + %3 = arith.constant { + layout_result_0 = #xegpu.layout + } dense<12.> : vector<16x8xf16> + scf.yield %3 : vector<16x8xf16> + } { layout_result_0 = #xegpu.layout } + xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) { +// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> +// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1 +// CHECK: scf.if %[[PREDICATE]] { +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +// CHECK-NEXT: } +gpu.module @test { + gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) { + %pred = llvm.mlir.poison : i1 + %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> + scf.if %pred { + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_result_0 = #xegpu.layout + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + } + gpu.return + } +} + // ----- // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> From 7d31574c675ddef0b2f4af310954e822cdadf1ee Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Mon, 8 Sep 2025 08:38:31 +0000 Subject: [PATCH 3/5] Yield if condition, range-based escaping values for innerwarps --- .../Vector/Transforms/VectorDistribute.cpp | 36 ++++++++++--------- .../Dialect/XeGPU/subgroup-distribute.mlir | 7 ++-- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index cf5928278aa64..db3e9e6922a44 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1794,14 +1794,18 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { return failure(); // The new `WarpOp` groups yields values in following order: - // 1. Escaping values then branch - // 2. Escaping values else branch - // 3. All non-`ifOp` yielded values. - SmallVector newWarpOpYieldValues{escapingValuesThen.begin(), - escapingValuesThen.end()}; + // 1. Branch condition + // 2. Escaping values then branch + // 3. Escaping values else branch + // 4. All non-`ifOp` yielded values. + SmallVector newWarpOpYieldValues{ifOp.getCondition()}; + newWarpOpYieldValues.append(escapingValuesThen.begin(), + escapingValuesThen.end()); newWarpOpYieldValues.append(escapingValuesElse.begin(), escapingValuesElse.end()); - SmallVector newWarpOpDistTypes = escapingValueDistTypesThen; + SmallVector newWarpOpDistTypes{ifOp.getCondition().getType()}; + newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(), + escapingValueDistTypesThen.end()); newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); @@ -1815,7 +1819,6 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Create the new `WarpOp` with the updated yield values and types. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); - // `ifOp` returns the result of the inner warp op. SmallVector newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1831,14 +1834,15 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Create a new `IfOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); - auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), - newIfOpDistResTypes, ifOp.getCondition(), - static_cast(ifOp.thenBlock()), - static_cast(ifOp.elseBlock())); + auto newIfOp = scf::IfOp::create( + rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), + static_cast(ifOp.thenBlock()), + static_cast(ifOp.elseBlock())); auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch, llvm::SmallSetVector &escapingValues, - SmallVector &escapingValueInputTypes) { + SmallVector &escapingValueInputTypes, + size_t warpResRangeStart) { OpBuilder::InsertionGuard g(rewriter); if (!newIfBranch) return; @@ -1846,8 +1850,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { llvm::SmallDenseMap escapeValToBlockArgIndex; SmallVector innerWarpInputVals; SmallVector innerWarpInputTypes; - for (size_t i = 0; i < escapingValues.size(); ++i) { - innerWarpInputVals.push_back(newWarpOp.getResult(i)); + for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { + innerWarpInputVals.push_back(newWarpOp.getResult(warpResRangeStart)); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1886,11 +1890,11 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { }; processBranch(&ifOp.getThenRegion().front(), &newIfOp.getThenRegion().front(), escapingValuesThen, - escapingValueInputTypesThen); + escapingValueInputTypesThen, 1); if (!ifOp.getElseRegion().empty()) processBranch(&ifOp.getElseRegion().front(), &newIfOp.getElseRegion().front(), escapingValuesElse, - escapingValueInputTypesElse); + escapingValueInputTypesElse, 1 + escapingValuesThen.size()); // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index b57903b2eb69b..60acea06c9a12 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -339,11 +339,11 @@ gpu.module @test { } // ----- -// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}}) { +// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}}, +// CHECK-SAME: %[[PREDICATE:.*]]: i1) { // CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16> // CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> -// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1 // CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) { // CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> // CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16> @@ -352,8 +352,7 @@ gpu.module @test { // CHECK-NEXT: } // CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> gpu.module @test { - gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>) { - %pred = llvm.mlir.poison : i1 + gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) { %1 = arith.constant {layout_result_0 = #xegpu.layout} dense<1>: vector<16xi1> %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> %loaded = scf.if %pred -> (vector<16x8xf16>) { From 784dda109ca2c145b1b9f92f87547c54af43dcc7 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Tue, 9 Sep 2025 10:16:03 +0000 Subject: [PATCH 4/5] Address feedback --- .../Vector/Transforms/VectorDistribute.cpp | 214 ++++++++++-------- 1 file changed, 125 insertions(+), 89 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index db3e9e6922a44..3ae866aeb2888 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -371,6 +371,36 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map, return targetType; } +/// Given a warpOp that contains ops with regions, the corresponding op's +/// "inner" region and the distributionMapFn, get all values used by the op's +/// region that are defined within the warpOp. Return the set of values, their +/// types and their distributed types. +std::tuple, SmallVector, + SmallVector> +getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion, + DistributionMapFn distributionMapFn) { + llvm::SmallSetVector escapingValues; + SmallVector escapingValueTypes; + SmallVector escapingValueDistTypes; // to yield from the new warpOp + if (innerRegion.empty()) + return {escapingValues, escapingValueTypes, escapingValueDistTypes}; + mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) { + Operation *parent = operand->get().getParentRegion()->getParentOp(); + if (warpOp->isAncestor(parent)) { + if (!escapingValues.insert(operand->get())) + return; + Type distType = operand->get().getType(); + if (auto vecType = dyn_cast(distType)) { + AffineMap map = distributionMapFn(operand->get()); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + escapingValueTypes.push_back(operand->get().getType()); + escapingValueDistTypes.push_back(distType); + } + }); + return {escapingValues, escapingValueTypes, escapingValueDistTypes}; +} + /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` /// will not be distributed (it should be less than the warp size). @@ -1713,6 +1743,32 @@ struct WarpOpInsert : public WarpDistributionPattern { } }; +/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if +/// the scf.if is the last operation in the region so that it doesn't +/// change the order of execution. This creates a new scf.if after the +/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in +/// the "inner" WarpExecuteOnLane0Op. Example: +/// ``` +/// gpu.warp_execute_on_lane_0(%laneid)[32] { +/// %payload = ... : vector<32xindex> +/// scf.if %pred { +/// vector.store %payload, %buffer[%idx] : memref<128xindex>, +/// vector<32xindex> +/// } +/// gpu.yield +/// } +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] { +/// %payload = ... : vector<32xindex> +/// gpu.yield %payload : vector<32xindex> +/// } +/// scf.if %pred { +/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) { +/// ^bb0(%arg1: vector<32xindex>): +/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex> +/// } +/// } +/// ``` struct WarpOpScfIfOp : public WarpDistributionPattern { WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} @@ -1728,7 +1784,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // The current `WarpOp` can yield two types of values: // 1. Not results of `IfOp`: // Preserve them in the new `WarpOp`. - // Collect their yield index. + // Collect their yield index to remap the usages. // 2. Results of `IfOp`: // They are not part of the new `WarpOp` results. // Map current warp's yield operand index to `IfOp` result idx. @@ -1757,38 +1813,14 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns // them - auto getEscapingValues = [&](Region &branch, - llvm::SmallSetVector &values, - SmallVector &inputTypes, - SmallVector &distTypes) { - if (branch.empty()) - return; - mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) { - Operation *parent = operand->get().getParentRegion()->getParentOp(); - if (warpOp->isAncestor(parent)) { - if (!values.insert(operand->get())) - return; - Type distType = operand->get().getType(); - if (auto vecType = dyn_cast(distType)) { - AffineMap map = distributionMapFn(operand->get()); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - } - inputTypes.push_back(operand->get().getType()); - distTypes.push_back(distType); - } - }); - }; - llvm::SmallSetVector escapingValuesThen; - SmallVector escapingValueInputTypesThen; // inner warp op block args - SmallVector escapingValueDistTypesThen; // new warp returns - getEscapingValues(ifOp.getThenRegion(), escapingValuesThen, - escapingValueInputTypesThen, escapingValueDistTypesThen); - llvm::SmallSetVector escapingValuesElse; - SmallVector escapingValueInputTypesElse; // inner warp op block args - SmallVector escapingValueDistTypesElse; // new warp returns - getEscapingValues(ifOp.getElseRegion(), escapingValuesElse, - escapingValueInputTypesElse, escapingValueDistTypesElse); - + auto [escapingValuesThen, escapingValueInputTypesThen, + escapingValueDistTypesThen] = + getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(), + distributionMapFn); + auto [escapingValuesElse, escapingValueInputTypesElse, + escapingValueDistTypesElse] = + getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(), + distributionMapFn); if (llvm::is_contained(escapingValueDistTypesThen, Type{}) || llvm::is_contained(escapingValueDistTypesElse, Type{})) return failure(); @@ -1825,6 +1857,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { Type distType = cast(res).getType(); if (auto vecType = dyn_cast(distType)) { AffineMap map = distributionMapFn(cast(res)); + // Fallback to affine map if the dist result was not previously recorded distType = ifResultDistTypes.count(i) ? ifResultDistTypes[i] : getDistributedType(vecType, map, warpOp.getWarpSize()); @@ -1838,63 +1871,66 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), static_cast(ifOp.thenBlock()), static_cast(ifOp.elseBlock())); - - auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch, - llvm::SmallSetVector &escapingValues, - SmallVector &escapingValueInputTypes, - size_t warpResRangeStart) { - OpBuilder::InsertionGuard g(rewriter); - if (!newIfBranch) - return; - rewriter.setInsertionPointToStart(newIfBranch); - llvm::SmallDenseMap escapeValToBlockArgIndex; - SmallVector innerWarpInputVals; - SmallVector innerWarpInputTypes; - for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { - innerWarpInputVals.push_back(newWarpOp.getResult(warpResRangeStart)); - escapeValToBlockArgIndex[escapingValues[i]] = - innerWarpInputTypes.size(); - innerWarpInputTypes.push_back(escapingValueInputTypes[i]); - } - auto innerWarp = WarpExecuteOnLane0Op::create( - rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(), - newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals, - innerWarpInputTypes); - - innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent()); - innerWarp.getWarpRegion().addArguments( - innerWarpInputTypes, - SmallVector(innerWarpInputTypes.size(), ifOp.getLoc())); - - SmallVector yieldOperands; - for (Value operand : oldIfBranch->getTerminator()->getOperands()) - yieldOperands.push_back(operand); - rewriter.eraseOp(oldIfBranch->getTerminator()); - - rewriter.setInsertionPointToEnd(innerWarp.getBody()); - gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); - rewriter.setInsertionPointAfter(innerWarp); - scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults()); - - // Update any users of escaping values that were forwarded to the - // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. - innerWarp.walk([&](Operation *op) { - for (OpOperand &operand : op->getOpOperands()) { - auto it = escapeValToBlockArgIndex.find(operand.get()); - if (it == escapeValToBlockArgIndex.end()) - continue; - operand.set(innerWarp.getBodyRegion().getArgument(it->second)); - } - }); - mlir::vector::moveScalarUniformCode(innerWarp); - }; - processBranch(&ifOp.getThenRegion().front(), - &newIfOp.getThenRegion().front(), escapingValuesThen, - escapingValueInputTypesThen, 1); + auto encloseRegionInWarpOp = + [&](Block *oldIfBranch, Block *newIfBranch, + llvm::SmallSetVector &escapingValues, + SmallVector &escapingValueInputTypes, + size_t warpResRangeStart) { + OpBuilder::InsertionGuard g(rewriter); + if (!newIfBranch) + return; + rewriter.setInsertionPointToStart(newIfBranch); + llvm::SmallDenseMap escapeValToBlockArgIndex; + SmallVector innerWarpInputVals; + SmallVector innerWarpInputTypes; + for (size_t i = 0; i < escapingValues.size(); + ++i, ++warpResRangeStart) { + innerWarpInputVals.push_back( + newWarpOp.getResult(warpResRangeStart)); + escapeValToBlockArgIndex[escapingValues[i]] = + innerWarpInputTypes.size(); + innerWarpInputTypes.push_back(escapingValueInputTypes[i]); + } + auto innerWarp = WarpExecuteOnLane0Op::create( + rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(), + newWarpOp.getLaneid(), newWarpOp.getWarpSize(), + innerWarpInputVals, innerWarpInputTypes); + + innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent()); + innerWarp.getWarpRegion().addArguments( + innerWarpInputTypes, + SmallVector(innerWarpInputTypes.size(), ifOp.getLoc())); + + SmallVector yieldOperands; + for (Value operand : oldIfBranch->getTerminator()->getOperands()) + yieldOperands.push_back(operand); + rewriter.eraseOp(oldIfBranch->getTerminator()); + + rewriter.setInsertionPointToEnd(innerWarp.getBody()); + gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); + rewriter.setInsertionPointAfter(innerWarp); + scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults()); + + // Update any users of escaping values that were forwarded to the + // inner `WarpOp`. These values are arguments of the inner `WarpOp`. + innerWarp.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + auto it = escapeValToBlockArgIndex.find(operand.get()); + if (it == escapeValToBlockArgIndex.end()) + continue; + operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + } + }); + mlir::vector::moveScalarUniformCode(innerWarp); + }; + encloseRegionInWarpOp(&ifOp.getThenRegion().front(), + &newIfOp.getThenRegion().front(), escapingValuesThen, + escapingValueInputTypesThen, 1); if (!ifOp.getElseRegion().empty()) - processBranch(&ifOp.getElseRegion().front(), - &newIfOp.getElseRegion().front(), escapingValuesElse, - escapingValueInputTypesElse, 1 + escapingValuesThen.size()); + encloseRegionInWarpOp(&ifOp.getElseRegion().front(), + &newIfOp.getElseRegion().front(), + escapingValuesElse, escapingValueInputTypesElse, + 1 + escapingValuesThen.size()); // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) From 90ef1ab5008f808e50ef498d74ca5218147bd85d Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Wed, 10 Sep 2025 08:16:46 +0000 Subject: [PATCH 5/5] Address feedback --- .../Vector/Transforms/VectorDistribute.cpp | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 3ae866aeb2888..995a2595e5fbb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -373,8 +373,8 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map, /// Given a warpOp that contains ops with regions, the corresponding op's /// "inner" region and the distributionMapFn, get all values used by the op's -/// region that are defined within the warpOp. Return the set of values, their -/// types and their distributed types. +/// region that are defined within the warpOp, but outside the inner region. +/// Return the set of values, their types and their distributed types. std::tuple, SmallVector, SmallVector> getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion, @@ -383,7 +383,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion, SmallVector escapingValueTypes; SmallVector escapingValueDistTypes; // to yield from the new warpOp if (innerRegion.empty()) - return {escapingValues, escapingValueTypes, escapingValueDistTypes}; + return {std::move(escapingValues), std::move(escapingValueTypes), + std::move(escapingValueDistTypes)}; mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) { Operation *parent = operand->get().getParentRegion()->getParentOp(); if (warpOp->isAncestor(parent)) { @@ -398,7 +399,8 @@ getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion, escapingValueDistTypes.push_back(distType); } }); - return {escapingValues, escapingValueTypes, escapingValueDistTypes}; + return {std::move(escapingValues), std::move(escapingValueTypes), + std::move(escapingValueDistTypes)}; } /// Distribute transfer_write ops based on the affine map returned by @@ -1998,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern { return failure(); // Collect Values that come from the `WarpOp` but are outside the `ForOp`. // Those Values need to be returned by the new warp op. - llvm::SmallSetVector escapingValues; - SmallVector escapingValueInputTypes; - SmallVector escapingValueDistTypes; - mlir::visitUsedValuesDefinedAbove( - forOp.getBodyRegion(), [&](OpOperand *operand) { - Operation *parent = operand->get().getParentRegion()->getParentOp(); - if (warpOp->isAncestor(parent)) { - if (!escapingValues.insert(operand->get())) - return; - Type distType = operand->get().getType(); - if (auto vecType = dyn_cast(distType)) { - AffineMap map = distributionMapFn(operand->get()); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - } - escapingValueInputTypes.push_back(operand->get().getType()); - escapingValueDistTypes.push_back(distType); - } - }); - + auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] = + getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(), + distributionMapFn); if (llvm::is_contained(escapingValueDistTypes, Type{})) return failure(); // `WarpOp` can yield two types of values: