Skip to content

Commit

Permalink
[IR] allow fast-math-flags on select of FP values
Browse files Browse the repository at this point in the history
This is a minimal start to correcting a problem most directly discussed in PR38086:
https://bugs.llvm.org/show_bug.cgi?id=38086

We have been hacking around a limitation for FP select patterns by using the
fast-math-flags on the condition of the select rather than the select itself.
This patch just allows FMF to appear with the 'select' opcode. No changes are
needed to "FPMathOperator" because it already includes select-of-FP because
that definition is based on the (return) value type.

Once we have this ability, we can start correcting and adding IR transforms
to use the FMF on a 'select' instruction. The instcombine and vectorizer test
diffs only show that the IRBuilder change is behaving as expected by applying
an FMF guard value to 'select'.

For reference:
rL241901 - allowed FMF with fcmp
rL255555 - allowed FMF with FP calls

Differential Revision: https://reviews.llvm.org/D61917

llvm-svn: 361401
  • Loading branch information
rotateright committed May 22, 2019
1 parent 63305c8 commit 5a4f7cf
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 36 deletions.
7 changes: 6 additions & 1 deletion llvm/docs/LangRef.rst
Expand Up @@ -9931,7 +9931,7 @@ Syntax:

::

<result> = select selty <cond>, <ty> <val1>, <ty> <val2> ; yields ty
<result> = select [fast-math flags] selty <cond>, <ty> <val1>, <ty> <val2> ; yields ty

selty is either i1 or {<N x i1>}

Expand All @@ -9948,6 +9948,11 @@ The '``select``' instruction requires an 'i1' value or a vector of 'i1'
values indicating the condition, and two values of the same :ref:`first
class <t_firstclass>` type.

#. The optional ``fast-math flags`` marker indicates that the select has one or more
:ref:`fast-math flags <fastmath>`. These are optimization hints to enable
otherwise unsafe floating-point optimizations. Fast-math flags are only valid
for selects that return a floating-point scalar or vector type.

Semantics:
""""""""""

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IRBuilder.h
Expand Up @@ -2067,6 +2067,8 @@ class IRBuilder : public IRBuilderBase, public Inserter {
MDNode *Unpred = MDFrom->getMetadata(LLVMContext::MD_unpredictable);
Sel = addBranchMetadata(Sel, Prof, Unpred);
}
if (isa<FPMathOperator>(Sel))
Sel = cast<SelectInst>(setFPAttrs(Sel, nullptr /* MDNode* */, FMF));
return Insert(Sel, Name);
}

Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/AsmParser/LLParser.cpp
Expand Up @@ -5701,7 +5701,19 @@ int LLParser::ParseInstruction(Instruction *&Inst, BasicBlock *BB,
case lltok::kw_inttoptr:
case lltok::kw_ptrtoint: return ParseCast(Inst, PFS, KeywordVal);
// Other.
case lltok::kw_select: return ParseSelect(Inst, PFS);
case lltok::kw_select: {
FastMathFlags FMF = EatFastMathFlagsIfPresent();
int Res = ParseSelect(Inst, PFS);
if (Res != 0)
return Res;
if (FMF.any()) {
if (!Inst->getType()->isFPOrFPVectorTy())
return Error(Loc, "fast-math-flags specified for select without "
"floating-point scalar or vector return type");
Inst->setFastMathFlags(FMF);
}
return 0;
}
case lltok::kw_va_arg: return ParseVA_Arg(Inst, PFS);
case lltok::kw_extractelement: return ParseExtractElement(Inst, PFS);
case lltok::kw_insertelement: return ParseInsertElement(Inst, PFS);
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Expand Up @@ -3835,6 +3835,11 @@ Error BitcodeReader::parseFunctionBody(Function *F) {

I = SelectInst::Create(Cond, TrueVal, FalseVal);
InstructionList.push_back(I);
if (OpNum < Record.size() && isa<FPMathOperator>(I)) {
FastMathFlags FMF = getDecodedFastMathFlags(Record[OpNum]);
if (FMF.any())
I->setFastMathFlags(FMF);
}
break;
}

Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
Expand Up @@ -2636,12 +2636,16 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
Vals.append(IVI->idx_begin(), IVI->idx_end());
break;
}
case Instruction::Select:
case Instruction::Select: {
Code = bitc::FUNC_CODE_INST_VSELECT;
pushValueAndType(I.getOperand(1), InstID, Vals);
pushValue(I.getOperand(2), InstID, Vals);
pushValueAndType(I.getOperand(0), InstID, Vals);
uint64_t Flags = getOptimizationFlags(&I);
if (Flags != 0)
Vals.push_back(Flags);
break;
}
case Instruction::ExtractElement:
Code = bitc::FUNC_CODE_INST_EXTRACTELT;
pushValueAndType(I.getOperand(0), InstID, Vals);
Expand Down
28 changes: 28 additions & 0 deletions llvm/test/Bitcode/compatibility.ll
Expand Up @@ -815,6 +815,34 @@ define void @fastmathflags_binops(float %op1, float %op2) {
ret void
}

define void @fastmathflags_select(i1 %cond, float %op1, float %op2) {
%f.nnan = select nnan i1 %cond, float %op1, float %op2
; CHECK: %f.nnan = select nnan i1 %cond, float %op1, float %op2
%f.ninf = select ninf i1 %cond, float %op1, float %op2
; CHECK: %f.ninf = select ninf i1 %cond, float %op1, float %op2
%f.nsz = select nsz i1 %cond, float %op1, float %op2
; CHECK: %f.nsz = select nsz i1 %cond, float %op1, float %op2
%f.arcp = select arcp i1 %cond, float %op1, float %op2
; CHECK: %f.arcp = select arcp i1 %cond, float %op1, float %op2
%f.contract = select contract i1 %cond, float %op1, float %op2
; CHECK: %f.contract = select contract i1 %cond, float %op1, float %op2
%f.afn = select afn i1 %cond, float %op1, float %op2
; CHECK: %f.afn = select afn i1 %cond, float %op1, float %op2
%f.reassoc = select reassoc i1 %cond, float %op1, float %op2
; CHECK: %f.reassoc = select reassoc i1 %cond, float %op1, float %op2
%f.fast = select fast i1 %cond, float %op1, float %op2
; CHECK: %f.fast = select fast i1 %cond, float %op1, float %op2
ret void
}

define void @fastmathflags_vector_select(<2 x i1> %cond, <2 x double> %op1, <2 x double> %op2) {
%f.nnan.nsz = select nnan nsz <2 x i1> %cond, <2 x double> %op1, <2 x double> %op2
; CHECK: %f.nnan.nsz = select nnan nsz <2 x i1> %cond, <2 x double> %op1, <2 x double> %op2
%f.fast = select fast <2 x i1> %cond, <2 x double> %op1, <2 x double> %op2
; CHECK: %f.fast = select fast <2 x i1> %cond, <2 x double> %op1, <2 x double> %op2
ret void
}

; Check various fast math flags and floating-point types on calls.

declare float @fmf1()
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/Generic/expand-experimental-reductions.ll
Expand Up @@ -277,7 +277,7 @@ define double @fmax_f64(<2 x double> %vec) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[RDX_SHUF:%.*]] = shufflevector <2 x double> [[VEC:%.*]], <2 x double> undef, <2 x i32> <i32 1, i32 undef>
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast ogt <2 x double> [[VEC]], [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select <2 x i1> [[RDX_MINMAX_CMP]], <2 x double> [[VEC]], <2 x double> [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <2 x i1> [[RDX_MINMAX_CMP]], <2 x double> [[VEC]], <2 x double> [[RDX_SHUF]]
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[RDX_MINMAX_SELECT]], i32 0
; CHECK-NEXT: ret double [[TMP0]]
;
Expand All @@ -291,7 +291,7 @@ define double @fmin_f64(<2 x double> %vec) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[RDX_SHUF:%.*]] = shufflevector <2 x double> [[VEC:%.*]], <2 x double> undef, <2 x i32> <i32 1, i32 undef>
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast olt <2 x double> [[VEC]], [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select <2 x i1> [[RDX_MINMAX_CMP]], <2 x double> [[VEC]], <2 x double> [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <2 x i1> [[RDX_MINMAX_CMP]], <2 x double> [[VEC]], <2 x double> [[RDX_SHUF]]
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[RDX_MINMAX_SELECT]], i32 0
; CHECK-NEXT: ret double [[TMP0]]
;
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/Transforms/InstCombine/fast-math.ll
Expand Up @@ -820,7 +820,7 @@ declare fp128 @fminl(fp128, fp128)
define float @max1(float %a, float %b) {
; CHECK-LABEL: @max1(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp fast ogt float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select fast i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: ret float [[TMP2]]
;
%c = fpext float %a to double
Expand All @@ -833,7 +833,7 @@ define float @max1(float %a, float %b) {
define float @max2(float %a, float %b) {
; CHECK-LABEL: @max2(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp nnan nsz ogt float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select nnan nsz i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: ret float [[TMP2]]
;
%c = call nnan float @fmaxf(float %a, float %b)
Expand All @@ -844,7 +844,7 @@ define float @max2(float %a, float %b) {
define double @max3(double %a, double %b) {
; CHECK-LABEL: @max3(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp fast ogt double [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], double [[A]], double [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select fast i1 [[TMP1]], double [[A]], double [[B]]
; CHECK-NEXT: ret double [[TMP2]]
;
%c = call fast double @fmax(double %a, double %b)
Expand All @@ -854,7 +854,7 @@ define double @max3(double %a, double %b) {
define fp128 @max4(fp128 %a, fp128 %b) {
; CHECK-LABEL: @max4(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp nnan nsz ogt fp128 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], fp128 [[A]], fp128 [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select nnan nsz i1 [[TMP1]], fp128 [[A]], fp128 [[B]]
; CHECK-NEXT: ret fp128 [[TMP2]]
;
%c = call nnan fp128 @fmaxl(fp128 %a, fp128 %b)
Expand All @@ -865,7 +865,7 @@ define fp128 @max4(fp128 %a, fp128 %b) {
define float @min1(float %a, float %b) {
; CHECK-LABEL: @min1(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp nnan nsz olt float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select nnan nsz i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: ret float [[TMP2]]
;
%c = fpext float %a to double
Expand All @@ -878,7 +878,7 @@ define float @min1(float %a, float %b) {
define float @min2(float %a, float %b) {
; CHECK-LABEL: @min2(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp fast olt float [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select fast i1 [[TMP1]], float [[A]], float [[B]]
; CHECK-NEXT: ret float [[TMP2]]
;
%c = call fast float @fminf(float %a, float %b)
Expand All @@ -888,7 +888,7 @@ define float @min2(float %a, float %b) {
define double @min3(double %a, double %b) {
; CHECK-LABEL: @min3(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp nnan nsz olt double [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], double [[A]], double [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select nnan nsz i1 [[TMP1]], double [[A]], double [[B]]
; CHECK-NEXT: ret double [[TMP2]]
;
%c = call nnan double @fmin(double %a, double %b)
Expand All @@ -898,7 +898,7 @@ define double @min3(double %a, double %b) {
define fp128 @min4(fp128 %a, fp128 %b) {
; CHECK-LABEL: @min4(
; CHECK-NEXT: [[TMP1:%.*]] = fcmp fast olt fp128 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], fp128 [[A]], fp128 [[B]]
; CHECK-NEXT: [[TMP2:%.*]] = select fast i1 [[TMP1]], fp128 [[A]], fp128 [[B]]
; CHECK-NEXT: ret fp128 [[TMP2]]
;
%c = call fast fp128 @fminl(fp128 %a, fp128 %b)
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/InstCombine/pow-sqrt.ll
Expand Up @@ -36,7 +36,7 @@ define double @pow_libcall_half_approx(double %x) {
; CHECK-NEXT: [[SQRT:%.*]] = call afn double @sqrt(double [[X:%.*]])
; CHECK-NEXT: [[ABS:%.*]] = call afn double @llvm.fabs.f64(double [[SQRT]])
; CHECK-NEXT: [[ISINF:%.*]] = fcmp afn oeq double [[X]], 0xFFF0000000000000
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[ISINF]], double 0x7FF0000000000000, double [[ABS]]
; CHECK-NEXT: [[TMP1:%.*]] = select afn i1 [[ISINF]], double 0x7FF0000000000000, double [[ABS]]
; CHECK-NEXT: ret double [[TMP1]]
;
%pow = call afn double @pow(double %x, double 5.0e-01)
Expand All @@ -48,7 +48,7 @@ define <2 x double> @pow_intrinsic_half_approx(<2 x double> %x) {
; CHECK-NEXT: [[SQRT:%.*]] = call afn <2 x double> @llvm.sqrt.v2f64(<2 x double> [[X:%.*]])
; CHECK-NEXT: [[ABS:%.*]] = call afn <2 x double> @llvm.fabs.v2f64(<2 x double> [[SQRT]])
; CHECK-NEXT: [[ISINF:%.*]] = fcmp afn oeq <2 x double> [[X]], <double 0xFFF0000000000000, double 0xFFF0000000000000>
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[ISINF]], <2 x double> <double 0x7FF0000000000000, double 0x7FF0000000000000>, <2 x double> [[ABS]]
; CHECK-NEXT: [[TMP1:%.*]] = select afn <2 x i1> [[ISINF]], <2 x double> <double 0x7FF0000000000000, double 0x7FF0000000000000>, <2 x double> [[ABS]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%pow = call afn <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> <double 5.0e-01, double 5.0e-01>)
Expand Down Expand Up @@ -92,7 +92,7 @@ define double @pow_libcall_half_nsz(double %x) {
; CHECK-LABEL: @pow_libcall_half_nsz(
; CHECK-NEXT: [[SQRT:%.*]] = call nsz double @sqrt(double [[X:%.*]])
; CHECK-NEXT: [[ISINF:%.*]] = fcmp nsz oeq double [[X]], 0xFFF0000000000000
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[ISINF]], double 0x7FF0000000000000, double [[SQRT]]
; CHECK-NEXT: [[TMP1:%.*]] = select nsz i1 [[ISINF]], double 0x7FF0000000000000, double [[SQRT]]
; CHECK-NEXT: ret double [[TMP1]]
;
%pow = call nsz double @pow(double %x, double 5.0e-01)
Expand All @@ -103,7 +103,7 @@ define double @pow_intrinsic_half_nsz(double %x) {
; CHECK-LABEL: @pow_intrinsic_half_nsz(
; CHECK-NEXT: [[SQRT:%.*]] = call nsz double @llvm.sqrt.f64(double [[X:%.*]])
; CHECK-NEXT: [[ISINF:%.*]] = fcmp nsz oeq double [[X]], 0xFFF0000000000000
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[ISINF]], double 0x7FF0000000000000, double [[SQRT]]
; CHECK-NEXT: [[TMP1:%.*]] = select nsz i1 [[ISINF]], double 0x7FF0000000000000, double [[SQRT]]
; CHECK-NEXT: ret double [[TMP1]]
;
%pow = call nsz double @llvm.pow.f64(double %x, double 5.0e-01)
Expand Down
Expand Up @@ -74,10 +74,10 @@ define float @minloopattr(float* nocapture readonly %arg) #0 {
; CHECK: middle.block:
; CHECK-NEXT: [[RDX_SHUF:%.*]] = shufflevector <4 x float> [[TMP6]], <4 x float> undef, <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast olt <4 x float> [[TMP6]], [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP6]], <4 x float> [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP6]], <4 x float> [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_SHUF1:%.*]] = shufflevector <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[RDX_MINMAX_CMP2:%.*]] = fcmp fast olt <4 x float> [[RDX_MINMAX_SELECT]], [[RDX_SHUF1]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT3:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
; CHECK-NEXT: [[RDX_MINMAX_SELECT3:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x float> [[RDX_MINMAX_SELECT3]], i32 0
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 65536, 65536
; CHECK-NEXT: br i1 [[CMP_N]], label [[OUT:%.*]], label [[SCALAR_PH]]
Expand Down

0 comments on commit 5a4f7cf

Please sign in to comment.