Skip to content

Commit

Permalink
[Codegen][LLVM] Allow void return type from PackedFunc (apache#14958)
Browse files Browse the repository at this point in the history
* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI

PRs apache#14913 and
apache#14914 made analogous changes to
`MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
Both PRs introduced the same symbol,
`tvm::tir::SubroutineCallRewriter`, a local utility to update internal
calls to a modified function.  While each PR passed CI individually,
and was therefore able to merge, having both changes caused a
duplicate symbol.

This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
their local utilities into anonymous namespaces, avoiding the
conflict.

* [Codegen][LLVM] Allow void return type from PackedFunc

Previously, calling a packed func that returns void would result in
an error being raised from `tir::APIType`, as there is no runtime
representation of a void type.  This commit updates
`CodeGenCPU::MakeCallPackedLowered` to only read the return value and
type fo a `PackedFunc` when the TIR return type is non-void.
  • Loading branch information
Lunderberg authored and mei-ye committed Jun 1, 2023
1 parent 359a50c commit 050fecb
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
const DataType& r_type,
const int64_t begin, const int64_t end,
bool use_string_lookup) {
PackedCall pc;
std::string func_name = args[0].as<StringImmNode>()->value;
std::string func_name = [&]() {
auto ptr = args[0].as<StringImmNode>();
ICHECK(ptr) << "Expected first argument of tir::Call to be "
<< "a string containing the callee's name, "
<< "but instead contained " << args[0];
return ptr->value;
}();
// call the function
int64_t nargs = end - begin;
ICHECK_GE(nargs, 0);
Expand Down Expand Up @@ -936,27 +941,32 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&

llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
PackedCall pc = {0};

if (!r_type.is_void()) {
// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
#if TVM_LLVM_VERSION >= 110
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
#else
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
#endif
}

pc.end_block = end_block;
return pc;
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind
* \return The corresponding API type.
*/
inline DataType APIType(DataType t) {
ICHECK(!t.is_void()) << "Cannot pass void type through packed API.";
if (t.is_handle()) return t;
ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
if (t.is_uint() || t.is_int()) return DataType::Int(64);
Expand Down
60 changes: 60 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,5 +1049,65 @@ def subroutine(A_data: T.handle("float32")):
assert arr.numpy()[0] == 42.0


@tvm.testing.requires_llvm
def test_call_packed_returning_void():
"""Allow codegen of PackedFunc calls returning void
The LLVM codegen uses the CallNode's dtype to cast the return type
of the PackedFunc into the appropriate LLVM output type. However,
there is no API type for `DataType::Void()`. When the return type
of a PackedFunc is void, the generated code should not attempt to
read the return value.
While `T.call_packed()` will produce a CallNode with an output
dtype of "int32", the use of other return types is valid in TIR.
This test case uses `T.Call` directly to allow an explicit dtype
for the packed function call.
"""

@T.prim_func
def func():
T.Call(
"void",
tvm.ir.Op.get("tir.tvm_call_packed"),
["dummy_function_name"],
)

# Error occurred during build, as part of
# CodeGenCPU::MakeCallPackedLowered.
built = tvm.build(func, target="llvm")


@tvm.testing.requires_llvm
def test_call_packed_without_string_arg():
"""The first argument to tvm_call_packed must be a string
Even if the invalid TIR is constructed, this should throw an
exception to exit cleanly. Previously, use of
`args[0].as<StringImmNode>()` without a null check resulted in
a segfault during codegen.
"""

@T.prim_func
def func(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "func"})
T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data])

with pytest.raises(tvm.TVMError):
built = tvm.build(func, target="llvm")


@tvm.testing.requires_llvm
def test_call_extern_returning_void():
"""Like test_call_packed_returning_void, but for call_extern"""

@T.prim_func
def func():
T.func_attr({"global_symbol": "func"})
T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"])

built = tvm.build(func, target="llvm")


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 050fecb

Please sign in to comment.