Skip to content

Commit

Permalink
combine FuncRefVar and FuncRefExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
psuriana committed Jul 20, 2016
1 parent 06a84d4 commit a546872
Show file tree
Hide file tree
Showing 25 changed files with 486 additions and 620 deletions.
30 changes: 15 additions & 15 deletions apps/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ vector<Expr> A(vector<Expr> l, const vector<T> &r) {

// Get call references to the first N elements of dimension dim of x. If temps
// is set, grab references to elements [-N, -1] instead.
typedef FuncRefExprT<ComplexExpr> ComplexFuncRefExpr;
vector<ComplexFuncRefExpr> get_func_refs(ComplexFunc x, int N, bool temps = false) {
typedef FuncRefT<ComplexExpr> ComplexFuncRef;
vector<ComplexFuncRef> get_func_refs(ComplexFunc x, int N, bool temps = false) {
vector<Var> args(x.args());
args.erase(args.begin());

vector<ComplexFuncRefExpr> refs;
vector<ComplexFuncRef> refs;
for (int i = 0; i < N; i++) {
if (temps) {
refs.push_back(x(A({Expr(-i - 1)}, args)));
Expand All @@ -112,8 +112,8 @@ ComplexFunc dft2(ComplexFunc f, const string& prefix) {
ComplexFunc F(prefix + "X2");
F(f.args()) = undef_z(type);

vector<ComplexFuncRefExpr> x = get_func_refs(f, 2);
vector<ComplexFuncRefExpr> X = get_func_refs(F, 2);
vector<ComplexFuncRef> x = get_func_refs(f, 2);
vector<ComplexFuncRef> X = get_func_refs(F, 2);

X[0] = x[0] + x[1];
X[1] = x[0] - x[1];
Expand All @@ -127,9 +127,9 @@ ComplexFunc dft4(ComplexFunc f, int sign, const string& prefix) {
ComplexFunc F(prefix + "X4");
F(f.args()) = undef_z(type);

vector<ComplexFuncRefExpr> x = get_func_refs(f, 4);
vector<ComplexFuncRefExpr> X = get_func_refs(F, 4);
vector<ComplexFuncRefExpr> T = get_func_refs(F, 2, true);
vector<ComplexFuncRef> x = get_func_refs(f, 4);
vector<ComplexFuncRef> X = get_func_refs(F, 4);
vector<ComplexFuncRef> T = get_func_refs(F, 2, true);
// We can re-use these two temps. T[0], T[2] and T[1], T[3] do not have
// overlapping lifetime.
T.push_back(T[1]);
Expand Down Expand Up @@ -161,9 +161,9 @@ ComplexFunc dft6(ComplexFunc f, int sign, const string& prefix) {
ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);

vector<ComplexFuncRefExpr> x = get_func_refs(f, 6);
vector<ComplexFuncRefExpr> X = get_func_refs(F, 6);
vector<ComplexFuncRefExpr> T = get_func_refs(F, 6, true);
vector<ComplexFuncRef> x = get_func_refs(f, 6);
vector<ComplexFuncRef> X = get_func_refs(F, 6);
vector<ComplexFuncRef> T = get_func_refs(F, 6, true);

// Prime factor FFT, N=2*3, no twiddle factors!
T[0] = (x[0] + x[3]);
Expand Down Expand Up @@ -192,9 +192,9 @@ ComplexFunc dft8(ComplexFunc f, int sign, const string& prefix) {
ComplexFunc F(prefix + "X8");
F(f.args()) = undef_z(type);

vector<ComplexFuncRefExpr> x = get_func_refs(f, 8);
vector<ComplexFuncRefExpr> X = get_func_refs(F, 8);
vector<ComplexFuncRefExpr> T = get_func_refs(F, 8, true);
vector<ComplexFuncRef> x = get_func_refs(f, 8);
vector<ComplexFuncRef> X = get_func_refs(F, 8);
vector<ComplexFuncRef> T = get_func_refs(F, 8, true);

X[0] = (x[0] + x[4]);
X[2] = (x[2] + x[6]);
Expand Down Expand Up @@ -1056,4 +1056,4 @@ Func fft2d_c2r(ComplexFunc c,
const Target& target,
const Fft2dDesc& desc) {
return fft2d_c2r(c, radix_factor(N0), radix_factor(N1), target, desc);
}
}
151 changes: 37 additions & 114 deletions apps/fft/funct.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,17 @@
#include <Halide.h>

template <typename T>
class FuncRefVarT : public T {
Halide::FuncRefVar untyped;
class FuncRefT : public T {
Halide::FuncRef untyped;

public:
typedef Halide::Stage Stage;
typedef Halide::Tuple Tuple;

FuncRefVarT(const Halide::FuncRefVar& untyped)
FuncRefT(const Halide::FuncRef& untyped)
: T(untyped.function().has_pure_definition() ? T(Tuple(untyped)) : T()),
untyped(untyped) {}

Stage operator=(T x) { return untyped = x; }
Stage operator+=(T x) { return untyped = T(Tuple(untyped)) + x; }
Stage operator-=(T x) { return untyped = T(Tuple(untyped)) - x; }
Stage operator*=(T x) { return untyped = T(Tuple(untyped)) * x; }
Stage operator/=(T x) { return untyped = T(Tuple(untyped)) / x; }
};

template <typename T>
class FuncRefExprT : public T {
Halide::FuncRefExpr untyped;

public:
typedef Halide::Stage Stage;
typedef Halide::Tuple Tuple;

FuncRefExprT(const Halide::FuncRefExpr& untyped)
: T(Tuple(untyped)), untyped(untyped) {}

Stage operator=(T x) { return untyped = x; }
Stage operator+=(T x) { return untyped = T(Tuple(untyped)) + x; }
Stage operator-=(T x) { return untyped = T(Tuple(untyped)) - x;}
Expand All @@ -56,129 +38,70 @@ class FuncT : public Halide::Func {
explicit FuncT(Func f) : Func(f) {}
explicit FuncT(Halide::Internal::Function f) : Func(f) {}

FuncRefVarT<T> operator()() const { return Func::operator()(); }
FuncRefVarT<T> operator()(Var x) const { return Func::operator()(x); }
FuncRefVarT<T> operator()(Var x, Var y) const { return Func::operator()(x, y); }
FuncRefVarT<T> operator()(Var x, Var y, Var z) const { return Func::operator()(x, y, z); }
FuncRefVarT<T> operator()(Var x, Var y, Var z, Var w) const { return Func::operator()(x, y, z, w); }
FuncRefVarT<T> operator()(Var x, Var y, Var z, Var w, Var u) const { return Func::operator()(x, y, z, w, u); }
FuncRefVarT<T> operator()(Var x, Var y, Var z, Var w, Var u, Var v) const { return Func::operator()(x, y, z, w, u, v); }
FuncRefVarT<T> operator()(std::vector<Var> vars) const { return Func::operator()(vars); }

FuncRefExprT<T> operator()(Expr x) const { return Func::operator()(x); }
FuncRefExprT<T> operator()(Expr x, Expr y) const { return Func::operator()(x, y); }
FuncRefExprT<T> operator()(Expr x, Expr y, Expr z) const { return Func::operator()(x, y, z); }
FuncRefExprT<T> operator()(Expr x, Expr y, Expr z, Expr w) const { return Func::operator()(x, y, z, w); }
FuncRefExprT<T> operator()(Expr x, Expr y, Expr z, Expr w, Expr u) const { return Func::operator()(x, y, z, w, u); }
FuncRefExprT<T> operator()(Expr x, Expr y, Expr z, Expr w, Expr u, Expr v) const { return Func::operator()(x, y, z, w, u, v); }
FuncRefExprT<T> operator()(std::vector<Expr> vars) const { return Func::operator()(vars); }
FuncRefT<T> operator()(Expr x) const { return Func::operator()(x); }
FuncRefT<T> operator()(Expr x, Expr y) const { return Func::operator()(x, y); }
FuncRefT<T> operator()(Expr x, Expr y, Expr z) const { return Func::operator()(x, y, z); }
FuncRefT<T> operator()(Expr x, Expr y, Expr z, Expr w) const { return Func::operator()(x, y, z, w); }
FuncRefT<T> operator()(Expr x, Expr y, Expr z, Expr w, Expr u) const { return Func::operator()(x, y, z, w, u); }
FuncRefT<T> operator()(Expr x, Expr y, Expr z, Expr w, Expr u, Expr v) const { return Func::operator()(x, y, z, w, u, v); }
FuncRefT<T> operator()(std::vector<Expr> vars) const { return Func::operator()(vars); }
FuncRefT<T> operator()(std::vector<Var> vars) const { return Func::operator()(vars); }
};

// Forward operator overload invocations on FuncRefVarT/FuncRefExprT to
// Forward operator overload invocations on FuncRefT to
// the type the user intended (T).

// TODO(dsharlet): This is obscene. Find a better way... but it is unlikely
// there is one.
template <typename T>
T operator-(FuncRefVarT<T> x) { return -static_cast<T>(x); }
template <typename T>
T operator~(FuncRefVarT<T> x) { return ~static_cast<T>(x); }

template <typename T>
T operator+(FuncRefVarT<T> a, T b) { return static_cast<T>(a) + b; }
template <typename T>
T operator-(FuncRefVarT<T> a, T b) { return static_cast<T>(a) - b; }
template <typename T>
T operator*(FuncRefVarT<T> a, T b) { return static_cast<T>(a) * b; }
template <typename T>
T operator/(FuncRefVarT<T> a, T b) { return static_cast<T>(a) / b; }
template <typename T>
T operator%(FuncRefVarT<T> a, T b) { return static_cast<T>(a) % b; }
template <typename T>
T operator+(T a, FuncRefVarT<T> b) { return a + static_cast<T>(b); }
template <typename T>
T operator-(T a, FuncRefVarT<T> b) { return a - static_cast<T>(b); }
template <typename T>
T operator*(T a, FuncRefVarT<T> b) { return a * static_cast<T>(b); }
template <typename T>
T operator/(T a, FuncRefVarT<T> b) { return a / static_cast<T>(b); }
template <typename T>
T operator%(T a, FuncRefVarT<T> b) { return a % static_cast<T>(b); }

template <typename T>
Halide::Expr operator==(FuncRefVarT<T> a, T b) { return static_cast<T>(a) == b; }
template <typename T>
Halide::Expr operator!=(FuncRefVarT<T> a, T b) { return static_cast<T>(a) != b; }
template <typename T>
Halide::Expr operator<=(FuncRefVarT<T> a, T b) { return static_cast<T>(a) <= b; }
template <typename T>
Halide::Expr operator>=(FuncRefVarT<T> a, T b) { return static_cast<T>(a) >= b; }
template <typename T>
Halide::Expr operator<(FuncRefVarT<T> a, T b) { return static_cast<T>(a) < b; }
template <typename T>
Halide::Expr operator>(FuncRefVarT<T> a, T b) { return static_cast<T>(a) > b; }
template <typename T>
Halide::Expr operator==(T a, FuncRefVarT<T> b) { return a == static_cast<T>(b); }
template <typename T>
Halide::Expr operator!=(T a, FuncRefVarT<T> b) { return a != static_cast<T>(b); }
template <typename T>
Halide::Expr operator<=(T a, FuncRefVarT<T> b) { return a <= static_cast<T>(b); }
template <typename T>
Halide::Expr operator>=(T a, FuncRefVarT<T> b) { return a >= static_cast<T>(b); }
template <typename T>
Halide::Expr operator<(T a, FuncRefVarT<T> b) { return a < static_cast<T>(b); }
template <typename T>
Halide::Expr operator>(T a, FuncRefVarT<T> b) { return a > static_cast<T>(b); }

template <typename T>
T operator-(FuncRefExprT<T> x) { return -static_cast<T>(x); }
T operator-(FuncRefT<T> x) { return -static_cast<T>(x); }
template <typename T>
T operator~(FuncRefExprT<T> x) { return ~static_cast<T>(x); }
T operator~(FuncRefT<T> x) { return ~static_cast<T>(x); }

template <typename T>
T operator+(FuncRefExprT<T> a, T b) { return static_cast<T>(a) + b; }
T operator+(FuncRefT<T> a, T b) { return static_cast<T>(a) + b; }
template <typename T>
T operator-(FuncRefExprT<T> a, T b) { return static_cast<T>(a) - b; }
T operator-(FuncRefT<T> a, T b) { return static_cast<T>(a) - b; }
template <typename T>
T operator*(FuncRefExprT<T> a, T b) { return static_cast<T>(a) * b; }
T operator*(FuncRefT<T> a, T b) { return static_cast<T>(a) * b; }
template <typename T>
T operator/(FuncRefExprT<T> a, T b) { return static_cast<T>(a) / b; }
T operator/(FuncRefT<T> a, T b) { return static_cast<T>(a) / b; }
template <typename T>
T operator%(FuncRefExprT<T> a, T b) { return static_cast<T>(a) % b; }
T operator%(FuncRefT<T> a, T b) { return static_cast<T>(a) % b; }
template <typename T>
T operator+(T a, FuncRefExprT<T> b) { return a + static_cast<T>(b); }
T operator+(T a, FuncRefT<T> b) { return a + static_cast<T>(b); }
template <typename T>
T operator-(T a, FuncRefExprT<T> b) { return a - static_cast<T>(b); }
T operator-(T a, FuncRefT<T> b) { return a - static_cast<T>(b); }
template <typename T>
T operator*(T a, FuncRefExprT<T> b) { return a * static_cast<T>(b); }
T operator*(T a, FuncRefT<T> b) { return a * static_cast<T>(b); }
template <typename T>
T operator/(T a, FuncRefExprT<T> b) { return a / static_cast<T>(b); }
T operator/(T a, FuncRefT<T> b) { return a / static_cast<T>(b); }
template <typename T>
T operator%(T a, FuncRefExprT<T> b) { return a % static_cast<T>(b); }
T operator%(T a, FuncRefT<T> b) { return a % static_cast<T>(b); }

template <typename T>
Halide::Expr operator==(FuncRefExprT<T> a, T b) { return static_cast<T>(a) == b; }
Halide::Expr operator==(FuncRefT<T> a, T b) { return static_cast<T>(a) == b; }
template <typename T>
Halide::Expr operator!=(FuncRefExprT<T> a, T b) { return static_cast<T>(a) != b; }
Halide::Expr operator!=(FuncRefT<T> a, T b) { return static_cast<T>(a) != b; }
template <typename T>
Halide::Expr operator<=(FuncRefExprT<T> a, T b) { return static_cast<T>(a) <= b; }
Halide::Expr operator<=(FuncRefT<T> a, T b) { return static_cast<T>(a) <= b; }
template <typename T>
Halide::Expr operator>=(FuncRefExprT<T> a, T b) { return static_cast<T>(a) >= b; }
Halide::Expr operator>=(FuncRefT<T> a, T b) { return static_cast<T>(a) >= b; }
template <typename T>
Halide::Expr operator<(FuncRefExprT<T> a, T b) { return static_cast<T>(a) < b; }
Halide::Expr operator<(FuncRefT<T> a, T b) { return static_cast<T>(a) < b; }
template <typename T>
Halide::Expr operator>(FuncRefExprT<T> a, T b) { return static_cast<T>(a) > b; }
Halide::Expr operator>(FuncRefT<T> a, T b) { return static_cast<T>(a) > b; }
template <typename T>
Halide::Expr operator==(T a, FuncRefExprT<T> b) { return a == static_cast<T>(b); }
Halide::Expr operator==(T a, FuncRefT<T> b) { return a == static_cast<T>(b); }
template <typename T>
Halide::Expr operator!=(T a, FuncRefExprT<T> b) { return a != static_cast<T>(b); }
Halide::Expr operator!=(T a, FuncRefT<T> b) { return a != static_cast<T>(b); }
template <typename T>
Halide::Expr operator<=(T a, FuncRefExprT<T> b) { return a <= static_cast<T>(b); }
Halide::Expr operator<=(T a, FuncRefT<T> b) { return a <= static_cast<T>(b); }
template <typename T>
Halide::Expr operator>=(T a, FuncRefExprT<T> b) { return a >= static_cast<T>(b); }
Halide::Expr operator>=(T a, FuncRefT<T> b) { return a >= static_cast<T>(b); }
template <typename T>
Halide::Expr operator<(T a, FuncRefExprT<T> b) { return a < static_cast<T>(b); }
Halide::Expr operator<(T a, FuncRefT<T> b) { return a < static_cast<T>(b); }
template <typename T>
Halide::Expr operator>(T a, FuncRefExprT<T> b) { return a > static_cast<T>(b); }
Halide::Expr operator>(T a, FuncRefT<T> b) { return a > static_cast<T>(b); }

#endif
#endif
5 changes: 2 additions & 3 deletions python_bindings/doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ Main classes Description
Helper Classes Description
==================================== ================================
:py:class:`halide.buffer_t` The raw representation of an image passed around by generated Halide code.
:py:class:`halide.FuncRefExpr`
:py:class:`halide.FuncRef`
:py:class:`halide.InternalFunction`
:py:class:`halide.FuncRefVar`
:py:class:`halide.Stage` A single definition of a Func.
:py:class:`halide.VarOrRVar`
==================================== ================================
Expand Down Expand Up @@ -91,6 +90,6 @@ Indices and tables
.. toctree::
:hidden:
:maxdepth: 2

self

20 changes: 9 additions & 11 deletions python_bindings/python/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,19 @@ p::object func_getitem_operator0(h::Func &that, p::tuple args_passed)
// We prioritize Args over Expr variant
if(var_args.size() == args_len)
{
h::FuncRefVar ret = that(var_args);
h::FuncRef ret = that(var_args);

p::copy_non_const_reference::apply<h::FuncRefVar &>::type converter;
p::copy_non_const_reference::apply<h::FuncRef &>::type converter;
PyObject* obj = converter( ret );
return_object = p::object( p::handle<>( obj ) );
}
else
{ user_assert(expr_args.size() == args_len)
<< "Not all func_getitem_operator0 arguments where converted to Expr "
<< "( expr_args.size() " << expr_args.size() << "!= args_len " << args_len << ")";
h::FuncRefExpr ret = that(expr_args);
h::FuncRef ret = that(expr_args);

p::copy_non_const_reference::apply<h::FuncRefExpr &>::type converter;
p::copy_non_const_reference::apply<h::FuncRef &>::type converter;
PyObject* obj = converter( ret );
return_object = p::object( p::handle<>( obj ) );
}
Expand Down Expand Up @@ -308,14 +308,14 @@ h::Stage func_setitem_operator0(h::Func &that, p::tuple args_passed, T right_han
// We prioritize Args
if(var_args.size() == args_len)
{
h::FuncRefVar ret = that(var_args);
h::FuncRef ret = that(var_args);
h::Stage s = (ret = right_hand);
return s;
}
else
{ user_assert(expr_args.size() == args_len) << "Not all func_setitem_operator0 arguments where converted to Expr";

h::FuncRefExpr ret = that(expr_args);
h::FuncRef ret = that(expr_args);
h::Stage s = (ret = right_hand);
return s;
}
Expand Down Expand Up @@ -598,12 +598,10 @@ void defineFunc()
.def("__getitem__", &func_getitem_operator1); // handles the case where a single index object is given

func_class
.def("__setitem__", &func_setitem_operator0<h::FuncRefVar>)
.def("__setitem__", &func_setitem_operator0<h::FuncRefExpr>)
.def("__setitem__", &func_setitem_operator0<h::FuncRef>)
.def("__setitem__", &func_setitem_operator0<h::Expr>)
.def("__setitem__", &func_setitem_operator0<h::Tuple>)
.def("__setitem__", &func_setitem_operator1<h::FuncRefVar>) // handles the case where a single index object is given
.def("__setitem__", &func_setitem_operator1<h::FuncRefExpr>)
.def("__setitem__", &func_setitem_operator1<h::FuncRef>) // handles the case where a single index object is given
.def("__setitem__", &func_setitem_operator1<h::Expr>)
.def("__setitem__", &func_setitem_operator1<h::Tuple>);

Expand Down Expand Up @@ -790,4 +788,4 @@ void defineFunc()
defineStage();
defineVarOrRVar();
defineFuncRef();
}
}
Loading

0 comments on commit a546872

Please sign in to comment.