Skip to content

Commit

Permalink
Generalized stencils with user functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ddemidov committed Jun 26, 2012
1 parent d9b3429 commit 256c1bb
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 31 deletions.
54 changes: 44 additions & 10 deletions examples/utests.cpp
Expand Up @@ -25,13 +25,13 @@ bool run_test(const std::string &name, std::function<bool()> test) {
return rc;
}

extern const char chk_if_gr_body[] = "return prm1 > prm2 ? 1 : 0;";
extern const char greater_body[] = "return prm1 > prm2 ? 1 : 0;";
UserFunction<greater_body, size_t(double, double)> greater;

int main(int argc, char *argv[]) {
uint seed = argc > 1 ? atoi(argv[1]) : static_cast<uint>(time(0));
std::cout << "seed: " << seed << std::endl;
srand(seed);
extern const char pow3_body[] = "return pow(prm1, 3);";
UserFunction<pow3_body, double(double)> pow3;

int main(int argc, char *argv[]) {
try {
vex::Context ctx(Filter::DoublePrecision && Filter::Env);
std::cout << ctx << std::endl;
Expand All @@ -41,6 +41,10 @@ int main(int argc, char *argv[]) {
return 1;
}

uint seed = argc > 1 ? atoi(argv[1]) : static_cast<uint>(time(0));
std::cout << "seed: " << seed << std::endl << std::endl;
srand(seed);

run_test("Empty vector construction", [&]() -> bool {
bool rc = true;
vex::vector<double> x;
Expand Down Expand Up @@ -780,10 +784,9 @@ int main(int argc, char *argv[]) {
vex::vector<double> y(ctx.queue(), N);
x = 1;
y = 2;
UserFunction<chk_if_gr_body, size_t(double, double)> chk_if_greater;
Reductor<size_t,SUM> sum(ctx.queue());
rc = rc && sum(chk_if_greater(x, y)) == 0;
rc = rc && sum(chk_if_greater(y, x)) == N;
rc = rc && sum(greater(x, y)) == 0;
rc = rc && sum(greater(y, x)) == N;
rc = rc && sum(x > y) == 0;
rc = rc && sum(x < y) == N;
return rc;
Expand Down Expand Up @@ -816,8 +819,7 @@ int main(int argc, char *argv[]) {
multivector<double, m> y(ctx.queue(), n);
x = 1;
y = 2;
UserFunction<chk_if_gr_body, size_t(double, double)> chk_if_greater;
x = chk_if_greater(x, y);
x = greater(x, y);
for(size_t k = 0; k < 10; k++) {
size_t i = rand() % n;
std::array<double,m> val = x[i];
Expand Down Expand Up @@ -1065,6 +1067,38 @@ int main(int argc, char *argv[]) {
return rc;
});

#ifdef VEXCL_VARIADIC_TEMPLATES
run_test("Generalized stencil with user function convolution", [&]() -> bool {
bool rc = true;
const int n = 1 << 20;

double sdata[] = {1, 0, 1};
gstencil<double> S(ctx.queue(), 1, 3, 1, sdata, sdata + 3);

std::vector<double> x(n);
std::vector<double> y(n);
std::generate(x.begin(), x.end(), [](){ return (double)rand() / RAND_MAX; });

vex::vector<double> X(ctx.queue(), x);
vex::vector<double> Y(ctx.queue(), n);

Y = X + pow3(X * S);

copy(Y, y);

double res = 0;
for(int i = 0; i < n; i++) {
int left = std::max(0, i - 1);
int right = std::min(n - 1, i + 1);

double sum = x[i] + pow(x[left] + x[right], 3);
res = std::max(res, fabs(sum - y[i]));
}
rc = rc && res < 1e-8;
return rc;
});
#endif

} catch (const cl::Error &err) {
std::cerr << "OpenCL error: " << err << std::endl;
return 1;
Expand Down
1 change: 1 addition & 0 deletions vexcl/stencil.hpp
Expand Up @@ -887,6 +887,7 @@ void gstencil<T>::convolve(const vex::vector<T> &x, vex::vector<T> &y,

source << standard_kernel_header <<
"typedef " << type_name<T>() << " real;\n"
<< UserFunctionDeclaration<func>::get() <<
"real read_x(\n"
" long g_id,\n"
" " << type_name<size_t>() << " n,\n"
Expand Down
84 changes: 63 additions & 21 deletions vexcl/vector.hpp
Expand Up @@ -74,6 +74,8 @@ template <class T, uint N> struct MultiConv;
template <class Expr, class T, uint N> struct MultiExConv;
template <class func, class T, uint N> struct MultiGConv;
template <class Expr, class func, class T, uint N> struct MultiExGConv;
template <class T> struct GStencilProd;
template <class T, uint N> struct MultiGStencilProd;

/// Base class for a member of an expression.
/**
Expand Down Expand Up @@ -1532,6 +1534,8 @@ struct All<Head, Tail...>

#define DEFINE_BUILTIN_FUNCTION(name) \
struct name##_name { \
static const bool is_builtin = true; \
static const bool is_userfun = false; \
static const char* value() { \
return #name; \
} \
Expand Down Expand Up @@ -1681,6 +1685,8 @@ struct BuiltinFunction : public expression {

#define DEFINE_BUILTIN_FUNCTION(name) \
struct name##_name { \
static const bool is_builtin = true; \
static const bool is_userfun = false; \
static const char* value() { \
return #name; \
} \
Expand Down Expand Up @@ -1765,9 +1771,23 @@ DEFINE_BUILTIN_FUNCTION(trunc)
//---------------------------------------------------------------------------
// User-defined functions.
//---------------------------------------------------------------------------
template <class T>
struct UserFunctionDeclaration {
static std::string get() { return ""; }
};

#ifdef VEXCL_VARIADIC_TEMPLATES
/// \cond INTERNAL

template <const char *body, class T>
struct UserFunction {};

template<const char *body, class RetType, class... ArgType>
struct UserFunction<body, RetType(ArgType...)>;

template <const char *body, class RetType, class... ArgType>
struct UserFunctionDeclaration<UserFunction<body, RetType(ArgType...)>>;

/// Custom user function expression template
template<class RetType, class... ArgType>
struct UserFunctionFamily {
Expand All @@ -1779,11 +1799,9 @@ struct UserFunctionFamily {
void preamble(std::ostream &os, std::string name) const {
build_preamble<0>(os, name);

os << type_name<RetType>() << " " << name << "_fun(";

build_arg_list<ArgType...>(os, 0);

os << "\n\t)\n{\n" << body << "\n}\n";
os << UserFunctionDeclaration<
UserFunction<body, RetType(ArgType...)>
>::get(name);
}

std::string kernel_name() const {
Expand Down Expand Up @@ -1880,19 +1898,6 @@ struct UserFunctionFamily {
build_preamble<num + 1>(os, name);
}

//------------------------------------------------------------
template <class T>
void build_arg_list(std::ostream &os, uint num) const {
os << "\n\t" << type_name<T>() << " prm" << num + 1;
}

template <class T, class... Args>
typename std::enable_if<sizeof...(Args), void>::type
build_arg_list(std::ostream &os, uint num) const {
os << "\n\t" << type_name<T>() << " prm" << num + 1 << ",";
build_arg_list<Args...>(os, num + 1);
}

//------------------------------------------------------------
template <int num>
typename std::enable_if<num == sizeof...(Expr), size_t>::type
Expand All @@ -1912,9 +1917,6 @@ struct UserFunctionFamily {
};
};

template <const char *body, class T>
struct UserFunction {};

/// \endcond

/// Custom user function
Expand All @@ -1941,6 +1943,9 @@ struct UserFunction {};
*/
template<const char *body, class RetType, class... ArgType>
struct UserFunction<body, RetType(ArgType...)> {
static const bool is_builtin = false; \
static const bool is_userfun = true;

/// Apply user function to the list of expressions.
/**
* Number of expressions in the list has to coincide with number of
Expand Down Expand Up @@ -1978,6 +1983,43 @@ struct UserFunction<body, RetType(ArgType...)> {
typename UserFunctionFamily<RetType, ArgType...>::template Function<body, typename multiex_traits<Expr>::subtype...>,
multiex_dim<Expr...>::dim >(ex);
}

template <class T>
GConv<UserFunction,T> operator()(const GStencilProd<T> &s) const {
return GConv<UserFunction, T>(s);
}

template <class T, uint N>
MultiGConv<UserFunction,T,N> operator()(const MultiGStencilProd<T,N> &s) const {
return MultiGConv<UserFunction, T, N>(s);
}

static const char* value() {
return "user_fun";
}
};

template <const char *body, class RetType, class... ArgType>
struct UserFunctionDeclaration<UserFunction<body, RetType(ArgType...)>> {
static std::string get(const std::string &name = "user") {
std::ostringstream decl;
decl << type_name<RetType>() << " " << name << "_fun(";
build_arg_list<ArgType...>(decl, 0);
decl << "\n\t)\n{\n" << body << "\n}\n";
return decl.str();
}

template <class T>
static void build_arg_list(std::ostream &os, uint num) {
os << "\n\t" << type_name<T>() << " prm" << num + 1;
}

template <class T, class... Args>
static typename std::enable_if<sizeof...(Args), void>::type
build_arg_list(std::ostream &os, uint num) {
os << "\n\t" << type_name<T>() << " prm" << num + 1 << ",";
build_arg_list<Args...>(os, num + 1);
}
};
#endif

Expand Down

0 comments on commit 256c1bb

Please sign in to comment.