Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
// If no layout, not valid.
if (!resLayout || !resLayout.isForSubgroup())
return false;
// Scalar result (e.g., vector<32xf32> to f32) is valid.
if (op.getType().isIntOrFloat())
return op.getReductionDims().size() == 1;
VectorType resTy = dyn_cast<VectorType>(op.getType());
if (!resTy)
return false;
Expand Down Expand Up @@ -600,7 +603,21 @@ struct SgToWiMultiDimReduction
op, "only unit leading dimensions are supported for "
"multi_reduction with rank > 2");
}
if (isReductionLaneLocal(op)) {
// Handle scalar result: full reduction of a distributed vector to a
// scalar. First do a local vector reduction, then cross-lane shuffles.
if (op.getType().isIntOrFloat()) {
auto reductionDim = reductionDims[0];
VectorType origSourceType = op.getSourceVectorType();
int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
// Local reduction to scalar, then cross-lane butterfly shuffles.
result =
xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
op.getKind(), reductionDimSize);
// Combine with accumulator if present.
if (adaptor.getAcc())
result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
result, adaptor.getAcc());
} else if (isReductionLaneLocal(op)) {
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
VectorType resVecTy = dyn_cast<VectorType>(op.getType());
auto resDistVecTyOrFailure =
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,35 @@ gpu.func @vector_broadcast_scalar_to_vector_uniform(%laneid: index) {
gpu.return
}
}

// -----
gpu.module @xevm_module {
// CHECK-LABEL: gpu.func @vector_multi_reduction_1d_to_scalar
// CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<32xf32>
// CHECK: %[[DIST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<32xf32> to vector<2xf32>
// CHECK: %[[ACC:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[LANE_RED:.*]] = vector.reduction <add>, %[[DIST]] : vector<2xf32> into f32
// CHECK: %[[SHFL1:.*]], %{{.*}} = gpu.shuffle xor %[[LANE_RED]], %[[C1:.*]], %[[C32:.*]] : f32
// CHECK: %[[ADD1:.*]] = arith.addf %[[LANE_RED]], %[[SHFL1]] : f32
// CHECK: %[[SHFL2:.*]], %{{.*}} = gpu.shuffle xor %[[ADD1]], %[[C2:.*]], %[[C32:.*]] : f32
// CHECK: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[SHFL2]] : f32
// CHECK: %[[SHFL3:.*]], %{{.*}} = gpu.shuffle xor %[[ADD2]], %[[C4:.*]], %[[C32:.*]] : f32
// CHECK: %[[ADD3:.*]] = arith.addf %[[ADD2]], %[[SHFL3]] : f32
// CHECK: %[[SHFL4:.*]], %{{.*}} = gpu.shuffle xor %[[ADD3]], %[[C8:.*]], %[[C32:.*]] : f32
// CHECK: %[[ADD4:.*]] = arith.addf %[[ADD3]], %[[SHFL4]] : f32
// CHECK: %[[SHFL5:.*]], %{{.*}} = gpu.shuffle xor %[[ADD4]], %[[C16:.*]], %[[C32:.*]] : f32
// CHECK: %[[ADD5:.*]] = arith.addf %[[ADD4]], %[[SHFL5]] : f32
// CHECK: %[[FINAL:.*]] = arith.addf %[[ADD5]], %[[ACC]] : f32
gpu.func @vector_multi_reduction_1d_to_scalar() {
%src = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
: () -> vector<32xf32>
%acc = arith.constant 0.0 : f32
%1 = vector.multi_reduction <add>, %src, %acc
{
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
}
[0] : vector<32xf32> to f32
gpu.return
}
}