New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement C++ registry to back Python target.generic_func #892
Conversation
OK I've fixed a couple of issues that came up when integrating the TOPI C++ schedules, but it should now be working. The python generic_func is still compatible with all existing usage in the python codebase. The TOPI C++ library now registers all its schedules and dense ops through the GenericFunc registry. The Python TOPI code is configured to override all of these, so C++ clients will see the C++ schedules, and Python clients will continue to see the Python schedules as before. |
include/tvm/build_module.h
Outdated
* these arguments. | ||
* \param ret The return value | ||
*/ | ||
TVM_DLL void invoke_func(const tvm::Target& target, TVMArgs args, TVMRetValue* ret) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we would prefer hide target as a context, possibly in Thread Local storage, which can be queried. The invoke can simply be operator()().
include/tvm/build_module.h
Outdated
* false, an error will be logged if the call would override a previously registered function. | ||
* \return reference to self. | ||
*/ | ||
TVM_DLL GenericFunc& set_generic_func(const PackedFunc value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_default_func
Ok I've addressed those - the only thing I wasn't sure about was what the signature of operator() should look like - so I've mirrored the style of PackedFunc with invoke_packed(TVMArgs args, TVMRetValue* ret) and then a separate operator()(Args&& ...args) |
Also fixed 2 issues that cropped up testing integrating this upstream:
|
Removed the empty exported function from (2) above - that doesn't work with python. Instead NNVM can load tvm_topi.dll dynamically. |
@tqchen Sorry to hassle you but could this be reviewed again? I think the issues raised should all be addressed. Let me know if there's anything that needs fixing |
just get back from a conference trip, on my todo list |
Ah no problem, I just wondered if it had got lost somehow. No rush! |
include/tvm/build_module.h
Outdated
* \param args The arguments to pass to the function. | ||
* \param ret The return value | ||
*/ | ||
TVM_DLL void invoke_packed(TVMArgs args, TVMRetValue* ret) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use CallPacked, same as PackedFunc
Some thoughts on how we can support target context. I think the Target context query TLS should belong to tvm::Target . Instead of using setter getter, the target can be supported via scoping backed by a stack in TLS, which looks like We can have RAII style version of scoping. Like
Which set target at construction time, and call Exit at destructor |
Ok the python registration is updated to use a reference to a GenericFunc node instead of repeatedly using func_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for keep improving the PR. I feel we are very close, here is another batch of comment
src/codegen/build_module.cc
Outdated
entry->context_stack.pop(); | ||
} | ||
|
||
tvm::Target* Target::current_target(bool allow_null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why pointer instead of const reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because it needs to be able to return nullptr if the stack is empty and allow_null is true. This is used preserve the behavior of python current_target() when the allow_none argument is true.
src/codegen/build_module.cc
Outdated
@@ -344,4 +384,151 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) | |||
p->stream << ")"; | |||
}); | |||
|
|||
struct GenericFuncNode::Manager { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we only need one Get function, that is a member of GenericFunc
src/codegen/build_module.cc
Outdated
@@ -344,4 +384,151 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) | |||
p->stream << ")"; | |||
}); | |||
|
|||
struct GenericFuncNode::Manager { | |||
std::unordered_map<std::string, std::shared_ptr<GenericFuncNode>> fmap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
```>> -> > > `` (add space)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I keep forgetting that one :)
return *this; | ||
} | ||
|
||
GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to register, to be consistent with python API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to work - register is a reserved keyword
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, get it. sorry i forget
src/codegen/build_module.cc
Outdated
return GenericFunc(GenericFuncNode::Get(name)); | ||
} | ||
|
||
GenericFunc& GenericFunc::set_default_func(const PackedFunc value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_default_func->set_default
python/tvm/target.py
Outdated
The register function is necessary. | ||
""" | ||
def _do_reg(myf): | ||
key_list = [key] if isinstance(key, str) else key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this logic to GenericFunc.register
python/tvm/target.py
Outdated
_api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override) | ||
|
||
|
||
def generic_func(name=None, override=False): | ||
"""Wrap a target generic function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic seems to be overly complicated, due to the fact that now it has two functionalities:
- Construct a generic function object
- Retrieve a generic function from C++ by its name.
I think we want to simplify the case by only support one use-case. Keep generic_func as its original use-case (define a generic func), except that this time it is backed by C++.
Like get_global_func and register_func, we can add functions get_generic_func and register_generic_func to support fetch and store global registry
python/tvm/target.py
Outdated
def set_default(self, func, allow_override): | ||
_api_internal._GenericFuncSetDefault(self, func, allow_override) | ||
|
||
def register_func(self, func, key_list, allow_override): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make API consistent with old generic_func both in c++ and python
- register
- set_default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last batch of changes requested, let us move the target to Node system as well so we don't have to exchange the strings, only pointers are needed. It also makes current_target API more natural
topi/src/topi.cc
Outdated
Tensor data = args[0]; | ||
Tensor weight = args[1]; | ||
Tensor bias_val; | ||
Tensor *bias; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor can be None, use bias.defined() to check if it is none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not use pointer here
src/codegen/build_module.cc
Outdated
*ret = GenericFunc(std::make_shared<GenericFuncNode>()); | ||
}); | ||
|
||
TVM_REGISTER_API("_GenericFuncAddToRegistry") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GenericFuncRegisterGlobal
src/codegen/build_module.cc
Outdated
GenericFunc::RegisterGenericFunc(func, func_name); | ||
}); | ||
|
||
TVM_REGISTER_API("_GenericFuncGet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_GenericFuncGetGlobal
include/tvm/build_module.h
Outdated
@@ -28,7 +29,7 @@ struct Target { | |||
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ | |||
int thread_warp_size = 1; | |||
/*! \brief Keys for this target */ | |||
std::unordered_set<std::string> keys; | |||
std::vector<std::string> keys; | |||
/*! \brief Options for this target */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might make sense to also bring Target to the Node system, so current Target can simply return a None which can be checked with target.defined()
Need rebase from master and updates according to the last comments |
…ble. Changed tvm::GenericFunc to mirror tvm::runtime::Registry as dmlc::Registry fails when used across .dll boundaries
…ent to invoke, and removed target as injected first argument to callbacks. Target is now a thread local context. Renamed invoke_func -> invoke_packed, and added operator() mirroring PackedFunc.
Library can be loaded dynamically by NNVM This reverts commit bc8f3d7.
…or cross-plat way to ensure TOPI is loaded and schedules/ops are registered
@alex-weaver can you make the change to bring back python's generic_func? Thanks! |
@tqchen yep I'll sort that. I was thinking the C++ GenericFunc could be renamed NativeGenericFunc, and this could be used for schedules across both C++ and Python. That would allow the C++ schedules to be overridden by python if NNVM is changed to use the C++ ones as a base. |
I would say just keep the name GenericFunc for now in c++ side, but we could name it NativeGenericFunc in python side. |
@tqchen the previous python generic_func has been restored, and the python attribute for C++ GenericFunc has been changed to override_native_generic_func - this is used to provide python overrides for the schedules (+dense op) provided by C++. Now, if NNVM registers the C++ schedules, they will be overridden with the python versions for python clients of NNVM or TOPI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some final changes to be made
python/tvm/target.py
Outdated
""" | ||
return _api_internal._GenericFuncGetGlobal(name) | ||
|
||
def register_native_generic_func(func, name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary, as we can use get_native_generic_func(name).register. remove the correspond PackedFunc
topi/src/topi.cc
Outdated
Tensor data = args[0]; | ||
Tensor weight = args[1]; | ||
Tensor bias_val; | ||
Tensor *bias; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not use pointer here
python/tvm/target.py
Outdated
|
||
@property | ||
def keys(self): | ||
return [k.value for k in self.keys_array] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do a cache so it does not fetch each time,
if self._keys is None:
self._keys = tuple(k.value for k in self.keys_array)
return self._keys
python/tvm/target.py
Outdated
|
||
def __repr__(self): | ||
return self.__str__() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add
def __init__(self, handle):
super(Target, self).__init__(handle)
self._keys = None
...
python/tvm/target.py
Outdated
def keys(self): | ||
return [k.value for k in self.keys_array] | ||
|
||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unit test cases to tests/python/unittest/test_target.py
to test
- parsing from string
- assert each field is correct
include/tvm/build_module.h
Outdated
@@ -13,76 +13,141 @@ | |||
#include "./tvm/lowered_func.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use relative path, instead of absolute one, which means we should do include "./lowered_func.h"
I have made some final comments, please fix these |
…e ops; cache properties in python Target; add test case for parsing Target from string
topi/src/topi.cc
Outdated
@@ -557,7 +526,7 @@ TVM_REGISTER_GENERIC_FUNC(dense) | |||
.set_default(WrapDenseOp([](const Target& target, | |||
const tvm::Tensor& data, | |||
const tvm::Tensor& weight, | |||
tvm::Tensor* bias) { | |||
tvm::Tensor bias) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const tvm::Tensor&
topi/src/topi.cc
Outdated
@@ -524,7 +500,7 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) | |||
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target, | |||
const tvm::Tensor& data, | |||
const tvm::Tensor& weight, | |||
tvm::Tensor* bias)>; | |||
tvm::Tensor bias)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const tvm::Tensor&
topi/include/topi/rocm/dense.h
Outdated
} | ||
|
||
auto batch = data->shape[0]; | ||
auto in_dim = data->shape[1]; | ||
auto out_dim = weight->shape[0]; | ||
|
||
if (target.libs.count("rocblas") > 0) { | ||
if (target->libs().count("rocblas") > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove >0, simply use count(condition)
topi/src/topi.cc
Outdated
|
||
TVM_REGISTER_GENERIC_FUNC(dense) | ||
.set_default(WrapDenseOp([](const Target& target, | ||
const tvm::Tensor& data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alignment, align to const Target& target,
topi/src/topi.cc
Outdated
.set_default(WrapDenseOp([](const Target& target, | ||
const tvm::Tensor& data, | ||
const tvm::Tensor& weight, | ||
tvm::Tensor bias) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const tvm::Tensor& bias
topi/src/topi.cc
Outdated
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target, | ||
const tvm::Tensor& data, | ||
const tvm::Tensor& weight, | ||
tvm::Tensor bias)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const tvm::Tensor& bias
Thanks for all the hard work in improving this PR, this is now merged! |
Excellent! Only a few more to go for a C++ NNVM build function ;) |
This is a step towards implementing NNVM issue https://github.com/dmlc/nnvm/issues/343
This PR sets up a C++ registry to back the target.generic_func mechanism in Python, which allows C++ to also register target-specialized functions (schedules and ops) in a Python compatible way. This mechanism provides compatibility in both directions: C++ can call functions registered in Python and vice versa. Python code can also override any function registered via C++ if necessary.
In the case of C++ calling Python, a target string is provided which is parsed and entered on the python side. In the case of Python calling C++, the global current target object is stringified and passed to the registered function as an extra first argument.