Skip to content

Commit

Permalink
Fix test naming error with half and double (#340) (#342)
Browse files Browse the repository at this point in the history
* Fix test naming error with half and double (#340)

* Fixes bug #340 Generated names for half and double are invalid
* Implements additional specializations of dump_args for double and
  cl::sycl::half.
  * A more general approach does not work well because the standard
    library does not work with with cl::sycl::half.

* Change PR to use partial specialization w/ dump_arg_helper struct
  • Loading branch information
hjabird committed Oct 18, 2022
1 parent c5e50c3 commit fc1fe35
Showing 1 changed file with 68 additions and 28 deletions.
96 changes: 68 additions & 28 deletions test/blas_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using namespace blas;
using test_executor_t =
blas::Executor<blas::PolicyHandler<blas::codeplay_policy>>;


using index_t = BLAS_INDEX_T;

/**
Expand Down Expand Up @@ -192,41 +192,81 @@ static inline void fill_trsm_matrix(std::vector<scalar_t> &A, size_t k,
}

/**
* @brief Dump an argument to a stream.
* @brief Helper class for dumping arguments to a stream, in a format compatible
* with google test test names.
*
* @param ss Output stream
* @param arg Argument to format
* @tparam T is the argument type to dump to the stream.
* @tparam Enable is a helper for partial template specialization.
*/
template <class T>
inline void dump_arg(std::ostream &ss, T arg) {
ss << arg;
}
template <class T, typename Enable = void>
struct dump_arg_helper {
/** Dump the argument to the stream.
*
* @param ss Output stream
* @param arg Argument to format
**/
inline void operator()(std::ostream &ss, T arg) { ss << arg; }
};

/** Specialization of dump_arg_helper for float and double. NB this is not a
* specialization for half. std::is_floating_point<cl::sycl::half>::value will
* return false.
*
* @tparam StdFloat A standard floating point type.
**/
template <class StdFloat>
struct dump_arg_helper<
StdFloat,
typename std::enable_if<std::is_floating_point<StdFloat>::value>::type> {
/**
* @brief Dump an argument to a stream.
* Format floating point numbers for GTest. A test name cannot contain
* "-" nor "." so they are replaced with "m" and "p" respectively. The
* fractional part is ignored if null otherwise it is printed with 2 digits.
*
* @param ss Output stream
* @param f Floating point number to format
*/
inline void operator()(std::ostream &ss, StdFloat f) {
static_assert(!std::is_same<StdFloat, cl::sycl::half>::value,
"std library functions will not work with half.");
if (std::isnan(f)) {
ss << "nan";
return;
}
if (f < 0) {
ss << "m";
f = std::fabs(f);
}
StdFloat int_part;
StdFloat frac_part = std::modf(f, &int_part);
ss << int_part;
if (frac_part > 0) {
ss << "p" << (int)(frac_part * 100);
}
}
};

/** Specialization of dump_arg_helper for cl::sycl::half.
* This is required since half will not work with standard library functions.
**/
template <>
struct dump_arg_helper<cl::sycl::half> {
inline void operator()(std::ostream &ss, cl::sycl::half f) {
dump_arg_helper<float>{}(ss, f);
}
};

/**
* @brief Dump an argument to a stream.
* Format floating point numbers for GTest. A test name cannot contain
* "-" nor "." so they are replaced with "m" and "p" respectively. The
* fractional part is ignored if null otherwise it is printed with 2 digits.
*
* @tparam T is the type of the argument to format.
* @param ss Output stream
* @param f Floating point number to format
* @param arg Argument to format
*/
template <>
inline void dump_arg<float>(std::ostream &ss, float f) {
if (std::isnan(f)) {
ss << "nan";
return;
}
if (f < 0) {
ss << "m";
f = std::fabs(f);
}
float int_part;
float frac_part = modff(f, &int_part);
ss << int_part;
if (frac_part > 0) {
ss << "p" << (int)(frac_part * 100);
}
template <class T>
inline void dump_arg(std::ostream &ss, T arg) {
dump_arg_helper<T>{}(ss, arg);
}

/**
Expand Down

0 comments on commit fc1fe35

Please sign in to comment.