Skip to content

Commit

Permalink
Organized definition of binary operations
Browse files Browse the repository at this point in the history
Defined all of OpenCL binary operators.
  • Loading branch information
ddemidov committed Jun 10, 2012
1 parent 1651194 commit 2eab656
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 188 deletions.
2 changes: 2 additions & 0 deletions examples/utests.cpp
Expand Up @@ -482,6 +482,8 @@ int main() {
Reductor<size_t,SUM> sum(ctx.queue()); Reductor<size_t,SUM> sum(ctx.queue());
rc = rc && sum(chk_if_greater(x, y)) == 0; rc = rc && sum(chk_if_greater(x, y)) == 0;
rc = rc && sum(chk_if_greater(y, x)) == N; rc = rc && sum(chk_if_greater(y, x)) == N;
rc = rc && sum(x > y) == 0;
rc = rc && sum(x < y) == N;
return rc; return rc;
}); });


Expand Down
319 changes: 131 additions & 188 deletions vexcl/vector.hpp
Expand Up @@ -570,25 +570,24 @@ class vector : public expression {
return *this; return *this;
} }


template <class Expr> #define COMPOUND_ASSIGNMENT(cop, op) \
const vector& operator+=(const Expr &expr) { template <class Expr> \
return *this = *this + expr; const vector& operator cop(const Expr &expr) { \
} return *this = *this op expr; \

template <class Expr>
const vector& operator*=(const Expr &expr) {
return *this = *this * expr;
} }


template <class Expr> COMPOUND_ASSIGNMENT(+=, +);
const vector& operator/=(const Expr &expr) { COMPOUND_ASSIGNMENT(-=, -);
return *this = *this / expr; COMPOUND_ASSIGNMENT(*=, *);
} COMPOUND_ASSIGNMENT(/=, /);
COMPOUND_ASSIGNMENT(%=, %);
COMPOUND_ASSIGNMENT(&=, &);
COMPOUND_ASSIGNMENT(|=, |);
COMPOUND_ASSIGNMENT(^=, ^);
COMPOUND_ASSIGNMENT(<<=, <<);
COMPOUND_ASSIGNMENT(>>=, >>);


template <class Expr> #undef COMPOUND_ASSIGNMENT
const vector& operator-=(const Expr &expr) {
return *this = *this - expr;
}


template <class Expr, typename column_t> template <class Expr, typename column_t>
const vector& operator=(const ExSpMV<Expr,T,column_t> &xmv); const vector& operator=(const ExSpMV<Expr,T,column_t> &xmv);
Expand Down Expand Up @@ -766,8 +765,58 @@ void swap(vector<T> &x, vector<T> &y) {
x.swap(y); x.swap(y);
} }


namespace binop {
enum kind {
Add,
Subtract,
Multiply,
Divide,
Remainder,
Greater,
Less,
GreaterEqual,
LessEqual,
Equal,
NotEqual,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
LogicalAnd,
LogicalOr,
RightShift,
LeftShift
};

template <kind> struct traits {};

#define BOP_TRAITS(kind, op, nm) \
template <> struct traits<kind> { \
static std::string oper() { return op; } \
static std::string name() { return nm; } \
};

BOP_TRAITS(Add, "+", "Add_")
BOP_TRAITS(Subtract, "-", "Sub_")
BOP_TRAITS(Multiply, "*", "Mul_")
BOP_TRAITS(Divide, "/", "Div_")
BOP_TRAITS(Remainder, "%", "Mod_")
BOP_TRAITS(Greater, ">", "Gtr_")
BOP_TRAITS(Less, "<", "Lss_")
BOP_TRAITS(GreaterEqual, ">=", "Geq_")
BOP_TRAITS(LessEqual, "<=", "Leq_")
BOP_TRAITS(Equal, "==", "Equ_")
BOP_TRAITS(NotEqual, "!=", "Neq_")
BOP_TRAITS(BitwiseAnd, "&", "BAnd_")
BOP_TRAITS(BitwiseOr, "|", "BOr_")
BOP_TRAITS(BitwiseXor, "^", "BXor_")
BOP_TRAITS(LogicalAnd, "&&", "LAnd_")
BOP_TRAITS(LogicalOr, "||", "LOr_")
BOP_TRAITS(RightShift, ">>", "Rsh_")
BOP_TRAITS(LeftShift, "<<", "Lsh_")
}

/// Expression template. /// Expression template.
template <class LHS, char OP, class RHS> template <class LHS, binop::kind OP, class RHS>
struct BinaryExpression : public expression { struct BinaryExpression : public expression {
BinaryExpression(const LHS &lhs, const RHS &rhs) : lhs(lhs), rhs(rhs) {} BinaryExpression(const LHS &lhs, const RHS &rhs) : lhs(lhs), rhs(rhs) {}


Expand All @@ -777,19 +826,7 @@ struct BinaryExpression : public expression {
} }


std::string kernel_name() const { std::string kernel_name() const {
// Polish notation. return binop::traits<OP>::name() + lhs.kernel_name() + rhs.kernel_name();
switch (OP) {
case '+':
return "p" + lhs.kernel_name() + rhs.kernel_name();
case '-':
return "m" + lhs.kernel_name() + rhs.kernel_name();
case '*':
return "t" + lhs.kernel_name() + rhs.kernel_name();
case '/':
return "d" + lhs.kernel_name() + rhs.kernel_name();
default:
throw "unknown operation";
}
} }


void kernel_prm(std::ostream &os, std::string name = "") const { void kernel_prm(std::ostream &os, std::string name = "") const {
Expand All @@ -800,7 +837,7 @@ struct BinaryExpression : public expression {
void kernel_expr(std::ostream &os, std::string name = "") const { void kernel_expr(std::ostream &os, std::string name = "") const {
os << "("; os << "(";
lhs.kernel_expr(os, name + "l"); lhs.kernel_expr(os, name + "l");
os << " " << OP << " "; os << " " << binop::traits<OP>::oper() << " ";
rhs.kernel_expr(os, name + "r"); rhs.kernel_expr(os, name + "r");
os << ")"; os << ")";
} }
Expand Down Expand Up @@ -831,46 +868,6 @@ struct valid_expr<T,
typename std::enable_if<std::is_arithmetic<T>::value>::type typename std::enable_if<std::is_arithmetic<T>::value>::type
> : std::true_type{}; > : std::true_type{};


/// Sum of two expressions.
template <class LHS, class RHS>
typename std::enable_if<
valid_expr<LHS>::value && valid_expr<RHS>::value,
BinaryExpression<LHS, '+', RHS>
>::type
operator+(const LHS &lhs, const RHS &rhs) {
return BinaryExpression<LHS,'+',RHS>(lhs, rhs);
}

/// Difference of two expressions.
template <class LHS, class RHS>
typename std::enable_if<
valid_expr<LHS>::value && valid_expr<RHS>::value,
BinaryExpression<LHS, '-', RHS>
>::type
operator-(const LHS &lhs, const RHS &rhs) {
return BinaryExpression<LHS,'-',RHS>(lhs, rhs);
}

/// Product of two expressions.
template <class LHS, class RHS>
typename std::enable_if<
valid_expr<LHS>::value && valid_expr<RHS>::value,
BinaryExpression<LHS, '*', RHS>
>::type
operator*(const LHS &lhs, const RHS &rhs) {
return BinaryExpression<LHS,'*',RHS>(lhs, rhs);
}

/// Division of two expressions.
template <class LHS, class RHS>
typename std::enable_if<
valid_expr<LHS>::value && valid_expr<RHS>::value,
BinaryExpression<LHS, '/', RHS>
>::type
operator/(const LHS &lhs, const RHS &rhs) {
return BinaryExpression<LHS,'/',RHS>(lhs, rhs);
}

//--------------------------------------------------------------------------- //---------------------------------------------------------------------------
// Multivector // Multivector
//--------------------------------------------------------------------------- //---------------------------------------------------------------------------
Expand Down Expand Up @@ -1154,25 +1151,24 @@ class multivector {
return *this; return *this;
} }


template <class Expr> #define COMPOUND_ASSIGNMENT(cop, op) \
const multivector& operator+=(const Expr &expr) { template <class Expr> \
return *this = *this + expr; const multivector& operator cop(const Expr &expr) { \
return *this = *this op expr; \
} }


template <class Expr> COMPOUND_ASSIGNMENT(+=, +);
const multivector& operator*=(const Expr &expr) { COMPOUND_ASSIGNMENT(-=, -);
return *this = *this * expr; COMPOUND_ASSIGNMENT(*=, *);
} COMPOUND_ASSIGNMENT(/=, /);
COMPOUND_ASSIGNMENT(%=, %);
COMPOUND_ASSIGNMENT(&=, &);
COMPOUND_ASSIGNMENT(|=, |);
COMPOUND_ASSIGNMENT(^=, ^);
COMPOUND_ASSIGNMENT(<<=, <<);
COMPOUND_ASSIGNMENT(>>=, >>);


template <class Expr> #undef COMPOUND_ASSIGNMENT
const multivector& operator/=(const Expr &expr) {
return *this = *this / expr;
}

template <class Expr>
const multivector& operator-=(const Expr &expr) {
return *this = *this - expr;
}


template <typename column_t> template <typename column_t>
const multivector& operator=(const MultiSpMV<T,column_t,N> &spmv); const multivector& operator=(const MultiSpMV<T,column_t,N> &spmv);
Expand Down Expand Up @@ -1270,109 +1266,56 @@ struct multiex_dim {
//--------------------------------------------------------------------------- //---------------------------------------------------------------------------
// Arithmetic expressions // Arithmetic expressions
//--------------------------------------------------------------------------- //---------------------------------------------------------------------------
template <class LHS, class RHS> #define DEFINE_BINARY_OP(kind, oper) \
typename std::enable_if<compatible_multiex<LHS, RHS>::value, template <class LHS, class RHS> \
MultiExpression< typename std::enable_if< \
BinaryExpression< valid_expr<LHS>::value && valid_expr<RHS>::value, \
typename multiex_traits<LHS>::subtype, BinaryExpression<LHS, kind, RHS> \
'+', >::type \
typename multiex_traits<RHS>::subtype operator oper(const LHS &lhs, const RHS &rhs) { \
>, return BinaryExpression<LHS, kind, RHS>(lhs, rhs); \
multiex_dim<LHS, RHS>::dim } \
>>::type template <class LHS, class RHS> \
operator+(const LHS &lhs, const RHS &rhs) { typename std::enable_if<compatible_multiex<LHS, RHS>::value, \
typedef BinaryExpression< MultiExpression< \
typename multiex_traits<LHS>::subtype, BinaryExpression< \
'+', typename multiex_traits<LHS>::subtype, \
typename multiex_traits<RHS>::subtype kind, \
> subtype; typename multiex_traits<RHS>::subtype \

>, \
std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex; multiex_dim<LHS, RHS>::dim \

>>::type \
for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++) operator oper(const LHS &lhs, const RHS &rhs) { \
ex[i].reset(new subtype( typedef BinaryExpression< \
extract_component(lhs, i), extract_component(rhs, i))); typename multiex_traits<LHS>::subtype, \

kind, \
return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex); typename multiex_traits<RHS>::subtype \
} > subtype; \

std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex; \
template <class LHS, class RHS> for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++) \
typename std::enable_if<compatible_multiex<LHS, RHS>::value, ex[i].reset(new subtype( \
MultiExpression< extract_component(lhs, i), extract_component(rhs, i))); \
BinaryExpression< return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex); \
typename multiex_traits<LHS>::subtype,
'-',
typename multiex_traits<RHS>::subtype
>,
multiex_dim<LHS, RHS>::dim
>>::type
operator-(const LHS &lhs, const RHS &rhs) {
typedef BinaryExpression<
typename multiex_traits<LHS>::subtype,
'-',
typename multiex_traits<RHS>::subtype
> subtype;

std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;

for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
ex[i].reset(new subtype(
extract_component(lhs, i), extract_component(rhs, i)));

return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
}

template <class LHS, class RHS>
typename std::enable_if<compatible_multiex<LHS, RHS>::value,
MultiExpression<
BinaryExpression<
typename multiex_traits<LHS>::subtype,
'*',
typename multiex_traits<RHS>::subtype
>,
multiex_dim<LHS, RHS>::dim
>>::type
operator*(const LHS &lhs, const RHS &rhs) {
typedef BinaryExpression<
typename multiex_traits<LHS>::subtype,
'*',
typename multiex_traits<RHS>::subtype
> subtype;

std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;

for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
ex[i].reset(new subtype(
extract_component(lhs, i), extract_component(rhs, i)));

return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
} }


template <class LHS, class RHS> DEFINE_BINARY_OP(binop::Add, + )
typename std::enable_if<compatible_multiex<LHS, RHS>::value, DEFINE_BINARY_OP(binop::Subtract, - )
MultiExpression< DEFINE_BINARY_OP(binop::Multiply, * )
BinaryExpression< DEFINE_BINARY_OP(binop::Divide, / )
typename multiex_traits<LHS>::subtype, DEFINE_BINARY_OP(binop::Remainder, % )
'/', DEFINE_BINARY_OP(binop::Greater, > )
typename multiex_traits<RHS>::subtype DEFINE_BINARY_OP(binop::Less, < )
>, DEFINE_BINARY_OP(binop::GreaterEqual, >=)
multiex_dim<LHS, RHS>::dim DEFINE_BINARY_OP(binop::LessEqual, <=)
>>::type DEFINE_BINARY_OP(binop::Equal, ==)
operator/(const LHS &lhs, const RHS &rhs) { DEFINE_BINARY_OP(binop::NotEqual, !=)
typedef BinaryExpression< DEFINE_BINARY_OP(binop::BitwiseAnd, & )
typename multiex_traits<LHS>::subtype, DEFINE_BINARY_OP(binop::BitwiseOr, | )
'/', DEFINE_BINARY_OP(binop::BitwiseXor, ^ )
typename multiex_traits<RHS>::subtype DEFINE_BINARY_OP(binop::LogicalAnd, &&)
> subtype; DEFINE_BINARY_OP(binop::LogicalOr, ||)

DEFINE_BINARY_OP(binop::RightShift, >>)
std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex; DEFINE_BINARY_OP(binop::LeftShift, <<)

for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
ex[i].reset(new subtype(
extract_component(lhs, i), extract_component(rhs, i)));

return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
}


//--------------------------------------------------------------------------- //---------------------------------------------------------------------------
// Builtin functions // Builtin functions
Expand Down

0 comments on commit 2eab656

Please sign in to comment.