Skip to content

Commit

Permalink
[ExpandMemCmp] Properly constant-fold all compares.
Browse files Browse the repository at this point in the history
Summary:
This gets rid of duplicated code and diverging behaviour w.r.t.
constants.
Fixes PR45086.

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75519
  • Loading branch information
legrosbuffle committed Mar 9, 2020
1 parent ee4dc98 commit f7e6f5f
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 245 deletions.
172 changes: 74 additions & 98 deletions llvm/lib/CodeGen/ExpandMemCmp.cpp
Expand Up @@ -103,8 +103,12 @@ class MemCmpExpansion {
Value *getMemCmpExpansionZeroCase();
Value *getMemCmpEqZeroOneBlock();
Value *getMemCmpOneBlock();
Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
uint64_t OffsetBytes);
struct LoadPair {
Value *Lhs = nullptr;
Value *Rhs = nullptr;
};
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
unsigned OffsetBytes);

static LoadEntryVector
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
Expand Down Expand Up @@ -261,18 +265,52 @@ void MemCmpExpansion::createResultBlock() {
EndBlock->getParent(), EndBlock);
}

/// Return a pointer to an element of type `LoadSizeType` at offset
/// `OffsetBytes`.
Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
Type *LoadSizeType,
uint64_t OffsetBytes) {
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
bool NeedsBSwap,
Type *CmpSizeType,
unsigned OffsetBytes) {
// Get the memory source at offset `OffsetBytes`.
Value *LhsSource = CI->getArgOperand(0);
Value *RhsSource = CI->getArgOperand(1);
if (OffsetBytes > 0) {
auto *ByteType = Type::getInt8Ty(CI->getContext());
Source = Builder.CreateConstGEP1_64(
ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
LhsSource = Builder.CreateConstGEP1_64(
ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
OffsetBytes);
RhsSource = Builder.CreateConstGEP1_64(
ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
OffsetBytes);
}
return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo());
RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo());

// Create a constant or a load from the source.
Value *Lhs = nullptr;
if (auto *C = dyn_cast<Constant>(LhsSource))
Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
if (!Lhs)
Lhs = Builder.CreateLoad(LoadSizeType, LhsSource);

Value *Rhs = nullptr;
if (auto *C = dyn_cast<Constant>(RhsSource))
Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL);
if (!Rhs)
Rhs = Builder.CreateLoad(LoadSizeType, RhsSource);

// Swap bytes if required.
if (NeedsBSwap) {
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
Intrinsic::bswap, LoadSizeType);
Lhs = Builder.CreateCall(Bswap, Lhs);
Rhs = Builder.CreateCall(Bswap, Rhs);
}

// Zero extend if required.
if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
}
return {Lhs, Rhs};
}

// This function creates the IR instructions for loading and comparing 1 byte.
Expand All @@ -282,18 +320,10 @@ Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
unsigned OffsetBytes) {
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
Value *Source1 =
getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
Value *Source2 =
getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);

Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
const LoadPair Loads =
getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
Type::getInt32Ty(CI->getContext()), OffsetBytes);
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);

PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);

Expand Down Expand Up @@ -340,41 +370,19 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];

IntegerType *LoadSizeType =
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);

Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
CurLoadEntry.Offset);
Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
CurLoadEntry.Offset);

// Get a constant or load a value for each source address.
Value *LoadSrc1 = nullptr;
if (auto *Source1C = dyn_cast<Constant>(Source1))
LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
if (!LoadSrc1)
LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);

Value *LoadSrc2 = nullptr;
if (auto *Source2C = dyn_cast<Constant>(Source2))
LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
if (!LoadSrc2)
LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
const LoadPair Loads = getLoadPair(
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
/*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);

if (NumLoads != 1) {
if (LoadSizeType != MaxLoadType) {
LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
}
// If we have multiple loads per block, we need to generate a composite
// comparison using xor+or.
Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs);
Diff = Builder.CreateZExt(Diff, MaxLoadType);
XorList.push_back(Diff);
} else {
// If there's only one load per block, we just compare the loaded values.
Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
}
}

Expand Down Expand Up @@ -451,35 +459,18 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {

Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);

Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
CurLoadEntry.Offset);
Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
CurLoadEntry.Offset);

// Load LoadSizeType from the base address.
Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

if (DL.isLittleEndian()) {
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
Intrinsic::bswap, LoadSizeType);
LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
}

if (LoadSizeType != MaxLoadType) {
LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
}
const LoadPair Loads =
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
CurLoadEntry.Offset);

// Add the loaded values to the phi nodes for calculating memcmp result only
// if result is not used in a zero equality.
if (!IsUsedForZeroCmp) {
ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
}

Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
? EndBlock
: LoadCmpBlocks[BlockIndex + 1];
Expand Down Expand Up @@ -568,42 +559,27 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
/// the compare, branch, and phi IR that is required in the general case.
Value *MemCmpExpansion::getMemCmpOneBlock() {
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
Value *Source1 = CI->getArgOperand(0);
Value *Source2 = CI->getArgOperand(1);

// Cast source to LoadSizeType*.
if (Source1->getType() != LoadSizeType)
Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
if (Source2->getType() != LoadSizeType)
Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());

// Load LoadSizeType from the base address.
Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);

if (DL.isLittleEndian() && Size != 1) {
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
Intrinsic::bswap, LoadSizeType);
LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
}
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;

// The i8 and i16 cases don't need compares. We zext the loaded values and
// subtract them to get the suitable negative, zero, or positive i32 result.
if (Size < 4) {
// The i8 and i16 cases don't need compares. We zext the loaded values and
// subtract them to get the suitable negative, zero, or positive i32 result.
LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
return Builder.CreateSub(LoadSrc1, LoadSrc2);
const LoadPair Loads =
getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
/*Offset*/ 0);
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
}

const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
/*Offset*/ 0);
// The result of memcmp is negative, zero, or positive, so produce that by
// subtracting 2 extended compare bits: sub (ugt, ult).
// If a target prefers to use selects to get -1/0/1, they should be able
// to transform this later. The inverse transform (going from selects to math)
// may not be possible in the DAG because the selects got converted into
// branches before we got there.
Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
return Builder.CreateSub(ZextUGT, ZextULT);
Expand Down
54 changes: 21 additions & 33 deletions llvm/test/CodeGen/PowerPC/memCmpUsedInZeroEqualityComparison.ll
Expand Up @@ -90,27 +90,23 @@ define signext i32 @zeroEqualityTest03(i8* %x, i8* %y) {
define signext i32 @zeroEqualityTest04() {
; CHECK-LABEL: zeroEqualityTest04:
; CHECK: # %bb.0:
; CHECK-NEXT: addis 3, 2, .LzeroEqualityTest02.buffer1@toc@ha
; CHECK-NEXT: addis 4, 2, .LzeroEqualityTest02.buffer2@toc@ha
; CHECK-NEXT: addi 6, 3, .LzeroEqualityTest02.buffer1@toc@l
; CHECK-NEXT: addi 5, 4, .LzeroEqualityTest02.buffer2@toc@l
; CHECK-NEXT: ldbrx 3, 0, 6
; CHECK-NEXT: ldbrx 4, 0, 5
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: bne 0, .LBB3_2
; CHECK-NEXT: b .LBB3_2
; CHECK-NEXT: # %bb.1: # %loadbb1
; CHECK-NEXT: li 4, 8
; CHECK-NEXT: ldbrx 3, 6, 4
; CHECK-NEXT: ldbrx 4, 5, 4
; CHECK-NEXT: li 3, 0
; CHECK-NEXT: li 5, 0
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: beq 0, .LBB3_3
; CHECK-NEXT: .LBB3_2: # %res_block
; CHECK-NEXT: li 4, 0
; CHECK-NEXT: b .LBB3_4
; CHECK-NEXT: .LBB3_2:
; CHECK-NEXT: li 3, 1
; CHECK-NEXT: li 4, 3
; CHECK-NEXT: sldi 3, 3, 58
; CHECK-NEXT: sldi 4, 4, 56
; CHECK-NEXT: # %bb.3: # %res_block
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: li 3, 1
; CHECK-NEXT: li 4, -1
; CHECK-NEXT: isel 5, 4, 3, 0
; CHECK-NEXT: .LBB3_3: # %endblock
; CHECK-NEXT: .LBB3_4: # %endblock
; CHECK-NEXT: extsw 3, 5
; CHECK-NEXT: neg 3, 3
; CHECK-NEXT: rldicl 3, 3, 1, 63
Expand All @@ -126,28 +122,20 @@ define signext i32 @zeroEqualityTest04() {
define signext i32 @zeroEqualityTest05() {
; CHECK-LABEL: zeroEqualityTest05:
; CHECK: # %bb.0:
; CHECK-NEXT: addis 3, 2, .LzeroEqualityTest03.buffer1@toc@ha
; CHECK-NEXT: addis 4, 2, .LzeroEqualityTest03.buffer2@toc@ha
; CHECK-NEXT: addi 6, 3, .LzeroEqualityTest03.buffer1@toc@l
; CHECK-NEXT: addi 5, 4, .LzeroEqualityTest03.buffer2@toc@l
; CHECK-NEXT: ldbrx 3, 0, 6
; CHECK-NEXT: ldbrx 4, 0, 5
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: bne 0, .LBB4_2
; CHECK-NEXT: li 3, 0
; CHECK-NEXT: li 4, 0
; CHECK-NEXT: # %bb.1: # %loadbb1
; CHECK-NEXT: li 4, 8
; CHECK-NEXT: ldbrx 3, 6, 4
; CHECK-NEXT: ldbrx 4, 5, 4
; CHECK-NEXT: li 5, 0
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: beq 0, .LBB4_3
; CHECK-NEXT: .LBB4_2: # %res_block
; CHECK-NEXT: li 3, 0
; CHECK-NEXT: # %bb.2:
; CHECK-NEXT: lis 3, 768
; CHECK-NEXT: lis 4, 1024
; CHECK-NEXT: # %bb.3: # %res_block
; CHECK-NEXT: cmpld 3, 4
; CHECK-NEXT: li 3, 1
; CHECK-NEXT: li 4, -1
; CHECK-NEXT: isel 5, 4, 3, 0
; CHECK-NEXT: .LBB4_3: # %endblock
; CHECK-NEXT: nor 3, 5, 5
; CHECK-NEXT: isel 3, 4, 3, 0
; CHECK-NEXT: # %bb.4: # %endblock
; CHECK-NEXT: nor 3, 3, 3
; CHECK-NEXT: rlwinm 3, 3, 1, 31, 31
; CHECK-NEXT: blr
%call = tail call signext i32 @memcmp(i8* bitcast ([4 x i32]* @zeroEqualityTest03.buffer1 to i8*), i8* bitcast ([4 x i32]* @zeroEqualityTest03.buffer2 to i8*), i64 16)
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/PowerPC/memcmpIR.ll
Expand Up @@ -20,8 +20,8 @@ entry:
; CHECK: [[BCC1:%[0-9]+]] = bitcast i32* {{.*}} to i8*
; CHECK-NEXT: [[BCC2:%[0-9]+]] = bitcast i32* {{.*}} to i8*
; CHECK-NEXT: [[GEP1:%[0-9]+]] = getelementptr i8, i8* [[BCC2]], i64 8
; CHECK-NEXT: [[BCL1:%[0-9]+]] = bitcast i8* [[GEP1]] to i64*
; CHECK-NEXT: [[GEP2:%[0-9]+]] = getelementptr i8, i8* [[BCC1]], i64 8
; CHECK-NEXT: [[BCL1:%[0-9]+]] = bitcast i8* [[GEP1]] to i64*
; CHECK-NEXT: [[BCL2:%[0-9]+]] = bitcast i8* [[GEP2]] to i64*
; CHECK-NEXT: [[LOAD1:%[0-9]+]] = load i64, i64* [[BCL1]]
; CHECK-NEXT: [[LOAD2:%[0-9]+]] = load i64, i64* [[BCL2]]
Expand All @@ -45,8 +45,8 @@ entry:
; CHECK-BE: [[BCC1:%[0-9]+]] = bitcast i32* {{.*}} to i8*
; CHECK-BE-NEXT: [[BCC2:%[0-9]+]] = bitcast i32* {{.*}} to i8*
; CHECK-BE-NEXT: [[GEP1:%[0-9]+]] = getelementptr i8, i8* [[BCC2]], i64 8
; CHECK-BE-NEXT: [[BCL1:%[0-9]+]] = bitcast i8* [[GEP1]] to i64*
; CHECK-BE-NEXT: [[GEP2:%[0-9]+]] = getelementptr i8, i8* [[BCC1]], i64 8
; CHECK-BE-NEXT: [[BCL1:%[0-9]+]] = bitcast i8* [[GEP1]] to i64*
; CHECK-BE-NEXT: [[BCL2:%[0-9]+]] = bitcast i8* [[GEP2]] to i64*
; CHECK-BE-NEXT: [[LOAD1:%[0-9]+]] = load i64, i64* [[BCL1]]
; CHECK-BE-NEXT: [[LOAD2:%[0-9]+]] = load i64, i64* [[BCL2]]
Expand Down
20 changes: 4 additions & 16 deletions llvm/test/CodeGen/X86/memcmp.ll
Expand Up @@ -98,23 +98,17 @@ define i32 @length2_const(i8* %X, i8* %Y) nounwind {
; X86: # %bb.0:
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-NEXT: movzwl (%eax), %eax
; X86-NEXT: movzwl .L.str+1, %ecx
; X86-NEXT: rolw $8, %ax
; X86-NEXT: rolw $8, %cx
; X86-NEXT: movzwl %ax, %eax
; X86-NEXT: movzwl %cx, %ecx
; X86-NEXT: subl %ecx, %eax
; X86-NEXT: addl $-12594, %eax # imm = 0xCECE
; X86-NEXT: retl
;
; X64-LABEL: length2_const:
; X64: # %bb.0:
; X64-NEXT: movzwl (%rdi), %eax
; X64-NEXT: movzwl .L.str+{{.*}}(%rip), %ecx
; X64-NEXT: rolw $8, %ax
; X64-NEXT: rolw $8, %cx
; X64-NEXT: movzwl %ax, %eax
; X64-NEXT: movzwl %cx, %ecx
; X64-NEXT: subl %ecx, %eax
; X64-NEXT: addl $-12594, %eax # imm = 0xCECE
; X64-NEXT: retq
%m = tail call i32 @memcmp(i8* %X, i8* getelementptr inbounds ([513 x i8], [513 x i8]* @.str, i32 0, i32 1), i64 2) nounwind
ret i32 %m
Expand All @@ -125,25 +119,19 @@ define i1 @length2_gt_const(i8* %X, i8* %Y) nounwind {
; X86: # %bb.0:
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-NEXT: movzwl (%eax), %eax
; X86-NEXT: movzwl .L.str+1, %ecx
; X86-NEXT: rolw $8, %ax
; X86-NEXT: rolw $8, %cx
; X86-NEXT: movzwl %ax, %eax
; X86-NEXT: movzwl %cx, %ecx
; X86-NEXT: subl %ecx, %eax
; X86-NEXT: addl $-12594, %eax # imm = 0xCECE
; X86-NEXT: testl %eax, %eax
; X86-NEXT: setg %al
; X86-NEXT: retl
;
; X64-LABEL: length2_gt_const:
; X64: # %bb.0:
; X64-NEXT: movzwl (%rdi), %eax
; X64-NEXT: movzwl .L.str+{{.*}}(%rip), %ecx
; X64-NEXT: rolw $8, %ax
; X64-NEXT: rolw $8, %cx
; X64-NEXT: movzwl %ax, %eax
; X64-NEXT: movzwl %cx, %ecx
; X64-NEXT: subl %ecx, %eax
; X64-NEXT: addl $-12594, %eax # imm = 0xCECE
; X64-NEXT: testl %eax, %eax
; X64-NEXT: setg %al
; X64-NEXT: retq
Expand Down

0 comments on commit f7e6f5f

Please sign in to comment.