Skip to content
Permalink
Browse files

expression_scalar: support for vec<T>

  • Loading branch information...
dlevin256 committed Apr 1, 2019
1 parent 1f2d997 commit 125177d1b75db4af7ffa1761ec31ed1769bb8bb1
Showing with 10 additions and 28 deletions.
  1. +8 −26 include/kfr/base/expression.hpp
  2. +2 −2 tests/complex_test.cpp
@@ -286,8 +286,8 @@ struct expression_with_arguments : input_expression
vec_shape<U, N>) const
{
static_assert(ArgIndex < count, "Incorrect ArgIndex");
return get_elements(
static_cast<vec<U, N>>(std::get<ArgIndex>(this->args), cinput, index, vec_shape<T, N>()));
return static_cast<vec<U, N>>(
get_elements(std::get<ArgIndex>(this->args), cinput, index, vec_shape<T, N>()));
}
template <typename U, size_t N,
typename T = value_type_of<typename details::get_nth_type<0, Args...>::type>>
@@ -317,20 +317,19 @@ struct expression_with_arguments : input_expression
}
};

template <typename T, size_t width = 1>
template <typename T>
struct expression_scalar : input_expression
{
using value_type = T;
expression_scalar() = delete;
constexpr expression_scalar(const T& val) CMT_NOEXCEPT : val(val) {}
constexpr expression_scalar(const vec<T, width>& val) CMT_NOEXCEPT : val(val) {}
vec<T, width> val;
T val;

template <size_t N>
friend KFR_INTRINSIC vec<T, N> get_elements(const expression_scalar& self, cinput_t, size_t,
vec_shape<T, N>)
{
return resize<N>(self.val);
return broadcast<N>(self.val);
}
};

@@ -341,23 +340,11 @@ struct arg_impl
};

template <typename T1, typename T2>
struct arg_impl<T1, T2, void_t<enable_if<is_number<T1>::value>>>
struct arg_impl<T1, T2, void_t<enable_if<is_vec_element<T1>::value>>>
{
using type = expression_scalar<T1>;
};

template <typename T1, typename T2>
struct arg_impl<complex<T1>, T2>
{
using type = expression_scalar<complex<T1>>;
};

template <typename T1, typename T2, size_t N>
struct arg_impl<vec<T1, N>, T2>
{
using type = expression_scalar<T1, N>;
};

template <typename T>
using arg = typename internal::arg_impl<decay<T>, T>::type;

@@ -404,12 +391,6 @@ CMT_INTRINSIC internal::expression_scalar<T> scalar(const T& val)
return internal::expression_scalar<T>(val);
}

template <typename T, size_t N>
CMT_INTRINSIC internal::expression_scalar<T, N> scalar(const vec<T, N>& val)
{
return internal::expression_scalar<T, N>(val);
}

template <typename Fn, typename... Args>
CMT_INTRINSIC internal::expression_function<decay<Fn>, Args...> bind_expression(Fn&& fn, Args&&... args)
{
@@ -428,7 +409,8 @@ CMT_INTRINSIC internal::expression_function<Fn, NewArgs...> rebind(
return internal::expression_function<Fn, NewArgs...>(e.get_fn(), std::forward<NewArgs>(args)...);
}

template <size_t width = 0, typename OutputExpr, typename InputExpr, size_t groupsize = 1>
template <size_t width = 0, typename OutputExpr, typename InputExpr, size_t groupsize = 1,
typename Tvec = vec<value_type_of<InputExpr>, 1>>
CMT_INTRINSIC static size_t process(OutputExpr&& out, const InputExpr& in, size_t start = 0,
size_t size = infinite_size, coutput_t coutput = nullptr,
cinput_t cinput = nullptr, csize_t<groupsize> = csize_t<groupsize>())
@@ -200,9 +200,9 @@ TEST(static_tests)
testo::assert_is_same<ftype<vec<complex<i32>, 4>>, vec<complex<f32>, 4>>();
testo::assert_is_same<ftype<vec<complex<i64>, 8>>, vec<complex<f64>, 8>>();

testo::assert_is_same<kfr::internal::arg<int>, kfr::internal::expression_scalar<int, 1>>();
testo::assert_is_same<kfr::internal::arg<int>, kfr::internal::expression_scalar<int>>();
testo::assert_is_same<kfr::internal::arg<complex<int>>,
kfr::internal::expression_scalar<kfr::complex<int>, 1>>();
kfr::internal::expression_scalar<kfr::complex<int>>>();

testo::assert_is_same<kfr::common_type<complex<int>, double>, complex<double>>();
}

0 comments on commit 125177d

Please sign in to comment.
You can’t perform that action at this time.