diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 9a3e8aeb0c23a..4e08043d718f1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1735,6 +1735,28 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { return nullptr; } +/// Return true if the cast from integer to FP can be proven to be exact for all +/// possible inputs (the conversion does lose any precision). +static bool isKnownExactCastIntToFP(CastInst &I) { + CastInst::CastOps Opcode = I.getOpcode(); + assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && + "Unexpected cast"); + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *FPTy = I.getType(); + bool IsSigned = Opcode == Instruction::SIToFP; + int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; + + // Easy case - if the source integer type has less bits than the FP mantissa, + // then the cast must be exact. + if (SrcSize <= FPTy->getFPMantissaWidth()) + return true; + + // TODO: + // Try harder to find if the source integer type has less significant bits. + return false; +} + Instruction *InstCombiner::visitFPExt(CastInst &CI) { return commonCastTransforms(CI); } @@ -1746,32 +1768,32 @@ Instruction *InstCombiner::visitFPExt(CastInst &CI) { Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) { if (!isa(FI.getOperand(0)) && !isa(FI.getOperand(0))) return nullptr; - Instruction *OpI = cast(FI.getOperand(0)); + auto *OpI = cast(FI.getOperand(0)); Value *X = OpI->getOperand(0); - Type *DestType = FI.getType(); - Type *FPType = OpI->getType(); Type *XType = X->getType(); - bool IsInputSigned = isa(OpI); + Type *DestType = FI.getType(); bool IsOutputSigned = isa(FI); - // We can safely assume the conversion won't overflow the output range, - // because (for example) (uint8_t)18293.f is undefined behavior. - // Since we can assume the conversion won't overflow, our decision as to // whether the input will fit in the float should depend on the minimum // of the input range and output range. // This means this is also safe for a signed input and unsigned output, since // a negative input would lead to undefined behavior. - int InputSize = (int)XType->getScalarSizeInBits() - IsInputSigned; - int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned; - int MinIntWidth = std::min(InputSize, OutputSize); - - if (MinIntWidth > FPType->getFPMantissaWidth()) - return nullptr; + if (!isKnownExactCastIntToFP(*OpI)) { + // The first cast may not round exactly based on the source integer width + // and FP width, but the overflow UB rules can still allow this to fold. + // If the destination type is narrow, that means the intermediate FP value + // must be large enough to hold the source value exactly. + // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. + int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned; + if (OutputSize > OpI->getType()->getFPMantissaWidth()) + return nullptr; + } if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { + bool IsInputSigned = isa(OpI); if (IsInputSigned && IsOutputSigned) return new SExtInst(X, DestType); return new ZExtInst(X, DestType);