Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Extension of initializing arbitrary bit integer in hcl.scalar from string #431

Open
wants to merge 7 commits into
base: tvm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/heterocl/compute_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def compute_body(name,
index, _, _ = get_index(shape, indices, 0)
stage.emit(_make.Store(buffer_var, _make.Cast(dtype, ret), index))
stmt = make_for(indices, stage.pop_stmt(), 0, name)
elif isinstance(ret, str):
indices = lambda_ivs
index, _, _ = get_index(shape, indices, 0)
stage.emit(_make.Store(buffer_var, _make.CastStr(dtype, ret), index))
stmt = make_for(indices, stage.pop_stmt(), 0, name)
elif isinstance(ret, Tensor): # reduction
ret_ivs = [_IterVar((0, ret.shape[i]), ret.name+"_i" + str(i), 0)
for i in range(0, len(ret.shape))]
Expand Down
167 changes: 167 additions & 0 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import heterocl as hcl
import numpy as np
hcl.init()

def test_int7(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x4A", "v", dtype=hcl.Int(7))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(7))

def test_uint7(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x4A", "v", dtype=hcl.UInt(7))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(7))

def test_int15(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x40FF", "v", dtype=hcl.Int(15))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(15))

def test_uint15(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x40FF", "v", dtype=hcl.UInt(15))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(15))

def test_int31(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x4F0000FF", "v", dtype=hcl.Int(31))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(31))

def test_uint31(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x4F0000FF", "v", dtype=hcl.UInt(31))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(31))

def test_int62(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x2A000000FF0000FF", "v", dtype=hcl.Int(62))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(62))

def test_uint62(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x2A000000FF0000FF", "v", dtype=hcl.UInt(62))
return v.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(62))

A_int7 = hcl.placeholder((1,), "A_int7", dtype=hcl.Int(7))
s_int7 = hcl.create_schedule([A_int7], test_int7)
m_int7 = hcl.build (s_int7)

A_uint7 = hcl.placeholder((1,), "A_uint7", dtype=hcl.UInt(7))
s_uint7 = hcl.create_schedule([A_uint7], test_uint7)
m_uint7 = hcl.build (s_uint7)

A_int15 = hcl.placeholder((1,), "A_int15", dtype=hcl.Int(15))
s_int15 = hcl.create_schedule([A_int15], test_int15)
m_int15 = hcl.build (s_int15)

A_uint15 = hcl.placeholder((1,), "A_uint15", dtype=hcl.UInt(15))
s_uint15 = hcl.create_schedule([A_uint15], test_uint15)
m_uint15 = hcl.build (s_uint15)

A_int31 = hcl.placeholder((1,), "A_int31", dtype=hcl.Int(31))
s_int31 = hcl.create_schedule([A_int31], test_int31)
m_int31 = hcl.build (s_int31)

A_uint31 = hcl.placeholder((1,), "A_uint31", dtype=hcl.UInt(31))
s_uint31 = hcl.create_schedule([A_uint31], test_uint31)
m_uint31 = hcl.build (s_uint31)

A_int62 = hcl.placeholder((1,), "A_int62", dtype=hcl.Int(62))
s_int62 = hcl.create_schedule([A_int62], test_int62)
m_int62 = hcl.build (s_int62)

A_uint62 = hcl.placeholder((1,), "A_uint62", dtype=hcl.UInt(62))
s_uint62 = hcl.create_schedule([A_uint62], test_uint62)
m_uint62 = hcl.build (s_uint62)

A_int7 = hcl.asarray([0xA0A0], dtype=A_int7.dtype)
R_int7 = hcl.asarray([99], dtype=hcl.Int(7))
m_int7(A_int7, R_int7)

A_uint7 = hcl.asarray([0xA0A0], dtype=A_uint7.dtype)
R_uint7 = hcl.asarray([99], dtype=hcl.UInt(7))
m_uint7(A_uint7, R_uint7)

A_int15 = hcl.asarray([0xA0A0], dtype=A_int15.dtype)
R_int15 = hcl.asarray([99], dtype=hcl.Int(15))
m_int15(A_int15, R_int15)

A_uint15 = hcl.asarray([0xA0A0], dtype=A_uint15.dtype)
R_uint15 = hcl.asarray([99], dtype=hcl.UInt(15))
m_uint15(A_uint15, R_uint15)

A_int31 = hcl.asarray([0xA0A0], dtype=A_int31.dtype)
R_int31 = hcl.asarray([99], dtype=hcl.Int(31))
m_int31(A_int31, R_int31)

A_uint31 = hcl.asarray([0xA0A0], dtype=A_uint31.dtype)
R_uint31 = hcl.asarray([99], dtype=hcl.UInt(31))
m_uint31(A_uint31, R_uint31)

A_int62 = hcl.asarray([0xA0A0], dtype=A_int62.dtype)
R_int62 = hcl.asarray([99], dtype=hcl.Int(62))
m_int62(A_int62, R_int62)

A_uint62 = hcl.asarray([0xA0A0], dtype=A_uint62.dtype)
R_uint62 = hcl.asarray([99], dtype=hcl.UInt(62))
m_uint62(A_uint62, R_uint62)

print(f"R_int7 = {[bin(i) for i in R_int7.asnumpy()]}")
print(f"R_uint7 = {[hex(i) for i in R_uint7.asnumpy()]}")
print(f"R_int15 = {[hex(i) for i in R_int15.asnumpy()]}")
print(f"R_uint15 = {[hex(i) for i in R_uint15.asnumpy()]}")
print(f"R_int31 = {[hex(i) for i in R_int31.asnumpy()]}")
print(f"R_uint31 = {[hex(i) for i in R_uint31.asnumpy()]}")
print(f"R_int62 = {[hex(i) for i in R_int62.asnumpy()]}")
print(f"R_uint62 = {[hex(i) for i in R_uint62.asnumpy()]}")

def test_int127_lower(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x8ABCFFFFFFBAFFFA000A", "v", dtype=hcl.UInt(127))
b = hcl.scalar(v >> 64, "b", dtype=hcl.UInt(63))
c = hcl.scalar(v & 0xFFFFFFFFFFFFFFFF, "c", dtype=hcl.UInt(64))
return c.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(64))

def test_int127_upper(A):
def doit(x):
x = 0xFA_FF00_FFFF
v = hcl.scalar("0x8ABCFFFFFFBAFFFA000A", "v", dtype=hcl.UInt(127))
b = hcl.scalar(v >> 64, "b", dtype=hcl.UInt(63))
c = hcl.scalar(v & 0x7FFFFFFFFFFFFFFF, "c", dtype=hcl.UInt(64))
return b.v
return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(63))

A = hcl.placeholder((1,), "A", dtype=hcl.UInt(63))
s_lower = hcl.create_schedule([A], test_int127_lower)
s_upper = hcl.create_schedule([A], test_int127_upper)
m_lower = hcl.build(s_lower)
m_upper = hcl.build(s_upper)

hcl_A = hcl.asarray([0], hcl.UInt(63))
hcl_R_lower = hcl.asarray([0], hcl.UInt(64))
hcl_R_upper = hcl.asarray([0], hcl.UInt(63))

m_lower(hcl_A, hcl_R_lower)
m_upper(hcl_A, hcl_R_upper)

print(f"hcl_R_lower = {[hex(i) for i in hcl_R_lower.asnumpy()]}")
print(f"hcl_R_upper = {[hex(i) for i in hcl_R_upper.asnumpy()]}")

1 change: 1 addition & 0 deletions tvm/HalideIR/src/ir/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class IRNodeType : int {
FloatImm,
StringImm,
Cast,
CastStr,
Variable,
Add,
Sub,
Expand Down
11 changes: 11 additions & 0 deletions tvm/HalideIR/src/ir/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ Expr Cast::make(Type t, Expr v) {
return Expr(node);
}

Expr CastStr::make(Type t, const std::string &val) {
std::shared_ptr<CastStr> node = std::make_shared<CastStr>();
node->type = t;
node->value = val;
return Expr(node);
}

Expr And::make(Expr a, Expr b) {
internal_assert(a.defined()) << "And of undefined\n";
internal_assert(b.defined()) << "And of undefined\n";
Expand Down Expand Up @@ -1050,6 +1057,10 @@ void ExprNode<Cast>::accept(IRVisitor *v, const Expr &e) const {
v->visit((const Cast *)this, e);
}
template <>
void ExprNode<CastStr>::accept(IRVisitor *v, const Expr &e) const {
v->visit((const CastStr *)this, e);
}
template <>
void ExprNode<Variable>::accept(IRVisitor *v, const Expr &e) const {
v->visit((const Variable *)this, e);
}
Expand Down
13 changes: 13 additions & 0 deletions tvm/HalideIR/src/ir/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ struct Cast : public ExprNode<Cast> {
static constexpr const char *_type_key = "Cast";
};

/** Cast a node from string to other datatype. */
struct CastStr : public ExprNode<CastStr> {
std::string value;
EXPORT static Expr make(Type t, const std::string &val);
void VisitAttrs(IR::AttrVisitor *v) final {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
static const IRNodeType _type_info = IRNodeType::CastStr;
static constexpr const char *_type_key = "CastStr";
};


/** base class of all Binary arithematic ops */
template <typename T>
struct BinaryOpNode : public ExprNode<T> {
Expand Down
5 changes: 5 additions & 0 deletions tvm/HalideIR/src/ir/IREquality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class IRComparer : public IRVisitor {
void visit(const FloatImm *, const Expr &);
void visit(const StringImm *, const Expr &);
void visit(const Cast *, const Expr &);
void visit(const CastStr *, const Expr &);
void visit(const Variable *, const Expr &);
void visit(const Add *, const Expr &);
void visit(const Sub *, const Expr &);
Expand Down Expand Up @@ -293,6 +294,10 @@ void IRComparer::visit(const StringImm *op, const Expr &e) {
compare_names(node->value, op->value);
}

void IRComparer::visit(const CastStr *op, const Expr &e) {
const CastStr *node = expr_.as<CastStr>();
compare_names(node->value, op->value);
}
void IRComparer::visit(const Cast *op, const Expr &e) {
compare_expr(expr_.as<Cast>()->value, op->value);
}
Expand Down
9 changes: 9 additions & 0 deletions tvm/HalideIR/src/ir/IRMutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ void IRMutator::visit(const Cast *op, const Expr &e) {
}
}

void IRMutator::visit(const CastStr *op, const Expr &e) {
std::string value = op->value;
if (value == op->value) {
expr = e;
} else {
expr = CastStr::make(op->type, value);
}
}

// use macro to access private function.
#define MUTATE_BINARY_OP(op, e, T) \
Expr a = mutate(op->a); \
Expand Down
1 change: 1 addition & 0 deletions tvm/HalideIR/src/ir/IRMutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class IRMutator : public IRVisitor {
EXPORT virtual void visit(const FloatImm *, const Expr &);
EXPORT virtual void visit(const StringImm *, const Expr &);
EXPORT virtual void visit(const Cast *, const Expr &);
EXPORT virtual void visit(const CastStr *, const Expr &);
EXPORT virtual void visit(const Variable *, const Expr &);
EXPORT virtual void visit(const Add *, const Expr &);
EXPORT virtual void visit(const Sub *, const Expr &);
Expand Down
35 changes: 35 additions & 0 deletions tvm/HalideIR/src/ir/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,41 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->print(op->value);
p->stream << ')';
})
.set_dispatch<CastStr>([](const CastStr *op, IRPrinter *p) {
p->stream << op->type << '(';
auto &stream = p->stream;
stream << '"';
for (size_t i = 0; i < op->value.size(); i++) {
unsigned char c = op->value[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
string hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
stream << '"';
p->stream << ')';
})
.set_dispatch<Variable>([](const Variable *op, IRPrinter *p) {
// omit the type
// stream << op->name << "." << op->type;
Expand Down
3 changes: 3 additions & 0 deletions tvm/HalideIR/src/ir/IRVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void IRVisitor::visit(const FloatImm *, const Expr &) {}
void IRVisitor::visit(const StringImm *, const Expr &) {}

void IRVisitor::visit(const Cast *op, const Expr &) { op->value.accept(this); }
void IRVisitor::visit(const CastStr *op, const Expr &) { }

void IRVisitor::visit(const Variable *, const Expr &) {}

Expand Down Expand Up @@ -356,6 +357,8 @@ void IRGraphVisitor::visit(const StringImm *, const Expr &) {}

void IRGraphVisitor::visit(const Cast *op, const Expr &) { include(op->value); }

void IRGraphVisitor::visit(const CastStr *op, const Expr &) { }

void IRGraphVisitor::visit(const Variable *op, const Expr &) {}

void IRGraphVisitor::visit(const Add *op, const Expr &) {
Expand Down
2 changes: 2 additions & 0 deletions tvm/HalideIR/src/ir/IRVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class IRVisitor {
EXPORT virtual void visit(const FloatImm *, const Expr &);
EXPORT virtual void visit(const StringImm *, const Expr &);
EXPORT virtual void visit(const Cast *, const Expr &);
EXPORT virtual void visit(const CastStr *, const Expr &);
EXPORT virtual void visit(const Variable *, const Expr &);
EXPORT virtual void visit(const Add *, const Expr &);
EXPORT virtual void visit(const Sub *, const Expr &);
Expand Down Expand Up @@ -116,6 +117,7 @@ class IRGraphVisitor : public IRVisitor {
EXPORT virtual void visit(const FloatImm *, const Expr &);
EXPORT virtual void visit(const StringImm *, const Expr &);
EXPORT virtual void visit(const Cast *, const Expr &);
EXPORT virtual void visit(const CastStr *, const Expr &);
EXPORT virtual void visit(const Variable *, const Expr &);
EXPORT virtual void visit(const Add *, const Expr &);
EXPORT virtual void visit(const Sub *, const Expr &);
Expand Down
1 change: 1 addition & 0 deletions tvm/include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ using Halide::Internal::Break;
using Halide::Internal::Broadcast;
using Halide::Internal::Call;
using Halide::Internal::Cast;
using Halide::Internal::CastStr;
using Halide::Internal::Div;
using Halide::Internal::EQ;
using Halide::Internal::Evaluate;
Expand Down
2 changes: 2 additions & 0 deletions tvm/include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CastStr* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -179,6 +180,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(Or);
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
IR_EXPR_FUNCTOR_DISPATCH(Cast);
IR_EXPR_FUNCTOR_DISPATCH(CastStr);
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
Expand Down
1 change: 1 addition & 0 deletions tvm/include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const Or* op, const Expr& e);
virtual Expr Mutate_(const Reduce* op, const Expr& e);
virtual Expr Mutate_(const Cast* op, const Expr& e);
virtual Expr Mutate_(const CastStr* op, const Expr& e);
virtual Expr Mutate_(const Not* op, const Expr& e);
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
Expand Down
1 change: 1 addition & 0 deletions tvm/include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const Or* op);
virtual void Visit_(const Reduce* op);
virtual void Visit_(const Cast* op);
virtual void Visit_(const CastStr* op);
virtual void Visit_(const Not* op);
virtual void Visit_(const Select* op);
virtual void Visit_(const Ramp* op);
Expand Down
1 change: 1 addition & 0 deletions tvm/src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(CastStr);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
Expand Down
Loading