diff --git a/examples/utests.cpp b/examples/utests.cpp index d949eae34..8b9552dcb 100644 --- a/examples/utests.cpp +++ b/examples/utests.cpp @@ -25,13 +25,13 @@ bool run_test(const std::string &name, std::function test) { return rc; } -extern const char chk_if_gr_body[] = "return prm1 > prm2 ? 1 : 0;"; +extern const char greater_body[] = "return prm1 > prm2 ? 1 : 0;"; +UserFunction greater; -int main(int argc, char *argv[]) { - uint seed = argc > 1 ? atoi(argv[1]) : static_cast(time(0)); - std::cout << "seed: " << seed << std::endl; - srand(seed); +extern const char pow3_body[] = "return pow(prm1, 3);"; +UserFunction pow3; +int main(int argc, char *argv[]) { try { vex::Context ctx(Filter::DoublePrecision && Filter::Env); std::cout << ctx << std::endl; @@ -41,6 +41,10 @@ int main(int argc, char *argv[]) { return 1; } + uint seed = argc > 1 ? atoi(argv[1]) : static_cast(time(0)); + std::cout << "seed: " << seed << std::endl << std::endl; + srand(seed); + run_test("Empty vector construction", [&]() -> bool { bool rc = true; vex::vector x; @@ -780,10 +784,9 @@ int main(int argc, char *argv[]) { vex::vector y(ctx.queue(), N); x = 1; y = 2; - UserFunction chk_if_greater; Reductor sum(ctx.queue()); - rc = rc && sum(chk_if_greater(x, y)) == 0; - rc = rc && sum(chk_if_greater(y, x)) == N; + rc = rc && sum(greater(x, y)) == 0; + rc = rc && sum(greater(y, x)) == N; rc = rc && sum(x > y) == 0; rc = rc && sum(x < y) == N; return rc; @@ -816,8 +819,7 @@ int main(int argc, char *argv[]) { multivector y(ctx.queue(), n); x = 1; y = 2; - UserFunction chk_if_greater; - x = chk_if_greater(x, y); + x = greater(x, y); for(size_t k = 0; k < 10; k++) { size_t i = rand() % n; std::array val = x[i]; @@ -1065,6 +1067,38 @@ int main(int argc, char *argv[]) { return rc; }); +#ifdef VEXCL_VARIADIC_TEMPLATES + run_test("Generalized stencil with user function convolution", [&]() -> bool { + bool rc = true; + const int n = 1 << 20; + + double sdata[] = {1, 0, 1}; + gstencil S(ctx.queue(), 1, 3, 1, sdata, sdata + 3); + + std::vector x(n); + std::vector y(n); + std::generate(x.begin(), x.end(), [](){ return (double)rand() / RAND_MAX; }); + + vex::vector X(ctx.queue(), x); + vex::vector Y(ctx.queue(), n); + + Y = X + pow3(X * S); + + copy(Y, y); + + double res = 0; + for(int i = 0; i < n; i++) { + int left = std::max(0, i - 1); + int right = std::min(n - 1, i + 1); + + double sum = x[i] + pow(x[left] + x[right], 3); + res = std::max(res, fabs(sum - y[i])); + } + rc = rc && res < 1e-8; + return rc; + }); +#endif + } catch (const cl::Error &err) { std::cerr << "OpenCL error: " << err << std::endl; return 1; diff --git a/vexcl/stencil.hpp b/vexcl/stencil.hpp index 5c2c0ae92..0edd929f6 100644 --- a/vexcl/stencil.hpp +++ b/vexcl/stencil.hpp @@ -887,6 +887,7 @@ void gstencil::convolve(const vex::vector &x, vex::vector &y, source << standard_kernel_header << "typedef " << type_name() << " real;\n" + << UserFunctionDeclaration::get() << "real read_x(\n" " long g_id,\n" " " << type_name() << " n,\n" diff --git a/vexcl/vector.hpp b/vexcl/vector.hpp index 05e5eedf3..a3a2c2658 100644 --- a/vexcl/vector.hpp +++ b/vexcl/vector.hpp @@ -74,6 +74,8 @@ template struct MultiConv; template struct MultiExConv; template struct MultiGConv; template struct MultiExGConv; +template struct GStencilProd; +template struct MultiGStencilProd; /// Base class for a member of an expression. /** @@ -1532,6 +1534,8 @@ struct All #define DEFINE_BUILTIN_FUNCTION(name) \ struct name##_name { \ + static const bool is_builtin = true; \ + static const bool is_userfun = false; \ static const char* value() { \ return #name; \ } \ @@ -1681,6 +1685,8 @@ struct BuiltinFunction : public expression { #define DEFINE_BUILTIN_FUNCTION(name) \ struct name##_name { \ + static const bool is_builtin = true; \ + static const bool is_userfun = false; \ static const char* value() { \ return #name; \ } \ @@ -1765,9 +1771,23 @@ DEFINE_BUILTIN_FUNCTION(trunc) //--------------------------------------------------------------------------- // User-defined functions. //--------------------------------------------------------------------------- +template +struct UserFunctionDeclaration { + static std::string get() { return ""; } +}; + #ifdef VEXCL_VARIADIC_TEMPLATES /// \cond INTERNAL +template +struct UserFunction {}; + +template +struct UserFunction; + +template +struct UserFunctionDeclaration>; + /// Custom user function expression template template struct UserFunctionFamily { @@ -1779,11 +1799,9 @@ struct UserFunctionFamily { void preamble(std::ostream &os, std::string name) const { build_preamble<0>(os, name); - os << type_name() << " " << name << "_fun("; - - build_arg_list(os, 0); - - os << "\n\t)\n{\n" << body << "\n}\n"; + os << UserFunctionDeclaration< + UserFunction + >::get(name); } std::string kernel_name() const { @@ -1880,19 +1898,6 @@ struct UserFunctionFamily { build_preamble(os, name); } - //------------------------------------------------------------ - template - void build_arg_list(std::ostream &os, uint num) const { - os << "\n\t" << type_name() << " prm" << num + 1; - } - - template - typename std::enable_if::type - build_arg_list(std::ostream &os, uint num) const { - os << "\n\t" << type_name() << " prm" << num + 1 << ","; - build_arg_list(os, num + 1); - } - //------------------------------------------------------------ template typename std::enable_if::type @@ -1912,9 +1917,6 @@ struct UserFunctionFamily { }; }; -template -struct UserFunction {}; - /// \endcond /// Custom user function @@ -1941,6 +1943,9 @@ struct UserFunction {}; */ template struct UserFunction { + static const bool is_builtin = false; \ + static const bool is_userfun = true; + /// Apply user function to the list of expressions. /** * Number of expressions in the list has to coincide with number of @@ -1978,6 +1983,43 @@ struct UserFunction { typename UserFunctionFamily::template Function::subtype...>, multiex_dim::dim >(ex); } + + template + GConv operator()(const GStencilProd &s) const { + return GConv(s); + } + + template + MultiGConv operator()(const MultiGStencilProd &s) const { + return MultiGConv(s); + } + + static const char* value() { + return "user_fun"; + } +}; + +template +struct UserFunctionDeclaration> { + static std::string get(const std::string &name = "user") { + std::ostringstream decl; + decl << type_name() << " " << name << "_fun("; + build_arg_list(decl, 0); + decl << "\n\t)\n{\n" << body << "\n}\n"; + return decl.str(); + } + + template + static void build_arg_list(std::ostream &os, uint num) { + os << "\n\t" << type_name() << " prm" << num + 1; + } + + template + static typename std::enable_if::type + build_arg_list(std::ostream &os, uint num) { + os << "\n\t" << type_name() << " prm" << num + 1 << ","; + build_arg_list(os, num + 1); + } }; #endif