From 198cec80a434b3d88a993e3d67a18b778ebc07f1 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 10:52:06 -0300 Subject: [PATCH 01/15] Use a priority queue (heap) to store the list of candidates while searching. This makes the code more efficient, especially when k is greater. For example, for knn, given a list of k candidates neighbors, we need to do 2 fast operations: - know the furthest of them. - insert a new candidate. This is the appropiate situation for using a heap. --- .../neighbor_search/neighbor_search_impl.hpp | 26 +++--- .../neighbor_search/neighbor_search_rules.hpp | 53 +++++++++--- .../neighbor_search_rules_impl.hpp | 81 +++++++++++-------- 3 files changed, 105 insertions(+), 55 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp index 73560e224db..8d0c694d8d8 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp @@ -376,8 +376,7 @@ Search(const MatType& querySet, if (naive) { // Create the helper object for the tree traversal. - RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric, - epsilon); + RuleType rules(*referenceSet, querySet, k, metric, epsilon); // The naive brute-force traversal. for (size_t i = 0; i < querySet.n_cols; ++i) @@ -385,12 +384,13 @@ Search(const MatType& querySet, rules.BaseCase(i, j); baseCases += querySet.n_cols * referenceSet->n_cols; + + rules.GetResults(*neighborPtr, *distancePtr); } else if (singleMode) { // Create the helper object for the tree traversal. - RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric, - epsilon); + RuleType rules(*referenceSet, querySet, k, metric, epsilon); // Create the traverser. typename Tree::template SingleTreeTraverser traverser(rules); @@ -404,6 +404,8 @@ Search(const MatType& querySet, Log::Info << rules.Scores() << " node combinations were scored.\n"; Log::Info << rules.BaseCases() << " base cases were calculated.\n"; + + rules.GetResults(*neighborPtr, *distancePtr); } else // Dual-tree recursion. { @@ -415,8 +417,7 @@ Search(const MatType& querySet, Timer::Start("computing_neighbors"); // Create the helper object for the tree traversal. - RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, - *distancePtr, metric, epsilon); + RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon); // Create the traverser. TraversalType traverser(rules); @@ -429,6 +430,8 @@ Search(const MatType& querySet, Log::Info << rules.Scores() << " node combinations were scored.\n"; Log::Info << rules.BaseCases() << " base cases were calculated.\n"; + rules.GetResults(*neighborPtr, *distancePtr); + delete queryTree; } @@ -541,8 +544,7 @@ Search(Tree* queryTree, // Create the helper object for the traversal. typedef NeighborSearchRules RuleType; - RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric, - epsilon); + RuleType rules(*referenceSet, querySet, k, metric, epsilon); // Create the traverser. TraversalType traverser(rules); @@ -551,6 +553,8 @@ Search(Tree* queryTree, scores += rules.Scores(); baseCases += rules.BaseCases(); + rules.GetResults(*neighborPtr, distances); + Timer::Stop("computing_neighbors"); // Do we need to map indices? @@ -612,8 +616,8 @@ Search(const size_t k, // Create the helper object for the traversal. typedef NeighborSearchRules RuleType; - RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr, - metric, epsilon, true /* don't return the same point as nearest neighbor */); + RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon, + true /* don't return the same point as nearest neighbor */); if (naive) { @@ -676,6 +680,8 @@ Search(const size_t k, treeNeedsReset = true; } + rules.GetResults(*neighborPtr, *distancePtr); + Timer::Stop("computing_neighbors"); // Do we need to map the reference indices? diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp index 47a7933dd04..0bcdc49d24d 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp @@ -9,6 +9,8 @@ #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP #include +#include +#include namespace mlpack { namespace neighbor { @@ -19,11 +21,20 @@ class NeighborSearchRules public: NeighborSearchRules(const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, - arma::Mat& neighbors, - arma::mat& distances, + const size_t k, MetricType& metric, const double epsilon = 0, const bool sameSet = false); + + /** + * Store the list of candidates for each query point in the given matrices. + * + * @param neighbors Matrix storing lists of neighbors for each query point. + * @param distances Matrix storing distances of neighbors for each query + * point. + */ + void GetResults(arma::Mat& neighbors, arma::mat& distances); + /** * Get the distance from the query point to the reference point. * This will update the "neighbor" matrix with the new point if appropriate @@ -109,11 +120,34 @@ class NeighborSearchRules //! The query set. const typename TreeType::Mat& querySet; - //! The matrix the resultant neighbor indices should be stored in. - arma::Mat& neighbors; - - //! The matrix the resultant neighbor distances should be stored in. - arma::mat& distances; + //! Candidate represents a possible candidate neighbor (from the reference + // set). + struct Candidate + { + //! Distance between the reference point and the query point. + double dist; + //! Index of the reference point. + size_t index; + //! Trivial constructor. + Candidate(double d, size_t i) : + dist(d), + index(i) + {}; + //! Compare the distance of two candidates. + friend bool operator<(const Candidate& l, const Candidate& r) + { + return !SortPolicy::IsBetter(r.dist, l.dist); + }; + }; + + //! Use a priority queue to represent the list of candidate neighbors. + typedef std::priority_queue CandidateList; + + //! Set of candidate neighbors for each point. + std::vector candidates; + + //! Number of neighbors to search for. + const size_t k; //! The instantiated metric. MetricType& metric; @@ -146,16 +180,13 @@ class NeighborSearchRules double CalculateBound(TreeType& queryNode) const; /** - * Insert a point into the neighbors and distances matrices; this is a helper - * function. + * Helper function to insert a point into the list of candidate points. * * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. * @param neighbor Index of reference point which is being inserted. * @param distance Distance from query point to reference point. */ void InsertNeighbor(const size_t queryIndex, - const size_t pos, const size_t neighbor, const double distance); }; diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 24f94856f5a..65d258eee63 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -17,15 +17,13 @@ template NeighborSearchRules::NeighborSearchRules( const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, - arma::Mat& neighbors, - arma::mat& distances, + const size_t k, MetricType& metric, const double epsilon, const bool sameSet) : referenceSet(referenceSet), querySet(querySet), - neighbors(neighbors), - distances(distances), + k(k), metric(metric), sameSet(sameSet), epsilon(epsilon), @@ -39,8 +37,41 @@ NeighborSearchRules::NeighborSearchRules( // use the this pointer. traversalInfo.LastQueryNode() = (TreeType*) this; traversalInfo.LastReferenceNode() = (TreeType*) this; + + // Let's build the list of candidate neighbors for each query point. + // It will be initialized with k candidates: (WorstDistance, size_t() - 1) + // The list of candidates will be updated when visiting new points with the + // BaseCase() method. + const Candidate def(SortPolicy::WorstDistance(), size_t() - 1); + + std::vector vect(k, def); + CandidateList pqueue(std::less(), std::move(vect)); + + candidates.reserve(querySet.n_cols); + for (size_t i = 0; i < querySet.n_cols; i++) + candidates.push_back(pqueue); } +template +void NeighborSearchRules::GetResults( + arma::Mat& neighbors, + arma::mat& distances) +{ + neighbors.set_size(k, querySet.n_cols); + distances.set_size(k, querySet.n_cols); + + for (size_t i = 0; i < querySet.n_cols; i++) + { + CandidateList& pqueue = candidates[i]; + for (size_t j = 1; j <= k; j++) + { + neighbors(k - j, i) = pqueue.top().index; + distances(k - j, i) = pqueue.top().dist; + pqueue.pop(); + } + } +}; + template inline force_inline // Absolutely MUST be inline so optimizations can happen. double NeighborSearchRules:: @@ -59,16 +90,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex) referenceSet.col(referenceIndex)); ++baseCases; - // If this distance is better than any of the current candidates, the - // SortDistance() function will give us the position to insert it into. - arma::vec queryDist = distances.unsafe_col(queryIndex); - arma::Col queryIndices = neighbors.unsafe_col(queryIndex); - const size_t insertPosition = SortPolicy::SortDistance(queryDist, - queryIndices, distance); - - // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1)) - InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance); + InsertNeighbor(queryIndex, referenceIndex, distance); // Cache this information for the next time BaseCase() is called. lastQueryIndex = queryIndex; @@ -114,7 +136,7 @@ inline double NeighborSearchRules::Score( } // Compare against the best k'th distance for this query point so far. - double bestDistance = distances(distances.n_rows - 1, queryIndex); + double bestDistance = candidates[queryIndex].top().dist; bestDistance = SortPolicy::Relax(bestDistance, epsilon); return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX; @@ -131,7 +153,7 @@ inline double NeighborSearchRules::Rescore( return oldScore; // Just check the score again against the distances. - double bestDistance = distances(distances.n_rows - 1, queryIndex); + double bestDistance = candidates[queryIndex].top().dist; bestDistance = SortPolicy::Relax(bestDistance, epsilon); return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX; @@ -354,7 +376,7 @@ inline double NeighborSearchRules:: // Loop over points held in the node. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { - const double distance = distances(distances.n_rows - 1, queryNode.Point(i)); + const double distance = candidates[queryNode.Point(i)].top().dist; if (SortPolicy::IsBetter(worstDistance, distance)) worstDistance = distance; if (SortPolicy::IsBetter(distance, bestPointDistance)) @@ -432,35 +454,26 @@ inline double NeighborSearchRules:: } /** - * Helper function to insert a point into the neighbors and distances matrices. + * Helper function to insert a point into the list of candidate points. * * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. * @param neighbor Index of reference point which is being inserted. * @param distance Distance from query point to reference point. */ template -void NeighborSearchRules::InsertNeighbor( +inline void NeighborSearchRules:: +InsertNeighbor( const size_t queryIndex, - const size_t pos, const size_t neighbor, const double distance) { - // We only memmove() if there is actually a need to shift something. - if (pos < (distances.n_rows - 1)) + Candidate c(distance, neighbor); + CandidateList& pqueue = candidates[queryIndex]; + if (c < pqueue.top()) { - int len = (distances.n_rows - 1) - pos; - memmove(distances.colptr(queryIndex) + (pos + 1), - distances.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(neighbors.colptr(queryIndex) + (pos + 1), - neighbors.colptr(queryIndex) + pos, - sizeof(size_t) * len); + pqueue.pop(); + pqueue.push(c); } - - // Now put the new information in the right index. - distances(pos, queryIndex) = distance; - neighbors(pos, queryIndex) = neighbor; } } // namespace neighbor From 9b03a968285e466e38fd85399fdc3da0edfbafbe Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 11:08:22 -0300 Subject: [PATCH 02/15] Add more documentation for NeighborSearchRules. --- .../neighbor_search/neighbor_search_rules.hpp | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp index 0bcdc49d24d..8dead3389fa 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp @@ -15,10 +15,33 @@ namespace mlpack { namespace neighbor { +/** + * The NeighborSearchRules class is a template helper class used by + * NeighborSearch class when performing distance-based neighbor searches. For + * each point in the query dataset, it keeps track of the k neighbors in the + * reference dataset which have the 'best' distance according to a given sorting + * policy. + * + * @tparam SortPolicy The sort policy for distances. + * @tparam MetricType The metric to use for computation. + * @tparam TreeType The tree type to use; must adhere to the TreeType API. + */ template class NeighborSearchRules { public: + /** + * Construct the NeighborSearchRules object. This is usually done from within + * the NeighborSearch class at search time. + * + * @param referenceSet Set of reference data. + * @param querySet Set of query data. + * @param k Number of neighbors to search for. + * @param metric Instantiated metric. + * @param epsilon Relative approximate error. + * @param sameSet If true, the query and reference set are taken to be the + * same, and a query point will not return itself in the results. + */ NeighborSearchRules(const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, const size_t k, From 07c13815a10c70ec57c16f6665d0a155071e9a74 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 11:09:11 -0300 Subject: [PATCH 03/15] Remove unnecesary fill (they will be filled when calling GetResults()). --- src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp index 8d0c694d8d8..79a31d03860 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp @@ -367,9 +367,7 @@ Search(const MatType& querySet, // Set the size of the neighbor and distance matrices. neighborPtr->set_size(k, querySet.n_cols); - neighborPtr->fill(size_t() - 1); distancePtr->set_size(k, querySet.n_cols); - distancePtr->fill(SortPolicy::WorstDistance()); typedef NeighborSearchRules RuleType; @@ -538,9 +536,7 @@ Search(Tree* queryTree, neighborPtr = new arma::Mat; neighborPtr->set_size(k, querySet.n_cols); - neighborPtr->fill(size_t() - 1); distances.set_size(k, querySet.n_cols); - distances.fill(SortPolicy::WorstDistance()); // Create the helper object for the traversal. typedef NeighborSearchRules RuleType; @@ -610,9 +606,7 @@ Search(const size_t k, // Initialize results. neighborPtr->set_size(k, referenceSet->n_cols); - neighborPtr->fill(size_t() - 1); distancePtr->set_size(k, referenceSet->n_cols); - distancePtr->fill(SortPolicy::WorstDistance()); // Create the helper object for the traversal. typedef NeighborSearchRules RuleType; From 31aeb5e813413505150a14be5c08979920c0d261 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 12:52:40 -0300 Subject: [PATCH 04/15] Use a priority queue (heap) to store the list of candidates while searching. --- src/mlpack/methods/rann/ra_search_impl.hpp | 35 ++++---- src/mlpack/methods/rann/ra_search_rules.hpp | 52 ++++++++--- .../methods/rann/ra_search_rules_impl.hpp | 90 +++++++++++-------- 3 files changed, 112 insertions(+), 65 deletions(-) diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp index aa8daa5010d..16360b5ee09 100644 --- a/src/mlpack/methods/rann/ra_search_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_impl.hpp @@ -360,9 +360,8 @@ Search(const MatType& querySet, if (naive) { - RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric, - tau, alpha, naive, sampleAtLeaves, firstLeafExact, - singleSampleLimit, false); + RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive, + sampleAtLeaves, firstLeafExact, singleSampleLimit, false); // Find how many samples from the reference set we need and sample uniformly // from the reference set without replacement. @@ -377,12 +376,13 @@ Search(const MatType& querySet, for (size_t i = 0; i < querySet.n_cols; ++i) for (size_t j = 0; j < distinctSamples.n_elem; ++j) rules.BaseCase(i, (size_t) distinctSamples[j]); + + rules.GetResults(*neighborPtr, *distancePtr); } else if (singleMode) { - RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric, - tau, alpha, naive, sampleAtLeaves, firstLeafExact, - singleSampleLimit, false); + RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive, + sampleAtLeaves, firstLeafExact, singleSampleLimit, false); // If the reference root node is a leaf, then the sampling has already been // done in the RASearchRules constructor. This happens when naive = true. @@ -402,6 +402,8 @@ Search(const MatType& querySet, << (rules.NumDistComputations() / querySet.n_cols) << "." << std::endl; } + + rules.GetResults(*neighborPtr, *distancePtr); } else // Dual-tree recursion. { @@ -415,9 +417,8 @@ Search(const MatType& querySet, Timer::Stop("tree_building"); Timer::Start("computing_neighbors"); - RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, - *distancePtr, metric, tau, alpha, naive, sampleAtLeaves, - firstLeafExact, singleSampleLimit, false); + RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha, + naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false); typename Tree::template DualTreeTraverser traverser(rules); Log::Info << "Query statistic pre-search: " @@ -429,6 +430,8 @@ Search(const MatType& querySet, Log::Info << "Average number of distance calculations per query point: " << (rules.NumDistComputations() / querySet.n_cols) << "." << std::endl; + rules.GetResults(*neighborPtr, *distancePtr); + delete queryTree; } @@ -529,14 +532,15 @@ void RASearch::Search( // Create the helper object for the tree traversal. typedef RASearchRules RuleType; - RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, distances, - metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact, - singleSampleLimit, false); + RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha, + naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false); // Create the traverser. typename Tree::template DualTreeTraverser traverser(rules); traverser.Traverse(*queryTree, *referenceTree); + rules.GetResults(*neighborPtr, distances); + Timer::Stop("computing_neighbors"); // Do we need to map indices? @@ -586,9 +590,8 @@ void RASearch::Search( // Create the helper object for the tree traversal. typedef RASearchRules RuleType; - RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr, - metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact, - singleSampleLimit, true /* sets are the same */); + RuleType rules(*referenceSet, *referenceSet, k, metric, tau, alpha, naive, + sampleAtLeaves, firstLeafExact, singleSampleLimit, true /* same sets */); if (naive) { @@ -622,6 +625,8 @@ void RASearch::Search( traverser.Traverse(*referenceTree, *referenceTree); } + rules.GetResults(*neighborPtr, *distancePtr); + Timer::Stop("computing_neighbors"); // Do we need to map the reference indices? diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp index 1037af4b821..b04f9bca367 100644 --- a/src/mlpack/methods/rann/ra_search_rules.hpp +++ b/src/mlpack/methods/rann/ra_search_rules.hpp @@ -10,6 +10,8 @@ #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP #include +#include +#include namespace mlpack { namespace neighbor { @@ -20,8 +22,7 @@ class RASearchRules public: RASearchRules(const arma::mat& referenceSet, const arma::mat& querySet, - arma::Mat& neighbors, - arma::mat& distances, + const size_t k, MetricType& metric, const double tau = 5, const double alpha = 0.95, @@ -31,6 +32,15 @@ class RASearchRules const size_t singleSampleLimit = 20, const bool sameSet = false); + /** + * Store the list of candidates for each query point in the given matrices. + * + * @param neighbors Matrix storing lists of neighbors for each query point. + * @param distances Matrix storing distances of neighbors for each query + * point. + */ + void GetResults(arma::Mat& neighbors, arma::mat& distances); + double BaseCase(const size_t queryIndex, const size_t referenceIndex); /** @@ -197,11 +207,34 @@ class RASearchRules //! The query set. const arma::mat& querySet; - //! The matrix the resultant neighbor indices should be stored in. - arma::Mat& neighbors; - - //! The matrix the resultant neighbor distances should be stored in. - arma::mat& distances; + //! Candidate represents a possible candidate neighbor (from the reference + // set). + struct Candidate + { + //! Distance between the reference point and the query point. + double dist; + //! Index of the reference point. + size_t index; + //! Trivial constructor. + Candidate(double d, size_t i) : + dist(d), + index(i) + {}; + //! Compare the distance of two candidates. + friend bool operator<(const Candidate& l, const Candidate& r) + { + return !SortPolicy::IsBetter(r.dist, l.dist); + }; + }; + + //! Use a priority queue to represent the list of candidate neighbors. + typedef std::priority_queue CandidateList; + + //! Set of candidate neighbors for each point. + std::vector candidates; + + //! Number of neighbors to search for. + const size_t k; //! The instantiated metric. MetricType& metric; @@ -233,16 +266,13 @@ class RASearchRules TraversalInfoType traversalInfo; /** - * Insert a point into the neighbors and distances matrices; this is a helper - * function. + * Helper function to insert a point into the list of candidate points. * * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. * @param neighbor Index of reference point which is being inserted. * @param distance Distance from query point to reference point. */ void InsertNeighbor(const size_t queryIndex, - const size_t pos, const size_t neighbor, const double distance); diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp index 2071de1b98e..bad3e24932a 100644 --- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp @@ -17,8 +17,7 @@ template RASearchRules:: RASearchRules(const arma::mat& referenceSet, const arma::mat& querySet, - arma::Mat& neighbors, - arma::mat& distances, + const size_t k, MetricType& metric, const double tau, const double alpha, @@ -29,8 +28,7 @@ RASearchRules(const arma::mat& referenceSet, const bool sameSet) : referenceSet(referenceSet), querySet(querySet), - neighbors(neighbors), - distances(distances), + k(k), metric(metric), sampleAtLeaves(sampleAtLeaves), firstLeafExact(firstLeafExact), @@ -42,7 +40,6 @@ RASearchRules(const arma::mat& referenceSet, // The rank approximation. const size_t n = referenceSet.n_cols; - const size_t k = neighbors.n_rows; const size_t t = (size_t) std::ceil(tau * (double) n / 100.0); if (t < k) { @@ -68,7 +65,20 @@ RASearchRules(const arma::mat& referenceSet, Log::Info << "Minimum samples required per query: " << numSamplesReqd << ", sampling ratio: " << samplingRatio << std::endl; - if (naive) // No tree traversal; just do naive sampling here. + // Let's build the list of candidate neighbors for each query point. + // It will be initialized with k candidates: (WorstDistance, size_t() - 1) + // The list of candidates will be updated when visiting new points with the + // BaseCase() method. + const Candidate def(SortPolicy::WorstDistance(), size_t() - 1); + + std::vector vect(k, def); + CandidateList pqueue(std::less(), std::move(vect)); + + candidates.reserve(querySet.n_cols); + for (size_t i = 0; i < querySet.n_cols; i++) + candidates.push_back(pqueue); + + if (naive)// No tree traversal; just do naive sampling here. { // Sample enough points. for (size_t i = 0; i < querySet.n_cols; ++i) @@ -81,6 +91,26 @@ RASearchRules(const arma::mat& referenceSet, } } +template +void RASearchRules::GetResults( + arma::Mat& neighbors, + arma::mat& distances) +{ + neighbors.set_size(k, querySet.n_cols); + distances.set_size(k, querySet.n_cols); + + for (size_t i = 0; i < querySet.n_cols; i++) + { + CandidateList& pqueue = candidates[i]; + for (size_t j = 1; j <= k; j++) + { + neighbors(k - j, i) = pqueue.top().index; + distances(k - j, i) = pqueue.top().dist; + pqueue.pop(); + } + } +}; + template inline force_inline double RASearchRules::BaseCase( @@ -95,16 +125,7 @@ double RASearchRules::BaseCase( double distance = metric.Evaluate(querySet.unsafe_col(queryIndex), referenceSet.unsafe_col(referenceIndex)); - // If this distance is better than any of the current candidates, the - // SortDistance() function will give us the position to insert it into. - arma::vec queryDist = distances.unsafe_col(queryIndex); - arma::Col queryIndices = neighbors.unsafe_col(queryIndex); - size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices, - distance); - - // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1)) - InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance); + InsertNeighbor(queryIndex, referenceIndex, distance); numSamplesMade[queryIndex]++; @@ -122,7 +143,7 @@ inline double RASearchRules::Score( const arma::vec queryPoint = querySet.unsafe_col(queryIndex); const double distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode); - const double bestDistance = distances(distances.n_rows - 1, queryIndex); + const double bestDistance = candidates[queryIndex].top().dist; return Score(queryIndex, referenceNode, distance, bestDistance); } @@ -136,7 +157,7 @@ inline double RASearchRules::Score( const arma::vec queryPoint = querySet.unsafe_col(queryIndex); const double distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode, baseCaseResult); - const double bestDistance = distances(distances.n_rows - 1, queryIndex); + const double bestDistance = candidates[queryIndex].top().dist; return Score(queryIndex, referenceNode, distance, bestDistance); } @@ -250,7 +271,7 @@ Rescore(const size_t queryIndex, return oldScore; // Just check the score again against the distances. - const double bestDistance = distances(distances.n_rows - 1, queryIndex); + const double bestDistance = candidates[queryIndex].top().dist; // If this is better than the best distance we've seen so far, // maybe there will be something down this node. @@ -350,7 +371,7 @@ inline double RASearchRules::Score( for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) + const double bound = candidates[queryNode.Point(i)].top().dist + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -389,7 +410,7 @@ inline double RASearchRules::Score( for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) + const double bound = candidates[queryNode.Point(i)].top().dist + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -603,7 +624,7 @@ Rescore(TreeType& queryNode, for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) + const double bound = candidates[queryNode.Point(i)].top().dist + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -775,35 +796,26 @@ Rescore(TreeType& queryNode, } // Rescore(node, node, oldScore) /** - * Helper function to insert a point into the neighbors and distances matrices. + * Helper function to insert a point into the list of candidate points. * * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. * @param neighbor Index of reference point which is being inserted. * @param distance Distance from query point to reference point. */ template -void RASearchRules::InsertNeighbor( +inline void RASearchRules:: +InsertNeighbor( const size_t queryIndex, - const size_t pos, const size_t neighbor, const double distance) { - // We only memmove() if there is actually a need to shift something. - if (pos < (distances.n_rows - 1)) + Candidate c(distance, neighbor); + CandidateList& pqueue = candidates[queryIndex]; + if (c < pqueue.top()) { - int len = (distances.n_rows - 1) - pos; - memmove(distances.colptr(queryIndex) + (pos + 1), - distances.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(neighbors.colptr(queryIndex) + (pos + 1), - neighbors.colptr(queryIndex) + pos, - sizeof(size_t) * len); + pqueue.pop(); + pqueue.push(c); } - - // Now put the new information in the right index. - distances(pos, queryIndex) = distance; - neighbors(pos, queryIndex) = neighbor; } } // namespace neighbor From 15ec97921fb2741dfa8f549201e1899a8a19e3a3 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 12:53:57 -0300 Subject: [PATCH 05/15] Remove unnecesary fill (they will be filled when calling GetResults()). --- src/mlpack/methods/rann/ra_search_impl.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp index 16360b5ee09..6fa366e2ce3 100644 --- a/src/mlpack/methods/rann/ra_search_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_impl.hpp @@ -354,7 +354,6 @@ Search(const MatType& querySet, // Set the size of the neighbor and distance matrices. neighborPtr->set_size(k, querySet.n_cols); distancePtr->set_size(k, querySet.n_cols); - distancePtr->fill(SortPolicy::WorstDistance()); typedef RASearchRules RuleType; @@ -526,9 +525,7 @@ void RASearch::Search( neighborPtr = new arma::Mat; neighborPtr->set_size(k, querySet.n_cols); - neighborPtr->fill(size_t() - 1); distances.set_size(k, querySet.n_cols); - distances.fill(SortPolicy::WorstDistance()); // Create the helper object for the tree traversal. typedef RASearchRules RuleType; @@ -584,9 +581,7 @@ void RASearch::Search( // Initialize results. neighborPtr->set_size(k, referenceSet->n_cols); - neighborPtr->fill(size_t() - 1); distancePtr->set_size(k, referenceSet->n_cols); - distancePtr->fill(SortPolicy::WorstDistance()); // Create the helper object for the tree traversal. typedef RASearchRules RuleType; From ff5c089ff6a84c9aa8e11006dd05353c7926aabf Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 13:03:35 -0300 Subject: [PATCH 06/15] Add more documentation to the RASearchRules class. --- src/mlpack/methods/rann/ra_search_rules.hpp | 36 +++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp index b04f9bca367..c7edeba9ebc 100644 --- a/src/mlpack/methods/rann/ra_search_rules.hpp +++ b/src/mlpack/methods/rann/ra_search_rules.hpp @@ -16,10 +16,39 @@ namespace mlpack { namespace neighbor { +/** + * The RASearchRules class is a template helper class used by RASearch class + * when performing rank-approximate search via random-sampling. + * + * @tparam SortPolicy The sort policy for distances. + * @tparam MetricType The metric to use for computation. + * @tparam TreeType The tree type to use; must adhere to the TreeType API. + */ template class RASearchRules { public: + /** + * Construct the RASearchRules object. This is usually done from within + * the RASearch class at search time. + * + * @param referenceSet Set of reference data. + * @param querySet Set of query data. + * @param k Number of neighbors to search for. + * @param metric Instantiated metric. + * @param tau The rank-approximation in percentile of the data. + * @param alpha The desired success probability. + * @param naive If true, the rank-approximate search will be performed by + * directly sampling the whole set instead of using the stratified + * sampling on the tree. + * @param sampleAtLeaves Sample at leaves for faster but less accurate + * computation. + * @param firstLeafExact Traverse to the first leaf without approximation. + * @param singleSampleLimit The limit on the largest node that can be + * approximated by sampling. + * @param sameSet If true, the query and reference set are taken to be the + * same, and a query point will not return itself in the results. + */ RASearchRules(const arma::mat& referenceSet, const arma::mat& querySet, const size_t k, @@ -41,6 +70,13 @@ class RASearchRules */ void GetResults(arma::Mat& neighbors, arma::mat& distances); + /** + * Get the distance from the query point to the reference point. + * This will update the list of candidates with the new point if appropriate. + * + * @param queryIndex Index of query point. + * @param referenceIndex Index of reference point. + */ double BaseCase(const size_t queryIndex, const size_t referenceIndex); /** From e52aae63607c61871ff7a331469aa5247d3b5a21 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 13:04:36 -0300 Subject: [PATCH 07/15] Detail in the documentation of BaseCase(). --- src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp index 8dead3389fa..25e717529cd 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp @@ -60,7 +60,7 @@ class NeighborSearchRules /** * Get the distance from the query point to the reference point. - * This will update the "neighbor" matrix with the new point if appropriate + * This will update the list of candidates with the new point if appropriate * and will track the number of base cases (number of points evaluated). * * @param queryIndex Index of query point. From b64e36591fcd53db0bddba6a7c2a9119eb1125f9 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 16:28:53 -0300 Subject: [PATCH 08/15] Use a priority queue (heap) to search for the best valued recommendations. --- src/mlpack/methods/cf/cf.cpp | 75 ++++++++++++------------------------ src/mlpack/methods/cf/cf.hpp | 34 ++++++++-------- 2 files changed, 43 insertions(+), 66 deletions(-) diff --git a/src/mlpack/methods/cf/cf.cpp b/src/mlpack/methods/cf/cf.cpp index daf121c3101..121cf139c4c 100644 --- a/src/mlpack/methods/cf/cf.cpp +++ b/src/mlpack/methods/cf/cf.cpp @@ -9,6 +9,8 @@ * specified data set. */ #include "cf.hpp" +#include +#include namespace mlpack { namespace cf { @@ -79,9 +81,8 @@ void CF::GetRecommendations(const size_t numRecs, // Generate recommendations for each query user by finding the maximum numRecs // elements in the averages matrix. recommendations.set_size(numRecs, users.n_elem); - recommendations.fill(cleanedData.n_rows); // Invalid item number. arma::mat values(numRecs, users.n_elem); - values.fill(-DBL_MAX); // The smallest possible value. + for (size_t i = 0; i < users.n_elem; i++) { // First, calculate average of neighborhood values. @@ -92,6 +93,14 @@ void CF::GetRecommendations(const size_t numRecs, averages += w * h.col(neighborhood(j, i)); averages /= neighborhood.n_rows; + // Let's build the list of candidate recomendations for the given user. + // Default candidate: the smallest possible value and invalid item number. + const Candidate def(-DBL_MAX, cleanedData.n_rows); + std::vector vect(numRecs, def); + typedef std::priority_queue, + std::greater> CandidateList; + CandidateList pqueue(std::greater(), std::move(vect)); + // Look through the averages column corresponding to the current user. for (size_t j = 0; j < averages.n_rows; ++j) { @@ -99,29 +108,27 @@ void CF::GetRecommendations(const size_t numRecs, if (cleanedData(j, users(i)) != 0.0) continue; // The user already rated the item. + Candidate c(averages[j], j); + // Is the estimated value better than the worst candidate? - const double value = averages[j]; - if (value > values(values.n_rows - 1, i)) + if (c > pqueue.top()) { - // It should be inserted. Which position? - size_t insertPosition = values.n_rows - 1; - while (insertPosition > 0) - { - if (value <= values(insertPosition - 1, i)) - break; // The current value is the right one. - insertPosition--; - } - - // Now insert it into the list. - InsertNeighbor(i, insertPosition, j, value, recommendations, - values); + pqueue.pop(); + pqueue.push(c); } } + for (size_t p = 1; p <= numRecs; p++) + { + recommendations(numRecs - p, i) = pqueue.top().item; + values(numRecs - p, i) = pqueue.top().value; + pqueue.pop(); + } + // If we were not able to come up with enough recommendations, issue a // warning. - if (recommendations(values.n_rows - 1, i) == cleanedData.n_rows + 1) - Log::Warn << "Could not provide " << values.n_rows << " recommendations " + if (recommendations(numRecs - 1, i) == def.item) + Log::Warn << "Could not provide " << numRecs << " recommendations " << "for user " << users(i) << " (not enough un-rated items)!" << std::endl; } @@ -247,37 +254,5 @@ void CF::CleanData(const arma::mat& data, arma::sp_mat& cleanedData) cleanedData = arma::sp_mat(locations, values, maxItemID, maxUserID); } -/** - * Helper function to insert a point into the recommendation matrices. - * - * @param queryIndex Index of point whose recommendations we are inserting into. - * @param pos Position in list to insert into. - * @param neighbor Index of item being inserted as a recommendation. - * @param value Value of recommendation. - */ -void CF::InsertNeighbor(const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double value, - arma::Mat& recommendations, - arma::mat& values) const -{ - // We only memmove() if there is actually a need to shift something. - if (pos < (recommendations.n_rows - 1)) - { - const int len = (values.n_rows - 1) - pos; - memmove(values.colptr(queryIndex) + (pos + 1), - values.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(recommendations.colptr(queryIndex) + (pos + 1), - recommendations.colptr(queryIndex) + pos, - sizeof(size_t) * len); - } - - // Now put the new information in the right index. - values(pos, queryIndex) = value; - recommendations(pos, queryIndex) = neighbor; -} - } // namespace mlpack } // namespace cf diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp index 22cb6fc3f30..42b1a4b7673 100644 --- a/src/mlpack/methods/cf/cf.hpp +++ b/src/mlpack/methods/cf/cf.hpp @@ -258,22 +258,24 @@ class CF //! Cleaned data matrix. arma::sp_mat cleanedData; - /** - * Helper function to insert a point into the recommendation matrices. - * - * @param queryIndex Index of point whose recommendations we are inserting - * into. - * @param pos Position in list to insert into. - * @param neighbor Index of item being inserted as a recommendation. - * @param value Value of recommendation. - */ - void InsertNeighbor(const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double value, - arma::Mat& recommendations, - arma::mat& values) const; - + //! Candidate represents a possible recommendation. + struct Candidate + { + //! Value of this recommendation. + double value; + //! Item of this recommendation. + size_t item; + //! Trivial constructor. + Candidate(double value, size_t item) : + value(value), + item(item) + {}; + //! Compare the value of two candidates. + friend bool operator>(const Candidate& l, const Candidate& r) + { + return l.value > r.value; + }; + }; }; // class CF } // namespace cf From fdeeb88bea641cce93f83556b0b1b146f55c1383 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 21 Jul 2016 22:14:12 -0300 Subject: [PATCH 09/15] Use a priority queue (heap) to store the list of candidates while searching. --- src/mlpack/methods/lsh/lsh_search.hpp | 48 ++++++----- src/mlpack/methods/lsh/lsh_search_impl.hpp | 98 ++++++++++------------ 2 files changed, 72 insertions(+), 74 deletions(-) diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp index 4e6cc97b3d9..45284bae294 100644 --- a/src/mlpack/methods/lsh/lsh_search.hpp +++ b/src/mlpack/methods/lsh/lsh_search.hpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include @@ -298,11 +299,13 @@ class LSHSearch * @param queryIndex The index of the query in question * @param referenceIndices The vector of indices of candidate neighbors for * the query. + * @param k Number of neighbors to search for. * @param neighbors Matrix holding output neighbors. * @param distances Matrix holding output distances. */ void BaseCase(const size_t queryIndex, const arma::uvec& referenceIndices, + const size_t k, arma::Mat& neighbors, arma::mat& distances) const; @@ -315,37 +318,18 @@ class LSHSearch * @param queryIndex The index of the query in question * @param referenceIndices The vector of indices of candidate neighbors for * the query. + * @param k Number of neighbors to search for. * @param querySet Set of query points. * @param neighbors Matrix holding output neighbors. * @param distances Matrix holding output distances. */ void BaseCase(const size_t queryIndex, const arma::uvec& referenceIndices, + const size_t k, const arma::mat& querySet, arma::Mat& neighbors, arma::mat& distances) const; - /** - * This is a helper function that efficiently inserts better neighbor - * candidates into an existing set of neighbor candidates. This function is - * only called by the 'BaseCase' function. - * - * @param distances Matrix holding output distances. - * @param neighbors Matrix holding output neighbors. - * @param queryIndex This is the index of the query being processed currently - * @param pos The position of the neighbor candidate in the current list of - * neighbor candidates. - * @param neighbor The neighbor candidate that is being inserted into the list - * of the best 'k' candidates for the query in question. - * @param distance The distance of the query to the neighbor candidate. - */ - void InsertNeighbor(arma::mat& distances, - arma::Mat& neighbors, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) const; - /** * This function implements the core idea behind Multiprobe LSH. It is called * by ReturnIndicesFromTables when T > 0. Given a query's code and its @@ -444,6 +428,28 @@ class LSHSearch //! The number of distance evaluations. size_t distanceEvaluations; + //! Candidate represents a possible candidate neighbor (from the reference + // set). + struct Candidate + { + //! Distance between the reference point and the query point. + double dist; + //! Index of the reference point. + size_t index; + //! Trivial constructor. + Candidate(double d, size_t i) : + dist(d), + index(i) + {}; + //! Compare the distance of two candidates. + friend bool operator<(const Candidate& l, const Candidate& r) + { + return !SortPolicy::IsBetter(r.dist, l.dist); + }; + }; + + //! Use a priority queue to represent the list of candidate neighbors. + typedef std::priority_queue CandidateList; }; // class LSHSearch } // namespace neighbor diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp index bcb57953bd9..d0b53aeebe8 100644 --- a/src/mlpack/methods/lsh/lsh_search_impl.hpp +++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp @@ -262,40 +262,22 @@ void LSHSearch::Train(const arma::mat& referenceSet, << std::endl; } -template -void LSHSearch::InsertNeighbor(arma::mat& distances, - arma::Mat& neighbors, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) const -{ - // We only memmove() if there is actually a need to shift something. - if (pos < (distances.n_rows - 1)) - { - const size_t len = (distances.n_rows - 1) - pos; - memmove(distances.colptr(queryIndex) + (pos + 1), - distances.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(neighbors.colptr(queryIndex) + (pos + 1), - neighbors.colptr(queryIndex) + pos, - sizeof(size_t) * len); - } - - // Now put the new information in the right index. - distances(pos, queryIndex) = distance; - neighbors(pos, queryIndex) = neighbor; -} - // Base case where the query set is the reference set. (So, we can't return // ourselves as the nearest neighbor.) template inline force_inline void LSHSearch::BaseCase(const size_t queryIndex, const arma::uvec& referenceIndices, + const size_t k, arma::Mat& neighbors, arma::mat& distances) const { + // Let's build the list of candidate neighbors for the given query point. + // It will be initialized with k candidates: + // (WorstDistance, referenceSet->n_cols) + const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols); + std::vector vect(k, def); + CandidateList pqueue(std::less(), std::move(vect)); for (size_t j = 0; j < referenceIndices.n_elem; ++j) { @@ -308,17 +290,20 @@ void LSHSearch::BaseCase(const size_t queryIndex, referenceSet->unsafe_col(queryIndex), referenceSet->unsafe_col(referenceIndex)); - // If this distance is better than any of the current candidates, the - // SortDistance() function will give us the position to insert it into. - arma::vec queryDist = distances.unsafe_col(queryIndex); - arma::Col queryIndices = neighbors.unsafe_col(queryIndex); - size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices, - distance); - - // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1)) - InsertNeighbor(distances, neighbors, queryIndex, insertPosition, - referenceIndex, distance); + Candidate c(distance, referenceIndex); + // If this distance is better than the worst candidate, let's insert it. + if (c < pqueue.top()) + { + pqueue.pop(); + pqueue.push(c); + } + } + + for (size_t j = 1; j <= k; j++) + { + neighbors(k - j, queryIndex) = pqueue.top().index; + distances(k - j, queryIndex) = pqueue.top().dist; + pqueue.pop(); } } @@ -327,10 +312,18 @@ template inline force_inline void LSHSearch::BaseCase(const size_t queryIndex, const arma::uvec& referenceIndices, + const size_t k, const arma::mat& querySet, arma::Mat& neighbors, arma::mat& distances) const { + // Let's build the list of candidate neighbors for the given query point. + // It will be initialized with k candidates: + // (WorstDistance, referenceSet->n_cols) + const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols); + std::vector vect(k, def); + CandidateList pqueue(std::less(), std::move(vect)); + for (size_t j = 0; j < referenceIndices.n_elem; ++j) { const size_t referenceIndex = referenceIndices[j]; @@ -338,20 +331,23 @@ void LSHSearch::BaseCase(const size_t queryIndex, querySet.unsafe_col(queryIndex), referenceSet->unsafe_col(referenceIndex)); - // If this distance is better than any of the current candidates, the - // SortDistance() function will give us the position to insert it into. - arma::vec queryDist = distances.unsafe_col(queryIndex); - arma::Col queryIndices = neighbors.unsafe_col(queryIndex); - size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices, - distance); - - // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1)) - InsertNeighbor(distances, neighbors, queryIndex, insertPosition, - referenceIndex, distance); + Candidate c(distance, referenceIndex); + // If this distance is better than the worst candidate, let's insert it. + if (c < pqueue.top()) + { + pqueue.pop(); + pqueue.push(c); + } + } + for (size_t j = 1; j <= k; j++) + { + neighbors(k - j, queryIndex) = pqueue.top().index; + distances(k - j, queryIndex) = pqueue.top().dist; + pqueue.pop(); } } + template inline force_inline double LSHSearch::PerturbationScore( @@ -794,8 +790,6 @@ void LSHSearch::Search(const arma::mat& querySet, // Set the size of the neighbor and distance matrices. resultingNeighbors.set_size(k, querySet.n_cols); distances.set_size(k, querySet.n_cols); - distances.fill(SortPolicy::WorstDistance()); - resultingNeighbors.fill(referenceSet->n_cols); // If the user asked for 0 nearest neighbors... uh... we're done. if (k == 0) @@ -854,7 +848,7 @@ void LSHSearch::Search(const arma::mat& querySet, // Sequentially go through all the candidates and save the best 'k' // candidates. - BaseCase(i, refIndices, querySet, resultingNeighbors, distances); + BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances); } Timer::Stop("computing_neighbors"); @@ -877,8 +871,6 @@ Search(const size_t k, // This is monochromatic search; the query set is the reference set. resultingNeighbors.set_size(k, referenceSet->n_cols); distances.set_size(k, referenceSet->n_cols); - distances.fill(SortPolicy::WorstDistance()); - resultingNeighbors.fill(referenceSet->n_cols); // If the user requested more than the available number of additional probing // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1 @@ -933,7 +925,7 @@ Search(const size_t k, // Sequentially go through all the candidates and save the best 'k' // candidates. - BaseCase(i, refIndices, resultingNeighbors, distances); + BaseCase(i, refIndices, k, resultingNeighbors, distances); } Timer::Stop("computing_neighbors"); From bccf3e0d442ba554a2f0276822f102fcf2a2218a Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 22 Jul 2016 01:39:22 -0300 Subject: [PATCH 10/15] Use a priority queue (heap) to store the list of candidates while searching fastmks. --- src/mlpack/methods/fastmks/fastmks.hpp | 30 +++-- src/mlpack/methods/fastmks/fastmks_impl.hpp | 107 +++++++----------- src/mlpack/methods/fastmks/fastmks_rules.hpp | 59 ++++++++-- .../methods/fastmks/fastmks_rules_impl.hpp | 104 ++++++++++------- 4 files changed, 174 insertions(+), 126 deletions(-) diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp index e849635d0e9..796f0db1ba2 100644 --- a/src/mlpack/methods/fastmks/fastmks.hpp +++ b/src/mlpack/methods/fastmks/fastmks.hpp @@ -12,6 +12,7 @@ #include #include "fastmks_stat.hpp" #include +#include namespace mlpack { namespace fastmks /** Fast max-kernel search. */ { @@ -250,13 +251,28 @@ class FastMKS //! The instantiated inner-product metric induced by the given kernel. metric::IPMetric metric; - //! Utility function. Copied too many times from too many places. - void InsertNeighbor(arma::Mat& indices, - arma::mat& products, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance); + //! Candidate point from the reference set. + struct Candidate + { + //! Kernel value calculated between a reference point and the query point. + double product; + //! Index of the reference point. + size_t index; + //! Trivial constructor. + Candidate(double p, size_t i) : + product(p), + index(i) + {}; + //! Compare two candidates. + friend bool operator>(const Candidate& l, const Candidate& r) + { + return l.product > r.product; + }; + }; + + //! Use a priority queue to represent the list of candidate points. + typedef std::priority_queue, + std::greater> CandidateList; }; } // namespace fastmks diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp index de85cdc56c3..f0f321f4ea6 100644 --- a/src/mlpack/methods/fastmks/fastmks_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp @@ -13,7 +13,6 @@ #include "fastmks_rules.hpp" #include -#include namespace mlpack { namespace fastmks { @@ -221,25 +220,31 @@ void FastMKS::Search( // Naive implementation. if (naive) { - // Fill kernels. - kernels.fill(-DBL_MAX); - // Simple double loop. Stupid, slow, but a good benchmark. for (size_t q = 0; q < querySet.n_cols; ++q) { + const Candidate def(-DBL_MAX, size_t() - 1); + std::vector cList(k, def); + CandidateList pqueue(std::greater(), std::move(cList)); + for (size_t r = 0; r < referenceSet->n_cols; ++r) { const double eval = metric.Kernel().Evaluate(querySet.col(q), referenceSet->col(r)); - size_t insertPosition; - for (insertPosition = 0; insertPosition < indices.n_rows; - ++insertPosition) - if (eval > kernels(insertPosition, q)) - break; + Candidate c(eval, r); + if (c > pqueue.top()) + { + pqueue.pop(); + pqueue.push(c); + } + } - if (insertPosition < indices.n_rows) - InsertNeighbor(indices, kernels, q, insertPosition, r, eval); + for (size_t j = 1; j <= k; j++) + { + indices(k - j, q) = pqueue.top().index; + kernels(k - j, q) = pqueue.top().product; + pqueue.pop(); } } @@ -251,13 +256,10 @@ void FastMKS::Search( // Single-tree implementation. if (singleMode) { - // Fill kernels. - kernels.fill(-DBL_MAX); - // Create rules object (this will store the results). This constructor // precalculates each self-kernel value. typedef FastMKSRules RuleType; - RuleType rules(*referenceSet, querySet, indices, kernels, metric.Kernel()); + RuleType rules(*referenceSet, querySet, k, metric.Kernel()); typename Tree::template SingleTreeTraverser traverser(rules); @@ -267,6 +269,8 @@ void FastMKS::Search( Log::Info << rules.BaseCases() << " base cases." << std::endl; Log::Info << rules.Scores() << " scores." << std::endl; + rules.GetResults(indices, kernels); + Timer::Stop("computing_products"); return; } @@ -310,12 +314,10 @@ void FastMKS::Search( // No remapping will be necessary because we are using the cover tree. indices.set_size(k, queryTree->Dataset().n_cols); kernels.set_size(k, queryTree->Dataset().n_cols); - kernels.fill(-DBL_MAX); Timer::Start("computing_products"); typedef FastMKSRules RuleType; - RuleType rules(*referenceSet, queryTree->Dataset(), indices, kernels, - metric.Kernel()); + RuleType rules(*referenceSet, queryTree->Dataset(), k, metric.Kernel()); typename Tree::template DualTreeTraverser traverser(rules); @@ -324,6 +326,8 @@ void FastMKS::Search( Log::Info << rules.BaseCases() << " base cases." << std::endl; Log::Info << rules.Scores() << " scores." << std::endl; + rules.GetResults(indices, kernels); + Timer::Stop("computing_products"); } @@ -341,7 +345,6 @@ void FastMKS::Search( Timer::Start("computing_products"); indices.set_size(k, referenceSet->n_cols); kernels.set_size(k, referenceSet->n_cols); - kernels.fill(-DBL_MAX); // Naive implementation. if (naive) @@ -349,6 +352,10 @@ void FastMKS::Search( // Simple double loop. Stupid, slow, but a good benchmark. for (size_t q = 0; q < referenceSet->n_cols; ++q) { + const Candidate def(-DBL_MAX, size_t() - 1); + std::vector cList(k, def); + CandidateList pqueue(std::greater(), std::move(cList)); + for (size_t r = 0; r < referenceSet->n_cols; ++r) { if (q == r) @@ -357,14 +364,19 @@ void FastMKS::Search( const double eval = metric.Kernel().Evaluate(referenceSet->col(q), referenceSet->col(r)); - size_t insertPosition; - for (insertPosition = 0; insertPosition < indices.n_rows; - ++insertPosition) - if (eval > kernels(insertPosition, q)) - break; + Candidate c(eval, r); + if (c > pqueue.top()) + { + pqueue.pop(); + pqueue.push(c); + } + } - if (insertPosition < indices.n_rows) - InsertNeighbor(indices, kernels, q, insertPosition, r, eval); + for (size_t j = 1; j <= k; j++) + { + indices(k - j, q) = pqueue.top().index; + kernels(k - j, q) = pqueue.top().product; + pqueue.pop(); } } @@ -379,8 +391,7 @@ void FastMKS::Search( // Create rules object (this will store the results). This constructor // precalculates each self-kernel value. typedef FastMKSRules RuleType; - RuleType rules(*referenceSet, *referenceSet, indices, kernels, - metric.Kernel()); + RuleType rules(*referenceSet, *referenceSet, k, metric.Kernel()); typename Tree::template SingleTreeTraverser traverser(rules); @@ -395,6 +406,8 @@ void FastMKS::Search( Log::Info << rules.BaseCases() << " base cases." << std::endl; Log::Info << rules.Scores() << " scores." << std::endl; + rules.GetResults(indices, kernels); + Timer::Stop("computing_products"); return; } @@ -405,44 +418,6 @@ void FastMKS::Search( Search(referenceTree, k, indices, kernels); } -/** - * Helper function to insert a point into the neighbors and distances matrices. - * - * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. - * @param neighbor Index of reference point which is being inserted. - * @param distance Distance from query point to reference point. - */ -template class TreeType> -void FastMKS::InsertNeighbor( - arma::Mat& indices, - arma::mat& products, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) -{ - // We only memmove() if there is actually a need to shift something. - if (pos < (products.n_rows - 1)) - { - int len = (products.n_rows - 1) - pos; - memmove(products.colptr(queryIndex) + (pos + 1), - products.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(indices.colptr(queryIndex) + (pos + 1), - indices.colptr(queryIndex) + pos, - sizeof(size_t) * len); - } - - // Now put the new information in the right index. - products(pos, queryIndex) = distance; - indices(pos, queryIndex) = neighbor; -} - //! Serialize the model. template #include #include +#include namespace mlpack { namespace fastmks { @@ -23,10 +24,17 @@ class FastMKSRules public: FastMKSRules(const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, - arma::Mat& indices, - arma::mat& products, + const size_t k, KernelType& kernel); + /** + * Store the list of candidates for each query point in the given matrices. + * + * @param indices Matrix storing lists of candidate points for each query point. + * @param products Matrix storing kernel value for each candidate. + */ + void GetResults(arma::Mat& indices, arma::mat& products); + //! Compute the base case (kernel value) between two points. double BaseCase(const size_t queryIndex, const size_t referenceIndex); @@ -101,10 +109,36 @@ class FastMKSRules //! The query dataset. const typename TreeType::Mat& querySet; - //! The indices of the maximum kernel results. - arma::Mat& indices; - //! The maximum kernels. - arma::mat& products; + //! Candidate point from the reference set. + struct Candidate + { + //! Kernel value calculated between a reference point and the query point. + double product; + //! Index of the reference point. + size_t index; + //! Trivial constructor. + Candidate(double p, size_t i) : + product(p), + index(i) + {}; + //! Compare two candidates. + friend bool operator>(const Candidate& l, const Candidate& r) + { + return l.product > r.product; + }; + }; + + //! Use a min heap to represent the list of candidate points. + //! We will use a vector and the std functions: push_heap() pop_heap(). + //! We can not use a priority queue because we need to iterate over all the + //! candidates and std::priority_queue doesn't provide that interface. + typedef std::vector CandidateList; + + //! Set of candidates for each point. + std::vector candidates; + + //! Number of points to search for. + const size_t k; //! Cached query set self-kernels (|| q || for each q). arma::vec queryKernels; @@ -124,11 +158,16 @@ class FastMKSRules //! Calculate the bound for a given query node. double CalculateBound(TreeType& queryNode) const; - //! Utility function to insert neighbor into list of results. + /** + * Helper function to insert a point into the list of candidate points. + * + * @param queryIndex Index of point whose neighbors we are inserting into. + * @param index Index of reference point which is being inserted. + * @param product Kernel value for given candidate. + */ void InsertNeighbor(const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance); + const size_t index, + const double product); //! For benchmarking. size_t baseCases; diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp index 27abacf971d..a5cf68143aa 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp @@ -9,6 +9,7 @@ // In case it hasn't already been included. #include "fastmks_rules.hpp" +#include namespace mlpack { namespace fastmks { @@ -17,13 +18,11 @@ template FastMKSRules::FastMKSRules( const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, - arma::Mat& indices, - arma::mat& products, + const size_t k, KernelType& kernel) : referenceSet(referenceSet), querySet(querySet), - indices(indices), - products(products), + k(k), kernel(kernel), lastQueryIndex(-1), lastReferenceIndex(-1), @@ -46,6 +45,41 @@ FastMKSRules::FastMKSRules( // dereference null pointers. traversalInfo.LastQueryNode() = (TreeType*) this; traversalInfo.LastReferenceNode() = (TreeType*) this; + + // Let's build the list of candidate points for each query point. + // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1) + // The list of candidates will be updated when visiting new points with the + // BaseCase() method. + const Candidate def(-DBL_MAX, size_t() - 1); + + CandidateList cList(k, def); + std::vector tmp(querySet.n_cols, cList); + candidates.swap(tmp); +} + +template +void FastMKSRules::GetResults( + arma::Mat& indices, + arma::mat& products) +{ + indices.set_size(k, querySet.n_cols); + products.set_size(k, querySet.n_cols); + + for (size_t i = 0; i < querySet.n_cols; i++) + { + CandidateList& pqueue = candidates[i]; + std::greater greater; + typedef typename CandidateList::iterator Iterator; + + for (Iterator end = pqueue.end(); end != pqueue.begin(); --end) + std::pop_heap(pqueue.begin(), end, greater); + + for (size_t j = 0; j < k; j++) + { + indices(j, i) = pqueue[j].index; + products(j, i) = pqueue[j].product; + } + } } template @@ -83,16 +117,7 @@ double FastMKSRules::BaseCase( if ((&querySet == &referenceSet) && (queryIndex == referenceIndex)) return kernelEval; - // If this is a better candidate, insert it into the list. - if (kernelEval < products(products.n_rows - 1, queryIndex)) - return kernelEval; - - size_t insertPosition = 0; - for ( ; insertPosition < products.n_rows; ++insertPosition) - if (kernelEval >= products(insertPosition, queryIndex)) - break; - - InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval); + InsertNeighbor(queryIndex, referenceIndex, kernelEval); return kernelEval; } @@ -102,7 +127,7 @@ double FastMKSRules::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. - const double bestKernel = products(products.n_rows - 1, queryIndex); + const double bestKernel = candidates[queryIndex].front().product; // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); @@ -385,7 +410,7 @@ double FastMKSRules::Rescore(const size_t queryIndex, TreeType& /*referenceNode*/, const double oldScore) const { - const double bestKernel = products(products.n_rows - 1, queryIndex); + const double bestKernel = candidates[queryIndex].front().product; return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX; } @@ -432,10 +457,11 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t point = queryNode.Point(i); - if (products(products.n_rows - 1, point) < worstPointKernel) - worstPointKernel = products(products.n_rows - 1, point); + const CandidateList& candidatesPoints = candidates[point]; + if (candidatesPoints.front().product < worstPointKernel) + worstPointKernel = candidatesPoints.front().product; - if (products(products.n_rows - 1, point) == -DBL_MAX) + if (candidatesPoints.front().product == -DBL_MAX) continue; // Avoid underflow. // This should be (queryDescendantDistance + centroidDistance) for any tree @@ -450,10 +476,10 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) // where p_j^*(p_q) is the j'th kernel candidate for query point p_q and // k_j^*(p_q) is K(p_q, p_j^*(p_q)). double worstPointCandidateKernel = DBL_MAX; - for (size_t j = 0; j < products.n_rows; ++j) + for (size_t j = 0; j < candidatesPoints.size(); ++j) { - const double candidateKernel = products(j, point) - - queryDescendantDistance * referenceKernels[indices(j, point)]; + const double candidateKernel = candidatesPoints[j].product - + queryDescendantDistance * referenceKernels[candidatesPoints[j].index]; if (candidateKernel < worstPointCandidateKernel) worstPointCandidateKernel = candidateKernel; } @@ -488,34 +514,26 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) } /** - * Helper function to insert a point into the neighbors and distances matrices. + * Helper function to insert a point into the list of candidate points. * * @param queryIndex Index of point whose neighbors we are inserting into. - * @param pos Position in list to insert into. - * @param neighbor Index of reference point which is being inserted. - * @param distance Distance from query point to reference point. + * @param index Index of reference point which is being inserted. + * @param product Kernel value for given candidate. */ template -void FastMKSRules::InsertNeighbor(const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) +inline void FastMKSRules::InsertNeighbor( + const size_t queryIndex, + const size_t index, + const double product) { - // We only memmove() if there is actually a need to shift something. - if (pos < (products.n_rows - 1)) + Candidate c(product, index); + CandidateList& pqueue = candidates[queryIndex]; + if (c > pqueue.front()) { - int len = (products.n_rows - 1) - pos; - memmove(products.colptr(queryIndex) + (pos + 1), - products.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(indices.colptr(queryIndex) + (pos + 1), - indices.colptr(queryIndex) + pos, - sizeof(size_t) * len); + std::pop_heap(pqueue.begin(), pqueue.end(), std::greater()); + pqueue.back() = c; + std::push_heap(pqueue.begin(), pqueue.end(), std::greater()); } - - // Now put the new information in the right index. - products(pos, queryIndex) = distance; - indices(pos, queryIndex) = neighbor; } } // namespace fastmks From 70c67cdd3b5d609fc1a075d49d68620a99855741 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 22 Jul 2016 01:44:21 -0300 Subject: [PATCH 11/15] Remove unnecessary method from SortPolicies. --- .../methods/neighbor_search/CMakeLists.txt | 2 - .../sort_policies/furthest_neighbor_sort.cpp | 27 ------- .../sort_policies/furthest_neighbor_sort.hpp | 18 ----- .../sort_policies/nearest_neighbor_sort.cpp | 27 ------- .../sort_policies/nearest_neighbor_sort.hpp | 18 ----- src/mlpack/tests/sort_policy_test.cpp | 72 ------------------- 6 files changed, 164 deletions(-) delete mode 100644 src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp delete mode 100644 src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt index 4d0d44b0805..1c51ce45f81 100644 --- a/src/mlpack/methods/neighbor_search/CMakeLists.txt +++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt @@ -9,10 +9,8 @@ set(SOURCES ns_model.hpp ns_model_impl.hpp sort_policies/nearest_neighbor_sort.hpp - sort_policies/nearest_neighbor_sort.cpp sort_policies/nearest_neighbor_sort_impl.hpp sort_policies/furthest_neighbor_sort.hpp - sort_policies/furthest_neighbor_sort.cpp sort_policies/furthest_neighbor_sort_impl.hpp typedef.hpp unmap.hpp diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp deleted file mode 100644 index f58e4d2c22c..00000000000 --- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/*** - * @file furthest_neighbor_sort.cpp - * @author Ryan Curtin - * - * Implementation of the simple FurthestNeighborSort policy class. - */ -#include "furthest_neighbor_sort.hpp" - -using namespace mlpack::neighbor; - -size_t FurthestNeighborSort::SortDistance(const arma::vec& list, - const arma::Col& indices, - double newDistance) -{ - // The first element in the list is the furthest neighbor. We only want to - // insert if the new distance is greater than the last element in the list. - if (newDistance < list[list.n_elem - 1]) - return (size_t() - 1); // Do not insert. - - // Search from the beginning. This may not be the best way. - for (size_t i = 0; i < list.n_elem; i++) - if (newDistance >= list[i] || indices[i] == (size_t() - 1)) - return i; - - // Control should never reach here. - return (size_t() - 1); -} diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp index a69c1679213..09c3d960144 100644 --- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp +++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp @@ -22,24 +22,6 @@ namespace neighbor { class FurthestNeighborSort { public: - /** - * Return the index in the vector where the new distance should be inserted, - * or size_t() - 1 if it should not be inserted (i.e. if it is not any better - * than any of the existing points in the list). The list should be sorted - * such that the best point is the first in the list. The actual insertion is - * not performed. - * - * @param list Vector of existing distance points, sorted such that the best - * point is the first in the list. - * @param new_distance Distance to try to insert. - * - * @return size_t containing the position to insert into, or (size_t() - 1) - * if the new distance should not be inserted. - */ - static size_t SortDistance(const arma::vec& list, - const arma::Col& indices, - double newDistance); - /** * Return whether or not value is "better" than ref. In this case, that means * that the value is greater than or equal to the reference. diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp deleted file mode 100644 index 4a755706bda..00000000000 --- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/** - * @file nearest_neighbor_sort.cpp - * @author Ryan Curtin - * - * Implementation of the simple NearestNeighborSort policy class. - */ -#include "nearest_neighbor_sort.hpp" - -using namespace mlpack::neighbor; - -size_t NearestNeighborSort::SortDistance(const arma::vec& list, - const arma::Col& indices, - double newDistance) -{ - // The first element in the list is the nearest neighbor. We only want to - // insert if the new distance is less than the last element in the list. - if (newDistance > list[list.n_elem - 1]) - return (size_t() - 1); // Do not insert. - - // Search from the beginning. This may not be the best way. - for (size_t i = 0; i < list.n_elem; i++) - if (newDistance <= list[i] || indices[i] == (size_t() - 1)) - return i; - - // Control should never reach here. - return (size_t() - 1); -} diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp index 42a08b06411..5ccce6dd6ad 100644 --- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp +++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp @@ -26,24 +26,6 @@ namespace neighbor { class NearestNeighborSort { public: - /** - * Return the index in the vector where the new distance should be inserted, - * or (size_t() - 1) if it should not be inserted (i.e. if it is not any - * better than any of the existing points in the list). The list should be - * sorted such that the best point is the first in the list. The actual - * insertion is not performed. - * - * @param list Vector of existing distance points, sorted such that the best - * point is first in the list. - * @param new_distance Distance to try to insert - * - * @return size_t containing the position to insert into, or (size_t() - 1) - * if the new distance should not be inserted. - */ - static size_t SortDistance(const arma::vec& list, - const arma::Col& indices, - double newDistance); - /** * Return whether or not value is "better" than ref. In this case, that means * that the value is less than or equal to the reference. diff --git a/src/mlpack/tests/sort_policy_test.cpp b/src/mlpack/tests/sort_policy_test.cpp index e336a76170b..5cf4b5606e1 100644 --- a/src/mlpack/tests/sort_policy_test.cpp +++ b/src/mlpack/tests/sort_policy_test.cpp @@ -56,42 +56,6 @@ BOOST_AUTO_TEST_CASE(NnsIsBetterNotStrict) BOOST_WARN(NearestNeighborSort::IsBetter(6.0, 6.0) == true); } -/** - * A simple test case of where to insert when all the values in the list are - * DBL_MAX. - */ -BOOST_AUTO_TEST_CASE(NnsSortDistanceAllDblMax) -{ - arma::vec list(5); - list.fill(DBL_MAX); - arma::Col indices(5); - indices.fill(0); - - // Should be inserted at the head of the list. - BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, indices, 5.0) == 0); -} - -/** - * Another test case, where we are just putting the new value in the middle of - * the list. - */ -BOOST_AUTO_TEST_CASE(NnsSortDistance2) -{ - arma::vec list(3); - list[0] = 0.66; - list[1] = 0.89; - list[2] = 1.14; - arma::Col indices(3); - indices.fill(0); - - // Run a couple possibilities through. - BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, indices, 0.61) == 0); - BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, indices, 0.76) == 1); - BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, indices, 0.99) == 2); - BOOST_REQUIRE(NearestNeighborSort::SortDistance(list, indices, 1.22) == - (size_t() - 1)); -} - /** * Very simple sanity check to ensure that bounds are working alright. We will * use a one-dimensional bound for simplicity. @@ -218,42 +182,6 @@ BOOST_AUTO_TEST_CASE(FnsIsBetterNotStrict) BOOST_WARN(FurthestNeighborSort::IsBetter(6.0, 6.0) == true); } -/** - * A simple test case of where to insert when all the values in the list are - * 0. - */ -BOOST_AUTO_TEST_CASE(FnsSortDistanceAllZero) -{ - arma::vec list(5); - list.fill(0); - arma::Col indices(5); - indices.fill(0); - - // Should be inserted at the head of the list. - BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, indices, 5.0) == 0); -} - -/** - * Another test case, where we are just putting the new value in the middle of - * the list. - */ -BOOST_AUTO_TEST_CASE(FnsSortDistance2) -{ - arma::vec list(3); - list[0] = 1.14; - list[1] = 0.89; - list[2] = 0.66; - arma::Col indices(3); - indices.fill(0); - - // Run a couple possibilities through. - BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, indices, 1.22) == 0); - BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, indices, 0.93) == 1); - BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, indices, 0.68) == 2); - BOOST_REQUIRE(FurthestNeighborSort::SortDistance(list, indices, 0.62) == - (size_t() - 1)); -} - /** * Very simple sanity check to ensure that bounds are working alright. We will * use a one-dimensional bound for simplicity. From d0d22f72f79bea8510d2cb8862353eef3413808a Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 22 Jul 2016 01:58:34 -0300 Subject: [PATCH 12/15] Add more documentation for FastMKSRules. --- src/mlpack/methods/fastmks/fastmks_rules.hpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp index 13be5f55a81..11b9604cb6c 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp @@ -16,12 +16,27 @@ namespace mlpack { namespace fastmks { /** - * The base case and pruning rules for FastMKS (fast max-kernel search). + * The FastMKSRules class is a template helper class used by FastMKS class when + * performing exact max-kernel search. For each point in the query dataset, it + * keeps track of the k best candidates in the reference dataset. + * + * @tparam KernelType Type of kernel to run FastMKS with. + * @tparam TreeType Type of tree to run FastMKS with; it must satisfy the + * TreeType policy API. */ template class FastMKSRules { public: + /** + * Construct the FastMKSRules object. This is usually done from within the + * FastMKS class at search time. + * + * @param referenceSet Set of reference data. + * @param querySet Set of query data. + * @param k Number of candidates to search for. + * @param kernel Kernel to run FastMKS with. + */ FastMKSRules(const typename TreeType::Mat& referenceSet, const typename TreeType::Mat& querySet, const size_t k, @@ -30,7 +45,7 @@ class FastMKSRules /** * Store the list of candidates for each query point in the given matrices. * - * @param indices Matrix storing lists of candidate points for each query point. + * @param indices Matrix storing lists of candidate for each query point. * @param products Matrix storing kernel value for each candidate. */ void GetResults(arma::Mat& indices, arma::mat& products); From c86275b4c7e5af0d69d3b504f9341785e163c443 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 22 Jul 2016 02:03:09 -0300 Subject: [PATCH 13/15] Add more documentation for RangeSearchRules. --- src/mlpack/methods/range_search/range_search_rules.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/mlpack/methods/range_search/range_search_rules.hpp b/src/mlpack/methods/range_search/range_search_rules.hpp index e392ee42ac2..0f492a8ed9d 100644 --- a/src/mlpack/methods/range_search/range_search_rules.hpp +++ b/src/mlpack/methods/range_search/range_search_rules.hpp @@ -13,6 +13,13 @@ namespace mlpack { namespace range { +/** + * The RangeSearchRules class is a template helper class used by RangeSearch + * class when performing range searches. + * + * @tparam MetricType The metric to use for computation. + * @tparam TreeType The tree type to use; must adhere to the TreeType API. + */ template class RangeSearchRules { From d1eadad6908d4d95a147455d7a6e60e2cca238f8 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 22 Jul 2016 16:48:31 -0300 Subject: [PATCH 14/15] Use std::pair instead of Candidate struct. --- src/mlpack/methods/cf/cf.cpp | 18 +++++------ src/mlpack/methods/cf/cf.hpp | 22 +++++-------- src/mlpack/methods/fastmks/fastmks.hpp | 24 +++++--------- src/mlpack/methods/fastmks/fastmks_impl.hpp | 24 +++++++------- src/mlpack/methods/fastmks/fastmks_rules.hpp | 22 +++++-------- .../methods/fastmks/fastmks_rules_impl.hpp | 31 +++++++++---------- src/mlpack/methods/lsh/lsh_search.hpp | 27 ++++++---------- src/mlpack/methods/lsh/lsh_search_impl.hpp | 26 +++++++++------- .../neighbor_search/neighbor_search_rules.hpp | 26 ++++++---------- .../neighbor_search_rules_impl.hpp | 20 ++++++------ src/mlpack/methods/rann/ra_search_rules.hpp | 26 ++++++---------- .../methods/rann/ra_search_rules_impl.hpp | 26 +++++++++------- 12 files changed, 125 insertions(+), 167 deletions(-) diff --git a/src/mlpack/methods/cf/cf.cpp b/src/mlpack/methods/cf/cf.cpp index 121cf139c4c..01a4c844d62 100644 --- a/src/mlpack/methods/cf/cf.cpp +++ b/src/mlpack/methods/cf/cf.cpp @@ -95,11 +95,11 @@ void CF::GetRecommendations(const size_t numRecs, // Let's build the list of candidate recomendations for the given user. // Default candidate: the smallest possible value and invalid item number. - const Candidate def(-DBL_MAX, cleanedData.n_rows); + const Candidate def = std::make_pair(-DBL_MAX, cleanedData.n_rows); std::vector vect(numRecs, def); - typedef std::priority_queue, - std::greater> CandidateList; - CandidateList pqueue(std::greater(), std::move(vect)); + typedef std::priority_queue, CandidateCmp> + CandidateList; + CandidateList pqueue(CandidateCmp(), std::move(vect)); // Look through the averages column corresponding to the current user. for (size_t j = 0; j < averages.n_rows; ++j) @@ -108,11 +108,11 @@ void CF::GetRecommendations(const size_t numRecs, if (cleanedData(j, users(i)) != 0.0) continue; // The user already rated the item. - Candidate c(averages[j], j); // Is the estimated value better than the worst candidate? - if (c > pqueue.top()) + if (averages[i] > pqueue.top().first) { + Candidate c = std::make_pair(averages[j], j); pqueue.pop(); pqueue.push(c); } @@ -120,14 +120,14 @@ void CF::GetRecommendations(const size_t numRecs, for (size_t p = 1; p <= numRecs; p++) { - recommendations(numRecs - p, i) = pqueue.top().item; - values(numRecs - p, i) = pqueue.top().value; + recommendations(numRecs - p, i) = pqueue.top().second; + values(numRecs - p, i) = pqueue.top().first; pqueue.pop(); } // If we were not able to come up with enough recommendations, issue a // warning. - if (recommendations(numRecs - 1, i) == def.item) + if (recommendations(numRecs - 1, i) == def.second) Log::Warn << "Could not provide " << numRecs << " recommendations " << "for user " << users(i) << " (not enough un-rated items)!" << std::endl; diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp index 42b1a4b7673..82d624cae3e 100644 --- a/src/mlpack/methods/cf/cf.hpp +++ b/src/mlpack/methods/cf/cf.hpp @@ -258,22 +258,14 @@ class CF //! Cleaned data matrix. arma::sp_mat cleanedData; - //! Candidate represents a possible recommendation. - struct Candidate - { - //! Value of this recommendation. - double value; - //! Item of this recommendation. - size_t item; - //! Trivial constructor. - Candidate(double value, size_t item) : - value(value), - item(item) - {}; - //! Compare the value of two candidates. - friend bool operator>(const Candidate& l, const Candidate& r) + //! Candidate represents a possible recommendation (value, item). + typedef std::pair Candidate; + + //! Compare two candidates based on the value. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return l.value > r.value; + return c1.first > c2.first; }; }; }; // class CF diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp index 796f0db1ba2..031fe43c39c 100644 --- a/src/mlpack/methods/fastmks/fastmks.hpp +++ b/src/mlpack/methods/fastmks/fastmks.hpp @@ -251,28 +251,20 @@ class FastMKS //! The instantiated inner-product metric induced by the given kernel. metric::IPMetric metric; - //! Candidate point from the reference set. - struct Candidate - { - //! Kernel value calculated between a reference point and the query point. - double product; - //! Index of the reference point. - size_t index; - //! Trivial constructor. - Candidate(double p, size_t i) : - product(p), - index(i) - {}; - //! Compare two candidates. - friend bool operator>(const Candidate& l, const Candidate& r) + //! Candidate represents a possible candidate point (value, index). + typedef std::pair Candidate; + + //! Compare two candidates based on the value. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return l.product > r.product; + return c1.first > c2.first; }; }; //! Use a priority queue to represent the list of candidate points. typedef std::priority_queue, - std::greater> CandidateList; + CandidateCmp> CandidateList; }; } // namespace fastmks diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp index f0f321f4ea6..993ba547dbf 100644 --- a/src/mlpack/methods/fastmks/fastmks_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp @@ -223,18 +223,18 @@ void FastMKS::Search( // Simple double loop. Stupid, slow, but a good benchmark. for (size_t q = 0; q < querySet.n_cols; ++q) { - const Candidate def(-DBL_MAX, size_t() - 1); + const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1); std::vector cList(k, def); - CandidateList pqueue(std::greater(), std::move(cList)); + CandidateList pqueue(CandidateCmp(), std::move(cList)); for (size_t r = 0; r < referenceSet->n_cols; ++r) { const double eval = metric.Kernel().Evaluate(querySet.col(q), referenceSet->col(r)); - Candidate c(eval, r); - if (c > pqueue.top()) + if (eval > pqueue.top().first) { + Candidate c = std::make_pair(eval, r); pqueue.pop(); pqueue.push(c); } @@ -242,8 +242,8 @@ void FastMKS::Search( for (size_t j = 1; j <= k; j++) { - indices(k - j, q) = pqueue.top().index; - kernels(k - j, q) = pqueue.top().product; + indices(k - j, q) = pqueue.top().second; + kernels(k - j, q) = pqueue.top().first; pqueue.pop(); } } @@ -352,9 +352,9 @@ void FastMKS::Search( // Simple double loop. Stupid, slow, but a good benchmark. for (size_t q = 0; q < referenceSet->n_cols; ++q) { - const Candidate def(-DBL_MAX, size_t() - 1); + const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1); std::vector cList(k, def); - CandidateList pqueue(std::greater(), std::move(cList)); + CandidateList pqueue(CandidateCmp(), std::move(cList)); for (size_t r = 0; r < referenceSet->n_cols; ++r) { @@ -364,9 +364,9 @@ void FastMKS::Search( const double eval = metric.Kernel().Evaluate(referenceSet->col(q), referenceSet->col(r)); - Candidate c(eval, r); - if (c > pqueue.top()) + if (eval > pqueue.top().first) { + Candidate c = std::make_pair(eval, r); pqueue.pop(); pqueue.push(c); } @@ -374,8 +374,8 @@ void FastMKS::Search( for (size_t j = 1; j <= k; j++) { - indices(k - j, q) = pqueue.top().index; - kernels(k - j, q) = pqueue.top().product; + indices(k - j, q) = pqueue.top().second; + kernels(k - j, q) = pqueue.top().first; pqueue.pop(); } } diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp index 11b9604cb6c..9aca42d3140 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp @@ -124,22 +124,14 @@ class FastMKSRules //! The query dataset. const typename TreeType::Mat& querySet; - //! Candidate point from the reference set. - struct Candidate - { - //! Kernel value calculated between a reference point and the query point. - double product; - //! Index of the reference point. - size_t index; - //! Trivial constructor. - Candidate(double p, size_t i) : - product(p), - index(i) - {}; - //! Compare two candidates. - friend bool operator>(const Candidate& l, const Candidate& r) + //! Candidate represents a possible candidate point (value, index). + typedef std::pair Candidate; + + //! Compare two candidates based on the value. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return l.product > r.product; + return c1.first > c2.first; }; }; diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp index a5cf68143aa..6efc9a7d96d 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp @@ -50,7 +50,7 @@ FastMKSRules::FastMKSRules( // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1) // The list of candidates will be updated when visiting new points with the // BaseCase() method. - const Candidate def(-DBL_MAX, size_t() - 1); + const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1); CandidateList cList(k, def); std::vector tmp(querySet.n_cols, cList); @@ -68,16 +68,15 @@ void FastMKSRules::GetResults( for (size_t i = 0; i < querySet.n_cols; i++) { CandidateList& pqueue = candidates[i]; - std::greater greater; typedef typename CandidateList::iterator Iterator; for (Iterator end = pqueue.end(); end != pqueue.begin(); --end) - std::pop_heap(pqueue.begin(), end, greater); + std::pop_heap(pqueue.begin(), end, CandidateCmp()); for (size_t j = 0; j < k; j++) { - indices(j, i) = pqueue[j].index; - products(j, i) = pqueue[j].product; + indices(j, i) = pqueue[j].second; + products(j, i) = pqueue[j].first; } } } @@ -127,7 +126,7 @@ double FastMKSRules::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. - const double bestKernel = candidates[queryIndex].front().product; + const double bestKernel = candidates[queryIndex].front().first; // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); @@ -410,7 +409,7 @@ double FastMKSRules::Rescore(const size_t queryIndex, TreeType& /*referenceNode*/, const double oldScore) const { - const double bestKernel = candidates[queryIndex].front().product; + const double bestKernel = candidates[queryIndex].front().first; return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX; } @@ -458,10 +457,10 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) { const size_t point = queryNode.Point(i); const CandidateList& candidatesPoints = candidates[point]; - if (candidatesPoints.front().product < worstPointKernel) - worstPointKernel = candidatesPoints.front().product; + if (candidatesPoints.front().first < worstPointKernel) + worstPointKernel = candidatesPoints.front().first; - if (candidatesPoints.front().product == -DBL_MAX) + if (candidatesPoints.front().first == -DBL_MAX) continue; // Avoid underflow. // This should be (queryDescendantDistance + centroidDistance) for any tree @@ -478,8 +477,8 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) double worstPointCandidateKernel = DBL_MAX; for (size_t j = 0; j < candidatesPoints.size(); ++j) { - const double candidateKernel = candidatesPoints[j].product - - queryDescendantDistance * referenceKernels[candidatesPoints[j].index]; + const double candidateKernel = candidatesPoints[j].first - + queryDescendantDistance * referenceKernels[candidatesPoints[j].second]; if (candidateKernel < worstPointCandidateKernel) worstPointCandidateKernel = candidateKernel; } @@ -526,13 +525,13 @@ inline void FastMKSRules::InsertNeighbor( const size_t index, const double product) { - Candidate c(product, index); CandidateList& pqueue = candidates[queryIndex]; - if (c > pqueue.front()) + if (product > pqueue.front().first) { - std::pop_heap(pqueue.begin(), pqueue.end(), std::greater()); + Candidate c = std::make_pair(product, index); + std::pop_heap(pqueue.begin(), pqueue.end(), CandidateCmp()); pqueue.back() = c; - std::push_heap(pqueue.begin(), pqueue.end(), std::greater()); + std::push_heap(pqueue.begin(), pqueue.end(), CandidateCmp()); } } diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp index 45284bae294..c622fbb2069 100644 --- a/src/mlpack/methods/lsh/lsh_search.hpp +++ b/src/mlpack/methods/lsh/lsh_search.hpp @@ -428,28 +428,21 @@ class LSHSearch //! The number of distance evaluations. size_t distanceEvaluations; - //! Candidate represents a possible candidate neighbor (from the reference - // set). - struct Candidate - { - //! Distance between the reference point and the query point. - double dist; - //! Index of the reference point. - size_t index; - //! Trivial constructor. - Candidate(double d, size_t i) : - dist(d), - index(i) - {}; - //! Compare the distance of two candidates. - friend bool operator<(const Candidate& l, const Candidate& r) + //! Candidate represents a possible candidate neighbor (distance, index). + typedef std::pair Candidate; + + //! Compare two candidates based on the distance. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return !SortPolicy::IsBetter(r.dist, l.dist); + return !SortPolicy::IsBetter(c2.first, c1.first); }; }; //! Use a priority queue to represent the list of candidate neighbors. - typedef std::priority_queue CandidateList; + typedef std::priority_queue, CandidateCmp> + CandidateList; + }; // class LSHSearch } // namespace neighbor diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp index d0b53aeebe8..bdef1be5b53 100644 --- a/src/mlpack/methods/lsh/lsh_search_impl.hpp +++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp @@ -275,9 +275,10 @@ void LSHSearch::BaseCase(const size_t queryIndex, // Let's build the list of candidate neighbors for the given query point. // It will be initialized with k candidates: // (WorstDistance, referenceSet->n_cols) - const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols); + const Candidate def = std::make_pair(SortPolicy::WorstDistance(), + referenceSet->n_cols); std::vector vect(k, def); - CandidateList pqueue(std::less(), std::move(vect)); + CandidateList pqueue(CandidateCmp(), std::move(vect)); for (size_t j = 0; j < referenceIndices.n_elem; ++j) { @@ -290,9 +291,9 @@ void LSHSearch::BaseCase(const size_t queryIndex, referenceSet->unsafe_col(queryIndex), referenceSet->unsafe_col(referenceIndex)); - Candidate c(distance, referenceIndex); + Candidate c = std::make_pair(distance, referenceIndex); // If this distance is better than the worst candidate, let's insert it. - if (c < pqueue.top()) + if (CandidateCmp()(c, pqueue.top())) { pqueue.pop(); pqueue.push(c); @@ -301,8 +302,8 @@ void LSHSearch::BaseCase(const size_t queryIndex, for (size_t j = 1; j <= k; j++) { - neighbors(k - j, queryIndex) = pqueue.top().index; - distances(k - j, queryIndex) = pqueue.top().dist; + neighbors(k - j, queryIndex) = pqueue.top().second; + distances(k - j, queryIndex) = pqueue.top().first; pqueue.pop(); } } @@ -320,9 +321,10 @@ void LSHSearch::BaseCase(const size_t queryIndex, // Let's build the list of candidate neighbors for the given query point. // It will be initialized with k candidates: // (WorstDistance, referenceSet->n_cols) - const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols); + const Candidate def = std::make_pair(SortPolicy::WorstDistance(), + referenceSet->n_cols); std::vector vect(k, def); - CandidateList pqueue(std::less(), std::move(vect)); + CandidateList pqueue(CandidateCmp(), std::move(vect)); for (size_t j = 0; j < referenceIndices.n_elem; ++j) { @@ -331,9 +333,9 @@ void LSHSearch::BaseCase(const size_t queryIndex, querySet.unsafe_col(queryIndex), referenceSet->unsafe_col(referenceIndex)); - Candidate c(distance, referenceIndex); + Candidate c = std::make_pair(distance, referenceIndex); // If this distance is better than the worst candidate, let's insert it. - if (c < pqueue.top()) + if (CandidateCmp()(c, pqueue.top())) { pqueue.pop(); pqueue.push(c); @@ -342,8 +344,8 @@ void LSHSearch::BaseCase(const size_t queryIndex, for (size_t j = 1; j <= k; j++) { - neighbors(k - j, queryIndex) = pqueue.top().index; - distances(k - j, queryIndex) = pqueue.top().dist; + neighbors(k - j, queryIndex) = pqueue.top().second; + distances(k - j, queryIndex) = pqueue.top().first; pqueue.pop(); } } diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp index 25e717529cd..a44d06a1efa 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp @@ -143,28 +143,20 @@ class NeighborSearchRules //! The query set. const typename TreeType::Mat& querySet; - //! Candidate represents a possible candidate neighbor (from the reference - // set). - struct Candidate - { - //! Distance between the reference point and the query point. - double dist; - //! Index of the reference point. - size_t index; - //! Trivial constructor. - Candidate(double d, size_t i) : - dist(d), - index(i) - {}; - //! Compare the distance of two candidates. - friend bool operator<(const Candidate& l, const Candidate& r) + //! Candidate represents a possible candidate neighbor (distance, index). + typedef std::pair Candidate; + + //! Compare two candidates based on the distance. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return !SortPolicy::IsBetter(r.dist, l.dist); + return !SortPolicy::IsBetter(c2.first, c1.first); }; }; //! Use a priority queue to represent the list of candidate neighbors. - typedef std::priority_queue CandidateList; + typedef std::priority_queue, CandidateCmp> + CandidateList; //! Set of candidate neighbors for each point. std::vector candidates; diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 65d258eee63..e40d09eda84 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -42,10 +42,11 @@ NeighborSearchRules::NeighborSearchRules( // It will be initialized with k candidates: (WorstDistance, size_t() - 1) // The list of candidates will be updated when visiting new points with the // BaseCase() method. - const Candidate def(SortPolicy::WorstDistance(), size_t() - 1); + const Candidate def = std::make_pair(SortPolicy::WorstDistance(), + size_t() - 1); std::vector vect(k, def); - CandidateList pqueue(std::less(), std::move(vect)); + CandidateList pqueue(CandidateCmp(), std::move(vect)); candidates.reserve(querySet.n_cols); for (size_t i = 0; i < querySet.n_cols; i++) @@ -65,8 +66,8 @@ void NeighborSearchRules::GetResults( CandidateList& pqueue = candidates[i]; for (size_t j = 1; j <= k; j++) { - neighbors(k - j, i) = pqueue.top().index; - distances(k - j, i) = pqueue.top().dist; + neighbors(k - j, i) = pqueue.top().second; + distances(k - j, i) = pqueue.top().first; pqueue.pop(); } } @@ -136,7 +137,7 @@ inline double NeighborSearchRules::Score( } // Compare against the best k'th distance for this query point so far. - double bestDistance = candidates[queryIndex].top().dist; + double bestDistance = candidates[queryIndex].top().first; bestDistance = SortPolicy::Relax(bestDistance, epsilon); return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX; @@ -153,7 +154,7 @@ inline double NeighborSearchRules::Rescore( return oldScore; // Just check the score again against the distances. - double bestDistance = candidates[queryIndex].top().dist; + double bestDistance = candidates[queryIndex].top().first; bestDistance = SortPolicy::Relax(bestDistance, epsilon); return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX; @@ -376,7 +377,7 @@ inline double NeighborSearchRules:: // Loop over points held in the node. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { - const double distance = candidates[queryNode.Point(i)].top().dist; + const double distance = candidates[queryNode.Point(i)].top().first; if (SortPolicy::IsBetter(worstDistance, distance)) worstDistance = distance; if (SortPolicy::IsBetter(distance, bestPointDistance)) @@ -467,9 +468,10 @@ InsertNeighbor( const size_t neighbor, const double distance) { - Candidate c(distance, neighbor); CandidateList& pqueue = candidates[queryIndex]; - if (c < pqueue.top()) + Candidate c = std::make_pair(distance, neighbor); + + if (CandidateCmp()(c, pqueue.top())) { pqueue.pop(); pqueue.push(c); diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp index c7edeba9ebc..93b7a8c2ebc 100644 --- a/src/mlpack/methods/rann/ra_search_rules.hpp +++ b/src/mlpack/methods/rann/ra_search_rules.hpp @@ -243,28 +243,20 @@ class RASearchRules //! The query set. const arma::mat& querySet; - //! Candidate represents a possible candidate neighbor (from the reference - // set). - struct Candidate - { - //! Distance between the reference point and the query point. - double dist; - //! Index of the reference point. - size_t index; - //! Trivial constructor. - Candidate(double d, size_t i) : - dist(d), - index(i) - {}; - //! Compare the distance of two candidates. - friend bool operator<(const Candidate& l, const Candidate& r) + //! Candidate represents a possible candidate neighbor (distance, index). + typedef std::pair Candidate; + + //! Compare two candidates based on the distance. + struct CandidateCmp { + bool operator()(const Candidate& c1, const Candidate& c2) { - return !SortPolicy::IsBetter(r.dist, l.dist); + return !SortPolicy::IsBetter(c2.first, c1.first); }; }; //! Use a priority queue to represent the list of candidate neighbors. - typedef std::priority_queue CandidateList; + typedef std::priority_queue, CandidateCmp> + CandidateList; //! Set of candidate neighbors for each point. std::vector candidates; diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp index bad3e24932a..9a886db53d8 100644 --- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp @@ -69,10 +69,11 @@ RASearchRules(const arma::mat& referenceSet, // It will be initialized with k candidates: (WorstDistance, size_t() - 1) // The list of candidates will be updated when visiting new points with the // BaseCase() method. - const Candidate def(SortPolicy::WorstDistance(), size_t() - 1); + const Candidate def = std::make_pair(SortPolicy::WorstDistance(), + size_t() - 1); std::vector vect(k, def); - CandidateList pqueue(std::less(), std::move(vect)); + CandidateList pqueue(CandidateCmp(), std::move(vect)); candidates.reserve(querySet.n_cols); for (size_t i = 0; i < querySet.n_cols; i++) @@ -104,8 +105,8 @@ void RASearchRules::GetResults( CandidateList& pqueue = candidates[i]; for (size_t j = 1; j <= k; j++) { - neighbors(k - j, i) = pqueue.top().index; - distances(k - j, i) = pqueue.top().dist; + neighbors(k - j, i) = pqueue.top().second; + distances(k - j, i) = pqueue.top().first; pqueue.pop(); } } @@ -143,7 +144,7 @@ inline double RASearchRules::Score( const arma::vec queryPoint = querySet.unsafe_col(queryIndex); const double distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode); - const double bestDistance = candidates[queryIndex].top().dist; + const double bestDistance = candidates[queryIndex].top().first; return Score(queryIndex, referenceNode, distance, bestDistance); } @@ -157,7 +158,7 @@ inline double RASearchRules::Score( const arma::vec queryPoint = querySet.unsafe_col(queryIndex); const double distance = SortPolicy::BestPointToNodeDistance(queryPoint, &referenceNode, baseCaseResult); - const double bestDistance = candidates[queryIndex].top().dist; + const double bestDistance = candidates[queryIndex].top().first; return Score(queryIndex, referenceNode, distance, bestDistance); } @@ -271,7 +272,7 @@ Rescore(const size_t queryIndex, return oldScore; // Just check the score again against the distances. - const double bestDistance = candidates[queryIndex].top().dist; + const double bestDistance = candidates[queryIndex].top().first; // If this is better than the best distance we've seen so far, // maybe there will be something down this node. @@ -371,7 +372,7 @@ inline double RASearchRules::Score( for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = candidates[queryNode.Point(i)].top().dist + const double bound = candidates[queryNode.Point(i)].top().first + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -410,7 +411,7 @@ inline double RASearchRules::Score( for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = candidates[queryNode.Point(i)].top().dist + const double bound = candidates[queryNode.Point(i)].top().first + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -624,7 +625,7 @@ Rescore(TreeType& queryNode, for (size_t i = 0; i < queryNode.NumPoints(); i++) { - const double bound = candidates[queryNode.Point(i)].top().dist + const double bound = candidates[queryNode.Point(i)].top().first + maxDescendantDistance; if (bound < pointBound) pointBound = bound; @@ -809,9 +810,10 @@ InsertNeighbor( const size_t neighbor, const double distance) { - Candidate c(distance, neighbor); CandidateList& pqueue = candidates[queryIndex]; - if (c < pqueue.top()) + Candidate c = std::make_pair(distance, neighbor); + + if (CandidateCmp()(c, pqueue.top())) { pqueue.pop(); pqueue.push(c); From a79ae634f673051ce6a8d319e5b36f182f23afc1 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Tue, 26 Jul 2016 16:29:19 -0300 Subject: [PATCH 15/15] Use boost::heap::priority_queue instead of push_heap()/pop_heap(). --- src/mlpack/methods/fastmks/fastmks_rules.hpp | 12 +++--- .../methods/fastmks/fastmks_rules_impl.hpp | 43 +++++++++---------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp index 9aca42d3140..49071c946b1 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace mlpack { @@ -129,17 +130,18 @@ class FastMKSRules //! Compare two candidates based on the value. struct CandidateCmp { - bool operator()(const Candidate& c1, const Candidate& c2) + bool operator()(const Candidate& c1, const Candidate& c2) const { return c1.first > c2.first; }; }; //! Use a min heap to represent the list of candidate points. - //! We will use a vector and the std functions: push_heap() pop_heap(). - //! We can not use a priority queue because we need to iterate over all the - //! candidates and std::priority_queue doesn't provide that interface. - typedef std::vector CandidateList; + //! We will use a boost::heap::priority_queue instead of a std::priority_queue + //! because we need to iterate over all the candidates and std::priority_queue + //! doesn't provide that interface. + typedef boost::heap::priority_queue> CandidateList; //! Set of candidates for each point. std::vector candidates; diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp index 6efc9a7d96d..fe6e09b8065 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp @@ -52,8 +52,11 @@ FastMKSRules::FastMKSRules( // BaseCase() method. const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1); - CandidateList cList(k, def); - std::vector tmp(querySet.n_cols, cList); + CandidateList pqueue; + pqueue.reserve(k); + for (size_t i = 0; i < k; i++) + pqueue.push(def); + std::vector tmp(querySet.n_cols, pqueue); candidates.swap(tmp); } @@ -68,15 +71,11 @@ void FastMKSRules::GetResults( for (size_t i = 0; i < querySet.n_cols; i++) { CandidateList& pqueue = candidates[i]; - typedef typename CandidateList::iterator Iterator; - - for (Iterator end = pqueue.end(); end != pqueue.begin(); --end) - std::pop_heap(pqueue.begin(), end, CandidateCmp()); - - for (size_t j = 0; j < k; j++) + for (size_t j = 1; j <= k; j++) { - indices(j, i) = pqueue[j].second; - products(j, i) = pqueue[j].first; + indices(k - j, i) = pqueue.top().second; + products(k - j, i) = pqueue.top().first; + pqueue.pop(); } } } @@ -126,7 +125,7 @@ double FastMKSRules::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. - const double bestKernel = candidates[queryIndex].front().first; + const double bestKernel = candidates[queryIndex].top().first; // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); @@ -409,7 +408,7 @@ double FastMKSRules::Rescore(const size_t queryIndex, TreeType& /*referenceNode*/, const double oldScore) const { - const double bestKernel = candidates[queryIndex].front().first; + const double bestKernel = candidates[queryIndex].top().first; return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX; } @@ -457,10 +456,10 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) { const size_t point = queryNode.Point(i); const CandidateList& candidatesPoints = candidates[point]; - if (candidatesPoints.front().first < worstPointKernel) - worstPointKernel = candidatesPoints.front().first; + if (candidatesPoints.top().first < worstPointKernel) + worstPointKernel = candidatesPoints.top().first; - if (candidatesPoints.front().first == -DBL_MAX) + if (candidatesPoints.top().first == -DBL_MAX) continue; // Avoid underflow. // This should be (queryDescendantDistance + centroidDistance) for any tree @@ -475,10 +474,11 @@ double FastMKSRules::CalculateBound(TreeType& queryNode) // where p_j^*(p_q) is the j'th kernel candidate for query point p_q and // k_j^*(p_q) is K(p_q, p_j^*(p_q)). double worstPointCandidateKernel = DBL_MAX; - for (size_t j = 0; j < candidatesPoints.size(); ++j) + typedef typename CandidateList::const_iterator iter; + for (iter it = candidatesPoints.begin(); it != candidatesPoints.end(); ++it) { - const double candidateKernel = candidatesPoints[j].first - - queryDescendantDistance * referenceKernels[candidatesPoints[j].second]; + const double candidateKernel = it->first - queryDescendantDistance * + referenceKernels[it->second]; if (candidateKernel < worstPointCandidateKernel) worstPointCandidateKernel = candidateKernel; } @@ -526,12 +526,11 @@ inline void FastMKSRules::InsertNeighbor( const double product) { CandidateList& pqueue = candidates[queryIndex]; - if (product > pqueue.front().first) + if (product > pqueue.top().first) { Candidate c = std::make_pair(product, index); - std::pop_heap(pqueue.begin(), pqueue.end(), CandidateCmp()); - pqueue.back() = c; - std::push_heap(pqueue.begin(), pqueue.end(), CandidateCmp()); + pqueue.pop(); + pqueue.push(c); } }