-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Support 2-way widening outer products #78975
[mlir][ArmSME] Support 2-way widening outer products #78975
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sme Author: Cullen Rhodes (c-rhodes) ChangesThis patch introduces support for widening outer products. This enables the folding of 2 or 4 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations. Changes:
For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' and 'arm_sme.smopa_wide_4way' descriptions. The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this. Patch is 154.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78975.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc46..f051e03efbcda64 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
class ArmSME_IntrLoadStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012fe..1365dd38c115ef2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -736,6 +736,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -814,6 +815,650 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ Pure,
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and accumulate";
+
+ let description = [{
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (fp16 to fp32):
+
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS RHS
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1] |
+ [A2 A3] | [B0 B2 B4 B6]
+ [A4 A5] | [B1 B3 B5 B7]
+ [A6 A7] |
+
+ ----------------------------------------------------------------------------
+
+ 2 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
+ |
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 2 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 2 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ The 2 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ ```
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+me
+ [1] https://developer.arm.com/documentation/ddi0616
+ }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+ : OuterProductWide2Way<"fmops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and subtract";
+ let description = [{
+ Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+ destination `result`.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+ ```
+ }];
+}
+
+def SMopaWide2WayOp
+ : OuterProductWide2Way<"smopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopsWide2WayOp
+ : OuterProductWide2Way<"smops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopaWide2WayOp
+ : OuterProductWide2Way<"umopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopsWide2WayOp
+ : OuterProductWide2Way<"umops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopaWide4WayOp
+ : OuterProductWide4Way<"smopa_wide_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ This operation represents a sum of 4 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (i8 to i32):
+
+ ```mlir
+ %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+ vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+ 4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS
+ [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+ RHS
+ [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1 A2 A3] | [B0 B4 B8 B12]
+ [A4 A5 A6 A7] | [B1 B5 B9 B13]
+ [A8 A9 A10 A11] | [B2 B6 B10 B14]
+ [A12 A13 A14 A15] | [B3 B7 B11 B15]
+
+ ----------------------------------------------------------------------------
+
+ 4 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B4 B8 B12] | [B1 B5 B9 B13]
+ |
+ [A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
+ A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
+ A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
+ A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
+ |
+ Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
+ ------------- | -------------
+ |
+ [B2, B6, B10, B14] | [B3 B7 B11 B15]
+ |
+ [A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
+ A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
+ A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
+ A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 4 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+ [ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
+ [ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
+ [ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
+ [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 4 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+ ```
+
+ The 4 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a2_ins = vector.scalable.ins %a2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a3_ins = vector.scalable.ins %a3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b2_ins = vector.scalable.ins %b2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b3_ins = vector.scalable.ins %b3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %0 = arm_sme.smopa_wide_4way %lh...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Cullen Rhodes (c-rhodes) ChangesThis patch introduces support for widening outer products. This enables the folding of 2 or 4 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations. Changes:
For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' and 'arm_sme.smopa_wide_4way' descriptions. The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this. Patch is 154.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78975.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc46..f051e03efbcda64 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
class ArmSME_IntrLoadStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012fe..1365dd38c115ef2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -736,6 +736,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -814,6 +815,650 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ Pure,
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and accumulate";
+
+ let description = [{
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (fp16 to fp32):
+
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS RHS
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1] |
+ [A2 A3] | [B0 B2 B4 B6]
+ [A4 A5] | [B1 B3 B5 B7]
+ [A6 A7] |
+
+ ----------------------------------------------------------------------------
+
+ 2 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
+ |
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 2 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 2 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ The 2 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ ```
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+me
+ [1] https://developer.arm.com/documentation/ddi0616
+ }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+ : OuterProductWide2Way<"fmops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and subtract";
+ let description = [{
+ Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+ destination `result`.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+ ```
+ }];
+}
+
+def SMopaWide2WayOp
+ : OuterProductWide2Way<"smopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopsWide2WayOp
+ : OuterProductWide2Way<"smops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopaWide2WayOp
+ : OuterProductWide2Way<"umopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopsWide2WayOp
+ : OuterProductWide2Way<"umops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopaWide4WayOp
+ : OuterProductWide4Way<"smopa_wide_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ This operation represents a sum of 4 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (i8 to i32):
+
+ ```mlir
+ %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+ vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+ 4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS
+ [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+ RHS
+ [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1 A2 A3] | [B0 B4 B8 B12]
+ [A4 A5 A6 A7] | [B1 B5 B9 B13]
+ [A8 A9 A10 A11] | [B2 B6 B10 B14]
+ [A12 A13 A14 A15] | [B3 B7 B11 B15]
+
+ ----------------------------------------------------------------------------
+
+ 4 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B4 B8 B12] | [B1 B5 B9 B13]
+ |
+ [A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
+ A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
+ A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
+ A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
+ |
+ Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
+ ------------- | -------------
+ |
+ [B2, B6, B10, B14] | [B3 B7 B11 B15]
+ |
+ [A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
+ A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
+ A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
+ A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 4 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+ [ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
+ [ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
+ [ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
+ [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 4 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+ ```
+
+ The 4 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a2_ins = vector.scalable.ins %a2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a3_ins = vector.scalable.ins %a3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b2_ins = vector.scalable.ins %b2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b3_ins = vector.scalable.ins %b3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %0 = arm_sme.smopa_wide_4way %lh...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick comment, will have a proper look later:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more comments, will look again soon, not looked at everything yet :)
Also, not sure if you've noticed but this test is failing in CI:
Failed Tests (1):
MLIR :: Dialect/ArmSME/outer-product-widening.mlir
(Looks like it's crashing during a rewrite)
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask, | ||
Optional<AnyVector>:$acc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional<AnyVector>:$acc
-> Optional<SMETile>:$acc
?
Optional<AnyVector>:$lhsMask
-> Optional<SVEPredicate>
?
Optional<AnyVector>:$rhsMask
-> Optional<SVEPredicate>
?
AnyVector:$rhs
-> SVEVector
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
functionally it makes no difference, this could literally be AnyType
(and perhaps it should), as AFAIK there's no way of expressing with the type alone that lhs
must equal rhs
without using constraints, unless the op only accepted one input type of course, in that case it could be inputType:$lhs, inputType:$rhs
.
It's a bit unfortunate because the auto-generated op documentation will say rhs
is vector of any type values
, which we know isn't true, but a vector type that matches the size of a SVE vector
isn't true either.
e14d7ed
to
bc9df96
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
bc9df96
to
455ae29
Compare
auto extOp = op.getLhs().getDefiningOp(); | ||
VectorType extSourceVectorType = | ||
cast<VectorType>(extOp->getOperand(0).getType()); | ||
VectorType widenedVectorType = | ||
VectorType::Builder(extSourceVectorType) | ||
.setDim(0, extSourceVectorType.getShape()[0] * 2); | ||
auto lhs = packInputs(widenedVectorType, | ||
op1.getLhs().getDefiningOp()->getOperand(0), | ||
op2.getLhs().getDefiningOp()->getOperand(0)); | ||
auto rhs = packInputs(widenedVectorType, | ||
op1.getRhs().getDefiningOp()->getOperand(0), | ||
op2.getRhs().getDefiningOp()->getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May want to check the getDefiningOp()
results, it can return null (e.g. for function arguments, loop iter_args, etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's checked in isWidenable
so think it's ok here?
455ae29
to
2ca481d
Compare
Rebased and updated to use target-agnostic interleave2 intrinsic instead of zip1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more comments.
Overall I'm happy with the overall design and the execution. I am away next week - don't wait for me to +1 this. As long as my comments and comments from @MacDue are addressed, I am happy for this to be merged.
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
Outdated
Show resolved
Hide resolved
This patch introduces support for 2-way widening outer products. This enables the folding of 2 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations. Changes: - Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants. These map to instruction variants added in SME2 and use different intrinsics. Intrinsics are already implemented for widening variants from SME1. - Adds the following operations: - fmopa_wide_2way, fmops_wide_2way - smopa_wide_2way, smops_wide_2way - umopa_wide_2way, umops_wide_2way - Implements conversions for the above ops to intrinsics in ArmSMEToLLVM. - Adds a pass 'arm-sme-outer-product' widening that folds 'arm_sme.outerproduct' operations. For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' description. The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this.
- hasOneUse - Negative tests for pass - move types and isWidening checks to isSupported method - negative test for unsupported type - move masking check - op1.erase() -> rewriter.eraseOp(op1); - braces round dangling-else - TypeSwitch - rename isWidenable and add comments for clarity - add TODO/REDEFINE for QEMU bug to make it clearer - s/widening/fusion/g - drop wide from op names
2ca481d
to
589c0d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I have a few final nits, but nothing major :)
- Backticks in constraint message. - Fix comment. - Remove ArmSVE from libs now zip1 is gone. - rename 'widenedType' to 'inputTypeX2' and move to 'packInputs'. - Use getIn for arith ext ops. - Add test to verify chain of 4 outer products are fused into 2 2-way outer products.
Believe I've addressed all comments now, thanks for reviewing! Andrzej is off this week so I'll land this tomorrow unless there's any further comments by then. Cheers. |
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following sequence: %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32> ... (same for rhs) %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32> // x4 chained by accumulator This chain of 4 outer products can be fused into a single 4-way widening variant but the pass doesn't match on the IR, as it expects the source of the inputs to be an extend and it can't look through the extracts. This patch fixes this with two rewrites that swaps extract(extend) into extend(extract). Related to llvm#78975, llvm#79288.
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following sequence: %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32> ... (same for rhs) %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32> // x4 chained by accumulator This chain of 4 outer products can be fused into a single 4-way widening variant but the pass doesn't match on the IR, as it expects the source of the inputs to be an extend and it can't look through the extracts. This patch fixes this with two rewrites that swaps extract(extend) into extend(extract). Related to llvm#78975, llvm#79288.
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following sequence: %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32> ... (same for rhs) %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32> // x4 chained by accumulator This chain of 4 outer products can be fused into a single 4-way widening variant but the pass doesn't match on the IR, as it expects the source of the inputs to be an extend and it can't look through the extracts. This patch fixes this with two rewrites that swaps extract(extend) into extend(extract). Related to #78975, #79288.
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following sequence: %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32> ... (same for rhs) %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32> // x4 chained by accumulator This chain of 4 outer products can be fused into a single 4-way widening variant but the pass doesn't match on the IR, as it expects the source of the inputs to be an extend and it can't look through the extracts. This patch fixes this with two rewrites that swaps extract(extend) into extend(extract). Related to llvm#78975, llvm#79288.
This patch introduces support for 2-way widening outer products. This
enables the fusion of 2 'arm_sme.outerproduct' operations that are
chained via the accumulator into a 2-way widening outer product operation.
Changes:
These map to instruction variants added in SME2 and use different
intrinsics. Intrinsics are already implemented for widening variants
from SME1.
'arm_sme.outerproduct' operations.
For a detailed description of these operations see the
'arm_sme.fmopa_2way' description.
The reason for introducing many operations rather than one is the
signed/unsigned variants can't be distinguished with types (e.g., ui16,
si16) since 'arith.extui' and 'arith.extsi' only support signless
integers. A single operation would require this information and an
attribute (for example) for the sign doesn't feel right if
floating-point types are also supported where this wouldn't apply.
Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32)
introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but
no subtract variant. Whilst these are not supported in this patch, it
felt simpler to have separate ops for add/subtract given this.