Skip to content

Commit

Permalink
[InstCombine] PR35354: Convert store(bitcast, load bitcast (select (C…
Browse files Browse the repository at this point in the history
…ond, &V1, &V2)) --> store (, load (select(Cond, load &V1, load &V2)))

Summary:
If we have the code like this:
```
float a, b;
a = std::max(a ,b);
```
it is converted into something like this:
```
%call = call dereferenceable(4) float* @_ZSt3maxIfERKT_S2_S2_(float* nonnull dereferenceable(4) %a.addr, float* nonnull dereferenceable(4) %b.addr)
%1 = bitcast float* %call to i32*
%2 = load i32, i32* %1, align 4
%3 = bitcast float* %a.addr to i32*
store i32 %2, i32* %3, align 4
```
After inlinning this code is converted to the next:
```
%1 = load float, float* %a.addr
%2 = load float, float* %b.addr
%cmp.i = fcmp fast olt float %1, %2
%__b.__a.i = select i1 %cmp.i, float* %a.addr, float* %b.addr
%3 = bitcast float* %__b.__a.i to i32*
%4 = load i32, i32* %3, align 4
%5 = bitcast float* %arrayidx to i32*
store i32 %4, i32* %5, align 4

```
This pattern is not recognized as minmax pattern.
Patch solves this problem by converting sequence
```
store (bitcast, (load bitcast (select ((cmp V1, V2), &V1, &V2))))
```
to a sequence
```
store (,load (select((cmp V1, V2), &V1, &V2)))
```
After this the code is recognized as minmax pattern.

Reviewers: RKSimon, spatel

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40304

llvm-svn: 320157
  • Loading branch information
alexey-bataev committed Dec 8, 2017
1 parent 83708ca commit ec95c6c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 11 deletions.
57 changes: 56 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
Expand Up @@ -22,9 +22,11 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "instcombine"

Expand Down Expand Up @@ -561,6 +563,28 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value
return NewStore;
}

/// Returns true if instruction represent minmax pattern like:
/// select ((cmp load V1, load V2), V1, V2).
static bool isMinMaxWithLoads(Value *V) {
assert(V->getType()->isPointerTy() && "Expected pointer type.");
// Ignore possible ty* to ixx* bitcast.
V = peekThroughBitcast(V);
// Check that select is select ((cmp load V1, load V2), V1, V2) - minmax
// pattern.
CmpInst::Predicate Pred;
Instruction *L1;
Instruction *L2;
Value *LHS;
Value *RHS;
if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)),
m_Value(LHS), m_Value(RHS))))
return false;
return (match(L1, m_Load(m_Specific(LHS))) &&
match(L2, m_Load(m_Specific(RHS)))) ||
(match(L1, m_Load(m_Specific(RHS))) &&
match(L2, m_Load(m_Specific(LHS))));
}

/// \brief Combine loads to match the type of their uses' value after looking
/// through intervening bitcasts.
///
Expand Down Expand Up @@ -598,10 +622,14 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
// integers instead of any other type. We only do this when the loaded type
// is sized and has a size exactly the same as its store size and the store
// size is a legal integer type.
// Do not perform canonicalization if minmax pattern is found (to avoid
// infinite loop).
if (!Ty->isIntegerTy() && Ty->isSized() &&
DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) &&
DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) &&
!DL.isNonIntegralPointerType(Ty)) {
!DL.isNonIntegralPointerType(Ty) &&
!isMinMaxWithLoads(
peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) {
if (all_of(LI.users(), [&LI](User *U) {
auto *SI = dyn_cast<StoreInst>(U);
return SI && SI->getPointerOperand() != &LI &&
Expand Down Expand Up @@ -1298,6 +1326,30 @@ static bool equivalentAddressValues(Value *A, Value *B) {
return false;
}

/// Converts store (bitcast (load (bitcast (select ...)))) to
/// store (load (select ...)), where select is minmax:
/// select ((cmp load V1, load V2), V1, V2).
bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, StoreInst &SI) {
// bitcast?
Value *StoreAddr;
if (!match(SI.getPointerOperand(), m_BitCast(m_Value(StoreAddr))))
return false;
// load? integer?
Value *LoadAddr;
if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr)))))
return false;
auto *LI = cast<LoadInst>(SI.getValueOperand());
if (!LI->getType()->isIntegerTy())
return false;
if (!isMinMaxWithLoads(LoadAddr))
return false;

LoadInst *NewLI = combineLoadToNewType(
IC, *LI, LoadAddr->getType()->getPointerElementType());
combineStoreToNewValue(IC, SI, NewLI);
return true;
}

Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
Value *Val = SI.getOperand(0);
Value *Ptr = SI.getOperand(1);
Expand All @@ -1322,6 +1374,9 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
if (unpackStoreToAggregate(*this, SI))
return eraseInstFromFunction(SI);

if (removeBitcastsFromLoadStoreOnMinMax(*this, SI))
return eraseInstFromFunction(SI);

// Replace GEP indices if possible.
if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) {
Worklist.Add(NewGEPI);
Expand Down
14 changes: 4 additions & 10 deletions llvm/test/Transforms/InstCombine/load-bitcast-select.ll
Expand Up @@ -21,11 +21,8 @@ define void @_Z3foov() {
; CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARRAYIDX]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load float, float* [[ARRAYIDX2]], align 4
; CHECK-NEXT: [[CMP_I:%.*]] = fcmp fast olt float [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[__B___A_I:%.*]] = select i1 [[CMP_I]], float* [[ARRAYIDX2]], float* [[ARRAYIDX]]
; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[__B___A_I]] to i32*
; CHECK-NEXT: [[TMP4:%.*]] = load i32, i32* [[TMP3]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = bitcast float* [[ARRAYIDX]] to i32*
; CHECK-NEXT: store i32 [[TMP4]], i32* [[TMP5]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[CMP_I]], float [[TMP2]], float [[TMP1]]
; CHECK-NEXT: store float [[TMP3]], float* [[ARRAYIDX]], align 4
; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I_0]], 1
; CHECK-NEXT: br label [[FOR_COND]]
;
Expand Down Expand Up @@ -91,11 +88,8 @@ define void @bitcasted_minmax_with_select_of_pointers(float* %loadaddr1, float*
; CHECK-NEXT: [[LD1:%.*]] = load float, float* [[LOADADDR1:%.*]], align 4
; CHECK-NEXT: [[LD2:%.*]] = load float, float* [[LOADADDR2:%.*]], align 4
; CHECK-NEXT: [[COND:%.*]] = fcmp ogt float [[LD1]], [[LD2]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float* [[LOADADDR1]], float* [[LOADADDR2]]
; CHECK-NEXT: [[INT_LOAD_ADDR:%.*]] = bitcast float* [[SEL]] to i32*
; CHECK-NEXT: [[LD:%.*]] = load i32, i32* [[INT_LOAD_ADDR]], align 4
; CHECK-NEXT: [[INT_STORE_ADDR:%.*]] = bitcast float* [[STOREADDR:%.*]] to i32*
; CHECK-NEXT: store i32 [[LD]], i32* [[INT_STORE_ADDR]], align 4
; CHECK-NEXT: [[LD3:%.*]] = select i1 [[COND]], float [[LD1]], float [[LD2]]
; CHECK-NEXT: store float [[LD3]], float* [[STOREADDR:%.*]], align 4
; CHECK-NEXT: ret void
;
%ld1 = load float, float* %loadaddr1, align 4
Expand Down

0 comments on commit ec95c6c

Please sign in to comment.