Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement C++ registry to back Python target.generic_func #892

Merged
merged 50 commits into from Mar 19, 2018
Merged

Implement C++ registry to back Python target.generic_func #892

merged 50 commits into from Mar 19, 2018

Conversation

alex-weaver
Copy link
Contributor

This is a step towards implementing NNVM issue https://github.com/dmlc/nnvm/issues/343

This PR sets up a C++ registry to back the target.generic_func mechanism in Python, which allows C++ to also register target-specialized functions (schedules and ops) in a Python compatible way. This mechanism provides compatibility in both directions: C++ can call functions registered in Python and vice versa. Python code can also override any function registered via C++ if necessary.

In the case of C++ calling Python, a target string is provided which is parsed and entered on the python side. In the case of Python calling C++, the global current target object is stringified and passed to the registered function as an extra first argument.

@alex-weaver
Copy link
Contributor Author

OK I've fixed a couple of issues that came up when integrating the TOPI C++ schedules, but it should now be working. The python generic_func is still compatible with all existing usage in the python codebase.

The TOPI C++ library now registers all its schedules and dense ops through the GenericFunc registry. The Python TOPI code is configured to override all of these, so C++ clients will see the C++ schedules, and Python clients will continue to see the Python schedules as before.

* these arguments.
* \param ret The return value
*/
TVM_DLL void invoke_func(const tvm::Target& target, TVMArgs args, TVMRetValue* ret) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would prefer hide target as a context, possibly in Thread Local storage, which can be queried. The invoke can simply be operator()().

* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_generic_func(const PackedFunc value,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_default_func

@alex-weaver
Copy link
Contributor Author

alex-weaver commented Feb 13, 2018

Ok I've addressed those - the only thing I wasn't sure about was what the signature of operator() should look like - so I've mirrored the style of PackedFunc with invoke_packed(TVMArgs args, TVMRetValue* ret) and then a separate operator()(Args&& ...args)

@alex-weaver
Copy link
Contributor Author

alex-weaver commented Feb 13, 2018

Also fixed 2 issues that cropped up testing integrating this upstream:

  1. WrapSchedule in topi.cc did not correctly test whether it was passed Array<Tensor> or just Tensor.
  2. Added an empty exported function to TOPI so that NNVM has a cross-platform way to make sure the library has been loaded, and therefore make sure the schedules have been registered.

@alex-weaver
Copy link
Contributor Author

Removed the empty exported function from (2) above - that doesn't work with python. Instead NNVM can load tvm_topi.dll dynamically.

@alex-weaver
Copy link
Contributor Author

@tqchen Sorry to hassle you but could this be reviewed again? I think the issues raised should all be addressed. Let me know if there's anything that needs fixing

@tqchen
Copy link
Member

tqchen commented Feb 20, 2018

just get back from a conference trip, on my todo list

@alex-weaver
Copy link
Contributor Author

Ah no problem, I just wondered if it had got lost somehow. No rush!

* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void invoke_packed(TVMArgs args, TVMRetValue* ret) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use CallPacked, same as PackedFunc

@tqchen
Copy link
Member

tqchen commented Feb 22, 2018

Some thoughts on how we can support target context. I think the Target context query TLS should belong to tvm::Target .

Instead of using setter getter, the target can be supported via scoping backed by a stack in TLS, which looks like Target::EnterTargetScope(target); Target::ExitTargetScope().

We can have RAII style version of scoping. Like

tvm::TargetContext ctx(target);

Which set target at construction time, and call Exit at destructor

@alex-weaver
Copy link
Contributor Author

Ok the python registration is updated to use a reference to a GenericFunc node instead of repeatedly using func_name

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for keep improving the PR. I feel we are very close, here is another batch of comment

entry->context_stack.pop();
}

tvm::Target* Target::current_target(bool allow_null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pointer instead of const reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because it needs to be able to return nullptr if the stack is empty and allow_null is true. This is used preserve the behavior of python current_target() when the allow_none argument is true.

@@ -344,4 +384,151 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});

struct GenericFuncNode::Manager {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only need one Get function, that is a member of GenericFunc

@@ -344,4 +384,151 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});

struct GenericFuncNode::Manager {
std::unordered_map<std::string, std::shared_ptr<GenericFuncNode>> fmap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

```>> -> > > `` (add space)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I keep forgetting that one :)

return *this;
}

GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to register, to be consistent with python API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to work - register is a reserved keyword

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, get it. sorry i forget

return GenericFunc(GenericFuncNode::Get(name));
}

GenericFunc& GenericFunc::set_default_func(const PackedFunc value,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_default_func->set_default

The register function is necessary.
"""
def _do_reg(myf):
key_list = [key] if isinstance(key, str) else key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this logic to GenericFunc.register

_api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override)


def generic_func(name=None, override=False):
"""Wrap a target generic function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems to be overly complicated, due to the fact that now it has two functionalities:

  • Construct a generic function object
  • Retrieve a generic function from C++ by its name.

I think we want to simplify the case by only support one use-case. Keep generic_func as its original use-case (define a generic func), except that this time it is backed by C++.

Like get_global_func and register_func, we can add functions get_generic_func and register_generic_func to support fetch and store global registry

def set_default(self, func, allow_override):
_api_internal._GenericFuncSetDefault(self, func, allow_override)

def register_func(self, func, key_list, allow_override):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make API consistent with old generic_func both in c++ and python

  • register
  • set_default

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last batch of changes requested, let us move the target to Node system as well so we don't have to exchange the strings, only pointers are needed. It also makes current_target API more natural

topi/src/topi.cc Outdated
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias_val;
Tensor *bias;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor can be None, use bias.defined() to check if it is none

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use pointer here

*ret = GenericFunc(std::make_shared<GenericFuncNode>());
});

TVM_REGISTER_API("_GenericFuncAddToRegistry")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenericFuncRegisterGlobal

GenericFunc::RegisterGenericFunc(func, func_name);
});

TVM_REGISTER_API("_GenericFuncGet")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_GenericFuncGetGlobal

@@ -28,7 +29,7 @@ struct Target {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
std::unordered_set<std::string> keys;
std::vector<std::string> keys;
/*! \brief Options for this target */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might make sense to also bring Target to the Node system, so current Target can simply return a None which can be checked with target.defined()

@tqchen
Copy link
Member

tqchen commented Mar 10, 2018

Need rebase from master and updates according to the last comments

…ble. Changed tvm::GenericFunc to mirror tvm::runtime::Registry as dmlc::Registry fails when used across .dll boundaries
…ent to invoke, and removed target as injected first argument to callbacks. Target is now a thread local context. Renamed invoke_func -> invoke_packed, and added operator() mirroring PackedFunc.
Library can be loaded dynamically by NNVM
This reverts commit bc8f3d7.
…or cross-plat way to ensure TOPI is loaded and schedules/ops are registered
@tqchen
Copy link
Member

tqchen commented Mar 16, 2018

@alex-weaver can you make the change to bring back python's generic_func? Thanks!

@alex-weaver
Copy link
Contributor Author

@tqchen yep I'll sort that. I was thinking the C++ GenericFunc could be renamed NativeGenericFunc, and this could be used for schedules across both C++ and Python. That would allow the C++ schedules to be overridden by python if NNVM is changed to use the C++ ones as a base.

@tqchen
Copy link
Member

tqchen commented Mar 16, 2018

I would say just keep the name GenericFunc for now in c++ side, but we could name it NativeGenericFunc in python side.

@alex-weaver
Copy link
Contributor Author

@tqchen the previous python generic_func has been restored, and the python attribute for C++ GenericFunc has been changed to override_native_generic_func - this is used to provide python overrides for the schedules (+dense op) provided by C++. Now, if NNVM registers the C++ schedules, they will be overridden with the python versions for python clients of NNVM or TOPI.

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some final changes to be made

"""
return _api_internal._GenericFuncGetGlobal(name)

def register_native_generic_func(func, name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary, as we can use get_native_generic_func(name).register. remove the correspond PackedFunc

topi/src/topi.cc Outdated
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias_val;
Tensor *bias;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use pointer here


@property
def keys(self):
return [k.value for k in self.keys_array]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do a cache so it does not fetch each time,

if self._keys is None:
   self._keys = tuple(k.value for k in self.keys_array)
return self._keys


def __repr__(self):
return self.__str__()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

def __init__(self, handle):
    super(Target, self).__init__(handle)
    self._keys = None
    ...

def keys(self):
return [k.value for k in self.keys_array]

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unit test cases to tests/python/unittest/test_target.py to test

  • parsing from string
  • assert each field is correct

@@ -13,76 +13,141 @@
#include "./tvm/lowered_func.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use relative path, instead of absolute one, which means we should do include "./lowered_func.h"

@tqchen
Copy link
Member

tqchen commented Mar 17, 2018

I have made some final comments, please fix these

…e ops; cache properties in python Target; add test case for parsing Target from string
topi/src/topi.cc Outdated
@@ -557,7 +526,7 @@ TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
tvm::Tensor bias) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const tvm::Tensor&

topi/src/topi.cc Outdated
@@ -524,7 +500,7 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias)>;
tvm::Tensor bias)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const tvm::Tensor&

}

auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];

if (target.libs.count("rocblas") > 0) {
if (target->libs().count("rocblas") > 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove >0, simply use count(condition)

topi/src/topi.cc Outdated

TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment, align to const Target& target,

topi/src/topi.cc Outdated
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor bias) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const tvm::Tensor& bias

topi/src/topi.cc Outdated
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor bias)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const tvm::Tensor& bias

@tqchen tqchen merged commit 4afc2f9 into apache:master Mar 19, 2018
@tqchen
Copy link
Member

tqchen commented Mar 19, 2018

Thanks for all the hard work in improving this PR, this is now merged!

@alex-weaver
Copy link
Contributor Author

Excellent! Only a few more to go for a C++ NNVM build function ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants