Skip to content

Commit

Permalink
Improve nth_element performance
Browse files Browse the repository at this point in the history
  • Loading branch information
roshanr95 committed Jul 31, 2014
1 parent d356a2c commit d898dc3
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 6 deletions.
89 changes: 85 additions & 4 deletions include/boost/compute/algorithm/nth_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
#define BOOST_COMPUTE_ALGORITHM_NTH_ELEMENT_HPP

#include <boost/compute/command_queue.hpp>
#include <boost/compute/algorithm/fill_n.hpp>
#include <boost/compute/algorithm/find.hpp>
#include <boost/compute/algorithm/partition.hpp>
#include <boost/compute/algorithm/sort.hpp>
#include <boost/compute/detail/read_write_single_value.hpp>
#include <boost/compute/detail/print_range.hpp>

namespace boost {
namespace compute {
Expand All @@ -26,9 +31,47 @@ inline void nth_element(Iterator first,
Compare compare,
command_queue &queue = system::default_queue())
{
(void) nth;
if(nth == last) return;

sort(first, last, compare, queue);
typedef typename std::iterator_traits<Iterator>::value_type value_type;

while(1)
{
value_type value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, nth),
queue
);

Iterator new_nth = partition(first, last, compare(_1, value), queue);

Iterator old_nth = find(new_nth, last, value, queue);

value_type new_value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, new_nth),
queue
);

fill_n(new_nth, 1, value, queue);
fill_n(old_nth, 1, new_value, queue);

new_value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, nth),
queue
);
if(value == new_value) break;

if(std::distance(first, nth) < std::distance(first, new_nth))
{
last = new_nth;
}
else
{
first = new_nth;
}
}
}

/// \overload
Expand All @@ -38,9 +81,47 @@ inline void nth_element(Iterator first,
Iterator last,
command_queue &queue = system::default_queue())
{
(void) nth;
if(nth == last) return;

typedef typename std::iterator_traits<Iterator>::value_type value_type;

while(1)
{
value_type value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, nth),
queue
);

Iterator new_nth = partition(first, last, _1 < value, queue);

Iterator old_nth = find(new_nth, last, value, queue);

value_type new_value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, new_nth),
queue
);

fill_n(new_nth, 1, value, queue);
fill_n(old_nth, 1, new_value, queue);

new_value = detail::read_single_value<value_type>(
first.get_buffer(),
first.get_index()+std::distance(first, nth),
queue
);
if(value == new_value) break;

sort(first, last, queue);
if(std::distance(first, nth) < std::distance(first, new_nth))
{
last = new_nth;
}
else
{
first = new_nth;
}
}
}

} // end compute namespace
Expand Down
30 changes: 28 additions & 2 deletions test/test_nth_element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ BOOST_AUTO_TEST_CASE(nth_element_int)
boost::compute::nth_element(
vector.begin(), vector.begin() + 5, vector.end(), queue
);
CHECK_RANGE_EQUAL(int, 5, vector, (1, 1, 4, 4, 9));
BOOST_CHECK_EQUAL(vector[5], 9);

boost::compute::copy_n(data, 10, vector.begin(), queue);

boost::compute::nth_element(
vector.begin(), vector.end(), vector.end(), queue
);
CHECK_RANGE_EQUAL(int, 10, vector, (1, 1, 4, 4, 9, 9, 9, 12, 15, 15));
CHECK_RANGE_EQUAL(int, 10, vector, (9, 15, 1, 4, 9, 9, 4, 15, 12, 1));
}

BOOST_AUTO_TEST_CASE(nth_element_median)
Expand All @@ -49,4 +51,28 @@ BOOST_AUTO_TEST_CASE(nth_element_median)
BOOST_CHECK_EQUAL(data[v.size()/2], 5);
}

// bool less_than(int a, int b) {
// return a < b;
// }

// BOOST_AUTO_TEST_CASE(nth_element_comparator)
// {
// int data[] = { 9, 15, 1, 4, 9, 9, 4, 15, 12, 1 };
// boost::compute::vector<int> vector(10, context);

// boost::compute::copy_n(data, 10, vector.begin(), queue);

// boost::compute::nth_element(
// vector.begin(), vector.begin() + 5, vector.end(), less_than, queue
// );
// BOOST_CHECK_EQUAL(vector[5], 9);

// boost::compute::copy_n(data, 10, vector.begin(), queue);

// boost::compute::nth_element(
// vector.begin(), vector.end(), vector.end(), less_than, queue
// );
// CHECK_RANGE_EQUAL(int, 10, vector, (9, 15, 1, 4, 9, 9, 4, 15, 12, 1));
// }

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit d898dc3

Please sign in to comment.