From 137d3890a97753f56742556613090792e1569193 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 10 Jun 2021 13:43:07 -0500 Subject: [PATCH] [LLVM] Fix CodeGenLLVM::LinkParameters (#8213) - Generate valid LLVM IR. - Set proper alignment on the constant variables. --- src/target/llvm/codegen_llvm.cc | 76 +++++++++++++-------------------- 1 file changed, 29 insertions(+), 47 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index d5140677d45a..48ccefafe3c4 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -62,7 +62,7 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, md_builder_.reset(new llvm::MDBuilder(*ctx_)); // types t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(); + t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); t_int_ = llvm::Type::getInt32Ty(*ctx_); t_char_ = llvm::Type::getInt8Ty(*ctx_); t_int8_ = llvm::Type::getInt8Ty(*ctx_); @@ -191,20 +191,10 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { void CodeGenLLVM::LinkParameters(const Map params) { // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc, // but they are at a different layer in the compiler... - std::vector param_types; - // args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // num_args - param_types.push_back(t_int_); - // ret_args - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); - // ret_tcodes - param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace())); - // resource_handle - param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace())); + llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace()); + // args, tcodes, num_args, ret_value, ret_tcode, resource_handle + std::vector param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_}; llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false); llvm::Function* function = @@ -215,41 +205,29 @@ void CodeGenLLVM::LinkParameters(const Map params) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function); builder_->SetInsertPoint(entry); - std::vector zero_index_list{llvm::ConstantInt::get(t_int32_, 0)}; - std::vector zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0), - llvm::ConstantInt::get(t_int32_, 0)}; - auto args_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[0], + + auto getArg = [function](int i) -> llvm::Argument* { +#if TVM_LLVM_VERSION >= 100 + return function->getArg(i); +#elif TVM_LLVM_VERSION >= 50 + return &function->arg_begin()[i]; #else - &(*(function->arg_begin())), + return &*std::next(function->arg_begin(), i); #endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)); - llvm::Value* sid = builder_->CreateBitCast( - builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()), - builder_->CreateInBoundsGEP(args_array, zero_index_list)), - t_int64_); + }; + + llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace()); + llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p)); + + auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p); + auto ret_value = + builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace())); llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function); - auto ret_types_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[4], -#else - &(*(std::next(function->arg_begin(), 4))), -#endif - llvm::ArrayType::get(t_int_, 1)->getPointerTo()); - auto retval_array = builder_->CreateBitCast( -#if TVM_LLVM_VERSION >= 50 - &function->arg_begin()[3], -#else - &(*std::next(function->arg_begin(), 3)), -#endif - llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo()); llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1); builder_->SetInsertPoint(default_block); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode); builder_->CreateRet(ConstInt32(kTvmErrorNoError)); // Add data to the global section. @@ -258,16 +236,20 @@ void CodeGenLLVM::LinkParameters(const Map params) { std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); + auto dtype = tvm::runtime::DataType(kv.second->param->dtype); + size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment); +#if TVM_LLVM_VERSION >= 100 + param_symbol->setAlignment(llvm::Align(align)); +#else + param_symbol->setAlignment(align); +#endif llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function); switch_inst->addCase( llvm::cast(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block); builder_->SetInsertPoint(case_block); - builder_->CreateStore( - builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())), - builder_->CreateInBoundsGEP(retval_array, zero_array_index_list)); - builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), - builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list)); + builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value); + builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode); builder_->CreateRet(ConstInt32(0)); } }