Skip to content

Commit

Permalink
[LLVM] Fix CodeGenLLVM::LinkParameters (apache#8213)
Browse files Browse the repository at this point in the history
- Generate valid LLVM IR.
- Set proper alignment on the constant variables.
  • Loading branch information
Krzysztof Parzyszek authored and trevor-m committed Jun 17, 2021
1 parent e0bffca commit 137d389
Showing 1 changed file with 29 additions and 47 deletions.
76 changes: 29 additions & 47 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
md_builder_.reset(new llvm::MDBuilder(*ctx_));
// types
t_void_ = llvm::Type::getVoidTy(*ctx_);
t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo();
t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace());
t_int_ = llvm::Type::getInt32Ty(*ctx_);
t_char_ = llvm::Type::getInt8Ty(*ctx_);
t_int8_ = llvm::Type::getInt8Ty(*ctx_);
Expand Down Expand Up @@ -191,20 +191,10 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
// It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
// but they are at a different layer in the compiler...
std::vector<llvm::Type*> param_types;
// args
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
// tcodes
param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
// num_args
param_types.push_back(t_int_);
// ret_args
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
// ret_tcodes
param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
// resource_handle
param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
llvm::Type* t_int_p = t_int_->getPointerTo(GetGlobalAddressSpace());

// args, tcodes, num_args, ret_value, ret_tcode, resource_handle
std::vector<llvm::Type*> param_types{t_void_p_, t_int_p, t_int_, t_void_p_, t_int_p, t_void_p_};
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);

llvm::Function* function =
Expand All @@ -215,41 +205,29 @@ void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {

llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
builder_->SetInsertPoint(entry);
std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
llvm::ConstantInt::get(t_int32_, 0)};
auto args_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[0],

auto getArg = [function](int i) -> llvm::Argument* {
#if TVM_LLVM_VERSION >= 100
return function->getArg(i);
#elif TVM_LLVM_VERSION >= 50
return &function->arg_begin()[i];
#else
&(*(function->arg_begin())),
return &*std::next(function->arg_begin(), i);
#endif
llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
llvm::Value* sid = builder_->CreateBitCast(
builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
builder_->CreateInBoundsGEP(args_array, zero_index_list)),
t_int64_);
};

llvm::Type* t_int64_p = t_int64_->getPointerTo(GetGlobalAddressSpace());
llvm::Value* sid = builder_->CreateLoad(t_int64_, builder_->CreateBitCast(getArg(0), t_int64_p));

auto ret_tcode = builder_->CreateBitCast(getArg(4), t_int_p);
auto ret_value =
builder_->CreateBitCast(getArg(3), t_void_p_->getPointerTo(GetGlobalAddressSpace()));

llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
auto ret_types_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[4],
#else
&(*(std::next(function->arg_begin(), 4))),
#endif
llvm::ArrayType::get(t_int_, 1)->getPointerTo());
auto retval_array = builder_->CreateBitCast(
#if TVM_LLVM_VERSION >= 50
&function->arg_begin()[3],
#else
&(*std::next(function->arg_begin(), 3)),
#endif
llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1)->getPointerTo());
llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);

builder_->SetInsertPoint(default_block);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr),
builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list));
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr), ret_tcode);
builder_->CreateRet(ConstInt32(kTvmErrorNoError));

// Add data to the global section.
Expand All @@ -258,16 +236,20 @@ void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first;
llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
*module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);
auto dtype = tvm::runtime::DataType(kv.second->param->dtype);
size_t align = std::max(tvm::runtime::GetVectorBytes(dtype), tvm::runtime::kAllocAlignment);
#if TVM_LLVM_VERSION >= 100
param_symbol->setAlignment(llvm::Align(align));
#else
param_symbol->setAlignment(align);
#endif

llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function);
switch_inst->addCase(
llvm::cast<llvm::ConstantInt>(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block);
builder_->SetInsertPoint(case_block);
builder_->CreateStore(
builder_->CreatePointerCast(param_symbol, t_void_->getPointerTo(GetGlobalAddressSpace())),
builder_->CreateInBoundsGEP(retval_array, zero_array_index_list));
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle),
builder_->CreateInBoundsGEP(ret_types_array, zero_array_index_list));
builder_->CreateStore(builder_->CreatePointerCast(param_symbol, t_void_p_), ret_value);
builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMOpaqueHandle), ret_tcode);
builder_->CreateRet(ConstInt32(0));
}
}
Expand Down

0 comments on commit 137d389

Please sign in to comment.