diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h index fe9a4bada2430..db4d9edb152ce 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h @@ -435,6 +435,18 @@ inline bool isCallIndirect(unsigned Opc) { } } +inline bool isCallRef(unsigned Opc) { + switch (Opc) { + case WebAssembly::CALL_REF: + case WebAssembly::CALL_REF_S: + case WebAssembly::RET_CALL_REF: + case WebAssembly::RET_CALL_REF_S: + return true; + default: + return false; + } +} + inline bool isBrTable(unsigned Opc) { switch (Opc) { case WebAssembly::BR_TABLE_I32: diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp index 2541b0433ab59..03c90c7160a68 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp @@ -120,60 +120,6 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) { return DAG->getTargetExternalSymbol(SymName, PtrVT); } -static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL, - SmallVector &Returns, - SmallVector &Params) { - auto toWasmValType = [](MVT VT) { - if (VT == MVT::i32) { - return wasm::ValType::I32; - } - if (VT == MVT::i64) { - return wasm::ValType::I64; - } - if (VT == MVT::f32) { - return wasm::ValType::F32; - } - if (VT == MVT::f64) { - return wasm::ValType::F64; - } - if (VT == MVT::externref) { - return wasm::ValType::EXTERNREF; - } - if (VT == MVT::funcref) { - return wasm::ValType::FUNCREF; - } - if (VT == MVT::exnref) { - return wasm::ValType::EXNREF; - } - LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT - << "\n"); - llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func"); - }; - auto NParams = Params.size(); - auto NReturns = Returns.size(); - auto BitWidth = (NParams + NReturns + 2) * 64; - auto Sig = APInt(BitWidth, 0); - - // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will - // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we - // always emit a CImm. So xor NParams with 0x7ffffff to ensure - // getSignificantBits() > 64 - Sig |= NReturns ^ 0x7ffffff; - for (auto &Return : Returns) { - auto V = toWasmValType(Return); - Sig <<= 64; - Sig |= (int64_t)V; - } - Sig <<= 64; - Sig |= NParams; - for (auto &Param : Params) { - auto V = toWasmValType(Param); - Sig <<= 64; - Sig |= (int64_t)V; - } - return Sig; -} - void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { // If we have a custom node, we already have selected! if (Node->isMachineOpcode()) { @@ -288,7 +234,8 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { Returns.push_back(VT); } } - auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params); + auto Sig = + WebAssembly::encodeFunctionSignature(CurDAG, DL, Returns, Params); auto SigOp = CurDAG->getTargetConstant( Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth())); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 163bf9ba5b089..bd0733c73f7ed 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -723,6 +723,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, bool IsIndirect = CallParams.getOperand(0).isReg() || CallParams.getOperand(0).isFI(); bool IsRetCall = CallResults.getOpcode() == WebAssembly::RET_CALL_RESULTS; + bool IsCallRef = false; bool IsFuncrefCall = false; if (IsIndirect && CallParams.getOperand(0).isReg()) { @@ -732,10 +733,19 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, const TargetRegisterClass *TRC = MRI.getRegClass(Reg); IsFuncrefCall = (TRC == &WebAssembly::FUNCREFRegClass); assert(!IsFuncrefCall || Subtarget->hasReferenceTypes()); + + if (IsFuncrefCall && Subtarget->hasGC()) { + IsIndirect = false; + IsCallRef = true; + } } unsigned CallOp; - if (IsIndirect && IsRetCall) { + if (IsCallRef && IsRetCall) { + CallOp = WebAssembly::RET_CALL_REF; + } else if (IsCallRef) { + CallOp = WebAssembly::CALL_REF; + } else if (IsIndirect && IsRetCall) { CallOp = WebAssembly::RET_CALL_INDIRECT; } else if (IsIndirect) { CallOp = WebAssembly::CALL_INDIRECT; @@ -771,6 +781,14 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, CallParams.addOperand(FnPtr); } + // Move the function pointer to the end of the arguments for funcref calls + if (IsCallRef) { + auto FnRef = CallParams.getOperand(0); + CallParams.removeOperand(0); + + CallParams.addOperand(FnRef); + } + for (auto Def : CallResults.defs()) MIB.add(Def); @@ -795,6 +813,12 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, } } + if (IsCallRef) { + // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp + MIB.addImm(0); + } + for (auto Use : CallParams.uses()) MIB.add(Use); @@ -1173,6 +1197,60 @@ static bool callingConvSupported(CallingConv::ID CallConv) { CallConv == CallingConv::Swift; } +APInt WebAssembly::encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL, + SmallVector &Returns, + SmallVector &Params) { + auto toWasmValType = [](MVT VT) { + if (VT == MVT::i32) { + return wasm::ValType::I32; + } + if (VT == MVT::i64) { + return wasm::ValType::I64; + } + if (VT == MVT::f32) { + return wasm::ValType::F32; + } + if (VT == MVT::f64) { + return wasm::ValType::F64; + } + if (VT == MVT::externref) { + return wasm::ValType::EXTERNREF; + } + if (VT == MVT::funcref) { + return wasm::ValType::FUNCREF; + } + if (VT == MVT::exnref) { + return wasm::ValType::EXNREF; + } + LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT + << "\n"); + llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func"); + }; + auto NParams = Params.size(); + auto NReturns = Returns.size(); + auto BitWidth = (NParams + NReturns + 2) * 64; + auto Sig = APInt(BitWidth, 0); + + // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will + // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we + // always emit a CImm. So xor NParams with 0x7ffffff to ensure + // getSignificantBits() > 64 + Sig |= NReturns ^ 0x7ffffff; + for (auto &Return : Returns) { + auto V = toWasmValType(Return); + Sig <<= 64; + Sig |= (int64_t)V; + } + Sig <<= 64; + Sig |= NParams; + for (auto &Param : Params) { + auto V = toWasmValType(Param); + Sig <<= 64; + Sig |= (int64_t)V; + } + return Sig; +} + SDValue WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI, SmallVectorImpl &InVals) const { @@ -1412,33 +1490,58 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI, InTys.push_back(In.VT); } - // Lastly, if this is a call to a funcref we need to add an instruction - // table.set to the chain and transform the call. + // Lastly, if this is a call to a funcref we need to insert an instruction + // to either cast the funcref to a typed funcref for call_ref, or place it + // into a table for call_indirect if (CLI.CB && WebAssembly::isWebAssemblyFuncrefType( CLI.CB->getCalledOperand()->getType())) { - // In the absence of function references proposal where a funcref call is - // lowered to call_ref, using reference types we generate a table.set to set - // the funcref to a special table used solely for this purpose, followed by - // a call_indirect. Here we just generate the table set, and return the - // SDValue of the table.set so that LowerCall can finalize the lowering by - // generating the call_indirect. - SDValue Chain = Ops[0]; + if (Subtarget->hasGC()) { + // Since LLVM doesn't directly support typed function references, we take + // the untyped funcref and ref.cast it into a typed funcref. + SmallVector Params; + SmallVector Returns; + + for (const auto &Out : Outs) { + Params.push_back(Out.VT); + } + for (const auto &In : Ins) { + Returns.push_back(In.VT); + } - MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol( - MF.getContext(), Subtarget); - SDValue Sym = DAG.getMCSymbol(Table, PtrVT); - SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32); - SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee}; - SDValue TableSet = DAG.getMemIntrinsicNode( - WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps, - MVT::funcref, - // Machine Mem Operand args - MachinePointerInfo( - WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF), - CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()), - MachineMemOperand::MOStore); - - Ops[0] = TableSet; // The new chain is the TableSet itself + auto Sig = + WebAssembly::encodeFunctionSignature(&DAG, DL, Returns, Params); + + auto SigOp = DAG.getTargetConstant( + Sig, DL, EVT::getIntegerVT(*DAG.getContext(), Sig.getBitWidth())); + MachineSDNode *RefCastNode = DAG.getMachineNode( + WebAssembly::REF_CAST_FUNCREF, DL, MVT::funcref, {SigOp, Callee}); + + Ops[1] = SDValue(RefCastNode, 0); + } else { + // In the absence of function references proposal where a funcref call is + // lowered to call_ref, using reference types we generate a table.set to + // set the funcref to a special table used solely for this purpose, + // followed by a call_indirect. Here we just generate the table set, and + // return the SDValue of the table.set so that LowerCall can finalize the + // lowering by generating the call_indirect. + SDValue Chain = Ops[0]; + + MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol( + MF.getContext(), Subtarget); + SDValue Sym = DAG.getMCSymbol(Table, PtrVT); + SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32); + SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee}; + SDValue TableSet = DAG.getMemIntrinsicNode( + WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps, + MVT::funcref, + // Machine Mem Operand args + MachinePointerInfo( + WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF), + CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()), + MachineMemOperand::MOStore); + + Ops[0] = TableSet; // The new chain is the TableSet itself + } } if (CLI.IsTailCall) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h index b33a8530310be..7d2194132f293 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -141,6 +141,11 @@ class WebAssemblyTargetLowering final : public TargetLowering { namespace WebAssembly { FastISel *createFastISel(FunctionLoweringInfo &funcInfo, const TargetLibraryInfo *libInfo); + +APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL, + SmallVector &Returns, + SmallVector &Params); + } // end namespace WebAssembly } // end namespace llvm diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td index ca9a5ef9dda1c..81b62f6a682ec 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td @@ -66,6 +66,16 @@ defm CALL_INDIRECT : [], "call_indirect", "call_indirect\t$type, $table", 0x11>; +let variadicOpsAreDefs = 1 in +defm CALL_REF : + I<(outs), + (ins TypeIndex:$type, variable_ops), + (outs), + (ins TypeIndex:$type), + [], + "call_ref", "call_ref\t$type", 0x14>, + Requires<[HasGC]>; + let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in defm RET_CALL : I<(outs), (ins function32_op:$callee, variable_ops), @@ -81,4 +91,14 @@ defm RET_CALL_INDIRECT : 0x13>, Requires<[HasTailCall]>; +let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in +defm RET_CALL_REF : + I<(outs), + (ins TypeIndex:$type, variable_ops), + (outs), + (ins TypeIndex:$type), + [], + "return_call_ref", "return_call_ref\t$type", 0x15>, + Requires<[HasTailCall, HasGC]>; + } // Uses = [SP32,SP64], isCall = 1 diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td index fc82e5b4a61da..6fa6ed897d647 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td @@ -41,6 +41,11 @@ defm REF_TEST_FUNCREF : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref), "ref.test\t$type, $ref", "ref.test $type", 0xfb14>, Requires<[HasGC]>; +defm REF_CAST_FUNCREF : I<(outs FUNCREF:$res), (ins TypeIndex:$type, FUNCREF:$ref), + (outs), (ins TypeIndex:$type), [], + "ref.cast\t$type, $ref", "ref.cast $type", 0xfb16>, + Requires<[HasGC]>; + defm "" : REF_I; defm "" : REF_I; defm "" : REF_I; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp index e48283aadb437..1ed15967c01fe 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -230,7 +230,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, break; } case llvm::MachineOperand::MO_CImmediate: { - // Lower type index placeholder for ref.test + // Lower type index placeholder for ref.test and ref.cast // Currently this is the only way that CImmediates show up so panic if we // get confused. unsigned DescIndex = I - NumVariadicDefs; @@ -266,14 +266,16 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, Params.push_back(WebAssembly::regClassToValType( MRI.getRegClass(MO.getReg())->getID())); - // call_indirect instructions have a callee operand at the end which - // doesn't count as a param. - if (WebAssembly::isCallIndirect(MI->getOpcode())) + // call_indirect and call_ref instructions have a callee operand at + // the end which doesn't count as a param. + if (WebAssembly::isCallIndirect(MI->getOpcode()) || + WebAssembly::isCallRef(MI->getOpcode())) Params.pop_back(); - // return_call_indirect instructions have the return type of the - // caller - if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT) + // return_call_indirect and return_call_ref instructions have the + // return type of the caller + if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT || + MI->getOpcode() == WebAssembly::RET_CALL_REF) getFunctionReturns(MI, Returns); MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params)); diff --git a/llvm/test/CodeGen/WebAssembly/call-ref.ll b/llvm/test/CodeGen/WebAssembly/call-ref.ll new file mode 100644 index 0000000000000..25fc7440ac64c --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/call-ref.ll @@ -0,0 +1,51 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mattr=+reference-types,-gc | FileCheck --check-prefixes=CHECK,NOGC %s +; RUN: llc < %s -mattr=+reference-types,+gc | FileCheck --check-prefixes=CHECK,GC %s + +; Test that calls through funcref lower to call_ref when GC is available + +target triple = "wasm32-unknown-unknown" + +%funcref = type ptr addrspace(20); + +define void @call_ref_void(%funcref %callee) { +; CHECK-LABEL: call_ref_void: +; CHECK: .functype call_ref_void (funcref) -> () +; CHECK-NEXT: # %bb.0: +; NOGC-NEXT: i32.const 0 +; CHECK-NEXT: local.get 0 +; NOGC-NEXT: table.set __funcref_call_table +; NOGC-NEXT: i32.const 0 +; NOGC-NEXT: call_indirect __funcref_call_table, () -> () +; NOGC-NEXT: i32.const 0 +; NOGC-NEXT: ref.null_func +; NOGC-NEXT: table.set __funcref_call_table +; GC-NEXT: ref.cast () -> () +; GC-NEXT: call_ref () -> () +; CHECK-NEXT: # fallthrough-return + call addrspace(20) void %callee() + ret void +} + +define void @call_ref_with_args_and_ret(%funcref %callee) { +; CHECK-LABEL: call_ref_with_args_and_ret: +; CHECK: .functype call_ref_with_args_and_ret (funcref) -> () +; CHECK-NEXT: # %bb.0: +; NOGC-NEXT: i32.const 0 +; NOGC-NEXT: local.get 0 +; NOGC-NEXT: table.set __funcref_call_table +; CHECK-NEXT: i32.const 1 +; CHECK-NEXT: f64.const 0x1p1 +; NOGC-NEXT: i32.const 0 +; NOGC-NEXT: call_indirect __funcref_call_table, (i32, f64) -> (i32) +; GC-NEXT: local.get 0 +; GC-NEXT: ref.cast (i32, f64) -> (i32) +; GC-NEXT: call_ref (i32, f64) -> (i32) +; CHECK-NEXT: drop +; NOGC-NEXT: i32.const 0 +; NOGC-NEXT: ref.null_func +; NOGC-NEXT: table.set __funcref_call_table +; CHECK-NEXT: # fallthrough-return + %result = call addrspace(20) i32 %callee(i32 1, double 2.0) + ret void +}