diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bb0b7e46baf8a..7112691de1bcb 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -684,6 +684,74 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { return call; } +llvm::Function* CodeGenLLVM::GetIntrinsicDecl( + llvm::Intrinsic::ID id, llvm::Type* ret_type, + llvm::ArrayRef arg_types) { + llvm::Module* module = module_.get(); + + if (!llvm::Intrinsic::isOverloaded(id)) { + return llvm::Intrinsic::getDeclaration(module, id, {}); + } + + llvm::SmallVector infos; + llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos); + llvm::SmallVector overload_types; + +#if TVM_LLVM_VERSION >= 90 + auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { + overload_types.clear(); + llvm::ArrayRef ref(infos); + auto match = + llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); + if (error) { + return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; + } + } + return match; + }; + + // First, try matching the signature assuming non-vararg case. + auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false); + switch (try_match(fn_ty, false)) { + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet: + // The return type doesn't match, there is nothing else to do. + return nullptr; + case llvm::Intrinsic::MatchIntrinsicTypes_Match: + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg: + break; + } + + // Keep adding one type at a time (starting from empty list), and + // try matching the vararg signature. + llvm::SmallVector var_types; + for (int i = 0, e = arg_types.size(); i <= e; ++i) { + if (i > 0) var_types.push_back(arg_types[i - 1]); + auto* ft = llvm::FunctionType::get(ret_type, var_types, true); + if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + } + } + // Failed to identify the type. + return nullptr; + +#else // TVM_LLVM_VERSION + llvm::ArrayRef ref(infos); + // matchIntrinsicType returns true on error. + if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { + return nullptr; + } + for (llvm::Type* t : arg_types) { + if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) { + return nullptr; + } + } + return llvm::Intrinsic::getDeclaration(module, id, overload_types); +#endif // TVM_LLVM_VERSION +} + llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); @@ -691,19 +759,27 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; - std::vector sig_type; + std::vector arg_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); if (i - 2 < static_cast(num_signature)) { - sig_type.push_back(arg_value.back()->getType()); + arg_type.push_back(arg_value.back()->getType()); } } - llvm::Type *return_type = GetLLVMType(GetRef(op)); - if (sig_type.size() > 0 && return_type != sig_type[0]) { - sig_type.insert(sig_type.begin(), return_type); - } - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), id, sig_type); + // LLVM's prefetch intrinsic returns "void", while TVM's prefetch + // returns int32. This causes problems because prefetch is one of + // those intrinsics that is generated automatically via the + // tvm.intrin.rule mechanism. Any other intrinsic with a type + // mismatch will have to be treated specially here. + // TODO(kparzysz-quic): fix this once TVM prefetch uses the same + // type as LLVM. + llvm::Type *return_type = (id != llvm::Intrinsic::prefetch) + ? GetLLVMType(GetRef(op)) + : llvm::Type::getVoidTy(*ctx_); + + llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); + CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " + << llvm::Intrinsic::getName(id, {}); return builder_->CreateCall(f, arg_value); } else if (op->is_intrinsic(CallNode::bitwise_and)) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 6249aa4f74bc5..e785f3eab2753 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -232,6 +232,21 @@ class CodeGenLLVM : * \param type The corresponding TVM Type. */ llvm::Type* GetLLVMType(const PrimExpr& expr) const; + /*! + * \brief Get the declaration of the LLVM intrinsic based on the intrinsic + * id, and the type of the actual call. + * + * \param id The intrinsic id. + * \param ret_type The call return type. + * \param arg_types The types of the call arguments. + * + * \return Return the llvm::Function pointer, or nullptr if the declaration + * could not be generated (e.g. if the argument/return types do not + * match). + */ + llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type* ret_type, + llvm::ArrayRef arg_types); // initialize the function state. void InitFuncState(); // Get alignment given index. diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 880a0fe580008..58bfb371c577d 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -30,7 +30,7 @@ namespace codegen { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>); +.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); @@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); @@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 34135c6ef7ee4..3de1d1679e70e 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -40,6 +40,30 @@ def test_llvm_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_overloaded_intrin(): + # Name lookup for overloaded intrinsics in LLVM 4- requires a name + # that includes the overloaded types. + if tvm.target.codegen.llvm_version_major() < 5: + return + + def use_llvm_intrinsic(A, C): + ib = tvm.tir.ir_builder.create() + L = A.vload((0,0)) + I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz', + tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1')) + S = C.vstore((0,0), I) + ib.emit(S) + return ib.get() + + A = tvm.te.placeholder((1,1), dtype = 'int32', name = 'A') + C = tvm.te.extern((1,1), [A], + lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]), + name = 'C' , dtype = 'int32') + + s = tvm.te.create_schedule(C.op) + f = tvm.build(s, [A, C], target = 'llvm') + + def test_llvm_import(): # extern "C" is necessary to get the correct signature cc_code = """ @@ -82,9 +106,9 @@ def check_llvm(use_file): def test_llvm_lookup_intrin(): ib = tvm.tir.ir_builder.create() - m = te.size_var("m") A = ib.pointer("uint8x8", name="A") - x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A) + z = tvm.tir.const(0, 'int32') + x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) ib.emit(x) body = ib.get() func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True) @@ -680,6 +704,7 @@ def vectorizer(op): test_llvm_vadd_pipeline() test_llvm_add_pipeline() test_llvm_intrin() + test_llvm_overloaded_intrin() test_llvm_flip_pipeline() test_llvm_madd_pipeline() test_llvm_temp_space() diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index bdda496f8fb81..b7da66f9168f1 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -197,7 +197,6 @@ def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] - args_1 = tvm.tir.const(1, 'uint32') args_2 = tvm.tir.const(2, 'uint32') if unipolar: @@ -237,10 +236,10 @@ def _instr(index): cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts8[i*2], cnts8[i*2+1]) + args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts4[i*2], cnts4[i*2+1]) + args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) @@ -257,10 +256,10 @@ def _instr(index): cnts8[i] = tvm.tir.popcount(w_ & x_) for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts8[i*2], cnts8[i*2+1]) + args_2, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, - args_1, cnts4[i*2], cnts4[i*2+1]) + args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)