Permalink
Browse files

Organized definition of binary operations

Defined all of OpenCL binary operators.
  • Loading branch information...
ddemidov committed Jun 10, 2012
1 parent 1651194 commit 2eab65612a9afc96f674f9647d02985492bf3458
Showing with 133 additions and 188 deletions.
  1. +2 −0 examples/utests.cpp
  2. +131 −188 vexcl/vector.hpp
View
@@ -482,6 +482,8 @@ int main() {
Reductor<size_t,SUM> sum(ctx.queue());
rc = rc && sum(chk_if_greater(x, y)) == 0;
rc = rc && sum(chk_if_greater(y, x)) == N;
+ rc = rc && sum(x > y) == 0;
+ rc = rc && sum(x < y) == N;
return rc;
});
View
@@ -570,25 +570,24 @@ class vector : public expression {
return *this;
}
- template <class Expr>
- const vector& operator+=(const Expr &expr) {
- return *this = *this + expr;
- }
-
- template <class Expr>
- const vector& operator*=(const Expr &expr) {
- return *this = *this * expr;
+#define COMPOUND_ASSIGNMENT(cop, op) \
+ template <class Expr> \
+ const vector& operator cop(const Expr &expr) { \
+ return *this = *this op expr; \
}
- template <class Expr>
- const vector& operator/=(const Expr &expr) {
- return *this = *this / expr;
- }
+ COMPOUND_ASSIGNMENT(+=, +);
+ COMPOUND_ASSIGNMENT(-=, -);
+ COMPOUND_ASSIGNMENT(*=, *);
+ COMPOUND_ASSIGNMENT(/=, /);
+ COMPOUND_ASSIGNMENT(%=, %);
+ COMPOUND_ASSIGNMENT(&=, &);
+ COMPOUND_ASSIGNMENT(|=, |);
+ COMPOUND_ASSIGNMENT(^=, ^);
+ COMPOUND_ASSIGNMENT(<<=, <<);
+ COMPOUND_ASSIGNMENT(>>=, >>);
- template <class Expr>
- const vector& operator-=(const Expr &expr) {
- return *this = *this - expr;
- }
+#undef COMPOUND_ASSIGNMENT
template <class Expr, typename column_t>
const vector& operator=(const ExSpMV<Expr,T,column_t> &xmv);
@@ -766,8 +765,58 @@ void swap(vector<T> &x, vector<T> &y) {
x.swap(y);
}
+namespace binop {
+ enum kind {
+ Add,
+ Subtract,
+ Multiply,
+ Divide,
+ Remainder,
+ Greater,
+ Less,
+ GreaterEqual,
+ LessEqual,
+ Equal,
+ NotEqual,
+ BitwiseAnd,
+ BitwiseOr,
+ BitwiseXor,
+ LogicalAnd,
+ LogicalOr,
+ RightShift,
+ LeftShift
+ };
+
+ template <kind> struct traits {};
+
+#define BOP_TRAITS(kind, op, nm) \
+ template <> struct traits<kind> { \
+ static std::string oper() { return op; } \
+ static std::string name() { return nm; } \
+ };
+
+ BOP_TRAITS(Add, "+", "Add_")
+ BOP_TRAITS(Subtract, "-", "Sub_")
+ BOP_TRAITS(Multiply, "*", "Mul_")
+ BOP_TRAITS(Divide, "/", "Div_")
+ BOP_TRAITS(Remainder, "%", "Mod_")
+ BOP_TRAITS(Greater, ">", "Gtr_")
+ BOP_TRAITS(Less, "<", "Lss_")
+ BOP_TRAITS(GreaterEqual, ">=", "Geq_")
+ BOP_TRAITS(LessEqual, "<=", "Leq_")
+ BOP_TRAITS(Equal, "==", "Equ_")
+ BOP_TRAITS(NotEqual, "!=", "Neq_")
+ BOP_TRAITS(BitwiseAnd, "&", "BAnd_")
+ BOP_TRAITS(BitwiseOr, "|", "BOr_")
+ BOP_TRAITS(BitwiseXor, "^", "BXor_")
+ BOP_TRAITS(LogicalAnd, "&&", "LAnd_")
+ BOP_TRAITS(LogicalOr, "||", "LOr_")
+ BOP_TRAITS(RightShift, ">>", "Rsh_")
+ BOP_TRAITS(LeftShift, "<<", "Lsh_")
+}
+
/// Expression template.
-template <class LHS, char OP, class RHS>
+template <class LHS, binop::kind OP, class RHS>
struct BinaryExpression : public expression {
BinaryExpression(const LHS &lhs, const RHS &rhs) : lhs(lhs), rhs(rhs) {}
@@ -777,19 +826,7 @@ struct BinaryExpression : public expression {
}
std::string kernel_name() const {
- // Polish notation.
- switch (OP) {
- case '+':
- return "p" + lhs.kernel_name() + rhs.kernel_name();
- case '-':
- return "m" + lhs.kernel_name() + rhs.kernel_name();
- case '*':
- return "t" + lhs.kernel_name() + rhs.kernel_name();
- case '/':
- return "d" + lhs.kernel_name() + rhs.kernel_name();
- default:
- throw "unknown operation";
- }
+ return binop::traits<OP>::name() + lhs.kernel_name() + rhs.kernel_name();
}
void kernel_prm(std::ostream &os, std::string name = "") const {
@@ -800,7 +837,7 @@ struct BinaryExpression : public expression {
void kernel_expr(std::ostream &os, std::string name = "") const {
os << "(";
lhs.kernel_expr(os, name + "l");
- os << " " << OP << " ";
+ os << " " << binop::traits<OP>::oper() << " ";
rhs.kernel_expr(os, name + "r");
os << ")";
}
@@ -831,46 +868,6 @@ struct valid_expr<T,
typename std::enable_if<std::is_arithmetic<T>::value>::type
> : std::true_type{};
-/// Sum of two expressions.
-template <class LHS, class RHS>
-typename std::enable_if<
- valid_expr<LHS>::value && valid_expr<RHS>::value,
- BinaryExpression<LHS, '+', RHS>
- >::type
- operator+(const LHS &lhs, const RHS &rhs) {
- return BinaryExpression<LHS,'+',RHS>(lhs, rhs);
- }
-
-/// Difference of two expressions.
-template <class LHS, class RHS>
-typename std::enable_if<
- valid_expr<LHS>::value && valid_expr<RHS>::value,
- BinaryExpression<LHS, '-', RHS>
- >::type
- operator-(const LHS &lhs, const RHS &rhs) {
- return BinaryExpression<LHS,'-',RHS>(lhs, rhs);
- }
-
-/// Product of two expressions.
-template <class LHS, class RHS>
-typename std::enable_if<
- valid_expr<LHS>::value && valid_expr<RHS>::value,
- BinaryExpression<LHS, '*', RHS>
- >::type
- operator*(const LHS &lhs, const RHS &rhs) {
- return BinaryExpression<LHS,'*',RHS>(lhs, rhs);
- }
-
-/// Division of two expressions.
-template <class LHS, class RHS>
-typename std::enable_if<
- valid_expr<LHS>::value && valid_expr<RHS>::value,
- BinaryExpression<LHS, '/', RHS>
- >::type
- operator/(const LHS &lhs, const RHS &rhs) {
- return BinaryExpression<LHS,'/',RHS>(lhs, rhs);
- }
-
//---------------------------------------------------------------------------
// Multivector
//---------------------------------------------------------------------------
@@ -1154,25 +1151,24 @@ class multivector {
return *this;
}
- template <class Expr>
- const multivector& operator+=(const Expr &expr) {
- return *this = *this + expr;
+#define COMPOUND_ASSIGNMENT(cop, op) \
+ template <class Expr> \
+ const multivector& operator cop(const Expr &expr) { \
+ return *this = *this op expr; \
}
- template <class Expr>
- const multivector& operator*=(const Expr &expr) {
- return *this = *this * expr;
- }
+ COMPOUND_ASSIGNMENT(+=, +);
+ COMPOUND_ASSIGNMENT(-=, -);
+ COMPOUND_ASSIGNMENT(*=, *);
+ COMPOUND_ASSIGNMENT(/=, /);
+ COMPOUND_ASSIGNMENT(%=, %);
+ COMPOUND_ASSIGNMENT(&=, &);
+ COMPOUND_ASSIGNMENT(|=, |);
+ COMPOUND_ASSIGNMENT(^=, ^);
+ COMPOUND_ASSIGNMENT(<<=, <<);
+ COMPOUND_ASSIGNMENT(>>=, >>);
- template <class Expr>
- const multivector& operator/=(const Expr &expr) {
- return *this = *this / expr;
- }
-
- template <class Expr>
- const multivector& operator-=(const Expr &expr) {
- return *this = *this - expr;
- }
+#undef COMPOUND_ASSIGNMENT
template <typename column_t>
const multivector& operator=(const MultiSpMV<T,column_t,N> &spmv);
@@ -1270,109 +1266,56 @@ struct multiex_dim {
//---------------------------------------------------------------------------
// Arithmetic expressions
//---------------------------------------------------------------------------
-template <class LHS, class RHS>
-typename std::enable_if<compatible_multiex<LHS, RHS>::value,
- MultiExpression<
- BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '+',
- typename multiex_traits<RHS>::subtype
- >,
- multiex_dim<LHS, RHS>::dim
- >>::type
-operator+(const LHS &lhs, const RHS &rhs) {
- typedef BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '+',
- typename multiex_traits<RHS>::subtype
- > subtype;
-
- std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;
-
- for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
- ex[i].reset(new subtype(
- extract_component(lhs, i), extract_component(rhs, i)));
-
- return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
-}
-
-template <class LHS, class RHS>
-typename std::enable_if<compatible_multiex<LHS, RHS>::value,
- MultiExpression<
- BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '-',
- typename multiex_traits<RHS>::subtype
- >,
- multiex_dim<LHS, RHS>::dim
- >>::type
-operator-(const LHS &lhs, const RHS &rhs) {
- typedef BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '-',
- typename multiex_traits<RHS>::subtype
- > subtype;
-
- std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;
-
- for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
- ex[i].reset(new subtype(
- extract_component(lhs, i), extract_component(rhs, i)));
-
- return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
-}
-
-template <class LHS, class RHS>
-typename std::enable_if<compatible_multiex<LHS, RHS>::value,
- MultiExpression<
- BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '*',
- typename multiex_traits<RHS>::subtype
- >,
- multiex_dim<LHS, RHS>::dim
- >>::type
-operator*(const LHS &lhs, const RHS &rhs) {
- typedef BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '*',
- typename multiex_traits<RHS>::subtype
- > subtype;
-
- std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;
-
- for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
- ex[i].reset(new subtype(
- extract_component(lhs, i), extract_component(rhs, i)));
-
- return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
+#define DEFINE_BINARY_OP(kind, oper) \
+template <class LHS, class RHS> \
+typename std::enable_if< \
+ valid_expr<LHS>::value && valid_expr<RHS>::value, \
+ BinaryExpression<LHS, kind, RHS> \
+ >::type \
+ operator oper(const LHS &lhs, const RHS &rhs) { \
+ return BinaryExpression<LHS, kind, RHS>(lhs, rhs); \
+ } \
+template <class LHS, class RHS> \
+typename std::enable_if<compatible_multiex<LHS, RHS>::value, \
+ MultiExpression< \
+ BinaryExpression< \
+ typename multiex_traits<LHS>::subtype, \
+ kind, \
+ typename multiex_traits<RHS>::subtype \
+ >, \
+ multiex_dim<LHS, RHS>::dim \
+ >>::type \
+operator oper(const LHS &lhs, const RHS &rhs) { \
+ typedef BinaryExpression< \
+ typename multiex_traits<LHS>::subtype, \
+ kind, \
+ typename multiex_traits<RHS>::subtype \
+ > subtype; \
+ std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex; \
+ for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++) \
+ ex[i].reset(new subtype( \
+ extract_component(lhs, i), extract_component(rhs, i))); \
+ return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex); \
}
-template <class LHS, class RHS>
-typename std::enable_if<compatible_multiex<LHS, RHS>::value,
- MultiExpression<
- BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '/',
- typename multiex_traits<RHS>::subtype
- >,
- multiex_dim<LHS, RHS>::dim
- >>::type
-operator/(const LHS &lhs, const RHS &rhs) {
- typedef BinaryExpression<
- typename multiex_traits<LHS>::subtype,
- '/',
- typename multiex_traits<RHS>::subtype
- > subtype;
-
- std::array<std::unique_ptr<subtype>, multiex_dim<LHS, RHS>::dim> ex;
-
- for(uint i = 0; i < multiex_dim<LHS, RHS>::dim; i++)
- ex[i].reset(new subtype(
- extract_component(lhs, i), extract_component(rhs, i)));
-
- return MultiExpression<subtype, multiex_dim<LHS, RHS>::dim>(ex);
-}
+DEFINE_BINARY_OP(binop::Add, + )
+DEFINE_BINARY_OP(binop::Subtract, - )
+DEFINE_BINARY_OP(binop::Multiply, * )
+DEFINE_BINARY_OP(binop::Divide, / )
+DEFINE_BINARY_OP(binop::Remainder, % )
+DEFINE_BINARY_OP(binop::Greater, > )
+DEFINE_BINARY_OP(binop::Less, < )
+DEFINE_BINARY_OP(binop::GreaterEqual, >=)
+DEFINE_BINARY_OP(binop::LessEqual, <=)
+DEFINE_BINARY_OP(binop::Equal, ==)
+DEFINE_BINARY_OP(binop::NotEqual, !=)
+DEFINE_BINARY_OP(binop::BitwiseAnd, & )
+DEFINE_BINARY_OP(binop::BitwiseOr, | )
+DEFINE_BINARY_OP(binop::BitwiseXor, ^ )
+DEFINE_BINARY_OP(binop::LogicalAnd, &&)
+DEFINE_BINARY_OP(binop::LogicalOr, ||)
+DEFINE_BINARY_OP(binop::RightShift, >>)
+DEFINE_BINARY_OP(binop::LeftShift, <<)
//---------------------------------------------------------------------------
// Builtin functions

0 comments on commit 2eab656

Please sign in to comment.