Permalink
Browse files

Register custom datatypes with TVM; specify Cast and Add lowering

This commit adds functionality for registering custom datatypes with TVM, and
furthermore adding custom lowering functions to lower those custom datatypes.
This commit only adds lowering for the Cast and Add ops; more ops will be added
soon.

Check out some custom datatype samples in my repository of samples:
https://github.com/gussmith23/tvm-custom-datatype-samples
  • Loading branch information...
gussmith23 committed Dec 12, 2018
1 parent 85a63dc commit cfefc6d394bc73c1d3f9b61445bfabb44cb2d291
@@ -472,6 +472,17 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);

/*!
* \brief Lower custom datatypes.
*
* See tvm::DatatypeRegistry for more information on adding custom datatypes.
*
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerDatatypes(LoweredFunc f, const std::string& target);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
@@ -18,8 +18,9 @@
#include "module.h"
#include "ndarray.h"
#include "node_base.h"
// TODO(gus): ...
#include "../../../src/codegen/datatype/datatype_registry.h"

extern "C" std::string GetTypeName(uint8_t);
extern "C" uint8_t GetTypeCode(const std::string& type_name);

namespace HalideIR {
// Forward declare type for extensions
@@ -878,7 +879,7 @@ inline const char* TypeCode2Str(int type_code) {

// TODO(gus): handle code-not-found error

auto type_name = DatatypeRegistry::GetTypeName((uint8_t)type_code);
auto type_name = GetTypeName(type_code);
std::ostringstream ss;
ss << "custom[" << type_name << "]";
auto str = ss.str();
@@ -961,7 +962,7 @@ inline TVMType String2TVMType(std::string s) {
scan += custom_name_len + 1;

auto type_name = s.substr(7, custom_name_len);
t.code = DatatypeRegistry::GetTypeCode(type_name);
t.code = GetTypeCode(type_name);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
@@ -264,6 +264,12 @@ def asnumpy(self):
check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
return np_arr

def copybytesto(self, target):
assert(target.flags['C_CONTIGUOUS'])
data = target.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(target.size * target.dtype.itemsize)
_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes)

def copyto(self, target):
"""Copy array to target
@@ -73,6 +73,13 @@ def __init__(self, type_str):
self.type_code = 4
bits = 64
head = ""
elif head.startswith("custom"):
low, high = head.find('['), head.find(']')
if not low or not high or low >= high:
raise ValueError("Badly formatted custom type string %s" % type_str)
type_name = head[low+1:high]
self.type_code = _api_internal._get_type_code(type_name)
head = head[high+1:]
else:
raise ValueError("Do not know how to handle type %s" % type_str)
bits = int(head) if head else bits
@@ -82,7 +89,11 @@ def __init__(self, type_str):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.type_code in TVMType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code]
else:
type_name = "custom[%s]"%_api_internal._get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
@@ -150,6 +150,7 @@ REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(RemapThreadAxis);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS2(LowerDatatypes);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
@@ -1,28 +1,70 @@
#include "datatype_registry.h"
#include <tvm/api_registry.h>
#include <iostream>

namespace tvm {

TVM_REGISTER_API("_register_datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DatatypeRegistry::RegisterDatatype(args[0], (uint8_t)args[1].operator int());
});
TVM_REGISTER_GLOBAL("_register_datatype")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DatatypeRegistry::Global()->RegisterDatatype(
args[0], (uint8_t)args[1].operator int());
});

void DatatypeRegistry::RegisterDatatype(const std::string& type_name, uint8_t type_code) {
auto inst = Global();
inst->code_to_name[type_code] = type_name;
inst->name_to_code[type_name] = type_code;
TVM_REGISTER_GLOBAL("_get_type_code")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = DatatypeRegistry::Global()->GetTypeCode(args[0]);
});

TVM_REGISTER_GLOBAL("_get_type_name")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = DatatypeRegistry::Global()->GetTypeName(args[0].operator int());
});

void DatatypeRegistry::RegisterDatatype(const std::string& type_name,
uint8_t type_code) {
code_to_name[type_code] = type_name;
name_to_code[type_name] = type_code;
}

uint8_t DatatypeRegistry::GetTypeCode(const std::string& type_name) {
auto inst = Global();
return inst->name_to_code[type_name];
return name_to_code[type_name];
}

std::string DatatypeRegistry::GetTypeName(uint8_t type_code) {
auto inst = Global();
return inst->code_to_name[type_code];
return code_to_name[type_code];
}

const runtime::PackedFunc* GetCastLowerFunc(const std::string& target,
uint8_t type_code,
uint8_t src_type_code) {
std::ostringstream ss;
ss << "tvm.datatypes.lower.";
ss << target << ".";
ss << "cast"
<< ".";

if (DatatypeRegistry::Global()->DatatypeRegistered(type_code)) {
ss << DatatypeRegistry::Global()->GetTypeName(type_code);
} else {
ss << runtime::TypeCode2Str(type_code);
}

ss << ".";

if (DatatypeRegistry::Global()->DatatypeRegistered(src_type_code)) {
ss << DatatypeRegistry::Global()->GetTypeName(src_type_code);
} else {
ss << runtime::TypeCode2Str(src_type_code);
}

return runtime::Registry::Get(ss.str());
}

const runtime::PackedFunc* GetAddLowerFunc(const std::string& target,
uint8_t type_code) {
internal_assert(DatatypeRegistry::Global()->DatatypeRegistered(type_code));
return runtime::Registry::Get(
"tvm.datatypes." + target + ".lower.add." +
DatatypeRegistry::Global()->GetTypeName(type_code));
}

} // namespace tvm
} // namespace tvm
@@ -1,28 +1,58 @@
#ifndef DATATYPE_REGISTRY_H_
#define DATATYPE_REGISTRY_H_

#include <unordered_map>
#include <tvm/runtime/packed_func.h>
#include <string>
#include <unordered_map>

namespace tvm {

const runtime::PackedFunc* GetCastLowerFunc(const std::string& target,
uint8_t type_code,
uint8_t src_type_code);
const runtime::PackedFunc* GetAddLowerFunc(const std::string& target,
uint8_t type_code);

/*!
* \brief Registry for custom datatypes.
*
* Adding custom datatypes currently requires two steps:
* 1. Register the datatype with the registry via a call to
* DatatypeRegistry::RegisterDatatype. This can also be done in Python
* directly---see the TVM globals registered in the corresponding .cc file.
* Currently, user should manually choose a type name and a type code,
* ensuring that neither conflict with existing types.
* 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to
* lower the custom datatype. In general, these will look like:
* For Casts: tvm.datatypes.lower.cast.<target>.<type>.<src_type>
* Example: tvm.datatypes.lower.cast.llvm.myfloat.float for a Cast from
* float to myfloat.
* For other ops: tvm.datatypes.lower.<op>.<target>.<type>
* Example: tvm.datatypes.lower.add.llvm.myfloat
*/
class DatatypeRegistry {
public:
static void RegisterDatatype(const std::string& type_name, uint8_t type_code);
static uint8_t GetTypeCode(const std::string& type_name);
static std::string GetTypeName(uint8_t type_code);

private:
static inline DatatypeRegistry* Global() {
static DatatypeRegistry inst;
return &inst;
}

// TODO(gus): ...what's the normal way to do this?
void RegisterDatatype(const std::string& type_name, uint8_t type_code);

uint8_t GetTypeCode(const std::string& type_name);

std::string GetTypeName(uint8_t type_code);

inline bool DatatypeRegistered(uint8_t type_code) {
return code_to_name.find(type_code) != code_to_name.end();
}

private:
// TODO(gus) is there a typedef for the code?
std::unordered_map<uint8_t, std::string> code_to_name;
std::unordered_map<std::string, uint8_t> name_to_code;
};

} // namespace tvm
} // namespace tvm

#endif
@@ -0,0 +1,60 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>

// TODO(gus) how to do these imports correctly?
#include "../codegen/datatype/datatype_registry.h"

namespace tvm {
namespace ir {

/*!
* \brief Helper mutator to implement lowering of custom datatypes.
*
* Lowering datatypes works as follows: for every expression containing a custom
* datatype, we search for a global (registered by the implementer of the custom
* datatype) for lowering this type of expression, and uses it to lower the
* expression.
*/
class DatatypesLowerer : public IRMutator {
public:
DatatypesLowerer(const std::string& target) : target_(target) {}

inline Expr Mutate_(const Cast* op, const Expr& e) final {
auto type_code = op->type.code();
auto src_type_code = op->value.type().code();
// If either datatype is a registered custom datatype, we must lower.
if (DatatypeRegistry::Global()->DatatypeRegistered(type_code) ||
DatatypeRegistry::Global()->DatatypeRegistered(src_type_code)) {
auto lower = GetCastLowerFunc(target_, type_code, src_type_code);
internal_assert(lower);
// TODO(gus) they use this->Mutate; why?
Expr r = (*lower)(e);
return Mutate(r);
}
return e;
}

inline Expr Mutate_(const Add* op, const Expr& e) final {
auto type_code = op->type.code();
if (DatatypeRegistry::Global()->DatatypeRegistered(type_code)) {
auto lower = GetAddLowerFunc(target_, type_code);
internal_assert(lower);
Expr r = (*lower)(e);
return Mutate(r);
}
return e;
}

private:
std::string target_;
};

LoweredFunc LowerDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = DatatypesLowerer(target).Mutate(n->body);
return LoweredFunc(n);
}

} // namespace ir
} // namespace tvm
@@ -0,0 +1,13 @@
#include <tvm/runtime/registry.h>
#include <string>

// TODO(gus) this is generating warnings due to returning a string.
extern "C" std::string GetTypeName(uint8_t type_code) {
return (*tvm::runtime::Registry::Get("_get_type_name"))(type_code).
operator std::string();
}

extern "C" uint8_t GetTypeCode(const std::string& type_name) {
return (*tvm::runtime::Registry::Get("_get_type_code"))(type_name).
operator int();
}

0 comments on commit cfefc6d

Please sign in to comment.