diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp index ff22ee8c86fac..a7734e9200a19 100644 --- a/llvm/lib/Target/X86/X86AsmPrinter.cpp +++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp @@ -478,9 +478,9 @@ static bool isIndirectBranchOrTailCall(const MachineInstr &MI) { Opc == X86::TAILJMPr64 || Opc == X86::TAILJMPm64 || Opc == X86::TCRETURNri || Opc == X86::TCRETURN_WIN64ri || Opc == X86::TCRETURN_HIPE32ri || Opc == X86::TCRETURNmi || - Opc == X86::TCRETURNri64 || Opc == X86::TCRETURNmi64 || - Opc == X86::TCRETURNri64_ImpCall || Opc == X86::TAILJMPr64_REX || - Opc == X86::TAILJMPm64_REX; + Opc == X86::TCRETURN_WINmi64 || Opc == X86::TCRETURNri64 || + Opc == X86::TCRETURNmi64 || Opc == X86::TCRETURNri64_ImpCall || + Opc == X86::TAILJMPr64_REX || Opc == X86::TAILJMPm64_REX; } void X86AsmPrinter::emitBasicBlockEnd(const MachineBasicBlock &MBB) { diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index 9457e718de699..4a9b824b0db14 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -276,8 +276,10 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, case X86::TCRETURNdi64cc: case X86::TCRETURNri64: case X86::TCRETURNri64_ImpCall: - case X86::TCRETURNmi64: { - bool isMem = Opcode == X86::TCRETURNmi || Opcode == X86::TCRETURNmi64; + case X86::TCRETURNmi64: + case X86::TCRETURN_WINmi64: { + bool isMem = Opcode == X86::TCRETURNmi || Opcode == X86::TCRETURNmi64 || + Opcode == X86::TCRETURN_WINmi64; MachineOperand &JumpTarget = MBBI->getOperand(0); MachineOperand &StackAdjust = MBBI->getOperand(isMem ? X86::AddrNumOperands : 1); @@ -341,7 +343,8 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB, MIB.addImm(MBBI->getOperand(2).getImm()); } - } else if (Opcode == X86::TCRETURNmi || Opcode == X86::TCRETURNmi64) { + } else if (Opcode == X86::TCRETURNmi || Opcode == X86::TCRETURNmi64 || + Opcode == X86::TCRETURN_WINmi64) { unsigned Op = (Opcode == X86::TCRETURNmi) ? X86::TAILJMPm : (IsX64 ? X86::TAILJMPm64_REX : X86::TAILJMPm64); diff --git a/llvm/lib/Target/X86/X86FrameLowering.cpp b/llvm/lib/Target/X86/X86FrameLowering.cpp index a293b4c87cfe4..08c9d738baceb 100644 --- a/llvm/lib/Target/X86/X86FrameLowering.cpp +++ b/llvm/lib/Target/X86/X86FrameLowering.cpp @@ -2402,7 +2402,7 @@ static bool isTailCallOpcode(unsigned Opc) { Opc == X86::TCRETURN_HIPE32ri || Opc == X86::TCRETURNdi || Opc == X86::TCRETURNmi || Opc == X86::TCRETURNri64 || Opc == X86::TCRETURNri64_ImpCall || Opc == X86::TCRETURNdi64 || - Opc == X86::TCRETURNmi64; + Opc == X86::TCRETURNmi64 || Opc == X86::TCRETURN_WINmi64; } void X86FrameLowering::emitEpilogue(MachineFunction &MF, diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td index 5a0df058b27f6..af7a33abaf758 100644 --- a/llvm/lib/Target/X86/X86InstrCompiler.td +++ b/llvm/lib/Target/X86/X86InstrCompiler.td @@ -1364,15 +1364,19 @@ def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off), // There wouldn't be enough scratch registers for base+index. def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off), (TCRETURNmi64 addr:$dst, timm:$off)>, - Requires<[In64BitMode, NotUseIndirectThunkCalls]>; + Requires<[In64BitMode, IsNotWin64CCFunc, NotUseIndirectThunkCalls]>; + +def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off), + (TCRETURN_WINmi64 addr:$dst, timm:$off)>, + Requires<[IsWin64CCFunc, NotUseIndirectThunkCalls]>; def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off), (INDIRECT_THUNK_TCRETURN64 ptr_rc_tailcall:$dst, timm:$off)>, - Requires<[In64BitMode, UseIndirectThunkCalls]>; + Requires<[In64BitMode, IsNotWin64CCFunc, UseIndirectThunkCalls]>; def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off), (INDIRECT_THUNK_TCRETURN32 ptr_rc_tailcall:$dst, timm:$off)>, - Requires<[Not64BitMode, UseIndirectThunkCalls]>; + Requires<[Not64BitMode, IsNotWin64CCFunc, UseIndirectThunkCalls]>; def : Pat<(X86tcret (i64 tglobaladdr:$dst), timm:$off), (TCRETURNdi64 tglobaladdr:$dst, timm:$off)>, @@ -2215,7 +2219,7 @@ let Predicates = [HasZU] in { def : Pat<(i64 (zext (i16 (mul (loadi16 addr:$src1), imm:$src2)))), (SUBREG_TO_REG (i64 0), (IMULZU16rmi addr:$src1, imm:$src2), sub_16bit)>; } - + // mul reg, imm def : Pat<(mul GR16:$src1, imm:$src2), (IMUL16rri GR16:$src1, imm:$src2)>; diff --git a/llvm/lib/Target/X86/X86InstrControl.td b/llvm/lib/Target/X86/X86InstrControl.td index 139aedd473ebc..d962bfff1444d 100644 --- a/llvm/lib/Target/X86/X86InstrControl.td +++ b/llvm/lib/Target/X86/X86InstrControl.td @@ -372,6 +372,9 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, def TCRETURNmi64 : PseudoI<(outs), (ins i64mem_TC:$dst, i32imm:$offset), []>, Sched<[WriteJumpLd]>; + def TCRETURN_WINmi64 : PseudoI<(outs), + (ins i64mem_w64TC:$dst, i32imm:$offset), + []>, Sched<[WriteJumpLd]>; def TAILJMPd64 : PseudoI<(outs), (ins i64i32imm_brtarget:$dst), []>, Sched<[WriteJump]>; diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td index 53a6b7c4c4c92..80843f6bb80e6 100644 --- a/llvm/lib/Target/X86/X86InstrOperands.td +++ b/llvm/lib/Target/X86/X86InstrOperands.td @@ -141,6 +141,11 @@ def i64mem_TC : X86MemOperand<"printqwordmem", X86Mem64AsmOperand, 64> { ptr_rc_tailcall, i32imm, SEGMENT_REG); } +def i64mem_w64TC : X86MemOperand<"printqwordmem", X86Mem64AsmOperand, 64> { + let MIOperandInfo = (ops GR64_TCW64, i8imm, + GR64_TCW64, i32imm, SEGMENT_REG); +} + // Special parser to detect 16-bit mode to select 16-bit displacement. def X86AbsMemMode16AsmOperand : AsmOperandClass { let Name = "AbsMemMode16"; diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 9ec04e740a08b..7963dc1b755c9 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -1010,6 +1010,7 @@ unsigned X86RegisterInfo::findDeadCallerSavedReg( case X86::TCRETURNri64: case X86::TCRETURNri64_ImpCall: case X86::TCRETURNmi64: + case X86::TCRETURN_WINmi64: case X86::EH_RETURN: case X86::EH_RETURN64: { LiveRegUnits LRU(*this); diff --git a/llvm/lib/Target/X86/X86SpeculativeLoadHardening.cpp b/llvm/lib/Target/X86/X86SpeculativeLoadHardening.cpp index 4cc456ece77e0..c28de14a97874 100644 --- a/llvm/lib/Target/X86/X86SpeculativeLoadHardening.cpp +++ b/llvm/lib/Target/X86/X86SpeculativeLoadHardening.cpp @@ -893,6 +893,7 @@ void X86SpeculativeLoadHardeningPass::unfoldCallAndJumpLoads( case X86::TAILJMPm64_REX: case X86::TAILJMPm: case X86::TCRETURNmi64: + case X86::TCRETURN_WINmi64: case X86::TCRETURNmi: { // Use the generic unfold logic now that we know we're dealing with // expected instructions. diff --git a/llvm/test/CodeGen/X86/win64-tailcall-memory.ll b/llvm/test/CodeGen/X86/win64-tailcall-memory.ll new file mode 100644 index 0000000000000..568f4fe04fea9 --- /dev/null +++ b/llvm/test/CodeGen/X86/win64-tailcall-memory.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -mtriple=x86_64-unknown-windows-gnu < %s | FileCheck %s + +; Check calling convention is correct for win64 when doing a tailcall +; for a pointer loaded from memory. + +declare void @foo(i64, ptr) + +define void @do_tailcall(ptr %objp) nounwind { +; CHECK-LABEL: do_tailcall: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rsi +; CHECK-NEXT: subq $32, %rsp +; CHECK-NEXT: movq %rcx, %rsi +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: xorl %edx, %edx +; CHECK-NEXT: callq foo +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: movq %rsi, %rax +; CHECK-NEXT: addq $32, %rsp +; CHECK-NEXT: popq %rsi +; CHECK-NEXT: rex64 jmpq *(%rax) # TAILCALL + tail call void @foo(i64 0, ptr null) + %fptr = load ptr, ptr %objp, align 8 + tail call void %fptr(ptr null) + ret void +} + +; Make sure aliases of ccc are also treated as win64 functions +define fastcc void @do_tailcall_fastcc(ptr %objp) nounwind { +; CHECK-LABEL: do_tailcall_fastcc: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rsi +; CHECK-NEXT: subq $32, %rsp +; CHECK-NEXT: movq %rcx, %rsi +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: xorl %edx, %edx +; CHECK-NEXT: callq foo +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: movq %rsi, %rax +; CHECK-NEXT: addq $32, %rsp +; CHECK-NEXT: popq %rsi +; CHECK-NEXT: rex64 jmpq *(%rax) # TAILCALL + tail call void @foo(i64 0, ptr null) + %fptr = load ptr, ptr %objp, align 8 + tail call fastcc void %fptr(ptr null) + ret void +}