Skip to content

Conversation

momchil-velikov
Copy link
Collaborator

Supersedes #135358

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

Supersedes #135358


Full diff: https://github.com/llvm/llvm-project/pull/135634.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+32)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+4)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+12)
  • (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+11)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 1a59062ccc93d..da2a8f89b4cfd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index fe13ed03356b2..b1846e15196fc 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
                DupQLaneLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
                     DupQLaneIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
                       DupQLaneOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 5d044517e0ea8..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.usmmla
+  %0 = arm_sve.usmmla %c, %a, %b :
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+  %0 = arm_sve.usmmla %c, %a, %b :
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ced59eb513b57..4d9b0da611cb0 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+  %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Supersedes #135358


Full diff: https://github.com/llvm/llvm-project/pull/135634.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+32)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+4)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+12)
  • (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+11)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+12)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 1a59062ccc93d..da2a8f89b4cfd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index fe13ed03356b2..b1846e15196fc 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
                DupQLaneLowering,
                ScalableMaskedAddIOpLowering,
                ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
                     DupQLaneIntrOp,
                     ScalableMaskedAddIIntrOp,
                     ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
                       DupQLaneOp,
                       ScalableMaskedAddIOp,
                       ScalableMaskedAddFOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 5d044517e0ea8..47587aa26506c 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.usmmla
+  %0 = arm_sve.usmmla %c, %a, %b :
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8575772..64e0cff39eb06 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+  %0 = arm_sve.usmmla %c, %a, %b :
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ced59eb513b57..4d9b0da611cb0 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+  %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nit

Thanks!

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svusmmla branch from 71e2f13 to 5e91c2e Compare April 15, 2025 15:54
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svdupq-lane branch from 02e68ee to 7e72ad9 Compare April 15, 2025 15:54
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svdupq-lane branch from 7e72ad9 to dbf1aa0 Compare May 14, 2025 17:16
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svusmmla branch 2 times, most recently from 5282373 to e60ca5a Compare May 15, 2025 12:45
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svdupq-lane branch from dbf1aa0 to 7c36d36 Compare May 15, 2025 12:45
Base automatically changed from users/momchil-velikov/svdupq-lane to main May 16, 2025 15:47
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/svusmmla branch from e60ca5a to 9eee3ad Compare May 16, 2025 15:55
@momchil-velikov momchil-velikov merged commit e9c9c33 into main May 16, 2025
11 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/svusmmla branch May 16, 2025 16:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants