diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index fce14999b4011..c9ff0a9ad1c03 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -915,9 +915,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>, let description = [{DPAS performs matrix multiplication on matrix A of `mxk` size, B of `kxn` size, and accumulate on matrix C of `mxn` to the same size - matrix , `m=8`, `n=16` and `k=8 * 32/bit_width_of_elem_type`. So for fp16 - data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`, - and `C/D: vector<8x16xf32>`. + matrix. + + The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2 + dimensions, the leading dimensions are treated as **batch dimensions** that + must match across all operands (lhs, rhs, acc, and result). The batch + dimensions represent independent matrix multiplications that are executed in + parallel. For example: + - 2D: `A: vector<8x16xf16>`, `B: vector<16x16xf16>` -> `C: vector<8x16xf32>` + - 3D: `A: vector<4x8x16xf16>`, `B: vector<4x16x16xf16>` -> `C: vector<4x8x16xf32>` + (4 independent 8x16 matrix multiplications) + - 4D: `A: vector<2x4x8x16xf16>`, `B: vector<2x4x16x16xf16>` -> `C: vector<2x4x8x16xf32>` + (2x4=8 independent 8x16 matrix multiplications) + + The last 2 dimensions are always the core matrix multiplication dimensions + (M, K for lhs; K, N for rhs; M, N for result). In lane level code, each lane from a subgroup holds a data fragment for A, B, C and the result, which are represented as 1D vectors. Please refer to [OpenCL Intel extentions] @@ -930,19 +942,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>, Arguments: - `lhs`: A vector value representing the left-hand-side matrix tile (A) participating in the - matrix multiply. + matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions. - - `rhs`: A vector value representing the right-hand-side matrix tile (B). + - `rhs`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D + where leading dimensions are batch dimensions, plus an optional trailing dimension for VNNI packing. - `acc`: [optional] A vector value representing the accumulator matrix tile (C). When present, the result is computed as `lhs * rhs + acc`; otherwise, the accumulator is implicitly assumed to be zero. + Must have the same batch dimensions as lhs and rhs. - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup level. - Example 1 (Workgroup level): + Example 1 (Workgroup level, 2D): ```mlir %d = xegpu.dpas %a, %b, %c <{ @@ -952,12 +966,20 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>, : vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32> ``` - Example 2 (Lane level): + Example 2 (Lane level, 1D): ```mlir %d = xegpu.dpas %a, %b, %c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> ``` + + Example 3 (Workgroup level, 3D with batch): + + ```mlir + // 4 independent 8x16 matrix multiplications + %d = xegpu.dpas %a, %b, %c + : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32> + ``` }]; let arguments = (ins @@ -1428,6 +1450,21 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments, size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size matrix. + The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2 + dimensions, the leading dimensions are treated as **batch dimensions** that + must match across all operands (a, b, acc, result, scale_a, scale_b). The batch + dimensions represent independent matrix multiplications that are executed in + parallel. For example: + - 2D: `A: vector<8x32xf8E5M2>`, `B: vector<32x16xf8E5M2>` -> `C: vector<8x16xbf16>` + - 3D: `A: vector<4x8x32xf8E5M2>`, `B: vector<4x32x16xf8E5M2>` -> `C: vector<4x8x16xbf16>` + (4 independent 8x16 scaled matrix multiplications) + - 4D: `A: vector<2x4x8x32xf8E5M2>`, `B: vector<2x4x32x16xf8E5M2>` -> `C: vector<2x4x8x16xbf16>` + (2x4=8 independent 8x16 scaled matrix multiplications) + + The last 2 dimensions are always the core matrix multiplication dimensions + (M, K for a; K, N for b; M, N for result). The scale vectors also follow the + same batch dimension structure. + In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result, which are represented as 1D vectors. @@ -1437,18 +1474,19 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments, Arguments: - `a`: A vector value representing the left-hand-side matrix tile (A) participating in the - matrix multiply. + matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions. - - `b`: A vector value representing the right-hand-side matrix tile (B). + - `b`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D + where leading dimensions are batch dimensions. - `acc`: A vector value representing the accumulator matrix tile (C). The - result is computed as `a * b + acc`. + result is computed as `a * b + acc`. Must have the same batch dimensions as a and b. - `scale_a`: A floating point vector/scalar value used to scale `a` for - matrix multiplication. + matrix multiplication. When a vector, must have matching batch dimensions. - `scale_b`: A floating point vector/scalar value used to scale `b` for - matrix multiplication. + matrix multiplication. When a vector, must have matching batch dimensions. - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts @@ -1460,9 +1498,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments, let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b, Optional:$acc, Optional]>>:$scale_a, + VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_a, Optional]>>:$scale_b, + VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_b, OptionalAttr:$layout_a, OptionalAttr:$layout_b, OptionalAttr:$layout_cd, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 95a3e22cf803b..0423303c23493 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>; def XeGPU_BaseAddrType : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>; -def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>; -def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>; +def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>; +def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>; def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>; def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>; def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType]>; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 56d340141ee8f..bebdd22e1c087 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -720,31 +720,78 @@ static LogicalResult verifyDpasDimensions(Operation *op, if (aRank == 1 && bRank == 1 && resRank == 1) return success(); - // Validate A and B are 2D - if (aRank != 2) - return op->emitOpError("A operand must be a 2D vector."); - if (bRank < 2 || bRank > 3) - return op->emitOpError("B operand must be a 2D or 3D vector."); - if (resRank != 2) - return op->emitOpError("Result must be a 2D vector."); + // A must be at least 2D, B must be 2D or 3D (innermost dims), result at + // least 2D. + if (aRank < 2) + return op->emitOpError("A operand must be at least a 2D vector."); + if (bRank < 2) + return op->emitOpError("B operand must be at least a 2D vector."); + if (resRank < 2) + return op->emitOpError("Result must be at least a 2D vector."); + + // Determine batch dimensions. For A[batch..., M, K], B[batch..., K, N] (or + // B[batch..., K/vnni, N, vnni]), result[batch..., M, N]. + // B may have one extra trailing dim for VNNI packing (3D innermost). + // Determine how many trailing dims are the "core" matmul dims. + // A core dims: last 2 (M, K) + // B core dims: last 2 (K, N) or last 3 (K/vnni, N, vnni) for packed + // Result core dims: last 2 (M, N) + int64_t aBatchRank = aRank - 2; + int64_t resBatchRank = resRank - 2; + + // B can have an extra trailing dim for VNNI packing. Determine B's batch + // rank: if bRank > aRank, the extra dim is the VNNI packing dim. + bool bPacked = (bRank == aRank + 1); + int64_t bBatchRank = bPacked ? bRank - 3 : bRank - 2; + + // Batch ranks must match across A, B, and result. + if (aBatchRank != bBatchRank || aBatchRank != resBatchRank) + return op->emitOpError("Batch dimension rank mismatch among A, B, and " + "result."); + + // Verify batch dimensions match. + for (int64_t i = 0; i < aBatchRank; ++i) { + if (aShape[i] != resShape[i]) + return op->emitOpError("Batch dimension mismatch at dim ") + << i << ": A has " << aShape[i] << " but result has " + << resShape[i] << "."; + if (aShape[i] != bShape[i]) + return op->emitOpError("Batch dimension mismatch at dim ") + << i << ": A has " << aShape[i] << " but B has " << bShape[i] + << "."; + } - // Calculate effective K dimension for B (handle 3D packed case) - int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0]; + // Now verify the core matmul dimensions (last 2 of A and result, last 2 or + // 3 of B). + int64_t aM = aShape[aBatchRank]; + int64_t aK = aShape[aBatchRank + 1]; + int64_t resM = resShape[resBatchRank]; + int64_t resN = resShape[resBatchRank + 1]; + + // Calculate effective K dimension for B (handle packed case) + int64_t bK, bN; + if (bPacked) { + bK = bShape[bBatchRank] * bShape[bBatchRank + 2]; + bN = bShape[bBatchRank + 1]; + } else { + bK = bShape[bBatchRank]; + bN = bShape[bBatchRank + 1]; + } // Verify K dimension match between A and B - if (bK != aShape[1]) + if (bK != aK) return op->emitOpError("K-dimension mismatch: A has K=") - << aShape[1] << " but B has K=" << bK << "."; + << aK << " but B has K=" << bK << "."; // Verify M dimension match between A and result - if (aShape[0] != resShape[0]) + if (aM != resM) return op->emitOpError("M-dimension mismatch: A has M=") - << aShape[0] << " but result has M=" << resShape[0] << "."; + << aM << " but result has M=" << resM << "."; // Verify N dimension match between B and result - if (bShape[1] != resShape[1]) + if (bN != resN) return op->emitOpError("N-dimension mismatch: B has N=") - << bShape[1] << " but result has N=" << resShape[1] << "."; + << bN << " but result has N=" << resN << "."; return success(); } @@ -907,6 +954,9 @@ LogicalResult DpasMxOp::verify() { if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape))) return failure(); + // Determine batch rank from A operand. + int64_t aBatchRank = aShape.size() - 2; + // Validate scale_a if present if (getScaleA()) { auto scaleAVecType = dyn_cast(getScaleAType()); @@ -914,19 +964,20 @@ LogicalResult DpasMxOp::verify() { if (scaleAVecType && scaleAVecType.getRank() > 1) { auto scaleAShape = scaleAVecType.getShape(); - if (scaleAVecType.getRank() != 2) - return emitOpError("Scale A must be a 2D vector when not a scalar."); + if (scaleAVecType.getRank() < 2) + return emitOpError("Scale A must be at least a 2D vector when not a " + "scalar."); // Verify layout distributability for scale_a if (failed(verifyLayoutDistributable(*this, getLayoutAScale(), scaleAShape, "ScaleA"))) return failure(); - // Validate M dimension: scale_a[0] must match a[0] - if (scaleAShape[0] != aShape[0]) + // Validate M dimension: scale_a's M must match A's M (last-1 dim) + if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank]) return emitOpError("Scale A M dimension [") - << scaleAShape[0] << "] must match A M dimension [" << aShape[0] - << "]."; + << scaleAShape[scaleAShape.size() - 2] + << "] must match A M dimension [" << aShape[aBatchRank] << "]."; } } @@ -937,19 +988,21 @@ LogicalResult DpasMxOp::verify() { if (scaleBVecType && scaleBVecType.getRank() > 1) { auto scaleBShape = scaleBVecType.getShape(); - if (scaleBVecType.getRank() != 2) - return emitOpError("Scale B must be a 2D vector when not a scalar."); + if (scaleBVecType.getRank() < 2) + return emitOpError("Scale B must be at least a 2D vector when not a " + "scalar."); // Verify layout distributability for scale_b if (failed(verifyLayoutDistributable(*this, getLayoutBScale(), scaleBShape, "ScaleB"))) return failure(); - // Validate N dimension: scale_b[1] must match b[1] - if (scaleBShape[1] != bShape[1]) + // Validate N dimension: scale_b's N (last dim) must match B's N (last + // dim) + if (scaleBShape.back() != bShape.back()) return emitOpError("Scale B N dimension [") - << scaleBShape[1] << "] must match B N dimension [" << bShape[1] - << "]."; + << scaleBShape.back() << "] must match B N dimension [" + << bShape.back() << "]."; } } @@ -964,11 +1017,12 @@ LogicalResult DpasMxOp::verify() { auto scaleAShape = scaleAVecType.getShape(); auto scaleBShape = scaleBVecType.getShape(); - // Validate scale K dimension compatibility: scale_a[1] must match - // scale_b[0] - if (scaleAShape[1] != scaleBShape[0]) + // Validate scale K dimension compatibility: scale_a's last dim must + // match scale_b's second-to-last dim + if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2]) return emitOpError("Scale K dimension mismatch: scale_a has K=") - << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0] + << scaleAShape.back() + << " but scale_b has K=" << scaleBShape[scaleBShape.size() - 2] << "."; } } diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 341427e7e1231..438f675ecc3f4 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -394,7 +394,7 @@ func.func @dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) { // ----- func.func @dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) { - // expected-error@+1 {{op A operand must be a 2D vector}} + // expected-error@+1 {{'xegpu.dpas' op Batch dimension rank mismatch among A, B, and result}} %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> return } @@ -702,21 +702,21 @@ func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector // ----- func.func @dpas_mx_a_not_2d(%a : vector<128xf8E5M2>, %b: vector<16x16xf8E5M2>) { - // expected-error@+1 {{A operand must be a 2D vector.}} + // expected-error@+1 {{A operand must be at least a 2D vector.}} %1 = xegpu.dpas_mx %a, %b : (vector<128xf8E5M2>, vector<16x16xf8E5M2>) -> vector<8x16xf32> return } // ----- func.func @dpas_mx_b_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<256xf8E5M2>) { - // expected-error@+1 {{B operand must be a 2D or 3D vector.}} + // expected-error@+1 {{B operand must be at least a 2D vector.}} %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<256xf8E5M2>) -> vector<8x16xf32> return } // ----- func.func @dpas_mx_result_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>) { - // expected-error@+1 {{Result must be a 2D vector.}} + // expected-error@+1 {{Result must be at least a 2D vector.}} %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<16x16xf8E5M2>) -> vector<128xf32> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 198dc64d2814b..5cc74cdf7c9e4 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -649,4 +649,25 @@ gpu.func @dpas_mx(%a : vector<8x32xf8E5M2>, %b: vector<32x16xf8E5M2>, %acc: vect gpu.return } +// CHECK-LABEL: gpu.func @dpas_3d_batch +gpu.func @dpas_3d_batch(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>) { + // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32> + %1 = xegpu.dpas %a, %b: vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32> + gpu.return +} + +// CHECK-LABEL: gpu.func @dpas_3d_batch_with_acc +gpu.func @dpas_3d_batch_with_acc(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>, %acc: vector<4x8x16xf32>) { + // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32> + %1 = xegpu.dpas %a, %b, %acc : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32> + gpu.return +} + +// CHECK-LABEL: gpu.func @dpas_mx_3d_batch +gpu.func @dpas_mx_3d_batch(%a : vector<4x8x32xf8E5M2>, %b: vector<4x32x16xf8E5M2>, %acc: vector<4x8x16xbf16>, %a_scale: vector<4x8x1xf8E8M0FNU>, %b_scale: vector<4x1x16xf8E8M0FNU>) { + // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} scale_a = %{{.+}} scale_b = %{{.+}} : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16> + %1 = xegpu.dpas_mx %a, %b, %acc scale_a = %a_scale scale_b = %b_scale : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16> + gpu.return +} + }