Skip to content

Commit

Permalink
Data validation macro generalisation (#1)
Browse files Browse the repository at this point in the history
* add prefixed macro definition

* Data validation macro generalisation fix windows (#2)

* check if linebreak causes windows build to break

* check if linebreak causes windows build to break

* check if empty macro arg causes windows build to break

* test if removing VA_ARGS resolves msvc error

* test PASS_ON hack

* Add missing PASS_ON

* Add tests for complex values is_finite tests

* Refactor macro expansion macro

* Test if PASS_ON is needed when calling GKO_APPLY_MACRO

* Refactor CALL_AND_RETURN_IF_CASTABLE macro

* Moves GKO_CALL_FOR_... macros to types.hpp
  • Loading branch information
greole committed May 20, 2021
1 parent d44edde commit 36d2357
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 35 deletions.
23 changes: 4 additions & 19 deletions core/components/validation_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,14 @@ bool is_symmetric_impl(const LinOp *matrix, const float tolerance)
});
}


#define GKO_CALL_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro, ...) \
_macro(float, int32, ##__VA_ARGS__); \
_macro(double, int32, ##__VA_ARGS__); \
_macro(float, int64, ##__VA_ARGS__); \
_macro(double, int64, ##__VA_ARGS__)


#define GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro, ...) \
GKO_CALL_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro, ##__VA_ARGS__); \
_macro(std::complex<float>, int32, ##__VA_ARGS__); \
_macro(std::complex<double>, int32, ##__VA_ARGS__); \
_macro(std::complex<float>, int64, ##__VA_ARGS__); \
_macro(std::complex<double>, int64, ##__VA_ARGS__)

#define CALL_AND_RETURN_IF_CASTABLE(T1, T2, func, matrix, ...) \
#define GKO_CALL_AND_RETURN_IF_CASTABLE(T1, T2, func, matrix, ...) \
if (dynamic_cast<const WritableToMatrixData<T1, T2> *>(matrix)) { \
return func<T1, T2>(matrix, ##__VA_ARGS__); \
}

bool is_symmetric(const LinOp *matrix, const float tolerance)
{
GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(CALL_AND_RETURN_IF_CASTABLE,
GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_CALL_AND_RETURN_IF_CASTABLE,
is_symmetric_impl, matrix, tolerance)
return false;
}
Expand All @@ -106,11 +91,11 @@ bool has_non_zero_diagonal_impl(const LinOp *matrix)

bool has_non_zero_diagonal(const LinOp *matrix)
{
GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(CALL_AND_RETURN_IF_CASTABLE,
GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_CALL_AND_RETURN_IF_CASTABLE,
has_non_zero_diagonal_impl, matrix)
return false;
}
#undef CALL_AND_RETURN_IF_CASTABLE
#undef GKO_CALL_AND_RETURN_IF_CASTABLE

template <typename IndexType>
bool is_row_ordered(const IndexType *row_ptrs, const size_type num_entries)
Expand Down
48 changes: 43 additions & 5 deletions core/test/components/validation_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/components/validation_helpers.hpp"
#include "core/test/utils.hpp"

#include <complex>
#include <limits>

namespace gko {
namespace test {
using RealValueTypes =
#if GINKGO_DPCPP_SINGLE_MODE
::testing::Types<float>;
#else
::testing::Types<float, double>;
#endif
} // namespace test
} // namespace gko


namespace {

#define GKO_DEFINE_ISSYMMETRIC(MATRIX_TYPE) \
Expand Down Expand Up @@ -191,18 +206,41 @@ TYPED_TEST(IndexTypeTest, IsWithinBoundsReturnsFalseUpperBound)
// ValueType Tests

template <typename T>
class ValueTypeTest : public ::testing::Test {
class RealValueTypeTest : public ::testing::Test {
protected:
ValueTypeTest() : exec(gko::ReferenceExecutor::create()) {}
RealValueTypeTest() : exec(gko::ReferenceExecutor::create()) {}

std::shared_ptr<const gko::Executor> exec;
};

TYPED_TEST_SUITE(ValueTypeTest, gko::test::ValueTypes);
TYPED_TEST_SUITE(RealValueTypeTest, gko::test::RealValueTypes);

TYPED_TEST(RealValueTypeTest, IsFiniteReturnsFalseOnInf)
{
TypeParam inf = std::numeric_limits<TypeParam>::infinity();
gko::Array<TypeParam> a{this->exec, {1., 3., 6.}};
a.get_data()[2] = inf;


ASSERT_EQ(gko::validate::is_finite(a.get_const_data(), a.get_num_elems()),
false);
}
template <typename T>
class ComplexValueTypeTest : public ::testing::Test {
protected:
ComplexValueTypeTest() : exec(gko::ReferenceExecutor::create()) {}

std::shared_ptr<const gko::Executor> exec;
};

TYPED_TEST(ValueTypeTest, IsFiniteReturnsFalseOnInf)
TYPED_TEST_SUITE(ComplexValueTypeTest, gko::test::ComplexValueTypes);
TYPED_TEST(ComplexValueTypeTest, IsFiniteReturnsFalseOnInf)
{
gko::Array<TypeParam> a{this->exec, {0., 1., 1.0 / 0.0}};
TypeParam inf =
std::numeric_limits<typename TypeParam::value_type>::infinity();
gko::Array<TypeParam> a{this->exec, {0., 1., 0.}};
a.get_data()[0] = std::complex<typename TypeParam::value_type>(inf);


ASSERT_EQ(gko::validate::is_finite(a.get_const_data(), a.get_num_elems()),
false);
Expand Down
76 changes: 65 additions & 11 deletions include/ginkgo/core/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,68 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x,
_enable_macro(DpcppExecutor, dpcpp); \
_enable_macro(CudaExecutor, cuda)

/**
* This helper macro handles consistent expansion of __VA_ARGS__ for MSVC and
* other compilers.
*
* Consider the following example given on https://bit.ly/3wp8PqT
*
* #define MACRO_WITH_3_PARAMS(p1, p2, p3) P1 = p1 | P2 = p2 | P3 = p3
* #define MACRO_VA_ARGS(...) MACRO_WITH_3_PARAMS( __VA_ARGS__)
* MACRO_VA_ARGS(foo, bar, baz)
*
* On MSVC this is expanded into
* P1 = foo, bar, baz | P2 = | P3 =
*
* In order to expand the macro consitently PASS_ON can be applied
* #define MACRO_VA_ARGS(...) PASS_ON(PASS_ON(MACRO_WITH_3_PARAMS)(
* __VA_ARGS__))
*/
#define PASS_ON(...) __VA_ARGS__

/**
* Calls a macro with variable number of arguments.
*
* The macro uses PASS_ON to expand __VA_ARGS__. This is needed to have
* consistent behaviour among MSVC and other compilers.
*
* @param _macro The macro which is to be expanded.
*/
#define GKO_APPLY_MACRO(_macro, ...) PASS_ON(PASS_ON(_macro)(__VA_ARGS__))

/**
* Expands a macro for each non_complex value and index type compiled by
* Ginkgo.
*
* @param _macro A macro which expands the template instantiation
* (not including the leading `template` specifier).
* Should take two arguments, which are replaced by the
* value and index types.
* @param ... Optional parameters which are passed to _macro
* @param ... Optional parameters which are passed to _macro
*/
#define GKO_CALL_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro, ...) \
GKO_APPLY_MACRO(_macro, float, int32, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, double, int32, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, float, int64, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, double, int64, ##__VA_ARGS__)


/**
* Expands a macro for each value and index type compiled by Ginkgo.
*
* @param _macro A macro which expands the template instantiation
* (not including the leading `template` specifier).
* Should take two arguments, which are replaced by the
* value and index types.
* @param ... Optional parameters which are passed to _macro
*/
#define GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro, ...) \
GKO_CALL_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, std::complex<float>, int32, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, std::complex<float>, int64, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, std::complex<double>, int32, ##__VA_ARGS__); \
GKO_APPLY_MACRO(_macro, std::complex<double>, int64, ##__VA_ARGS__)

/**
* Instantiates a template for each non-complex value type compiled by Ginkgo.
Expand Down Expand Up @@ -490,10 +552,7 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x,
_macro(double, int64) GKO_NOT_IMPLEMENTED
#else
#define GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro) \
template _macro(float, int32); \
template _macro(double, int32); \
template _macro(float, int64); \
template _macro(double, int64)
GKO_CALL_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(template _macro)
#endif


Expand All @@ -515,12 +574,8 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x,
template <> \
_macro(std::complex<double>, int64) GKO_NOT_IMPLEMENTED
#else
#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro) \
GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro); \
template _macro(std::complex<float>, int32); \
template _macro(std::complex<double>, int32); \
template _macro(std::complex<float>, int64); \
template _macro(std::complex<double>, int64)
#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro) \
GKO_CALL_FOR_EACH_VALUE_AND_INDEX_TYPE(template _macro)
#endif


Expand Down Expand Up @@ -568,5 +623,4 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x,

} // namespace gko


#endif // GKO_PUBLIC_CORE_BASE_TYPES_HPP_

0 comments on commit 36d2357

Please sign in to comment.