259 changes: 259 additions & 0 deletions llvm/lib/Analysis/ScalarEvolutionDivision.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the class that knows how to divide SCEV's.
//
//===----------------------------------------------------------------------===//

#include "llvm/Analysis/ScalarEvolutionDivision.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <cassert>
#include <cstdint>

namespace llvm {
class Type;
}

using namespace llvm;

namespace {

static inline int sizeOfSCEV(const SCEV *S) {
struct FindSCEVSize {
int Size = 0;

FindSCEVSize() = default;

bool follow(const SCEV *S) {
++Size;
// Keep looking at all operands of S.
return true;
}

bool isDone() const { return false; }
};

FindSCEVSize F;
SCEVTraversal<FindSCEVSize> ST(F);
ST.visitAll(S);
return F.Size;
}

} // namespace

// Computes the Quotient and Remainder of the division of Numerator by
// Denominator.
void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
const SCEV *Denominator, const SCEV **Quotient,
const SCEV **Remainder) {
assert(Numerator && Denominator && "Uninitialized SCEV");

SCEVDivision D(SE, Numerator, Denominator);

// Check for the trivial case here to avoid having to check for it in the
// rest of the code.
if (Numerator == Denominator) {
*Quotient = D.One;
*Remainder = D.Zero;
return;
}

if (Numerator->isZero()) {
*Quotient = D.Zero;
*Remainder = D.Zero;
return;
}

// A simple case when N/1. The quotient is N.
if (Denominator->isOne()) {
*Quotient = Numerator;
*Remainder = D.Zero;
return;
}

// Split the Denominator when it is a product.
if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
const SCEV *Q, *R;
*Quotient = Numerator;
for (const SCEV *Op : T->operands()) {
divide(SE, *Quotient, Op, &Q, &R);
*Quotient = Q;

// Bail out when the Numerator is not divisible by one of the terms of
// the Denominator.
if (!R->isZero()) {
*Quotient = D.Zero;
*Remainder = Numerator;
return;
}
}
*Remainder = D.Zero;
return;
}

D.visit(Numerator);
*Quotient = D.Quotient;
*Remainder = D.Remainder;
}

void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
APInt NumeratorVal = Numerator->getAPInt();
APInt DenominatorVal = D->getAPInt();
uint32_t NumeratorBW = NumeratorVal.getBitWidth();
uint32_t DenominatorBW = DenominatorVal.getBitWidth();

if (NumeratorBW > DenominatorBW)
DenominatorVal = DenominatorVal.sext(NumeratorBW);
else if (NumeratorBW < DenominatorBW)
NumeratorVal = NumeratorVal.sext(DenominatorBW);

APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
Quotient = SE.getConstant(QuotientVal);
Remainder = SE.getConstant(RemainderVal);
return;
}
}

void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
const SCEV *StartQ, *StartR, *StepQ, *StepR;
if (!Numerator->isAffine())
return cannotDivide(Numerator);
divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
// Bail out if the types do not match.
Type *Ty = Denominator->getType();
if (Ty != StartQ->getType() || Ty != StartR->getType() ||
Ty != StepQ->getType() || Ty != StepR->getType())
return cannotDivide(Numerator);
Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
Numerator->getNoWrapFlags());
Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
Numerator->getNoWrapFlags());
}

void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
SmallVector<const SCEV *, 2> Qs, Rs;
Type *Ty = Denominator->getType();

for (const SCEV *Op : Numerator->operands()) {
const SCEV *Q, *R;
divide(SE, Op, Denominator, &Q, &R);

// Bail out if types do not match.
if (Ty != Q->getType() || Ty != R->getType())
return cannotDivide(Numerator);

Qs.push_back(Q);
Rs.push_back(R);
}

if (Qs.size() == 1) {
Quotient = Qs[0];
Remainder = Rs[0];
return;
}

Quotient = SE.getAddExpr(Qs);
Remainder = SE.getAddExpr(Rs);
}

void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
SmallVector<const SCEV *, 2> Qs;
Type *Ty = Denominator->getType();

bool FoundDenominatorTerm = false;
for (const SCEV *Op : Numerator->operands()) {
// Bail out if types do not match.
if (Ty != Op->getType())
return cannotDivide(Numerator);

if (FoundDenominatorTerm) {
Qs.push_back(Op);
continue;
}

// Check whether Denominator divides one of the product operands.
const SCEV *Q, *R;
divide(SE, Op, Denominator, &Q, &R);
if (!R->isZero()) {
Qs.push_back(Op);
continue;
}

// Bail out if types do not match.
if (Ty != Q->getType())
return cannotDivide(Numerator);

FoundDenominatorTerm = true;
Qs.push_back(Q);
}

if (FoundDenominatorTerm) {
Remainder = Zero;
if (Qs.size() == 1)
Quotient = Qs[0];
else
Quotient = SE.getMulExpr(Qs);
return;
}

if (!isa<SCEVUnknown>(Denominator))
return cannotDivide(Numerator);

// The Remainder is obtained by replacing Denominator by 0 in Numerator.
ValueToValueMap RewriteMap;
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(Zero)->getValue();
Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);

if (Remainder->isZero()) {
// The Quotient is obtained by replacing Denominator by 1 in Numerator.
RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
cast<SCEVConstant>(One)->getValue();
Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
return;
}

// Quotient is (Numerator - Remainder) divided by Denominator.
const SCEV *Q, *R;
const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
// This SCEV does not seem to simplify: fail the division here.
if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
return cannotDivide(Numerator);
divide(SE, Diff, Denominator, &Q, &R);
if (R != Zero)
return cannotDivide(Numerator);
Quotient = Q;
}

SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
const SCEV *Denominator)
: SE(S), Denominator(Denominator) {
Zero = SE.getZero(Denominator->getType());
One = SE.getOne(Denominator->getType());

// We generally do not know how to divide Expr by Denominator. We initialize
// the division to a "cannot divide" state to simplify the rest of the code.
cannotDivide(Numerator);
}

// Convenience function for giving up on the division. We set the quotient to
// be equal to zero and the remainder to be equal to the numerator.
void SCEVDivision::cannotDivide(const SCEV *Numerator) {
Quotient = Zero;
Remainder = Numerator;
}
2 changes: 2 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4147,6 +4147,8 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL,
} else if (Operator::getOpcode(V) == Instruction::BitCast ||
Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
V = cast<Operator>(V)->getOperand(0);
if (!V->getType()->isPointerTy())
return V;
} else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
if (GA->isInterposable())
return V;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ static const Value *stripPointerCastsAndOffsets(
V = GEP->getPointerOperand();
} else if (Operator::getOpcode(V) == Instruction::BitCast) {
V = cast<Operator>(V)->getOperand(0);
if (!V->getType()->isPointerTy())
return V;
} else if (StripKind != PSK_ZeroIndicesSameRepresentation &&
Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
// TODO: If we know an address space cast will not change the
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,8 +2471,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
if (DestTy == Src->getType())
return replaceInstUsesWith(CI, Src);

if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) {
if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) {
PointerType *SrcPTy = cast<PointerType>(SrcTy);
PointerType *DstPTy = cast<PointerType>(DestTy);
Type *DstElTy = DstPTy->getElementType();
Type *SrcElTy = SrcPTy->getElementType();

Expand Down
32 changes: 10 additions & 22 deletions llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,36 +1747,24 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) &&
"Integer types must be the exact same to convert.");

// See if we need inttoptr for this type pair. A cast involving both scalars
// and vectors requires and additional bitcast.
// See if we need inttoptr for this type pair. May require additional bitcast.
if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
// Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8*
if (OldTy->isVectorTy() && !NewTy->isVectorTy())
return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
NewTy);

// Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*>
if (!OldTy->isVectorTy() && NewTy->isVectorTy())
return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
NewTy);

return IRB.CreateIntToPtr(V, NewTy);
// Expand <4 x i32> to <2 x i8*> --> <4 x i32> to <2 x i64> to <2 x i8*>
// Directly handle i64 to i8*
return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)),
NewTy);
}

// See if we need ptrtoint for this type pair. A cast involving both scalars
// and vectors requires and additional bitcast.
// See if we need ptrtoint for this type pair. May require additional bitcast.
if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) {
// Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128
if (OldTy->isVectorTy() && !NewTy->isVectorTy())
return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
NewTy);

// Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32>
if (!OldTy->isVectorTy() && NewTy->isVectorTy())
return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
NewTy);

return IRB.CreatePtrToInt(V, NewTy);
// Expand <2 x i8*> to <4 x i32> --> <2 x i8*> to <2 x i64> to <4 x i32>
// Expand i8* to i64 --> i8* to i64 to i64
return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)),
NewTy);
}

if (OldTy->isPtrOrPtrVectorTy() && NewTy->isPtrOrPtrVectorTy()) {
Expand Down
6 changes: 6 additions & 0 deletions llvm/test/Transforms/InstCombine/bitcast.ll
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,9 @@ define void @constant_fold_vector_to_half() {
store volatile half bitcast (<4 x i4> <i4 0, i4 0, i4 0, i4 4> to half), half* undef
ret void
}

; Ensure that we do not crash when looking at such a weird bitcast.
define i8* @bitcast_from_single_element_pointer_vector_to_pointer(<1 x i8*> %ptrvec) {
%ptr = bitcast <1 x i8*> %ptrvec to i8*
ret i8* %ptr
}
14 changes: 14 additions & 0 deletions llvm/test/Transforms/InstSimplify/icmp.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -instsimplify -S | FileCheck %s

target datalayout = "e-p:64:64:64-p1:16:16:16-p2:32:32:32-p3:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"

declare void @usei8ptr(i8* %ptr)

; Ensure that we do not crash when looking at such a weird bitcast.
define i1 @bitcast_from_single_element_pointer_vector_to_pointer(<1 x i8*> %ptr1vec, i8* %ptr2) {
%ptr1 = bitcast <1 x i8*> %ptr1vec to i8*
call void @usei8ptr(i8* %ptr1)
%cmp = icmp eq i8* %ptr1, %ptr2
ret i1 %cmp
}
40 changes: 39 additions & 1 deletion llvm/test/Transforms/SROA/vector-conversion.ll
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ define <4 x i32*> @vector_inttoptr({<2 x i64>, <2 x i64>} %x) {
}

define <2 x i64> @vector_ptrtointbitcast({<1 x i32*>, <1 x i32*>} %x) {
; CHECK-LABEL: @vector_ptrtointbitcast
; CHECK-LABEL: @vector_ptrtointbitcast(
%a = alloca {<1 x i32*>, <1 x i32*>}
; CHECK-NOT: alloca

Expand All @@ -51,3 +51,41 @@ define <2 x i64> @vector_ptrtointbitcast({<1 x i32*>, <1 x i32*>} %x) {

ret <2 x i64> %vec
}

define <2 x i8*> @vector_inttoptrbitcast_vector({<16 x i8>, <16 x i8>} %x) {
; CHECK-LABEL: @vector_inttoptrbitcast_vector(
%a = alloca {<16 x i8>, <16 x i8>}
; CHECK-NOT: alloca

store {<16 x i8>, <16 x i8>} %x, {<16 x i8>, <16 x i8>}* %a
; CHECK-NOT: store

%cast = bitcast {<16 x i8>, <16 x i8>}* %a to <2 x i8*>*
%vec = load <2 x i8*>, <2 x i8*>* %cast
; CHECK-NOT: load
; CHECK: extractvalue
; CHECK: extractvalue
; CHECK: bitcast
; CHECK: inttoptr

ret <2 x i8*> %vec
}

define <16 x i8> @vector_ptrtointbitcast_vector({<2 x i8*>, <2 x i8*>} %x) {
; CHECK-LABEL: @vector_ptrtointbitcast_vector(
%a = alloca {<2 x i8*>, <2 x i8*>}
; CHECK-NOT: alloca

store {<2 x i8*>, <2 x i8*>} %x, {<2 x i8*>, <2 x i8*>}* %a
; CHECK-NOT: store

%cast = bitcast {<2 x i8*>, <2 x i8*>}* %a to <16 x i8>*
%vec = load <16 x i8>, <16 x i8>* %cast
; CHECK-NOT: load
; CHECK: extractvalue
; CHECK: ptrtoint
; CHECK: bitcast
; CHECK: extractvalue

ret <16 x i8> %vec
}