diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 4e3781cb4e9d59..cd5dc0e01ed0ea 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -313,13 +313,22 @@ llvm::getIConstantVRegSExtVal(Register VReg, const MachineRegisterInfo &MRI) { namespace { -typedef std::function IsOpcodeFn; -typedef std::function(const MachineInstr *MI)> GetAPCstFn; - -std::optional getConstantVRegValWithLookThrough( - Register VReg, const MachineRegisterInfo &MRI, IsOpcodeFn IsConstantOpcode, - GetAPCstFn getAPCstValue, bool LookThroughInstrs = true, - bool LookThroughAnyExt = false) { +// This function is used in many places, and as such, it has some +// micro-optimizations to try and make it as fast as it can be. +// +// - We use template arguments to avoid an indirect call caused by passing a +// function_ref/std::function +// - GetAPCstValue does not return std::optional as that's expensive. +// Instead it returns true/false and places the result in a pre-constructed +// APInt. +// +// Please change this function carefully and benchmark your changes. +template +std::optional +getConstantVRegValWithLookThrough(Register VReg, const MachineRegisterInfo &MRI, + bool LookThroughInstrs = true, + bool LookThroughAnyExt = false) { SmallVector, 4> SeenOpcodes; MachineInstr *MI; @@ -353,26 +362,25 @@ std::optional getConstantVRegValWithLookThrough( if (!MI || !IsConstantOpcode(MI)) return std::nullopt; - std::optional MaybeVal = getAPCstValue(MI); - if (!MaybeVal) + APInt Val; + if (!GetAPCstValue(MI, Val)) return std::nullopt; - APInt &Val = *MaybeVal; - for (auto [Opcode, Size] : reverse(SeenOpcodes)) { - switch (Opcode) { + for (auto &Pair : reverse(SeenOpcodes)) { + switch (Pair.first) { case TargetOpcode::G_TRUNC: - Val = Val.trunc(Size); + Val = Val.trunc(Pair.second); break; case TargetOpcode::G_ANYEXT: case TargetOpcode::G_SEXT: - Val = Val.sext(Size); + Val = Val.sext(Pair.second); break; case TargetOpcode::G_ZEXT: - Val = Val.zext(Size); + Val = Val.zext(Pair.second); break; } } - return ValueAndVReg{Val, VReg}; + return ValueAndVReg{std::move(Val), VReg}; } bool isIConstant(const MachineInstr *MI) { @@ -394,42 +402,46 @@ bool isAnyConstant(const MachineInstr *MI) { return Opc == TargetOpcode::G_CONSTANT || Opc == TargetOpcode::G_FCONSTANT; } -std::optional getCImmAsAPInt(const MachineInstr *MI) { +bool getCImmAsAPInt(const MachineInstr *MI, APInt &Result) { const MachineOperand &CstVal = MI->getOperand(1); - if (CstVal.isCImm()) - return CstVal.getCImm()->getValue(); - return std::nullopt; + if (!CstVal.isCImm()) + return false; + Result = CstVal.getCImm()->getValue(); + return true; } -std::optional getCImmOrFPImmAsAPInt(const MachineInstr *MI) { +bool getCImmOrFPImmAsAPInt(const MachineInstr *MI, APInt &Result) { const MachineOperand &CstVal = MI->getOperand(1); if (CstVal.isCImm()) - return CstVal.getCImm()->getValue(); - if (CstVal.isFPImm()) - return CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); - return std::nullopt; + Result = CstVal.getCImm()->getValue(); + else if (CstVal.isFPImm()) + Result = CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); + else + return false; + return true; } } // end anonymous namespace std::optional llvm::getIConstantVRegValWithLookThrough( Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs) { - return getConstantVRegValWithLookThrough(VReg, MRI, isIConstant, - getCImmAsAPInt, LookThroughInstrs); + return getConstantVRegValWithLookThrough( + VReg, MRI, LookThroughInstrs); } std::optional llvm::getAnyConstantVRegValWithLookThrough( Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs, bool LookThroughAnyExt) { - return getConstantVRegValWithLookThrough( - VReg, MRI, isAnyConstant, getCImmOrFPImmAsAPInt, LookThroughInstrs, - LookThroughAnyExt); + return getConstantVRegValWithLookThrough( + VReg, MRI, LookThroughInstrs, LookThroughAnyExt); } std::optional llvm::getFConstantVRegValWithLookThrough( Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs) { - auto Reg = getConstantVRegValWithLookThrough( - VReg, MRI, isFConstant, getCImmOrFPImmAsAPInt, LookThroughInstrs); + auto Reg = + getConstantVRegValWithLookThrough( + VReg, MRI, LookThroughInstrs); if (!Reg) return std::nullopt; return FPValueAndVReg{getConstantFPVRegVal(Reg->VReg, MRI)->getValueAPF(),