Skip to content

Commit

Permalink
add dpcpp is_sorted_by_col_idxs kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jun 18, 2021
1 parent c6bc571 commit c25639a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 8 deletions.
26 changes: 24 additions & 2 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,30 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void is_sorted_by_column_index(
std::shared_ptr<const DpcppExecutor> exec,
const matrix::Csr<ValueType, IndexType> *to_check,
bool *is_sorted) GKO_NOT_IMPLEMENTED;
const matrix::Csr<ValueType, IndexType> *to_check, bool *is_sorted)
{
Array<bool> is_sorted_device_array{exec, {true}};
const auto num_rows = to_check->get_size()[0];
const auto row_ptrs = to_check->get_const_row_ptrs();
const auto cols = to_check->get_const_col_idxs();
auto is_sorted_device = is_sorted_device_array.get_data();
exec->get_queue()->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) {
const auto row = static_cast<size_type>(idx[0]);
const auto begin = row_ptrs[row];
const auto end = row_ptrs[row + 1];
if (*is_sorted_device) {
for (auto i = begin; i < end - 1; i++) {
if (cols[i] > cols[i + 1]) {
*is_sorted_device = false;
break;
}
}
}
});
});
exec->get_master()->copy_from(exec.get(), 1, is_sorted_device, is_sorted);
};

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX);
Expand Down
57 changes: 51 additions & 6 deletions dpcpp/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include "core/test/utils.hpp"
#include "core/test/utils/unsort_matrix.hpp"


namespace {
Expand Down Expand Up @@ -115,6 +116,24 @@ class Csr : public ::testing::Test {
dbeta->copy_from(beta.get());
}

struct matrix_pair {
std::unique_ptr<Mtx> ref;
std::unique_ptr<Mtx> dpcpp;
};

matrix_pair gen_unsorted_mtx()
{
constexpr int min_nnz_per_row{2};
auto local_mtx_ref =
gen_mtx<Mtx>(mtx_size[0], mtx_size[1], min_nnz_per_row);
gko::test::unsort_matrix(gko::lend(local_mtx_ref), rand_engine);

auto local_mtx_dpcpp = Mtx::create(dpcpp);
local_mtx_dpcpp->copy_from(local_mtx_ref.get());

return {std::move(local_mtx_ref), std::move(local_mtx_dpcpp)};
}

std::shared_ptr<gko::ReferenceExecutor> ref;
std::shared_ptr<const gko::DpcppExecutor> dpcpp;

Expand Down Expand Up @@ -144,7 +163,7 @@ TEST_F(Csr, AdvancedApplyToCsrMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, r<value_type>::value);
GKO_ASSERT_MTX_EQ_SPARSITY(square_dmtx, square_mtx);
ASSERT_TRUE(gko::clone(ref, square_dmtx)->is_sorted_by_column_index());
ASSERT_TRUE(square_dmtx->is_sorted_by_column_index());
}


Expand All @@ -159,7 +178,7 @@ TEST_F(Csr, SimpleApplyToCsrMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, r<value_type>::value);
GKO_ASSERT_MTX_EQ_SPARSITY(square_dmtx, square_mtx);
ASSERT_TRUE(gko::clone(ref, square_dmtx)->is_sorted_by_column_index());
ASSERT_TRUE(square_dmtx->is_sorted_by_column_index());
}


Expand All @@ -176,7 +195,7 @@ TEST_F(Csr, SimpleApplyToSparseCsrMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_EQ_SPARSITY(square_dmtx, square_mtx);
GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, r<value_type>::value);
ASSERT_TRUE(gko::clone(ref, square_dmtx)->is_sorted_by_column_index());
ASSERT_TRUE(square_dmtx->is_sorted_by_column_index());
}


Expand All @@ -196,7 +215,7 @@ TEST_F(Csr, SimpleApplySparseToSparseCsrMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_EQ_SPARSITY(square_dmtx, square_mtx);
GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, r<value_type>::value);
ASSERT_TRUE(gko::clone(ref, square_dmtx)->is_sorted_by_column_index());
ASSERT_TRUE(square_dmtx->is_sorted_by_column_index());
}


Expand All @@ -213,7 +232,7 @@ TEST_F(Csr, SimpleApplyToEmptyCsrMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_EQ_SPARSITY(square_dmtx, square_mtx);
GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, r<value_type>::value);
ASSERT_TRUE(gko::clone(ref, square_dmtx)->is_sorted_by_column_index());
ASSERT_TRUE(square_dmtx->is_sorted_by_column_index());
}


Expand All @@ -235,7 +254,33 @@ TEST_F(Csr, AdvancedApplyToIdentityMatrixIsEquivalentToRef)

GKO_ASSERT_MTX_NEAR(b, db, r<value_type>::value);
GKO_ASSERT_MTX_EQ_SPARSITY(b, db);
ASSERT_TRUE(gko::clone(ref, db)->is_sorted_by_column_index());
ASSERT_TRUE(db->is_sorted_by_column_index());
}


TEST_F(Csr, RecognizeSortedMatrixIsEquivalentToRef)
{
set_up_apply_data();
bool is_sorted_dpcpp{};
bool is_sorted_ref{};

is_sorted_ref = mtx->is_sorted_by_column_index();
is_sorted_dpcpp = dmtx->is_sorted_by_column_index();

ASSERT_EQ(is_sorted_ref, is_sorted_dpcpp);
}


TEST_F(Csr, RecognizeUnsortedMatrixIsEquivalentToRef)
{
auto uns_mtx = gen_unsorted_mtx();
bool is_sorted_dpcpp{};
bool is_sorted_ref{};

is_sorted_ref = uns_mtx.ref->is_sorted_by_column_index();
is_sorted_dpcpp = uns_mtx.dpcpp->is_sorted_by_column_index();

ASSERT_EQ(is_sorted_ref, is_sorted_dpcpp);
}


Expand Down

0 comments on commit c25639a

Please sign in to comment.