-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla #135634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135358 Full diff: https://github.com/llvm/llvm-project/pull/135634.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 1a59062ccc93d..da2a8f89b4cfd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+ The unsigned by signed integer matrix multiply-accumulate operation
+ multiplies the 2×8 matrix of unsigned 8-bit integer values held
+ the first source vector by the 8×2 matrix of signed 8-bit integer
+ values in the second source vector. The resulting 2×2 widened 32-bit
+ integer matrix product is then added to the 32-bit integer matrix
+ accumulator.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def UsmmlaIntrOp :
+ ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index fe13ed03356b2..b1846e15196fc 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ UsmmlaOpLowering,
DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ UsmmlaIntrOp,
DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ UsmmlaOp,
DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 5d044517e0ea8..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: arm_sve.intr.usmmla
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ced59eb513b57..4d9b0da611cb0 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+ %arg1: vector<[16]xi8>,
+ %arg2: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+ %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+ -> vector<[4]xi32>
+ llvm.return %0 : vector<[4]xi32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135358 Full diff: https://github.com/llvm/llvm-project/pull/135634.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 1a59062ccc93d..da2a8f89b4cfd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+ The unsigned by signed integer matrix multiply-accumulate operation
+ multiplies the 2×8 matrix of unsigned 8-bit integer values held
+ the first source vector by the 8×2 matrix of signed 8-bit integer
+ values in the second source vector. The resulting 2×2 widened 32-bit
+ integer matrix product is then added to the 32-bit integer matrix
+ accumulator.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def UsmmlaIntrOp :
+ ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index fe13ed03356b2..b1846e15196fc 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ UsmmlaOpLowering,
DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ UsmmlaIntrOp,
DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ UsmmlaOp,
DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 5d044517e0ea8..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: arm_sve.intr.usmmla
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+ %0 = arm_sve.usmmla %c, %a, %b :
+ vector<[16]xi8> to vector<[4]xi32>
+ return %0 : vector<[4]xi32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ced59eb513b57..4d9b0da611cb0 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+ %arg1: vector<[16]xi8>,
+ %arg2: vector<[4]xi32>)
+ -> vector<[4]xi32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+ %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+ -> vector<[4]xi32>
+ llvm.return %0 : vector<[4]xi32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit
Thanks!
71e2f13
to
5e91c2e
Compare
02e68ee
to
7e72ad9
Compare
7e72ad9
to
dbf1aa0
Compare
5282373
to
e60ca5a
Compare
dbf1aa0
to
7c36d36
Compare
e60ca5a
to
9eee3ad
Compare
Supersedes #135358