Skip to content

Commit

Permalink
Merge symmetric permutation kernels
Browse files Browse the repository at this point in the history
This PR adds symmetric permutation functions and slightly improves the documentation for Permutable.

Related PR: #684
  • Loading branch information
upsj committed Jan 6, 2021
2 parents 4a44dbd + d509d71 commit 5259383
Show file tree
Hide file tree
Showing 28 changed files with 973 additions and 127 deletions.
26 changes: 26 additions & 0 deletions common/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1062,4 +1062,30 @@ __global__ __launch_bounds__(default_block_size) void inv_row_permute_kernel(
out_cols[out_begin + i] = in_cols[in_begin + i];
out_vals[out_begin + i] = in_vals[in_begin + i];
}
}


template <int subwarp_size, typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void inv_symm_permute_kernel(
size_type num_rows, const IndexType *__restrict__ permutation,
const IndexType *__restrict__ in_row_ptrs,
const IndexType *__restrict__ in_cols,
const ValueType *__restrict__ in_vals,
const IndexType *__restrict__ out_row_ptrs,
IndexType *__restrict__ out_cols, ValueType *__restrict__ out_vals)
{
auto tid = thread::get_subwarp_id_flat<subwarp_size>();
if (tid >= num_rows) {
return;
}
auto lane = threadIdx.x % subwarp_size;
auto in_row = tid;
auto out_row = permutation[tid];
auto in_begin = in_row_ptrs[in_row];
auto in_size = in_row_ptrs[in_row + 1] - in_begin;
auto out_begin = out_row_ptrs[out_row];
for (IndexType i = lane; i < in_size; i += subwarp_size) {
out_cols[out_begin + i] = permutation[in_cols[in_begin + i]];
out_vals[out_begin + i] = in_vals[in_begin + i];
}
}
66 changes: 46 additions & 20 deletions common/matrix/dense_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,48 @@ __global__ __launch_bounds__(default_block_size) void reduce_total_cols(
}


template <size_type block_size, typename IndexType, typename ValueType>
__global__ __launch_bounds__(block_size) void row_gather(
template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void symm_permute(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
constexpr auto warps_per_block = block_size / config::warp_size;
const auto global_id =
thread::get_thread_id<config::warp_size, warps_per_block>();
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
result[row_id * stride_result + col_id] =
orig[perm_idxs[row_id] * stride_orig + perm_idxs[col_id]];
}
}


template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void inv_symm_permute(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
result[perm_idxs[row_id] * stride_result + perm_idxs[col_id]] =
orig[row_id * stride_orig + col_id];
}
}


template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void row_gather(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
Expand All @@ -440,16 +472,14 @@ __global__ __launch_bounds__(block_size) void row_gather(
}


template <size_type block_size, typename IndexType, typename ValueType>
__global__ __launch_bounds__(block_size) void column_permute(
template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void column_permute(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
constexpr auto warps_per_block = block_size / config::warp_size;
const auto global_id =
thread::get_thread_id<config::warp_size, warps_per_block>();
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
Expand All @@ -459,16 +489,14 @@ __global__ __launch_bounds__(block_size) void column_permute(
}


template <size_type block_size, typename IndexType, typename ValueType>
__global__ __launch_bounds__(block_size) void inverse_row_permute(
template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void inverse_row_permute(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
constexpr auto warps_per_block = block_size / config::warp_size;
const auto global_id =
thread::get_thread_id<config::warp_size, warps_per_block>();
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
Expand All @@ -478,16 +506,14 @@ __global__ __launch_bounds__(block_size) void inverse_row_permute(
}


template <size_type block_size, typename IndexType, typename ValueType>
__global__ __launch_bounds__(block_size) void inverse_column_permute(
template <typename IndexType, typename ValueType>
__global__ __launch_bounds__(default_block_size) void inverse_column_permute(
size_type num_rows, size_type num_cols,
const IndexType *__restrict__ perm_idxs, const ValueType *__restrict__ orig,
size_type stride_orig, ValueType *__restrict__ result,
size_type stride_result)
{
constexpr auto warps_per_block = block_size / config::warp_size;
const auto global_id =
thread::get_thread_id<config::warp_size, warps_per_block>();
const auto global_id = thread::get_thread_id_flat();
const auto row_id = global_id / num_cols;
const auto col_id = global_id % num_cols;
if (row_id < num_rows) {
Expand Down
18 changes: 18 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL(ValueType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_DENSE_ROW_GATHER_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
Expand Down Expand Up @@ -688,6 +700,12 @@ GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
Expand Down
63 changes: 51 additions & 12 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ GKO_REGISTER_OPERATION(convert_to_ell, csr::convert_to_ell);
GKO_REGISTER_OPERATION(convert_to_hybrid, csr::convert_to_hybrid);
GKO_REGISTER_OPERATION(transpose, csr::transpose);
GKO_REGISTER_OPERATION(conj_transpose, csr::conj_transpose);
GKO_REGISTER_OPERATION(inv_symm_permute, csr::inv_symm_permute);
GKO_REGISTER_OPERATION(row_permute, csr::row_permute);
GKO_REGISTER_OPERATION(inverse_row_permute, csr::inverse_row_permute);
GKO_REGISTER_OPERATION(inverse_column_permute, csr::inverse_column_permute);
Expand Down Expand Up @@ -402,6 +403,48 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::conj_transpose() const
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::permute(
const Array<IndexType> *permutation_indices) const
{
GKO_ASSERT_IS_SQUARE_MATRIX(this);
GKO_ASSERT_EQ(permutation_indices->get_num_elems(), this->get_size()[0]);
auto exec = this->get_executor();
auto permute_cpy =
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
this->get_strategy());
Array<IndexType> inv_permutation(exec, this->get_size()[1]);

exec->run(csr::make_invert_permutation(
this->get_size()[1],
make_temporary_clone(exec, permutation_indices)->get_const_data(),
inv_permutation.get_data()));
exec->run(csr::make_inv_symm_permute(inv_permutation.get_const_data(), this,
permute_cpy.get()));
permute_cpy->make_srow();
return std::move(permute_cpy);
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::inverse_permute(
const Array<IndexType> *permutation_indices) const
{
GKO_ASSERT_IS_SQUARE_MATRIX(this);
GKO_ASSERT_EQ(permutation_indices->get_num_elems(), this->get_size()[0]);
auto exec = this->get_executor();
auto permute_cpy =
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
this->get_strategy());

exec->run(csr::make_inv_symm_permute(
make_temporary_clone(exec, permutation_indices)->get_const_data(), this,
permute_cpy.get()));
permute_cpy->make_srow();
return std::move(permute_cpy);
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::row_permute(
const Array<IndexType> *permutation_indices) const
Expand Down Expand Up @@ -445,39 +488,35 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::column_permute(

template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::inverse_row_permute(
const Array<IndexType> *inverse_permutation_indices) const
const Array<IndexType> *permutation_indices) const
{
GKO_ASSERT_EQ(inverse_permutation_indices->get_num_elems(),
this->get_size()[0]);
GKO_ASSERT_EQ(permutation_indices->get_num_elems(), this->get_size()[0]);
auto exec = this->get_executor();
auto inverse_permute_cpy =
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
this->get_strategy());

exec->run(csr::make_inverse_row_permute(
make_temporary_clone(exec, inverse_permutation_indices)
->get_const_data(),
this, inverse_permute_cpy.get()));
make_temporary_clone(exec, permutation_indices)->get_const_data(), this,
inverse_permute_cpy.get()));
inverse_permute_cpy->make_srow();
return std::move(inverse_permute_cpy);
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Csr<ValueType, IndexType>::inverse_column_permute(
const Array<IndexType> *inverse_permutation_indices) const
const Array<IndexType> *permutation_indices) const
{
GKO_ASSERT_EQ(inverse_permutation_indices->get_num_elems(),
this->get_size()[1]);
GKO_ASSERT_EQ(permutation_indices->get_num_elems(), this->get_size()[1]);
auto exec = this->get_executor();
auto inverse_permute_cpy =
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
this->get_strategy());

exec->run(csr::make_inverse_column_permute(
make_temporary_clone(exec, inverse_permutation_indices)
->get_const_data(),
this, inverse_permute_cpy.get()));
make_temporary_clone(exec, permutation_indices)->get_const_data(), this,
inverse_permute_cpy.get()));
inverse_permute_cpy->make_srow();
inverse_permute_cpy->sort_by_column_index();
return std::move(inverse_permute_cpy);
Expand Down
8 changes: 8 additions & 0 deletions core/matrix/csr_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ namespace kernels {
const matrix::Csr<ValueType, IndexType> *orig, \
matrix::Csr<ValueType, IndexType> *trans)

#define GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL(ValueType, IndexType) \
void inv_symm_permute(std::shared_ptr<const DefaultExecutor> exec, \
const IndexType *permutation_indices, \
const matrix::Csr<ValueType, IndexType> *orig, \
matrix::Csr<ValueType, IndexType> *permuted)

#define GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL(ValueType, IndexType) \
void row_permute(std::shared_ptr<const DefaultExecutor> exec, \
const IndexType *permutation_indices, \
Expand Down Expand Up @@ -207,6 +213,8 @@ namespace kernels {
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL(ValueType, IndexType); \
Expand Down
Loading

0 comments on commit 5259383

Please sign in to comment.