diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 5259db1ff2dd7..98c7709acf938 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -220,8 +220,10 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectConst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; - bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, - bool IsSigned) const; + bool selectSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectSelectDefaultArgs(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const; bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsSigned, unsigned Opcode) const; bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -510,7 +512,18 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) { if (isTypeFoldingSupported(Def->getOpcode()) && Def->getOpcode() != TargetOpcode::G_CONSTANT && Def->getOpcode() != TargetOpcode::G_FCONSTANT) { - bool Res = selectImpl(I, *CoverageInfo); + bool Res = false; + if (Def->getOpcode() == TargetOpcode::G_SELECT) { + Register SelectDstReg = Def->getOperand(0).getReg(); + Res = selectSelect(SelectDstReg, GR.getSPIRVTypeForVReg(SelectDstReg), + *Def); + GR.invalidateMachineInstr(Def); + Def->removeFromParent(); + MRI->replaceRegWith(DstReg, SelectDstReg); + GR.invalidateMachineInstr(&I); + I.removeFromParent(); + } else + Res = selectImpl(I, *CoverageInfo); LLVM_DEBUG({ if (!Res && Def->getOpcode() != TargetOpcode::G_CONSTANT) { dbgs() << "Unexpected pattern in ASSIGN_TYPE.\nInstruction: "; @@ -2565,8 +2578,52 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, bool SPIRVInstructionSelector::selectSelect(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I, - bool IsSigned) const { + MachineInstr &I) const { + Register SelectFirstArg = I.getOperand(2).getReg(); + Register SelectSecondArg = I.getOperand(3).getReg(); + assert(ResType == GR.getSPIRVTypeForVReg(SelectFirstArg) && + ResType == GR.getSPIRVTypeForVReg(SelectSecondArg)); + + bool IsFloatTy = + GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypeFloat); + bool IsPtrTy = + GR.isScalarOrVectorOfType(SelectFirstArg, SPIRV::OpTypePointer); + bool IsVectorTy = GR.getSPIRVTypeForVReg(SelectFirstArg)->getOpcode() == + SPIRV::OpTypeVector; + + bool IsScalarBool = + GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); + unsigned Opcode; + if (IsVectorTy) { + if (IsFloatTy) { + Opcode = IsScalarBool ? SPIRV::OpSelectVFSCond : SPIRV::OpSelectVFVCond; + } else if (IsPtrTy) { + Opcode = IsScalarBool ? SPIRV::OpSelectVPSCond : SPIRV::OpSelectVPVCond; + } else { + Opcode = IsScalarBool ? SPIRV::OpSelectVISCond : SPIRV::OpSelectVIVCond; + } + } else { + if (IsFloatTy) { + Opcode = IsScalarBool ? SPIRV::OpSelectSFSCond : SPIRV::OpSelectVFVCond; + } else if (IsPtrTy) { + Opcode = IsScalarBool ? SPIRV::OpSelectSPSCond : SPIRV::OpSelectVPVCond; + } else { + Opcode = IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectVIVCond; + } + } + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(SelectFirstArg) + .addUse(SelectSecondArg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectSelectDefaultArgs(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { // To extend a bool, we need to use OpSelect between constants. Register ZeroReg = buildZerosVal(ResType, I); Register OneReg = buildOnesVal(IsSigned, ResType, I); @@ -2598,7 +2655,7 @@ bool SPIRVInstructionSelector::selectIToF(Register ResVReg, TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); } SrcReg = createVirtualRegister(TmpType, &GR, MRI, MRI->getMF()); - selectSelect(SrcReg, TmpType, I, false); + selectSelectDefaultArgs(SrcReg, TmpType, I, false); } return selectOpWithSrcs(ResVReg, ResType, I, {SrcReg}, Opcode); } @@ -2608,7 +2665,7 @@ bool SPIRVInstructionSelector::selectExt(Register ResVReg, MachineInstr &I, bool IsSigned) const { Register SrcReg = I.getOperand(1).getReg(); if (GR.isScalarOrVectorOfType(SrcReg, SPIRV::OpTypeBool)) - return selectSelect(ResVReg, ResType, I, IsSigned); + return selectSelectDefaultArgs(ResVReg, ResType, I, IsSigned); SPIRVType *SrcType = GR.getSPIRVTypeForVReg(SrcReg); if (SrcType == ResType) diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index b62db7fd62b2e..1a08c6ac0dcaf 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -441,13 +441,10 @@ void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType, // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is // present after each auto-folded instruction to take a type reference from. Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) { - MRI.setRegClass(NewReg, RC); - } else { - auto RegClass = GR->getRegClass(SpvType); - MRI.setRegClass(NewReg, RegClass); - MRI.setRegClass(Reg, RegClass); - } + const auto *RegClass = GR->getRegClass(SpvType); + MRI.setRegClass(NewReg, RegClass); + MRI.setRegClass(Reg, RegClass); + GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType // when processing the actual MI (i.e. not pseudo one). diff --git a/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll new file mode 100644 index 0000000000000..69ea054de1e4f --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/instructions/issue-135572-emit-float-opselect.ll @@ -0,0 +1,53 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.6 %} + +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#func_type:]] = OpTypeFunction %[[#float_32]] %[[#float_32]] %[[#float_32]] +; CHECK-DAG: %[[#bool:]] = OpTypeBool +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 +; CHECK-DAG: %[[#vec_func_type:]] = OpTypeFunction %[[#vec4_float_32]] %[[#vec4_float_32]] %[[#vec4_float_32]] +; CHECK-DAG: %[[#vec_4_bool:]] = OpTypeVector %[[#bool]] 4 + +define spir_func float @opselect_float_scalar_test(float %x, float %y) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#func_type]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]] + ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#bool]] %[[#arg0]] %[[#arg1]] + ; CHECK: %[[#fselect:]] = OpSelect %[[#float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]] + ; CHECK: OpReturnValue %[[#fselect]] + %0 = fcmp ogt float %x, %y + %1 = select i1 %0, float %x, float %y + ret float %1 +} + +define spir_func <4 x float> @opselect_float4_vec_test(<4 x float> %x, <4 x float> %y) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#vec_func_type]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#vec_4_bool]] %[[#arg0]] %[[#arg1]] + ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg0]] %[[#arg1]] + ; CHECK: OpReturnValue %[[#fselect]] + %0 = fcmp ogt <4 x float> %x, %y + %1 = select <4 x i1> %0, <4 x float> %x, <4 x float> %y + ret <4 x float> %1 +} + +define spir_func <4 x float> @opselect_scalar_bool_float4_vec_test(float %a, float %b, <4 x float> %x, <4 x float> %y) { +entry: + ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]] + ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#float_32]] + ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#float_32]] + ; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#arg3:]] = OpFunctionParameter %[[#vec4_float_32]] + ; CHECK: %[[#fcmp:]] = OpFOrdGreaterThan %[[#bool]] %[[#arg0]] %[[#arg1]] + ; CHECK: %[[#fselect:]] = OpSelect %[[#vec4_float_32]] %[[#fcmp]] %[[#arg2]] %[[#arg3]] + ; CHECK: OpReturnValue %[[#fselect]] + %0 = fcmp ogt float %a, %b + %1 = select i1 %0, <4 x float> %x, <4 x float> %y + ret <4 x float> %1 +} \ No newline at end of file