Skip to content

Commit

Permalink
Use general ConfigSet
Browse files Browse the repository at this point in the history
Co-authored-by: Terry Cojean <terry.cojean@kit.edu>
  • Loading branch information
yhmtsai and tcojean committed May 17, 2021
1 parent 3c03405 commit 0814b7a
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 57 deletions.
61 changes: 30 additions & 31 deletions core/synthesizer/implementation_selection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,36 @@ namespace syn {
} \
}

#define GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION(_name, _callable) \
template <typename Predicate, bool... BoolArgs, int... IntArgs, \
gko::size_type... SizeTArgs, typename... TArgs, \
typename... InferredArgs> \
inline void _name(::gko::syn::value_list<gko::Config>, Predicate, \
::gko::syn::value_list<bool, BoolArgs...>, \
::gko::syn::value_list<int, IntArgs...>, \
::gko::syn::value_list<gko::size_type, SizeTArgs...>, \
::gko::syn::type_list<TArgs...>, InferredArgs...) \
GKO_KERNEL_NOT_FOUND; \
\
template <gko::Config K, gko::Config... Rest, typename Predicate, \
bool... BoolArgs, int... IntArgs, gko::size_type... SizeTArgs, \
typename... TArgs, typename... InferredArgs> \
inline void _name( \
::gko::syn::value_list<gko::Config, K, Rest...>, \
Predicate is_eligible, \
::gko::syn::value_list<bool, BoolArgs...> bool_args, \
::gko::syn::value_list<int, IntArgs...> int_args, \
::gko::syn::value_list<gko::size_type, SizeTArgs...> size_args, \
::gko::syn::type_list<TArgs...> type_args, InferredArgs... args) \
{ \
if (is_eligible(K)) { \
std::cout << "call " << K << std::endl; \
_callable<BoolArgs..., IntArgs..., SizeTArgs..., TArgs..., K>( \
std::forward<InferredArgs>(args)...); \
} else { \
_name(::gko::syn::value_list<gko::Config, Rest...>(), is_eligible, \
bool_args, int_args, size_args, type_args, \
std::forward<InferredArgs>(args)...); \
} \
#define GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION(_name, _callable) \
template <typename Predicate, bool... BoolArgs, int... IntArgs, \
gko::size_type... SizeTArgs, typename... TArgs, \
typename... InferredArgs> \
inline void _name(::gko::syn::value_list<int>, Predicate, \
::gko::syn::value_list<bool, BoolArgs...>, \
::gko::syn::value_list<int, IntArgs...>, \
::gko::syn::value_list<gko::size_type, SizeTArgs...>, \
::gko::syn::type_list<TArgs...>, InferredArgs...) \
GKO_KERNEL_NOT_FOUND; \
\
template <int K, int... Rest, typename Predicate, bool... BoolArgs, \
int... IntArgs, gko::size_type... SizeTArgs, typename... TArgs, \
typename... InferredArgs> \
inline void _name( \
::gko::syn::value_list<int, K, Rest...>, Predicate is_eligible, \
::gko::syn::value_list<bool, BoolArgs...> bool_args, \
::gko::syn::value_list<int, IntArgs...> int_args, \
::gko::syn::value_list<gko::size_type, SizeTArgs...> size_args, \
::gko::syn::type_list<TArgs...> type_args, InferredArgs... args) \
{ \
if (is_eligible(K)) { \
std::cout << "call " << K << std::endl; \
_callable<BoolArgs..., IntArgs..., SizeTArgs..., TArgs..., K>( \
std::forward<InferredArgs>(args)...); \
} else { \
_name(::gko::syn::value_list<int, Rest...>(), is_eligible, \
bool_args, int_args, size_args, type_args, \
std::forward<InferredArgs>(args)...); \
} \
}


Expand Down
30 changes: 15 additions & 15 deletions dpcpp/test/components/cooperative_groups_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace {


using namespace gko::kernels::dpcpp;
using KCfg = gko::ConfigSet<12, 7>;


class CooperativeGroups : public ::testing::Test {
Expand Down Expand Up @@ -109,23 +110,22 @@ void test_assert(bool *success, bool partial)
}

// kernel implementation
template <gko::Config config>
[[intel::reqd_work_group_size(1, 1, gko::get_warp_size(config))]] void
cg_shuffle(bool *s, sycl::nd_item<3> item_ct1)
template <int config>
[[intel::reqd_work_group_size(1, 1, KCfg::decode<0>(config))]] void cg_shuffle(
bool *s, sycl::nd_item<3> item_ct1)
{
auto group = group::tiled_partition<gko::get_warp_size(config)>(
auto group = group::tiled_partition<KCfg::decode<0>(config)>(
group::this_thread_block(item_ct1));
auto i = int(group.thread_rank());
test_assert(s, group.shfl_up(i, 1) == sycl::max(0, (int)(i - 1)));
test_assert(s,
group.shfl_down(i, 1) ==
sycl::min((unsigned int)(i + 1),
(unsigned int)(gko::get_warp_size(config) - 1)));
test_assert(s, group.shfl_down(i, 1) ==
sycl::min((unsigned int)(i + 1),
(unsigned int)(KCfg::decode<0>(config) - 1)));
test_assert(s, group.shfl(i, 0) == 0);
}

// group all kernel things together
template <gko::Config config>
template <int config>
void cg_shuffle_host(dim3 grid, dim3 block, size_t dynamic_shared_memory,
sycl::queue *stream, bool *s)
{
Expand All @@ -148,16 +148,16 @@ void cg_shuffle_config_call(dim3 grid, dim3 block, size_t dynamic_shared_memory,
{
auto exec_info = exec->get_const_exec_info();
constexpr auto default_config_list =
::gko::syn::value_list<gko::Config, gko::config_set(32, 32),
gko::config_set(16, 16), gko::config_set(8, 8),
gko::config_set(4, 4)>();
::gko::syn::value_list<int, KCfg::encode(32, 32), KCfg::encode(16, 16),
KCfg::encode(8, 8), KCfg::encode(4, 4)>();
std::cout << "block.x " << block.x << std::endl;
cg_shuffle_config(
default_config_list,
// validate
[&exec_info, &block](gko::Config config) {
return exec_info.validate(config) &&
(gko::get_warp_size(config) == block.x);
[&exec_info, &block](int config) {
return exec_info.validate(KCfg::decode<0>(config),
KCfg::decode<1>(config)) &&
(KCfg::decode<1>(config) == block.x);
},
::gko::syn::value_list<bool>(), ::gko::syn::value_list<int>(),
::gko::syn::value_list<gko::size_type>(), ::gko::syn::type_list<>(),
Expand Down
4 changes: 1 addition & 3 deletions include/ginkgo/core/base/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,11 +815,9 @@ class Executor : public log::EnableLogging<Executor> {
/**
* The validate function for Config
*/
bool validate(Config config)
bool validate(int blocksize, int warpsize)
{
bool allowed = false;
auto blocksize = get_block_size(config);
auto warpsize = get_warp_size(config);
for (auto &i : subgroup_sizes) {
allowed |= (i == warpsize);
}
Expand Down
66 changes: 58 additions & 8 deletions include/ginkgo/core/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_PUBLIC_CORE_BASE_TYPES_HPP_


#include <array>
#include <cassert>
#include <climits>
#include <complex>
Expand Down Expand Up @@ -196,25 +197,74 @@ constexpr size_type byte_size = CHAR_BIT;
/**
* Config type
*/
using Config = int;
namespace helper {

constexpr int config_scaler = 1000;

constexpr Config config_set(int block_size, int warp_size)
template <size_type current_size>
constexpr int mask()
{
return block_size * config_scaler + warp_size;
return (1 << (current_size - 1)) | mask<current_size - 1>();
}

constexpr int get_warp_size(Config config_set)
template <>
constexpr int mask<0>()
{
return config_set % config_scaler;
return 0;
}

constexpr int get_block_size(Config config_set)
template <int num_groups, int current_shift>
constexpr std::enable_if_t<(num_groups == current_shift + 1), int> shift(
const std::array<char, num_groups> &bits)
{
return config_set / config_scaler;
return 0;
}

template <int num_groups, int current_shift>
constexpr std::enable_if_t<(num_groups > current_shift + 1), int> shift(
const std::array<char, num_groups> &bits)
{
return bits[current_shift + 1] +
shift<num_groups, (current_shift + 1)>(bits);
}


} // namespace helper


template <int... num_bits>
class ConfigSet {
public:
static constexpr size_type num_groups = sizeof...(num_bits);
static constexpr std::array<char, num_groups> bits{num_bits...};

template <int position>
static constexpr int decode(int encoded)
{
static_assert(position < num_groups,
"This position is over the bounds.");
constexpr int shift = helper::shift<num_groups, position>(bits);
constexpr int mask = helper::mask<bits[position]>();
return (encoded >> shift) & mask;
}

template <size_type current_iter>
static constexpr std::enable_if_t<(current_iter == num_groups), int>
encode()
{
return 0;
}

template <size_type current_iter = 0, typename First, typename... Rest>
static constexpr std::enable_if_t<(current_iter < num_groups), int> encode(
First first, Rest &&... rest)
{
constexpr int shift = helper::shift<num_groups, current_iter>(bits);
return (first << shift) |
encode<current_iter + 1>(std::forward<Rest>(rest)...);
}
};


/**
* Evaluates if all template arguments Args fulfill std::is_integral. If that is
* the case, this class inherits from `std::true_type`, otherwise, it inherits
Expand Down

0 comments on commit 0814b7a

Please sign in to comment.