Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen] Improve ExpandMemCmp for more efficient non-register aligned sizes handling #69942

Merged
merged 3 commits into from
Oct 27, 2023

Conversation

igogo-x86
Copy link
Contributor

  • Enhanced the logic of ExpandMemCmp pass to merge contiguous subsequences
    in LoadSequence, based on sizes allowed in AllowedTailExpansions.
  • This enhancement seeks to minimize the number of basic blocks and produce
    optimized code when using memcmp with non-register aligned sizes.
  • Enable this feature for AArch64 with memcmp sizes modulo 8 equal to
    3, 5, and 6.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: Igor Kirillov (igogo-x86)

Changes
  • Enhanced the logic of ExpandMemCmp pass to merge contiguous subsequences
    in LoadSequence, based on sizes allowed in AllowedTailExpansions.
  • This enhancement seeks to minimize the number of basic blocks and produce
    optimized code when using memcmp with non-register aligned sizes.
  • Enable this feature for AArch64 with memcmp sizes modulo 8 equal to
    3, 5, and 6.

Patch is 147.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69942.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+11)
  • (modified) llvm/lib/CodeGen/ExpandMemCmp.cpp (+59-15)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+1)
  • (added) llvm/test/CodeGen/AArch64/memcmp.ll (+3002)
  • (added) llvm/test/Transforms/ExpandMemCmp/AArch64/memcmp.ll (+877)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..3ec80d99b392b2e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -907,6 +907,17 @@ class TargetTransformInfo {
     // be done with two 4-byte compares instead of 4+2+1-byte compares. This
     // requires all loads in LoadSizes to be doable in an unaligned way.
     bool AllowOverlappingLoads = false;
+
+    // Sometimes, the amount of data that needs to be compared is smaller than
+    // the standard register size, but it cannot be loaded with just one load
+    // instruction. For example, if the size of the memory comparison is 6
+    // bytes, we can handle it more efficiently by loading all 6 bytes in a
+    // single block and generating an 8-byte number, instead of generating two
+    // separate blocks with conditional jumps for 4 and 2 byte loads. This
+    // approach simplifies the process and produces the comparison result as
+    // normal. This array lists the allowed sizes of memcmp tails that can be
+    // merged into one block
+    SmallVector<unsigned, 4> AllowedTailExpansions;
   };
   MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
                                                bool IsZeroCmp) const;
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index 911ebd41afc5b91..d9c2c6f5f39ba6d 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -117,8 +117,8 @@ class MemCmpExpansion {
     Value *Lhs = nullptr;
     Value *Rhs = nullptr;
   };
-  LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
-                       unsigned OffsetBytes);
+  LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *BSwapSizeType,
+                       Type *CmpSizeType, unsigned OffsetBytes);
 
   static LoadEntryVector
   computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -255,6 +255,31 @@ MemCmpExpansion::MemCmpExpansion(
     }
   }
   assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
+  // This part of code attempts to optimize the LoadSequence by merging allowed
+  // subsequences into single loads of allowed sizes from
+  // `AllowedTailExpansions`. If it is for zero comparison or if no allowed tail
+  // expansions are specified, we exit early.
+  if (IsUsedForZeroCmp || !Options.AllowedTailExpansions.size())
+    return;
+
+  while (LoadSequence.size() >= 2) {
+    auto Last = LoadSequence[LoadSequence.size() - 1];
+    auto PreLast = LoadSequence[LoadSequence.size() - 2];
+
+    // Exit the loop if the two sequences are not contiguous
+    if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
+      break;
+
+    auto LoadSize = Last.LoadSize + PreLast.LoadSize;
+    if (find(Options.AllowedTailExpansions, LoadSize) ==
+        Options.AllowedTailExpansions.end())
+      break;
+
+    // Remove the last two sequences and replace with the combined sequence
+    LoadSequence.pop_back();
+    LoadSequence.pop_back();
+    LoadSequence.emplace_back(PreLast.Offset, LoadSize);
+  }
 }
 
 unsigned MemCmpExpansion::getNumBlocks() {
@@ -279,6 +304,7 @@ void MemCmpExpansion::createResultBlock() {
 
 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
                                                        bool NeedsBSwap,
+                                                       Type *BSwapSizeType,
                                                        Type *CmpSizeType,
                                                        unsigned OffsetBytes) {
   // Get the memory source at offset `OffsetBytes`.
@@ -307,16 +333,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
   if (!Rhs)
     Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
 
+  // Zero extend if Byte Swap intrinsic has different type
+  if (NeedsBSwap && LoadSizeType != BSwapSizeType) {
+    Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
+    Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
+  }
+
   // Swap bytes if required.
   if (NeedsBSwap) {
-    Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
-                                                Intrinsic::bswap, LoadSizeType);
+    Function *Bswap = Intrinsic::getDeclaration(
+        CI->getModule(), Intrinsic::bswap, BSwapSizeType);
     Lhs = Builder.CreateCall(Bswap, Lhs);
     Rhs = Builder.CreateCall(Bswap, Rhs);
   }
 
   // Zero extend if required.
-  if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
+  if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
     Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
     Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
   }
@@ -333,7 +365,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
   Builder.SetInsertPoint(BB);
   const LoadPair Loads =
       getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
-                  Type::getInt32Ty(CI->getContext()), OffsetBytes);
+                  nullptr, Type::getInt32Ty(CI->getContext()), OffsetBytes);
   Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
 
   PhiRes->addIncoming(Diff, BB);
@@ -385,11 +417,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
   IntegerType *const MaxLoadType =
       NumLoads == 1 ? nullptr
                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
+
   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
     const LoadPair Loads = getLoadPair(
         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
-        /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
+        /*NeedsBSwap=*/false, nullptr, MaxLoadType, CurLoadEntry.Offset);
 
     if (NumLoads != 1) {
       // If we have multiple loads per block, we need to generate a composite
@@ -475,14 +508,18 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
 
   Type *LoadSizeType =
       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
-  Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
+  Type *BSwapSizeType = IntegerType::get(
+      CI->getContext(), PowerOf2Ceil(CurLoadEntry.LoadSize * 8));
+  Type *MaxLoadType = IntegerType::get(
+      CI->getContext(),
+      std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
 
   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
 
   const LoadPair Loads =
-      getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
-                  CurLoadEntry.Offset);
+      getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(),
+                  BSwapSizeType, MaxLoadType, CurLoadEntry.Offset);
 
   // Add the loaded values to the phi nodes for calculating memcmp result only
   // if result is not used in a zero equality.
@@ -588,19 +625,26 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
 /// the compare, branch, and phi IR that is required in the general case.
 Value *MemCmpExpansion::getMemCmpOneBlock() {
   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
+  Type *BSwapSizeType =
+      IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8));
+  Type *MaxLoadType =
+      IntegerType::get(CI->getContext(),
+                       std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);
+
   bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
 
   // The i8 and i16 cases don't need compares. We zext the loaded values and
   // subtract them to get the suitable negative, zero, or positive i32 result.
   if (Size < 4) {
-    const LoadPair Loads =
-        getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
-                    /*Offset*/ 0);
+    const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType,
+                                       Builder.getInt32Ty(),
+                                       /*Offset*/ 0);
     return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
   }
 
-  const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
-                                     /*Offset*/ 0);
+  const LoadPair Loads =
+      getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType, MaxLoadType,
+                  /*Offset*/ 0);
   // The result of memcmp is negative, zero, or positive, so produce that by
   // subtracting 2 extended compare bits: sub (ugt, ult).
   // If a target prefers to use selects to get -1/0/1, they should be able
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d8a0e68d7123759..388034dd76ca4cb 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2938,6 +2938,7 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
   // they may wake up the FP unit, which raises the power consumption.  Perhaps
   // they could be used with no holds barred (-O3).
   Options.LoadSizes = {8, 4, 2, 1};
+  Options.AllowedTailExpansions = {3, 5, 6};
   return Options;
 }
 
diff --git a/llvm/test/CodeGen/AArch64/memcmp.ll b/llvm/test/CodeGen/AArch64/memcmp.ll
new file mode 100644
index 000000000000000..b38acbae1091536
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/memcmp.ll
@@ -0,0 +1,3002 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc < %s -mtriple=aarch64-unknown-unknown | FileCheck %s
+
+@.str = private constant [513 x i8] c"01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901\00", align 1
+
+declare dso_local i32 @memcmp(ptr, ptr, i64)
+
+define i32 @length0(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w0, wzr
+; CHECK-NEXT:    ret
+   %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 0) nounwind
+   ret i32 %m
+ }
+
+define i1 @length0_eq(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length0_eq:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w0, #1 // =0x1
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 0) nounwind
+  %c = icmp eq i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length0_lt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length0_lt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w0, wzr
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 0) nounwind
+  %c = icmp slt i32 %m, 0
+  ret i1 %c
+}
+
+define i32 @length2(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    ldrh w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    sub w0, w8, w9, lsr #16
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 2) nounwind
+  ret i32 %m
+}
+
+define i32 @length2_const(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_const:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w9, [x0]
+; CHECK-NEXT:    mov w8, #-12594 // =0xffffcece
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    add w0, w8, w9, lsr #16
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr getelementptr inbounds ([513 x i8], ptr @.str, i32 0, i32 1), i64 2) nounwind
+  ret i32 %m
+}
+
+define i1 @length2_gt_const(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_gt_const:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w9, [x0]
+; CHECK-NEXT:    mov w8, #-12594 // =0xffffcece
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    add w8, w8, w9, lsr #16
+; CHECK-NEXT:    cmp w8, #0
+; CHECK-NEXT:    cset w0, gt
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr getelementptr inbounds ([513 x i8], ptr @.str, i32 0, i32 1), i64 2) nounwind
+  %c = icmp sgt i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length2_eq(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_eq:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    ldrh w9, [x1]
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 2) nounwind
+  %c = icmp eq i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length2_lt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_lt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    ldrh w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    sub w8, w8, w9, lsr #16
+; CHECK-NEXT:    lsr w0, w8, #31
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 2) nounwind
+  %c = icmp slt i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length2_gt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_gt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    ldrh w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    lsr w8, w8, #16
+; CHECK-NEXT:    sub w8, w8, w9, lsr #16
+; CHECK-NEXT:    cmp w8, #0
+; CHECK-NEXT:    cset w0, gt
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 2) nounwind
+  %c = icmp sgt i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length2_eq_const(ptr %X) nounwind {
+; CHECK-LABEL: length2_eq_const:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    mov w9, #12849 // =0x3231
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr getelementptr inbounds ([513 x i8], ptr @.str, i32 0, i32 1), i64 2) nounwind
+  %c = icmp ne i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length2_eq_nobuiltin_attr(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length2_eq_nobuiltin_attr:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    mov w2, #2 // =0x2
+; CHECK-NEXT:    bl memcmp
+; CHECK-NEXT:    cmp w0, #0
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 2) nounwind nobuiltin
+  %c = icmp eq i32 %m, 0
+  ret i1 %c
+}
+
+define i32 @length3(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrb w8, [x0, #2]
+; CHECK-NEXT:    ldrh w9, [x0]
+; CHECK-NEXT:    ldrb w10, [x1, #2]
+; CHECK-NEXT:    ldrh w11, [x1]
+; CHECK-NEXT:    orr w8, w9, w8, lsl #16
+; CHECK-NEXT:    orr w9, w11, w10, lsl #16
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    sub w0, w8, w9
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 3) nounwind
+  ret i32 %m
+}
+
+define i1 @length3_eq(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length3_eq:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0]
+; CHECK-NEXT:    ldrh w9, [x1]
+; CHECK-NEXT:    ldrb w10, [x0, #2]
+; CHECK-NEXT:    ldrb w11, [x1, #2]
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    ccmp w10, w11, #0, eq
+; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 3) nounwind
+  %c = icmp ne i32 %m, 0
+  ret i1 %c
+}
+
+define i32 @length4(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    ldr w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w0, w8, w9
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 4) nounwind
+  ret i32 %m
+}
+
+define i1 @length4_eq(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length4_eq:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    ldr w9, [x1]
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 4) nounwind
+  %c = icmp ne i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length4_lt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length4_lt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    ldr w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w8, w8, w9
+; CHECK-NEXT:    lsr w0, w8, #31
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 4) nounwind
+  %c = icmp slt i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length4_gt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length4_gt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    ldr w9, [x1]
+; CHECK-NEXT:    rev w8, w8
+; CHECK-NEXT:    rev w9, w9
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w8, w8, w9
+; CHECK-NEXT:    cmp w8, #0
+; CHECK-NEXT:    cset w0, gt
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 4) nounwind
+  %c = icmp sgt i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length4_eq_const(ptr %X) nounwind {
+; CHECK-LABEL: length4_eq_const:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    mov w9, #12849 // =0x3231
+; CHECK-NEXT:    movk w9, #13363, lsl #16
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr getelementptr inbounds ([513 x i8], ptr @.str, i32 0, i32 1), i64 4) nounwind
+  %c = icmp eq i32 %m, 0
+  ret i1 %c
+}
+
+define i32 @length5(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length5:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrb w8, [x0, #4]
+; CHECK-NEXT:    ldr w9, [x0]
+; CHECK-NEXT:    ldrb w10, [x1, #4]
+; CHECK-NEXT:    ldr w11, [x1]
+; CHECK-NEXT:    orr x8, x9, x8, lsl #32
+; CHECK-NEXT:    orr x9, x11, x10, lsl #32
+; CHECK-NEXT:    rev x8, x8
+; CHECK-NEXT:    rev x9, x9
+; CHECK-NEXT:    cmp x8, x9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w0, w8, w9
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 5) nounwind
+  ret i32 %m
+}
+
+define i1 @length5_eq(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length5_eq:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:    ldr w9, [x1]
+; CHECK-NEXT:    ldrb w10, [x0, #4]
+; CHECK-NEXT:    ldrb w11, [x1, #4]
+; CHECK-NEXT:    cmp w8, w9
+; CHECK-NEXT:    ccmp w10, w11, #0, eq
+; CHECK-NEXT:    cset w0, ne
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 5) nounwind
+  %c = icmp ne i32 %m, 0
+  ret i1 %c
+}
+
+define i1 @length5_lt(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length5_lt:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrb w8, [x0, #4]
+; CHECK-NEXT:    ldr w9, [x0]
+; CHECK-NEXT:    ldrb w10, [x1, #4]
+; CHECK-NEXT:    ldr w11, [x1]
+; CHECK-NEXT:    orr x8, x9, x8, lsl #32
+; CHECK-NEXT:    orr x9, x11, x10, lsl #32
+; CHECK-NEXT:    rev x8, x8
+; CHECK-NEXT:    rev x9, x9
+; CHECK-NEXT:    cmp x8, x9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w8, w8, w9
+; CHECK-NEXT:    lsr w0, w8, #31
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 5) nounwind
+  %c = icmp slt i32 %m, 0
+  ret i1 %c
+}
+
+define i32 @length6(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length6:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldrh w8, [x0, #4]
+; CHECK-NEXT:    ldr w9, [x0]
+; CHECK-NEXT:    ldrh w10, [x1, #4]
+; CHECK-NEXT:    ldr w11, [x1]
+; CHECK-NEXT:    orr x8, x9, x8, lsl #32
+; CHECK-NEXT:    orr x9, x11, x10, lsl #32
+; CHECK-NEXT:    rev x8, x8
+; CHECK-NEXT:    rev x9, x9
+; CHECK-NEXT:    cmp x8, x9
+; CHECK-NEXT:    cset w8, hi
+; CHECK-NEXT:    cset w9, lo
+; CHECK-NEXT:    sub w0, w8, w9
+; CHECK-NEXT:    ret
+  %m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 6) nounwind
+  ret i32 %m
+}
+
+define i32 @length7(ptr %X, ptr %Y) nounwind {
+; CHECK-LABEL: length7:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr w8, [x0]
+; CHECK-NEXT:   ...
[truncated]

llvm/lib/CodeGen/ExpandMemCmp.cpp Show resolved Hide resolved
llvm/lib/CodeGen/ExpandMemCmp.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/ExpandMemCmp.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/AArch64/memcmp.ll Outdated Show resolved Hide resolved
@igogo-x86 igogo-x86 force-pushed the expand-memcpy-branchless-tail branch from f85ef70 to e3dd630 Compare October 25, 2023 18:09
…d sizes handling

* Enhanced the logic of ExpandMemCmp pass to merge contiguous subsequences
  in LoadSequence, based on sizes allowed in `AllowedTailExpansions`.
* This enhancement seeks to minimize the number of basic blocks and produce
  optimized code when using memcmp with non-register aligned sizes.
* Enable this feature for AArch64 with memcmp sizes modulo 8 equal to
  3, 5, and 6.
@igogo-x86 igogo-x86 force-pushed the expand-memcpy-branchless-tail branch from e3dd630 to 7686677 Compare October 27, 2023 10:36
@igogo-x86 igogo-x86 merged commit 9bcb30d into llvm:main Oct 27, 2023
2 of 3 checks passed
@DavidSpickett
Copy link
Collaborator

Linaro is seeing a test suite failure on this change: https://lab.llvm.org/buildbot/#/builders/197/builds/10630

This bot is AArch64 SVE vector length agnostic, however, the change hasn't hit the other bots yet. If we're lucky it's a general issue and not SVE specific.

@DavidSpickett
Copy link
Collaborator

https://lab.llvm.org/buildbot/#/builders/185/builds/5239 is the plain AArch64 build, still in stage 1 at the moment.

igogo-x86 added a commit that referenced this pull request Oct 27, 2023
@igogo-x86
Copy link
Contributor Author

@DavidSpickett, yes, I reverted the commit. mecmp(..., ..., 3) on certain data was not correct. Going to investigate how to handle it properly

igogo-x86 added a commit that referenced this pull request Oct 30, 2023
…d sizes handling (#70469)

* Enhanced the logic of ExpandMemCmp pass to merge contiguous
subsequences
  in LoadSequence, based on sizes allowed in `AllowedTailExpansions`.
* This enhancement seeks to minimize the number of basic blocks and
produce
  optimized code when using memcmp with non-register aligned sizes.
* Enable this feature for AArch64 with memcmp sizes modulo 8 equal to
  3, 5, and 6.

Reapplication of #69942 after fixing a bug
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants