From f3cbe1800e97c5493d52437daa2489abe9336e04 Mon Sep 17 00:00:00 2001 From: KulikovNikita Date: Wed, 14 Dec 2022 12:01:21 +0000 Subject: [PATCH] Fixes in DBSCAN algorithm (#2170) --- cpp/oneapi/dal/algo/dbscan/BUILD | 3 +- .../dbscan/backend/gpu/kernel_fp_impl.hpp | 64 +++++----- .../dal/algo/dbscan/backend/gpu/results.hpp | 114 +++++++++++------- cpp/oneapi/dal/backend/primitives/ndarray.hpp | 22 ++++ .../dal/backend/primitives/selection.hpp | 1 + 5 files changed, 126 insertions(+), 78 deletions(-) diff --git a/cpp/oneapi/dal/algo/dbscan/BUILD b/cpp/oneapi/dal/algo/dbscan/BUILD index 6b838e55a8b..2fe1fae2e6d 100644 --- a/cpp/oneapi/dal/algo/dbscan/BUILD +++ b/cpp/oneapi/dal/algo/dbscan/BUILD @@ -9,7 +9,8 @@ dal_module( auto = True, dal_deps = [ "@onedal//cpp/oneapi/dal:core", - "@onedal//cpp/oneapi/dal/backend/primitives", + "@onedal//cpp/oneapi/dal/backend/primitives:common", + "@onedal//cpp/oneapi/dal/backend/primitives:selection", ], extra_deps = [ "@onedal//cpp/daal/src/algorithms/dbscan:kernel", diff --git a/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp b/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp index 8695b140d49..6931ed6fb71 100644 --- a/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp +++ b/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp @@ -227,53 +227,51 @@ std::int32_t kernels_fp::start_next_cluster(sycl::queue& queue, const pr::ndview& cores, pr::ndview& responses, const bk::event_vector& deps) { + using oneapi::dal::backend::operator+; ONEDAL_PROFILER_TASK(start_next_cluster, queue); ONEDAL_ASSERT(cores.get_dimension(0) > 0); ONEDAL_ASSERT(cores.get_dimension(0) == responses.get_dimension(0)); std::int64_t block_size = cores.get_dimension(0); auto [start_index, start_index_event] = - pr::ndarray::full(queue, 1, block_size); - start_index_event.wait_and_throw(); + pr::ndarray::full(queue, { 1 }, block_size, sycl::usm::alloc::device); auto start_index_ptr = start_index.get_mutable_data(); - start_index_ptr[0] = block_size; const std::int32_t* cores_ptr = cores.get_data(); std::int32_t* responses_ptr = responses.get_mutable_data(); std::int64_t wg_size = get_recommended_sg_size(queue); - queue - .submit([&](sycl::handler& cgh) { - cgh.depends_on(deps); - cgh.parallel_for( - bk::make_multiple_nd_range_2d({ wg_size, 1 }, { wg_size, 1 }), - [=](sycl::nd_item<2> item) { - auto sg = item.get_sub_group(); - const std::uint32_t sg_id = sg.get_group_id()[0]; - if (sg_id > 0) - return; - const std::uint32_t local_id = sg.get_local_id(); - const std::uint32_t local_size = sg.get_local_range()[0]; - std::int32_t adjusted_block_size = - local_size * (block_size / local_size + bool(block_size % local_size)); + auto full_deps = deps + bk::event_vector{ start_index_event }; + auto index_event = queue.submit([&](sycl::handler& cgh) { + cgh.depends_on(full_deps); + cgh.parallel_for( + bk::make_multiple_nd_range_2d({ wg_size, 1 }, { wg_size, 1 }), + [=](sycl::nd_item<2> item) { + auto sg = item.get_sub_group(); + const std::uint32_t sg_id = sg.get_group_id()[0]; + if (sg_id > 0) + return; + const std::int32_t local_id = sg.get_local_id(); + const std::int32_t local_size = sg.get_local_range()[0]; + std::int32_t adjusted_block_size = + local_size * (block_size / local_size + bool(block_size % local_size)); - for (int32_t i = local_id; i < adjusted_block_size; i += local_size) { - const bool found = - i < block_size ? cores_ptr[i] == 1 && responses_ptr[i] < 0 : false; - const std::int32_t index = - sycl::reduce_over_group(sg, - (std::int32_t)(found ? i : block_size), - sycl::ext::oneapi::minimum()); - if (index < block_size) { - if (local_id == 0) { - start_index_ptr[0] = index; - } - break; + for (std::int32_t i = local_id; i < adjusted_block_size; i += local_size) { + const bool found = + i < block_size ? cores_ptr[i] == 1 && responses_ptr[i] < 0 : false; + const std::int32_t index = + sycl::reduce_over_group(sg, + (std::int32_t)(found ? i : block_size), + sycl::ext::oneapi::minimum()); + if (index < block_size) { + if (local_id == 0) { + *start_index_ptr = index; } + break; } - }); - }) - .wait_and_throw(); - return *start_index_ptr; + } + }); + }); + return start_index.to_host(queue, { index_event }).at(0); } sycl::event set_queue_ptr(sycl::queue& queue, diff --git a/cpp/oneapi/dal/algo/dbscan/backend/gpu/results.hpp b/cpp/oneapi/dal/algo/dbscan/backend/gpu/results.hpp index 51a917e6532..1980fccd2d1 100644 --- a/cpp/oneapi/dal/algo/dbscan/backend/gpu/results.hpp +++ b/cpp/oneapi/dal/algo/dbscan/backend/gpu/results.hpp @@ -20,6 +20,7 @@ #include "oneapi/dal/backend/common.hpp" #include "oneapi/dal/backend/primitives/ndarray.hpp" +#include "oneapi/dal/backend/primitives/selection.hpp" #include "oneapi/dal/backend/memory.hpp" #include "oneapi/dal/algo/dbscan/backend/gpu/kernels_fp.hpp" @@ -32,13 +33,55 @@ namespace oneapi::dal::dbscan::backend { using descriptor_t = detail::descriptor_base; using result_t = compute_result; -template +template +inline auto output_core_indices(sycl::queue& queue, + std::int64_t block_size, + std::int64_t core_count, + const pr::ndview& cores, + const bk::event_vector& deps = {}) { + using oneapi::dal::backend::operator+; + + ONEDAL_ASSERT(block_size > 0); + ONEDAL_ASSERT(core_count > 0); + ONEDAL_ASSERT(cores.has_data()); + + auto [res, res_event] = + pr::ndarray::zeros(queue, core_count, sycl::usm::alloc::device); + auto [err, err_event] = pr::ndarray::full(queue, 1, false, sycl::usm::alloc::device); + + auto* const err_ptr = err.get_mutable_data(); + auto* const res_ptr = res.get_mutable_data(); + const auto* const cores_ptr = cores.get_data(); + auto full_deps = deps + bk::event_vector{ err_event, res_event }; + auto event = queue.submit([&](sycl::handler& h) { + h.depends_on(deps); + h.single_task([=]() { + std::int64_t pos = 0; + for (std::int64_t i = 0; i < block_size; i++) { + if (*(cores_ptr + i) > 0) { + if (pos < core_count) { + *(res_ptr + pos) = i; + pos++; + } + else { + *err_ptr = true; + break; + } + } + } + }); + }); + + ONEDAL_ASSERT(err.to_host(queue, { event }).at(0)); + return std::make_tuple(res, event); +} + +template inline auto make_results(sycl::queue& queue, const descriptor_t& desc, const pr::ndarray data, - const pr::ndarray responses, - const pr::ndarray cores, - + const pr::ndarray responses, + const pr::ndarray cores, std::int64_t cluster_count, std::int64_t core_count = -1) { const std::int64_t column_count = data.get_dimension(1); @@ -64,8 +107,7 @@ inline auto make_results(sycl::queue& queue, if (core_count == -1) { core_count = count_cores(queue, cores); } - ONEDAL_ASSERT(block_size >= core_count); - if (core_count == 0) { + else if (core_count == 0) { if (return_core_indices) { results.set_core_observation_indices(dal::homogen_table{}); } @@ -73,46 +115,30 @@ inline auto make_results(sycl::queue& queue, results.set_core_observations(dal::homogen_table{}); } } - if (return_core_indices) { - auto host_indices = array::empty(core_count); - auto host_indices_ptr = host_indices.get_mutable_data(); - std::int64_t pos = 0; - auto host_cores = cores.to_host(queue); - auto host_cores_ptr = host_cores.get_data(); - for (std::int64_t i = 0; i < block_size; i++) { - if (host_cores_ptr[i] > 0) { - ONEDAL_ASSERT(pos < core_count); - host_indices_ptr[pos] = i; - pos++; - } + else { + ONEDAL_ASSERT(core_count > 0); + ONEDAL_ASSERT(block_size >= core_count); + + auto [ids_array, ids_event] = output_core_indices(queue, block_size, core_count, cores); + + if (return_core_indices) { + results.set_core_observation_indices( + dal::homogen_table::wrap(ids_array.flatten(queue, { ids_event }), + core_count, + 1)); } - auto device_indices = - pr::ndarray::empty(queue, core_count, sycl::usm::alloc::device); - dal::detail::memcpy_host2usm(queue, - device_indices.get_mutable_data(), - host_indices_ptr, - core_count * sizeof(std::int32_t)); - results.set_core_observation_indices( - dal::homogen_table::wrap(device_indices.flatten(queue), core_count, 1)); - } - if (return_core_observations) { - auto observations = pr::ndarray::empty(queue, core_count * column_count); - auto observations_ptr = observations.get_mutable_data(); - std::int64_t pos = 0; - auto host_cores = cores.to_host(queue); - auto host_cores_ptr = host_cores.get_data(); - for (std::int64_t i = 0; i < block_size; i++) { - if (host_cores_ptr[i] > 0) { - ONEDAL_ASSERT(pos < core_count * column_count); - bk::memcpy(queue, - observations_ptr + pos * column_count, - data.get_data() + i * column_count, - std::size_t(column_count) * sizeof(Float)); - pos += column_count; - } + if (return_core_observations) { + auto res = pr::ndarray::empty(queue, + { core_count, column_count }, + sycl::usm::alloc::device); + + auto event = pr::select_indexed_rows(queue, ids_array, data, res, { ids_event }); + + results.set_core_observations( + dal::homogen_table::wrap(res.flatten(queue, { event }), + core_count, + column_count)); } - results.set_core_observations( - dal::homogen_table::wrap(observations.flatten(queue), core_count, column_count)); } } return results; diff --git a/cpp/oneapi/dal/backend/primitives/ndarray.hpp b/cpp/oneapi/dal/backend/primitives/ndarray.hpp index 5c4af1f2131..3f0c631c5d1 100644 --- a/cpp/oneapi/dal/backend/primitives/ndarray.hpp +++ b/cpp/oneapi/dal/backend/primitives/ndarray.hpp @@ -303,6 +303,28 @@ class ndview : public ndarray_base { } #endif + template > + T* begin() { + ONEDAL_ASSERT(data_is_mutable_); + return get_mutable_data(); + } + + template > + T* end() { + ONEDAL_ASSERT(data_is_mutable_); + return get_mutable_data() + this->get_count(); + } + + template > + const T* cbegin() const { + return get_data(); + } + + template > + const T* cend() const { + return get_data() + this->get_count(); + } + protected: explicit ndview(const T* data, const shape_t& shape, diff --git a/cpp/oneapi/dal/backend/primitives/selection.hpp b/cpp/oneapi/dal/backend/primitives/selection.hpp index 089bb82793d..842992373f7 100644 --- a/cpp/oneapi/dal/backend/primitives/selection.hpp +++ b/cpp/oneapi/dal/backend/primitives/selection.hpp @@ -18,3 +18,4 @@ #include "oneapi/dal/backend/primitives/selection/kselect_by_rows.hpp" #include "oneapi/dal/backend/primitives/selection/select_indexed.hpp" +#include "oneapi/dal/backend/primitives/selection/select_indexed_rows.hpp"