diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def index e55314568d683..a4caf1621b656 100644 --- a/llvm/include/llvm/Support/TargetOpcodes.def +++ b/llvm/include/llvm/Support/TargetOpcodes.def @@ -643,6 +643,7 @@ HANDLE_TARGET_OPCODE(G_FMA) /// Generic FP multiply and add. Behaves as separate fmul and fadd. HANDLE_TARGET_OPCODE(G_FMAD) +HANDLE_TARGET_OPCODE(G_STRICT_FMULADD) /// Generic FP division. HANDLE_TARGET_OPCODE(G_FDIV) diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td index e3f995d53484f..88aa448bc24c2 100644 --- a/llvm/include/llvm/Target/GenericOpcodes.td +++ b/llvm/include/llvm/Target/GenericOpcodes.td @@ -1729,6 +1729,7 @@ def G_STRICT_FREM : ConstrainedInstruction; def G_STRICT_FMA : ConstrainedInstruction; def G_STRICT_FSQRT : ConstrainedInstruction; def G_STRICT_FLDEXP : ConstrainedInstruction; +def G_STRICT_FMULADD : ConstrainedInstruction; //------------------------------------------------------------------------------ // Memory intrinsics diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 884c3f1692e94..d40e2f396228c 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2080,6 +2080,8 @@ static unsigned getConstrainedOpcode(Intrinsic::ID ID) { return TargetOpcode::G_STRICT_FSQRT; case Intrinsic::experimental_constrained_ldexp: return TargetOpcode::G_STRICT_FLDEXP; + case Intrinsic::experimental_constrained_fmuladd: + return TargetOpcode::G_STRICT_FMULADD; default: return 0; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 989950fb8f8b5..b042e7d545785 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -237,6 +237,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned) const; + bool selectStrictFMulAdd(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectTrunc(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -746,6 +749,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, case TargetOpcode::G_FMA: return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma); + case TargetOpcode::G_STRICT_FMULADD: + return selectStrictFMulAdd(ResVReg, ResType, I); + case TargetOpcode::G_STRICT_FLDEXP: return selectExtInst(ResVReg, ResType, I, CL::ldexp); @@ -1193,6 +1199,37 @@ bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg, return MIB.constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectStrictFMulAdd(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + Register MulLHS = I.getOperand(1).getReg(); + Register MulRHS = I.getOperand(2).getReg(); + Register AddRHS = I.getOperand(3).getReg(); + SPIRVType *MulLHSType = GR.getSPIRVTypeForVReg(MulLHS); + unsigned MulOpcode, AddOpcode; + if (MulLHSType->getOpcode() == SPIRV::OpTypeFloat) { + MulOpcode = SPIRV::OpFMulS; + AddOpcode = SPIRV::OpFAddS; + } else { + MulOpcode = SPIRV::OpFMulV; + AddOpcode = SPIRV::OpFAddV; + } + Register MulTemp = MRI->createVirtualRegister(MRI->getRegClass(MulLHS)); + BuildMI(BB, I, I.getDebugLoc(), TII.get(MulOpcode)) + .addDef(MulTemp) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(MulLHS) + .addUse(MulRHS) + .constrainAllUses(TII, TRI, RBI); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(AddOpcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(MulTemp) + .addUse(AddRHS) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 53074ea3b2597..438c2704a6414 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -193,7 +193,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { .legalFor(allIntScalarsAndVectors) .legalIf(extendedScalarsAndVectors); - getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA}) + getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA, G_STRICT_FMULADD}) .legalFor(allFloatScalarsAndVectors); getActionDefinitionsBuilder(G_STRICT_FLDEXP) diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-fmuladd.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-fmuladd.ll new file mode 100644 index 0000000000000..340f2d78fc21b --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-fmuladd.ll @@ -0,0 +1,64 @@ +; RUN: llc -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTE +; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTZ +; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTP +; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTN +; CHECK-DAG: OpDecorate %[[#]] FPRoundingMode RTE + +; CHECK: OpFMul %[[#]] %[[#]] %[[#]] +; CHECK: OpFAdd %[[#]] %[[#]] %[[#]] +define spir_kernel void @test_f32(float %a) { +entry: + %r = tail call float @llvm.experimental.constrained.fmuladd.f32( + float %a, float %a, float %a, + metadata !"round.tonearest", metadata !"fpexcept.strict") + ret void +} + +; CHECK: OpFMul %[[#]] %[[#]] %[[#]] +; CHECK: OpFAdd %[[#]] %[[#]] %[[#]] +define spir_kernel void @test_f64(double %a) { +entry: + %r = tail call double @llvm.experimental.constrained.fmuladd.f64( + double %a, double %a, double %a, + metadata !"round.towardzero", metadata !"fpexcept.strict") + ret void +} + +; CHECK: OpFMul %[[#]] %[[#]] %[[#]] +; CHECK: OpFAdd %[[#]] %[[#]] %[[#]] +define spir_kernel void @test_v2f32(<2 x float> %a) { +entry: + %r = tail call <2 x float> @llvm.experimental.constrained.fmuladd.v2f32( + <2 x float> %a, <2 x float> %a, <2 x float> %a, + metadata !"round.upward", metadata !"fpexcept.strict") + ret void +} + +; CHECK: OpFMul %[[#]] %[[#]] %[[#]] +; CHECK: OpFAdd %[[#]] %[[#]] %[[#]] +define spir_kernel void @test_v4f32(<4 x float> %a) { +entry: + %r = tail call <4 x float> @llvm.experimental.constrained.fmuladd.v4f32( + <4 x float> %a, <4 x float> %a, <4 x float> %a, + metadata !"round.downward", metadata !"fpexcept.strict") + ret void +} + +; CHECK: OpFMul %[[#]] %[[#]] %[[#]] +; CHECK: OpFAdd %[[#]] %[[#]] %[[#]] +define spir_kernel void @test_v2f64(<2 x double> %a) { +entry: + %r = tail call <2 x double> @llvm.experimental.constrained.fmuladd.v2f64( + <2 x double> %a, <2 x double> %a, <2 x double> %a, + metadata !"round.tonearest", metadata !"fpexcept.strict") + ret void +} + +declare float @llvm.experimental.constrained.fmuladd.f32(float, float, float, metadata, metadata) +declare double @llvm.experimental.constrained.fmuladd.f64(double, double, double, metadata, metadata) +declare <2 x float> @llvm.experimental.constrained.fmuladd.v2f32(<2 x float>, <2 x float>, <2 x float>, metadata, metadata) +declare <4 x float> @llvm.experimental.constrained.fmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>, metadata, metadata) +declare <2 x double> @llvm.experimental.constrained.fmuladd.v2f64(<2 x double>, <2 x double>, <2 x double>, metadata, metadata)