diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 5a806799e896f..0aa2cd45088f3 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -147,13 +147,14 @@ Value lowerToVectorReductions(TypedValue src, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter); -/// Creates a constant vector filled with the neutral (identity) value for the +/// Creates a constant filled with the neutral (identity) value for the /// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND, /// max/min signed/unsigned int for MINSI/MINUI/MAXSI/MAXUI, and +/-infinity -/// for float min/max operations. Returns nullptr if the element type is -/// incompatible with the requested reduction kind. -Value createReductionNeutralValue(OpBuilder &builder, Location loc, - VectorType type, vector::CombiningKind kind); +/// for float min/max operations. If \p type is a VectorType, returns a splat +/// vector constant; otherwise returns a scalar constant. Returns nullptr if +/// the element type is incompatible with the requested reduction kind. +Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, + vector::CombiningKind kind); /// Lowers cross-lane reductions to shuffle operations on a 2D vector. /// Extracts slices along the reduction dimension, performs subgroup reductions diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 6dea94c0c5de3..3d1d1ca3ecf98 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1233,13 +1233,13 @@ struct WgToSgMultiDimReductionOp Location loc = op.getLoc(); VectorType srcType = op.getSourceVectorType(); - VectorType dstType = dyn_cast(op.getResult().getType()); - if (!dstType) - return failure(); + Type resultTy = op.getResult().getType(); + VectorType dstVecType = dyn_cast(resultTy); + bool isScalarResult = !dstVecType; auto originalSrcShape = srcType.getShape(); - auto originalDstShape = dstType.getShape(); int srcVecRank = originalSrcShape.size(); + Type elemTy = srcType.getElementType(); xegpu::DistributeLayoutAttr layout = xegpu::getTemporaryLayout(dyn_cast(op.getResult())); @@ -1258,25 +1258,33 @@ struct WgToSgMultiDimReductionOp return rewriter.notifyMatchFailure( op, "Reduction should have SliceAttr layout"); - Type elemTy = dstType.getElementType(); - - // Step 1: perform local subgroup reductions with ZERO accumulator + // Step 1: perform local subgroup reductions with neutral accumulator SmallVector localReductions; - SmallVector sgDstShape = - getSgShapeAndCount(originalDstShape, layout).first; auto sgSrcs = adaptor.getSource(); auto sgSrcType = dyn_cast(sgSrcs.front().getType()); SmallVector sgSrcShape(sgSrcType.getShape().begin(), sgSrcType.getShape().end()); - VectorType newDstType = VectorType::get(sgDstShape, elemTy); + // Determine the SG-level destination type. + // For scalar results (all dims reduced), the sg result is also scalar. + // For vector results, compute the sg destination shape from layout. + Type sgDstType; + if (dstVecType) { + auto originalDstShape = dstVecType.getShape(); + SmallVector sgDstShape = + getSgShapeAndCount(originalDstShape, layout).first; + sgDstType = VectorType::get(sgDstShape, elemTy); + } else { + sgDstType = elemTy; + } + for (auto sgSrc : sgSrcs) { - // Create ZERO accumulator for local reduction - auto neutralLocalAcc = xegpu::createReductionNeutralValue( - rewriter, loc, newDstType, op.getKind()); - // Local reduction with ZERO accumulator + // Create neutral accumulator for local reduction + Value neutralLocalAcc = xegpu::createReductionNeutralValue( + rewriter, loc, sgDstType, op.getKind()); + // Local reduction with neutral accumulator auto localReduce = vector::MultiDimReductionOp::create( - rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc, + rewriter, loc, sgDstType, op.getKind(), sgSrc, neutralLocalAcc, reductionDims); localReductions.push_back(localReduce.getResult()); } @@ -1310,8 +1318,15 @@ struct WgToSgMultiDimReductionOp for (int64_t dim : reductionDims) slmStoreDataShape[dim] = 1; VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy); - Value slmStoreData = vector::ShapeCastOp::create( - rewriter, loc, slmStoreDataType, localReductions[0]); + Value slmStoreData; + if (isScalarResult) { + // Scalar result: broadcast scalar to vector<1x...x1> for SLM store + slmStoreData = vector::BroadcastOp::create( + rewriter, loc, slmStoreDataType, localReductions[0]); + } else { + slmStoreData = vector::ShapeCastOp::create( + rewriter, loc, slmStoreDataType, localReductions[0]); + } SmallVector slmShape(originalSrcShape.begin(), originalSrcShape.end()); @@ -1393,12 +1408,12 @@ struct WgToSgMultiDimReductionOp rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets, /*layout=*/nullptr); - // Step 6: Perform final reduction with ZERO accumulator - auto neutralFinalAcc = xegpu::createReductionNeutralValue( - rewriter, loc, newDstType, op.getKind()); + // Step 6: Perform final reduction with neutral accumulator + Value neutralFinalAcc = xegpu::createReductionNeutralValue( + rewriter, loc, sgDstType, op.getKind()); auto finalReduce = vector::MultiDimReductionOp::create( - rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(), + rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(), neutralFinalAcc, reductionDims); // Step 7: Add the original accumulator at the end diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index f60635830cc74..6c902f725ca0c 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -801,77 +801,60 @@ Value xegpu::lowerCrossLaneReductionToShuffles( } Value xegpu::createReductionNeutralValue(OpBuilder &builder, Location loc, - VectorType type, + Type type, vector::CombiningKind kind) { - Type elemTy = type.getElementType(); + auto vecTy = dyn_cast(type); + Type elemTy = vecTy ? vecTy.getElementType() : type; + + // Helper to create either a splat vector or scalar constant from an attr. + auto makeConst = [&](Attribute scalarAttr) -> Value { + if (vecTy) + return arith::ConstantOp::create( + builder, loc, vecTy, DenseElementsAttr::get(vecTy, scalarAttr)); + return arith::ConstantOp::create(builder, loc, cast(scalarAttr)); + }; switch (kind) { case vector::CombiningKind::ADD: case vector::CombiningKind::XOR: case vector::CombiningKind::OR: - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getZeroAttr(elemTy))); + case vector::CombiningKind::MAXUI: + return makeConst(builder.getZeroAttr(elemTy)); case vector::CombiningKind::MUL: case vector::CombiningKind::AND: - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getOneAttr(elemTy))); + return makeConst(builder.getOneAttr(elemTy)); case vector::CombiningKind::MINSI: - // Use max signed int value for signed integer min - if (auto intTy = dyn_cast(elemTy)) { - auto maxVal = APInt::getSignedMaxValue(intTy.getWidth()); - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal))); - } + if (auto intTy = dyn_cast(elemTy)) + return makeConst(builder.getIntegerAttr( + elemTy, APInt::getSignedMaxValue(intTy.getWidth()))); return nullptr; case vector::CombiningKind::MINUI: - if (auto intTy = dyn_cast(elemTy)) { - auto maxVal = APInt::getMaxValue(intTy.getWidth()); - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal))); - } + if (auto intTy = dyn_cast(elemTy)) + return makeConst( + builder.getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth()))); return nullptr; case vector::CombiningKind::MAXSI: - if (auto intTy = dyn_cast(elemTy)) { - auto minVal = APInt::getSignedMinValue(intTy.getWidth()); - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, minVal))); - } + if (auto intTy = dyn_cast(elemTy)) + return makeConst(builder.getIntegerAttr( + elemTy, APInt::getSignedMinValue(intTy.getWidth()))); return nullptr; - case vector::CombiningKind::MAXUI: - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getZeroAttr(elemTy))); - case vector::CombiningKind::MINNUMF: case vector::CombiningKind::MINIMUMF: - // Use +infinity for float min operations - if (auto floatTy = dyn_cast(elemTy)) { - auto posInf = APFloat::getInf(floatTy.getFloatSemantics()); - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, posInf))); - } + if (auto floatTy = dyn_cast(elemTy)) + return makeConst(builder.getFloatAttr( + elemTy, APFloat::getInf(floatTy.getFloatSemantics()))); return nullptr; case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: - // Use -infinity for float max operations - if (auto floatTy = dyn_cast(elemTy)) { - auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true); - return arith::ConstantOp::create( - builder, loc, type, - DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, negInf))); - } + if (auto floatTy = dyn_cast(elemTy)) + return makeConst(builder.getFloatAttr( + elemTy, APFloat::getInf(floatTy.getFloatSemantics(), true))); return nullptr; } return nullptr; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index ecc5fe3dd75e0..950d9ba66f0cc 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s -// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)> -// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)> -// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)> -// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)> -// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)> +// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)> +// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)> +// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)> +// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)> // CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)> // CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)> // CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)> @@ -412,6 +412,33 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: gpu.func @vector_reduce_scalar_cross_sg + // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>) + // CHECK-DAG: %[[CST:.*]] = arith.constant {{.*}} 0.000000e+00 : f32 + // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32> + // CHECK-DAG: %[[CST_ACC:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[LOCAL:.*]] = vector.multi_reduction , %[[LOAD]], %[[CST_ACC]] [0, 1] : vector<8x8xf32> to f32 + // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[LOCAL]] : f32 to vector<1x1xf32> + // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<64xi8, 3> + // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<64xi8, 3> -> !xegpu.mem_desc<4x4xf32> + // CHECK-DAG: xegpu.store_matrix %[[BCAST]], %[[MEM_DESC]]{{.*}} : vector<1x1xf32>, !xegpu.mem_desc<4x4xf32> + // CHECK-DAG: gpu.barrier + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} -> vector<4x4xf32> + // CHECK-DAG: %[[CST_FINAL:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[FINAL:.*]] = vector.multi_reduction , %[[LOAD_SLM]], %[[CST_FINAL]] [0, 1] : vector<4x4xf32> to f32 + // CHECK-DAG: arith.addf %[[FINAL]], %[[CST]] : f32 + gpu.func @vector_reduce_scalar_cross_sg(%src: memref<32x32xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} 0.0 : f32 + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> + -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout} + : !xegpu.tensor_desc<32x32xf32, #xegpu.layout> + -> vector<32x32xf32> + %reduce = vector.multi_reduction , %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} [0, 1] + : vector<32x32xf32> to f32 + gpu.return + } + // CHECK-LABEL: vector_step_op gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index @@ -654,9 +681,9 @@ gpu.module @test_distribution { // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map()[%[[SGID]]] - // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map2()[%[[SGID]]] + // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]] + // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]] + // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]] // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index