From 88f1bb955493fe06a5c6521142d188a135972474 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Mon, 27 Jun 2016 00:32:44 +0300 Subject: [PATCH 01/12] Added vantage point trees --- src/mlpack/core/tree/CMakeLists.txt | 2 + src/mlpack/core/tree/binary_space_tree.hpp | 1 + .../core/tree/binary_space_tree/typedef.hpp | 7 + .../binary_space_tree/vantage_point_split.hpp | 108 ++++++++ .../vantage_point_split_impl.hpp | 242 ++++++++++++++++++ src/mlpack/tests/tree_test.cpp | 46 ++++ 6 files changed, 406 insertions(+) create mode 100644 src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp create mode 100644 src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt index 36f9e836f61..62696e65b19 100644 --- a/src/mlpack/core/tree/CMakeLists.txt +++ b/src/mlpack/core/tree/CMakeLists.txt @@ -16,6 +16,8 @@ set(SOURCES binary_space_tree/midpoint_split_impl.hpp binary_space_tree/single_tree_traverser.hpp binary_space_tree/single_tree_traverser_impl.hpp + binary_space_tree/vantage_point_split.hpp + binary_space_tree/vantage_point_split_impl.hpp binary_space_tree/traits.hpp binary_space_tree/typedef.hpp bounds.hpp diff --git a/src/mlpack/core/tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree.hpp index e37b7b25670..ba549817a24 100644 --- a/src/mlpack/core/tree/binary_space_tree.hpp +++ b/src/mlpack/core/tree/binary_space_tree.hpp @@ -11,6 +11,7 @@ #include "bounds.hpp" #include "binary_space_tree/midpoint_split.hpp" #include "binary_space_tree/mean_split.hpp" +#include "binary_space_tree/vantage_point_split.hpp" #include "binary_space_tree/binary_space_tree.hpp" #include "binary_space_tree/single_tree_traverser.hpp" #include "binary_space_tree/single_tree_traverser_impl.hpp" diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp index 28145d11bb2..c9d5b215f72 100644 --- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp +++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp @@ -135,6 +135,13 @@ using MeanSplitBallTree = BinarySpaceTree; +template +using VantagePointTree = BinarySpaceTree; + } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp new file mode 100644 index 00000000000..625c26d8b75 --- /dev/null +++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp @@ -0,0 +1,108 @@ +/** + * @file vantage_point_split.hpp + * @author Mikhail Lozhnikov + * + */ +#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP +#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP + +#include + +namespace mlpack { +namespace tree /** Trees and tree-building procedures. */ { + +template +class VantagePointSplit +{ + public: + typedef typename MatType::elem_type ElemType; + /** + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param begin Index of the starting point in the dataset that belongs to + * this node. + * @param count Number of points in this node. + * @param splitCol The index at which the dataset is divided into two parts + * after the rearrangement. + */ + static bool SplitNode(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + size_t& splitCol); + + /** + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param begin Index of the starting point in the dataset that belongs to + * this node. + * @param count Number of points in this node. + * @param splitCol The index at which the dataset is divided into two parts + * after the rearrangement. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + */ + static bool SplitNode(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + size_t& splitCol, + std::vector& oldFromNew); + private: + static const size_t maxNumSamples = 1000; + + template + struct SortStruct + { + size_t point; + size_t n; + ElemType dist; + }; + + template + static bool StructComp(const SortStruct& s1, + const SortStruct& s2) + { + return (s1.dist < s2.dist); + }; + + static void SelectVantagePoint(const BoundType& bound, const MatType& data, + const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu); + + static void GetDistinctSamples(arma::uvec& distinctSamples, + const size_t numSamples, const size_t begin, const size_t upperBound); + + static void GetMedian(const BoundType& bound, const MatType& data, + const arma::uvec& samples, const size_t vantagePoint, ElemType& mu); + + static ElemType GetSecondMoment(const BoundType& bound, const MatType& data, + const arma::uvec& samples, const size_t vantagePoint); + + static bool IsContainedInBall(const BoundType& bound, const MatType& mat, + const size_t vantagePoint, const size_t point, const ElemType mu); + + static size_t PerformSplit(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + const size_t vantagePoint, + const ElemType mu); + + static size_t PerformSplit(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + const size_t vantagePoint, + const ElemType mu, + std::vector& oldFromNew); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation. +#include "vantage_point_split_impl.hpp" + +#endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp new file mode 100644 index 00000000000..3b0ceffc38c --- /dev/null +++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp @@ -0,0 +1,242 @@ +/** + * @file vantage_point_split_impl.hpp + * @author Mikhail Lozhnikov + * + * Implementation of class (VantagePointSplit) to split a binary space partition + * tree. + */ +#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP +#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP + +#include "vantage_point_split.hpp" +#include + +namespace mlpack { +namespace tree { + +template +bool VantagePointSplit:: +SplitNode(const BoundType& bound, MatType& data, const size_t begin, + const size_t count, size_t& splitCol) +{ + typename BoundType::ElemType mu; + size_t vantagePoint; + + SelectVantagePoint(bound, data, begin, count, vantagePoint, mu); + + splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu); + return true; +} + +template +bool VantagePointSplit:: +SplitNode(const BoundType& bound, MatType& data, const size_t begin, + const size_t count, size_t& splitCol, std::vector& oldFromNew) +{ + ElemType mu; + size_t vantagePoint; + + SelectVantagePoint(bound, data, begin, count, vantagePoint, mu); + + splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu, oldFromNew); + return true; +} + +template +size_t VantagePointSplit::PerformSplit(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + const size_t vantagePoint, + const ElemType mu) +{ + // This method modifies the input dataset. We loop both from the left and + // right sides of the points contained in this node. The points less than + // splitVal should be on the left side of the matrix, and the points greater + // than splitVal should be on the right side of the matrix. + size_t left = begin; + size_t right = begin + count - 1; + + // First half-iteration of the loop is out here because the termination + // condition is in the middle. + while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + left++; + while ((!IsContainedInBall(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) + right--; + + while (left <= right) + { + // Swap columns. + data.swap_cols(left, right); + + // See how many points on the left are correct. When they are correct, + // increase the left counter accordingly. When we encounter one that isn't + // correct, stop. We will switch it later. + while ((IsContainedInBall(bound, data, vantagePoint, left, mu)) && (left <= right)) + left++; + + // Now see how many points on the right are correct. When they are correct, + // decrease the right counter accordingly. When we encounter one that isn't + // correct, stop. We will switch it with the wrong point we found in the + // previous loop. + while ((!IsContainedInBall(bound, data, vantagePoint, right, mu)) && (left <= right)) + right--; + } + + Log::Assert(left == right + 1); + + return left; +} + +template +size_t VantagePointSplit::PerformSplit(const BoundType& bound, + MatType& data, + const size_t begin, + const size_t count, + const size_t vantagePoint, + const ElemType mu, + std::vector& oldFromNew) +{ + // This method modifies the input dataset. We loop both from the left and + // right sides of the points contained in this node. The points less than + // splitVal should be on the left side of the matrix, and the points greater + // than splitVal should be on the right side of the matrix. + size_t left = begin; + size_t right = begin + count - 1; + + // First half-iteration of the loop is out here because the termination + // condition is in the middle. + while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + left++; + while (!IsContainedInBall(bound, data, vantagePoint, right, mu) && (left <= right) && (right > 0)) + right--; + + while (left <= right) + { + // Swap columns. + data.swap_cols(left, right); + + // Update the indices for what we changed. + size_t t = oldFromNew[left]; + oldFromNew[left] = oldFromNew[right]; + oldFromNew[right] = t; + + // See how many points on the left are correct. When they are correct, + // increase the left counter accordingly. When we encounter one that isn't + // correct, stop. We will switch it later. + while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + left++; + + // Now see how many points on the right are correct. When they are correct, + // decrease the right counter accordingly. When we encounter one that isn't + // correct, stop. We will switch it with the wrong point we found in the + // previous loop. + while (!IsContainedInBall(bound, data, vantagePoint, right, mu) && (left <= right)) + right--; + } + + Log::Assert(left == right + 1); + + return left; +} + +template +void VantagePointSplit:: +SelectVantagePoint(const BoundType& bound, const MatType& data, + const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu) +{ + arma::uvec vantagePointCandidates; + + GetDistinctSamples(vantagePointCandidates, maxNumSamples, begin, count); + + ElemType bestSpread = 0; + + for (size_t i = 0; i < vantagePointCandidates.n_rows; i++) + { + arma::uvec samples; + + GetDistinctSamples(samples, maxNumSamples, begin, count); + + const ElemType spread = GetSecondMoment(bound, data, samples, + vantagePointCandidates[i]); + + if (spread > bestSpread) + { + bestSpread = spread; + vantagePoint = vantagePointCandidates[i]; + GetMedian(bound, data, samples, vantagePoint, mu); + } + } +} + +template +void VantagePointSplit:: +GetDistinctSamples(arma::uvec& distinctSamples, const size_t numSamples, + const size_t begin, const size_t upperBound) +{ + arma::Col samples; + + samples.zeros(upperBound); + + for (size_t i = 0; i < numSamples; i++) + samples [ (size_t) math::RandInt(upperBound) ]++; + + distinctSamples = arma::find(samples > 0); + + distinctSamples += begin; +} + +template +void VantagePointSplit:: +GetMedian(const BoundType& bound, const MatType& data, + const arma::uvec& samples, const size_t vantagePoint, ElemType& mu) +{ + std::vector> sorted(samples.n_rows); + + for (size_t i = 0; i < samples.n_rows; i++) + { + sorted[i].point = samples[i]; + sorted[i].dist = bound.Metric().Evaluate(data.col(vantagePoint), + data.col(samples[i])); + sorted[i].n = i; + } + + std::sort(sorted.begin(), sorted.end(), StructComp); + + mu = bound.Metric().Evaluate(data.col(vantagePoint), + data.col(sorted[sorted.size() / 2].n)); +} + +template +typename MatType::elem_type VantagePointSplit:: +GetSecondMoment(const BoundType& bound, const MatType& data, + const arma::uvec& samples, const size_t vantagePoint) +{ + ElemType moment = 0; + + for (size_t i = 0; i < samples.size(); i++) + { + const ElemType dist = + bound.Metric().Evaluate(data.col(vantagePoint), data.col(samples[i])); + + moment += dist * dist; + } + + return moment; +} + +template +bool VantagePointSplit:: +IsContainedInBall(const BoundType& bound, const MatType& mat, + const size_t vantagePoint, const size_t point, const ElemType mu) +{ + if (bound.Metric().Evaluate(mat.col(vantagePoint), mat.col(point)) < mu) + return true; + + return false; +} + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp index 81a94463b25..12f3678c338 100644 --- a/src/mlpack/tests/tree_test.cpp +++ b/src/mlpack/tests/tree_test.cpp @@ -1430,6 +1430,52 @@ BOOST_AUTO_TEST_CASE(BallTreeTest) } } +BOOST_AUTO_TEST_CASE(VantagePointTreeTest) +{ + typedef VantagePointTree TreeType; + + size_t maxRuns = 10; // Ten total tests. + size_t pointIncrements = 1000; // Range is from 2000 points to 11000. + + // We use the default leaf size of 20. + for (size_t run = 0; run < maxRuns; run++) + { + size_t dimensions = run + 2; + size_t maxPoints = (run + 1) * pointIncrements; + + size_t size = maxPoints; + arma::mat dataset = arma::mat(dimensions, size); + arma::mat datacopy; // Used to test mappings. + + // Mappings for post-sort verification of data. + std::vector newToOld; + std::vector oldToNew; + + // Generate data. + dataset.randu(); + + // Build the tree itself. + TreeType root(dataset, newToOld, oldToNew); + const arma::mat& treeset = root.Dataset(); + + // Ensure the size of the tree is correct. + BOOST_REQUIRE_EQUAL(root.NumDescendants(), size); + + // Check the forward and backward mappings for correctness. + for(size_t i = 0; i < size; i++) + { + for(size_t j = 0; j < dimensions; j++) + { + BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i])); + BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i)); + } + } + + // Now check that each point is contained inside of all bounds above it. + CheckPointBounds(root); + } +} + template bool DoBoundsIntersect(HRectBound& a, HRectBound& b) From 6d8976e211e46c57a11912eb8fce0ab2f891b8d3 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Mon, 27 Jun 2016 15:10:52 +0300 Subject: [PATCH 02/12] Added some vantage point tree fixes and comments. --- .../binary_space_tree/vantage_point_split.hpp | 101 +++++++++++++++- .../vantage_point_split_impl.hpp | 111 +++++++++++++----- 2 files changed, 174 insertions(+), 38 deletions(-) diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp index 625c26d8b75..f587e6592ce 100644 --- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp +++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp @@ -2,6 +2,8 @@ * @file vantage_point_split.hpp * @author Mikhail Lozhnikov * + * Definition of class VantagePointSplit, a class that splits a binary space + * partitioning into two parts using the distance to a certain vantage point. */ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP @@ -17,6 +19,7 @@ class VantagePointSplit public: typedef typename MatType::elem_type ElemType; /** + * Split the node according to the distance to a vantage point. * * @param bound The bound used for this node. * @param data The dataset used by the binary space tree. @@ -33,6 +36,7 @@ class VantagePointSplit size_t& splitCol); /** + * Split the node according to the distance to a vantage point. * * @param bound The bound used for this node. * @param data The dataset used by the binary space tree. @@ -51,13 +55,16 @@ class VantagePointSplit size_t& splitCol, std::vector& oldFromNew); private: - static const size_t maxNumSamples = 1000; + /** + * The maximum number of samples used for vantage point estimation and for + * estimation of the median. + */ + static const size_t maxNumSamples = 100; template struct SortStruct { size_t point; - size_t n; ElemType dist; }; @@ -68,33 +75,115 @@ class VantagePointSplit return (s1.dist < s2.dist); }; + /** + * Select the best vantage point i.e. the point with the largest second moment + * of the distance from a number of random node points to the vantage point. + * Firstly this methods selects no more than maxNumSamples random points. + * Then it evaluates each point i.e. calcilates the corresponding second + * moment and selects the point with the largest moment. Each random point + * belongs to the node. + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param begin Index of the starting point in the dataset that belongs to + * this node. + * @param count Number of points in this node. + * @param vantagePoint The index of the vantage point in the dataset. + * @param mu The median value of distance form the vantage point to + * a number of random points. + */ static void SelectVantagePoint(const BoundType& bound, const MatType& data, const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu); + /** + * Find no more then max(numSamples, upperBound) random samples i.e. + * random points that belong to the node. Each sample belongs to + * the interval [begin, begin + upperBound) + * + * @param distinctSamples The vector of samples indices. + * @param numSamples Maximum number of samples. + * @param begin The least index. + * @param upperBound The upper bound of indices. + */ static void GetDistinctSamples(arma::uvec& distinctSamples, const size_t numSamples, const size_t begin, const size_t upperBound); + /** + * Get the median value of the distance from a certain vantage point to a + * number of samples. + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param samples The indices of random samples. + * @param vantagePoint The vantage point. + * @param mu The median value. + */ static void GetMedian(const BoundType& bound, const MatType& data, const arma::uvec& samples, const size_t vantagePoint, ElemType& mu); + /** + * Calculate the second moment of the distance from a certain vantage point to + * a number of random samples. + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param samples The indices of random samples. + * @param vantagePoint The vantage point. + */ static ElemType GetSecondMoment(const BoundType& bound, const MatType& data, const arma::uvec& samples, const size_t vantagePoint); - static bool IsContainedInBall(const BoundType& bound, const MatType& mat, - const size_t vantagePoint, const size_t point, const ElemType mu); + /** + * This method returns true if a point should be assigned to the left subtree + * i.e. the distance from the point to the vantage point is less then + * the median value. Otherwise it returns false. + * + * @param bound The bound used for this node. + * @param data The dataset used by the binary space tree. + * @param vantagePoint The vantage point. + * @param point The point that is being assigned. + * @param mu The median value. + */ + template + static bool AssignToLeftSubtree(const BoundType& bound, const MatType& mat, + const VecType& vantagePoint, const size_t point, const ElemType mu); + /** + * Perform split according to the median value and the vantage point. + * + * @param data The dataset used by the binary space tree. + * @param begin Index of the starting point in the dataset that belongs to + * this node. + * @param count Number of points in this node. + * @param vantagePoint The vantage point. + * @param mu The median value. + */ + template static size_t PerformSplit(const BoundType& bound, MatType& data, const size_t begin, const size_t count, - const size_t vantagePoint, + const VecType& vantagePoint, const ElemType mu); + /** + * Perform split according to the median value and the vantage point. + * + * @param data The dataset used by the binary space tree. + * @param begin Index of the starting point in the dataset that belongs to + * this node. + * @param count Number of points in this node. + * @param vantagePoint The vantage point. + * @param mu The median value. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + */ + template static size_t PerformSplit(const BoundType& bound, MatType& data, const size_t begin, const size_t count, - const size_t vantagePoint, + const VecType& vantagePoint, const ElemType mu, std::vector& oldFromNew); }; diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp index 3b0ceffc38c..e762eb9affa 100644 --- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp +++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp @@ -3,7 +3,7 @@ * @author Mikhail Lozhnikov * * Implementation of class (VantagePointSplit) to split a binary space partition - * tree. + * tree according to the median value of the distance to a certain vantage point. */ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP @@ -20,11 +20,20 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, const size_t count, size_t& splitCol) { typename BoundType::ElemType mu; - size_t vantagePoint; + size_t vantagePointIndex; - SelectVantagePoint(bound, data, begin, count, vantagePoint, mu); + // Find the best vantage point + SelectVantagePoint(bound, data, begin, count, vantagePointIndex, mu); + // All points are equal + if (mu == 0) + return false; + + arma::Col vantagePoint = data.col(vantagePointIndex); splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu); + + assert(splitCol > begin); + assert(splitCol < begin + count); return true; } @@ -34,34 +43,46 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, const size_t count, size_t& splitCol, std::vector& oldFromNew) { ElemType mu; - size_t vantagePoint; + size_t vantagePointIndex; + + // Find the best vantage point + SelectVantagePoint(bound, data, begin, count, vantagePointIndex, mu); + + // All points are equal + if (mu == 0) + return false; - SelectVantagePoint(bound, data, begin, count, vantagePoint, mu); + arma::Col vantagePoint = data.col(vantagePointIndex); splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu, oldFromNew); + + assert(splitCol > begin); + assert(splitCol < begin + count); return true; } template +template size_t VantagePointSplit::PerformSplit(const BoundType& bound, MatType& data, const size_t begin, const size_t count, - const size_t vantagePoint, + const VecType& vantagePoint, const ElemType mu) { // This method modifies the input dataset. We loop both from the left and - // right sides of the points contained in this node. The points less than - // splitVal should be on the left side of the matrix, and the points greater - // than splitVal should be on the right side of the matrix. + // right sides of the points contained in this node. The points closer to + // the vantage point should be on the left side of the matrix, and the farther + // from the vantage point should be on the right side of the matrix. size_t left = begin; size_t right = begin + count - 1; // First half-iteration of the loop is out here because the termination // condition is in the middle. - while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) left++; - while ((!IsContainedInBall(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) + + while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) right--; while (left <= right) @@ -72,14 +93,14 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // See how many points on the left are correct. When they are correct, // increase the left counter accordingly. When we encounter one that isn't // correct, stop. We will switch it later. - while ((IsContainedInBall(bound, data, vantagePoint, left, mu)) && (left <= right)) + while ((AssignToLeftSubtree(bound, data, vantagePoint, left, mu)) && (left <= right)) left++; // Now see how many points on the right are correct. When they are correct, // decrease the right counter accordingly. When we encounter one that isn't // correct, stop. We will switch it with the wrong point we found in the // previous loop. - while ((!IsContainedInBall(bound, data, vantagePoint, right, mu)) && (left <= right)) + while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right)) right--; } @@ -89,26 +110,29 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun } template +template size_t VantagePointSplit::PerformSplit(const BoundType& bound, MatType& data, const size_t begin, const size_t count, - const size_t vantagePoint, + const VecType& vantagePoint, const ElemType mu, std::vector& oldFromNew) { // This method modifies the input dataset. We loop both from the left and - // right sides of the points contained in this node. The points less than - // splitVal should be on the left side of the matrix, and the points greater - // than splitVal should be on the right side of the matrix. + // right sides of the points contained in this node. The points closer to + // the vantage point should be on the left side of the matrix, and the farther + // from the vantage point should be on the right side of the matrix. size_t left = begin; size_t right = begin + count - 1; // First half-iteration of the loop is out here because the termination // condition is in the middle. - while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + + while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) left++; - while (!IsContainedInBall(bound, data, vantagePoint, right, mu) && (left <= right) && (right > 0)) + + while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) right--; while (left <= right) @@ -124,14 +148,14 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // See how many points on the left are correct. When they are correct, // increase the left counter accordingly. When we encounter one that isn't // correct, stop. We will switch it later. - while (IsContainedInBall(bound, data, vantagePoint, left, mu) && (left <= right)) + while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) left++; // Now see how many points on the right are correct. When they are correct, // decrease the right counter accordingly. When we encounter one that isn't // correct, stop. We will switch it with the wrong point we found in the // previous loop. - while (!IsContainedInBall(bound, data, vantagePoint, right, mu) && (left <= right)) + while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right)) right--; } @@ -147,16 +171,21 @@ SelectVantagePoint(const BoundType& bound, const MatType& data, { arma::uvec vantagePointCandidates; + // Get no more than max(maxNumSamples, count) vantage point candidates GetDistinctSamples(vantagePointCandidates, maxNumSamples, begin, count); ElemType bestSpread = 0; + // Evaluate eache candidate for (size_t i = 0; i < vantagePointCandidates.n_rows; i++) { arma::uvec samples; + // Get no more than max(maxNumSamples, count) random samples GetDistinctSamples(samples, maxNumSamples, begin, count); + // Calculate the second moment of the distance to the vantage point candidate + // using these random samples const ElemType spread = GetSecondMoment(bound, data, samples, vantagePointCandidates[i]); @@ -164,9 +193,12 @@ SelectVantagePoint(const BoundType& bound, const MatType& data, { bestSpread = spread; vantagePoint = vantagePointCandidates[i]; + // Calculate the median value of the distance from the vantage point candidate + // to these samples GetMedian(bound, data, samples, vantagePoint, mu); } } + assert(bestSpread > 0); } template @@ -174,16 +206,26 @@ void VantagePointSplit:: GetDistinctSamples(arma::uvec& distinctSamples, const size_t numSamples, const size_t begin, const size_t upperBound) { - arma::Col samples; + if (upperBound > numSamples) + { + arma::Col samples; - samples.zeros(upperBound); + samples.zeros(upperBound); - for (size_t i = 0; i < numSamples; i++) - samples [ (size_t) math::RandInt(upperBound) ]++; + for (size_t i = 0; i < numSamples; i++) + samples [ (size_t) math::RandInt(upperBound) ]++; - distinctSamples = arma::find(samples > 0); + distinctSamples = arma::find(samples > 0); - distinctSamples += begin; + distinctSamples += begin; + } + else + { + // The node contains less points than requested + distinctSamples.set_size(upperBound); + for (size_t i = 0; i < upperBound; i++) + distinctSamples[i] = begin + i; + } } template @@ -198,13 +240,14 @@ GetMedian(const BoundType& bound, const MatType& data, sorted[i].point = samples[i]; sorted[i].dist = bound.Metric().Evaluate(data.col(vantagePoint), data.col(samples[i])); - sorted[i].n = i; } + // Sort samples according to the distance to the vantage point std::sort(sorted.begin(), sorted.end(), StructComp); + // Get the midian value mu = bound.Metric().Evaluate(data.col(vantagePoint), - data.col(sorted[sorted.size() / 2].n)); + data.col(sorted[sorted.size() / 2].point)); } template @@ -222,15 +265,19 @@ GetSecondMoment(const BoundType& bound, const MatType& data, moment += dist * dist; } + moment /= samples.size(); + return moment; } template +template bool VantagePointSplit:: -IsContainedInBall(const BoundType& bound, const MatType& mat, - const size_t vantagePoint, const size_t point, const ElemType mu) +AssignToLeftSubtree(const BoundType& bound, const MatType& mat, + const VecType& vantagePoint, const size_t point, const ElemType mu) { - if (bound.Metric().Evaluate(mat.col(vantagePoint), mat.col(point)) < mu) + // Return true if the point is close to the vantage point + if (bound.Metric().Evaluate(vantagePoint, mat.col(point)) < mu) return true; return false; From 430cbe34de83c12a83f30fe91b588459fd968f74 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Mon, 11 Jul 2016 23:33:48 +0300 Subject: [PATCH 03/12] Added a separate class for the vantage point tree. Added HollowBallBound. Added tests. --- src/mlpack/core/tree/CMakeLists.txt | 15 +- src/mlpack/core/tree/binary_space_tree.hpp | 1 - .../core/tree/binary_space_tree/typedef.hpp | 7 - src/mlpack/core/tree/bounds.hpp | 1 + src/mlpack/core/tree/hollow_ball_bound.hpp | 222 +++++ .../core/tree/hollow_ball_bound_impl.hpp | 388 ++++++++ src/mlpack/core/tree/vantage_point_tree.hpp | 21 + .../dual_tree_traverser.hpp | 102 ++ .../dual_tree_traverser_impl.hpp | 538 +++++++++++ .../single_tree_traverser.hpp | 63 ++ .../single_tree_traverser_impl.hpp | 113 +++ .../core/tree/vantage_point_tree/traits.hpp | 60 ++ .../core/tree/vantage_point_tree/typedef.hpp | 26 + .../vantage_point_split.hpp | 0 .../vantage_point_split_impl.hpp | 11 +- .../vantage_point_tree/vantage_point_tree.hpp | 218 +++++ .../vantage_point_tree_impl.hpp | 879 ++++++++++++++++++ src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/tree_test.cpp | 46 - src/mlpack/tests/vantage_point_tree_test.cpp | 291 ++++++ 20 files changed, 2945 insertions(+), 58 deletions(-) create mode 100644 src/mlpack/core/tree/hollow_ball_bound.hpp create mode 100644 src/mlpack/core/tree/hollow_ball_bound_impl.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/traits.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/typedef.hpp rename src/mlpack/core/tree/{binary_space_tree => vantage_point_tree}/vantage_point_split.hpp (100%) rename src/mlpack/core/tree/{binary_space_tree => vantage_point_tree}/vantage_point_split_impl.hpp (96%) create mode 100644 src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp create mode 100644 src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp create mode 100644 src/mlpack/tests/vantage_point_tree_test.cpp diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt index 62696e65b19..ea1a1f7f2d5 100644 --- a/src/mlpack/core/tree/CMakeLists.txt +++ b/src/mlpack/core/tree/CMakeLists.txt @@ -16,8 +16,6 @@ set(SOURCES binary_space_tree/midpoint_split_impl.hpp binary_space_tree/single_tree_traverser.hpp binary_space_tree/single_tree_traverser_impl.hpp - binary_space_tree/vantage_point_split.hpp - binary_space_tree/vantage_point_split_impl.hpp binary_space_tree/traits.hpp binary_space_tree/typedef.hpp bounds.hpp @@ -34,6 +32,8 @@ set(SOURCES cover_tree/traits.hpp cover_tree/typedef.hpp example_tree.hpp + hollow_ball_bound.hpp + hollow_ball_bound_impl.hpp hrectbound.hpp hrectbound_impl.hpp rectangle_tree.hpp @@ -56,6 +56,17 @@ set(SOURCES statistic.hpp traversal_info.hpp tree_traits.hpp + vantage_point_tree.hpp + vantage_point_tree/dual_tree_traverser.hpp + vantage_point_tree/dual_tree_traverser_impl.hpp + vantage_point_tree/single_tree_traverser.hpp + vantage_point_tree/single_tree_traverser_impl.hpp + vantage_point_tree/traits.hpp + vantage_point_tree/typedef.hpp + vantage_point_tree/vantage_point_split.hpp + vantage_point_tree/vantage_point_split_impl.hpp + vantage_point_tree/vantage_point_tree.hpp + vantage_point_tree/vantage_point_tree_impl.hpp ) # add directory name to sources diff --git a/src/mlpack/core/tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree.hpp index ba549817a24..e37b7b25670 100644 --- a/src/mlpack/core/tree/binary_space_tree.hpp +++ b/src/mlpack/core/tree/binary_space_tree.hpp @@ -11,7 +11,6 @@ #include "bounds.hpp" #include "binary_space_tree/midpoint_split.hpp" #include "binary_space_tree/mean_split.hpp" -#include "binary_space_tree/vantage_point_split.hpp" #include "binary_space_tree/binary_space_tree.hpp" #include "binary_space_tree/single_tree_traverser.hpp" #include "binary_space_tree/single_tree_traverser_impl.hpp" diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp index c9d5b215f72..28145d11bb2 100644 --- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp +++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp @@ -135,13 +135,6 @@ using MeanSplitBallTree = BinarySpaceTree; -template -using VantagePointTree = BinarySpaceTree; - } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/bounds.hpp b/src/mlpack/core/tree/bounds.hpp index 79584c98259..bedaae4b1c0 100644 --- a/src/mlpack/core/tree/bounds.hpp +++ b/src/mlpack/core/tree/bounds.hpp @@ -13,5 +13,6 @@ #include "bound_traits.hpp" #include "hrectbound.hpp" #include "ballbound.hpp" +#include "hollow_ball_bound.hpp" #endif // MLPACK_CORE_TREE_BOUNDS_HPP diff --git a/src/mlpack/core/tree/hollow_ball_bound.hpp b/src/mlpack/core/tree/hollow_ball_bound.hpp new file mode 100644 index 00000000000..5770acbef28 --- /dev/null +++ b/src/mlpack/core/tree/hollow_ball_bound.hpp @@ -0,0 +1,222 @@ +/** + * @file hollow_ball_bound.hpp + * + * Bounds that are useful for binary space partitioning trees. + * Interface to a ball bound that works in arbitrary metric spaces. + */ +#ifndef MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_HPP +#define MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_HPP + +#include +#include +#include "bound_traits.hpp" + +namespace mlpack { +namespace bound { + +/** + * Ball bound encloses a set of points at a specific distance (radius) from a + * specific point (center). MetricType is the custom metric type that defaults + * to the Euclidean (L2) distance. + * + * @tparam MetricType metric type used in the distance measure. + * @tparam VecType Type of vector (arma::vec or arma::sp_vec or similar). + */ +template, + typename VecType = arma::vec> +class HollowBallBound +{ + public: + //! The underlying data type. + typedef typename VecType::elem_type ElemType; + //! A public version of the vector type. + typedef VecType Vec; + + private: + //! The radius of the inner ball bound. + ElemType innerRadius; + //! The radius of the outer ball bound. + ElemType outerRadius; + //! The center of the ball bound. + VecType center; + //! The metric used in this bound. + MetricType* metric; + + /** + * To know whether this object allocated memory to the metric member + * variable. This will be true except in the copy constructor and the + * overloaded assignment operator. We need this to know whether we should + * delete the metric member variable in the destructor. + */ + bool ownsMetric; + + public: + + //! Empty Constructor. + HollowBallBound(); + + /** + * Create the ball bound with the specified dimensionality. + * + * @param dimension Dimensionality of ball bound. + */ + HollowBallBound(const size_t dimension); + + /** + * Create the ball bound with the specified radius and center. + * + * @param innerRradius Inner radius of ball bound. + * @param outerRradius Outer radius of ball bound. + * @param center Center of ball bound. + */ + HollowBallBound(const ElemType innerRadius, + const ElemType outerRadius, + const VecType& center); + + //! Copy constructor. To prevent memory leaks. + HollowBallBound(const HollowBallBound& other); + + //! For the same reason as the copy constructor: to prevent memory leaks. + HollowBallBound& operator=(const HollowBallBound& other); + + //! Move constructor: take possession of another bound. + HollowBallBound(HollowBallBound&& other); + + //! Destructor to release allocated memory. + ~HollowBallBound(); + + //! Get the outer radius of the ball. + ElemType OuterRadius() const { return outerRadius; } + //! Modify the outer radius of the ball. + ElemType& OuterRadius() { return outerRadius; } + + //! Get the innner radius of the ball. + ElemType InnerRadius() const { return innerRadius; } + //! Modify the inner radius of the ball. + ElemType& InnerRadius() { return innerRadius; } + + //! Get the center point of the ball. + const VecType& Center() const { return center; } + //! Modify the center point of the ball. + VecType& Center() { return center; } + + //! Get the dimensionality of the ball. + size_t Dim() const { return center.n_elem; } + + /** + * Get the minimum width of the bound (this is same as the diameter). + * For ball bounds, width along all dimensions remain same. + */ + ElemType MinWidth() const { return outerRadius * 2.0; } + + //! Get the range in a certain dimension. + math::RangeType operator[](const size_t i) const; + + /** + * Determines if a point is within this bound. + */ + bool Contains(const VecType& point) const; + + /** + * Determines if another bound is within this bound. + */ + bool Contains(const HollowBallBound& other) const; + + /** + * Place the center of BallBound into the given vector. + * + * @param center Vector which the centroid will be written to. + */ + void Center(VecType& center) const { center = this->center; } + + /** + * Calculates minimum bound-to-point squared distance. + */ + template + ElemType MinDistance(const OtherVecType& point, + typename boost::enable_if>* = 0) + const; + + /** + * Calculates minimum bound-to-bound squared distance. + */ + ElemType MinDistance(const HollowBallBound& other) const; + + /** + * Computes maximum distance. + */ + template + ElemType MaxDistance(const OtherVecType& point, + typename boost::enable_if>* = 0) + const; + + /** + * Computes maximum distance. + */ + ElemType MaxDistance(const HollowBallBound& other) const; + + /** + * Calculates minimum and maximum bound-to-point distance. + */ + template + math::RangeType RangeDistance( + const OtherVecType& other, + typename boost::enable_if>* = 0) const; + + /** + * Calculates minimum and maximum bound-to-bound distance. + * + * Example: bound1.MinDistanceSq(other) for minimum distance. + */ + math::RangeType RangeDistance(const HollowBallBound& other) const; + + /** + * Expand the bound to include the given point. The centroid will not be + * moved. + * + * @tparam MatType Type of matrix; could be arma::mat, arma::spmat, or a + * vector. + * @tparam data Data points to add. + */ + template + const HollowBallBound& operator|=(const MatType& data); + + /** + * Expand the bound to include the given bound. The centroid will not be + * moved. + * + * @tparam MatType Type of matrix; could be arma::mat, arma::spmat, or a + * vector. + * @tparam data Data points to add. + */ + const HollowBallBound& operator|=(const HollowBallBound& other); + + /** + * Returns the diameter of the ballbound. + */ + ElemType Diameter() const { return 2 * outerRadius; } + + //! Returns the distance metric used in this bound. + const MetricType& Metric() const { return *metric; } + //! Modify the distance metric used in this bound. + MetricType& Metric() { return *metric; } + + //! Serialize the bound. + template + void Serialize(Archive& ar, const unsigned int version); +}; + +//! A specialization of BoundTraits for this bound type. +template +struct BoundTraits> +{ + //! These bounds are potentially loose in some dimensions. + const static bool HasTightBounds = false; +}; + +} // namespace bound +} // namespace mlpack + +#include "hollow_ball_bound_impl.hpp" + +#endif // MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_HPP diff --git a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp new file mode 100644 index 00000000000..570d1806f50 --- /dev/null +++ b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp @@ -0,0 +1,388 @@ +/** + * @file hollow_ball_bound_impl.hpp + * + * Bounds that are useful for binary space partitioning trees. + * Implementation of HollowBallBound ball bound metric policy class. + * + * @experimental + */ +#ifndef MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP +#define MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP + +// In case it hasn't been included already. +#include "hollow_ball_bound.hpp" + +#include + +namespace mlpack { +namespace bound { + +//! Empty Constructor. +template +HollowBallBound::HollowBallBound() : + innerRadius(std::numeric_limits::lowest()), + outerRadius(std::numeric_limits::lowest()), + metric(new MetricType()), + ownsMetric(true) +{ /* Nothing to do. */ } + +/** + * Create the hollow ball bound with the specified dimensionality. + * + * @param dimension Dimensionality of ball bound. + */ +template +HollowBallBound::HollowBallBound(const size_t dimension) : + innerRadius(std::numeric_limits::lowest()), + outerRadius(std::numeric_limits::lowest()), + center(dimension), + metric(new MetricType()), + ownsMetric(true) +{ /* Nothing to do. */ } + +/** + * Create the hollow ball bound with the specified radii and center. + * + * @param innerRadius Inner radius of hollow ball bound. + * @param outerRadius Outer radius of hollow ball bound. + * @param center Center of hollow ball bound. + */ +template +HollowBallBound:: +HollowBallBound(const ElemType innerRadius, + const ElemType outerRadius, + const VecType& center) : + innerRadius(innerRadius), + outerRadius(outerRadius), + center(center), + metric(new MetricType()), + ownsMetric(true) +{ /* Nothing to do. */ } + +//! Copy Constructor. To prevent memory leaks. +template +HollowBallBound::HollowBallBound( + const HollowBallBound& other) : + innerRadius(other.innerRadius), + outerRadius(other.outerRadius), + center(other.center), + metric(other.metric), + ownsMetric(false) +{ /* Nothing to do. */ } + +//! For the same reason as the copy constructor: to prevent memory leaks. +template +HollowBallBound& HollowBallBound:: +operator=(const HollowBallBound& other) +{ + innerRadius = other.innerRadius; + outerRadius = other.outerRadius; + center = other.center; + metric = other.metric; + ownsMetric = false; + + return *this; +} + +//! Move constructor. +template +HollowBallBound::HollowBallBound(HollowBallBound&& other) : + innerRadius(other.innerRadius), + outerRadius(other.outerRadius), + center(other.center), + metric(other.metric), + ownsMetric(other.ownsMetric) +{ + // Fix the other bound. + other.innerRadius = 0.0; + other.outerRadius = 0.0; + other.center = VecType(); + other.metric = NULL; + other.ownsMetric = false; +} + +//! Destructor to release allocated memory. +template +HollowBallBound::~HollowBallBound() +{ + if (ownsMetric) + delete metric; +} + +//! Get the range in a certain dimension. +template +math::RangeType::ElemType> +HollowBallBound::operator[](const size_t i) const +{ + if (outerRadius < 0) + return math::Range(); + else + return math::Range(center[i] - outerRadius, center[i] + outerRadius); +} + +/** + * Determines if a point is within the bound. + */ +template +bool HollowBallBound::Contains(const VecType& point) const +{ + if (outerRadius < 0) + return false; + else + { + const ElemType dist = metric->Evaluate(center, point); + return ((dist <= outerRadius) && (dist >= innerRadius)); + } +} + +/** + * Determines if another bound is within this bound. + */ +template +bool HollowBallBound::Contains( + const HollowBallBound& other) const +{ + if (outerRadius < 0) + return false; + else + { + const ElemType dist = metric->Evaluate(center, other.center); + + bool containOnOneSide = (dist - other.outerRadius >= innerRadius) && + (dist + other.outerRadius <= outerRadius); + bool containOnEverySide = (dist + innerRadius <= other.innerRadius) && + (dist + other.outerRadius <= outerRadius); + + bool containAsBall = (innerRadius == 0) && + (dist + other.outerRadius <= outerRadius); + + return (containOnOneSide || containOnEverySide || containAsBall); + } +} + + +/** + * Calculates minimum bound-to-point squared distance. + */ +template +template +typename HollowBallBound::ElemType +HollowBallBound::MinDistance( + const OtherVecType& point, + typename boost::enable_if>* /* junk */) const +{ + if (outerRadius < 0) + return std::numeric_limits::max(); + else + { + const ElemType dist = metric->Evaluate(point, center); + + const ElemType outerDistance = math::ClampNonNegative(dist - outerRadius); + const ElemType innerDistance = math::ClampNonNegative(innerRadius - dist); + + return innerDistance + outerDistance; + } +} + +/** + * Calculates minimum bound-to-bound squared distance. + */ +template +typename HollowBallBound::ElemType +HollowBallBound::MinDistance(const HollowBallBound& other) + const +{ + if (outerRadius < 0 || other.outerRadius < 0) + return std::numeric_limits::max(); + else + { + const ElemType centerDistance = metric->Evaluate(center, other.center); + + const ElemType outerDistance = math::ClampNonNegative(centerDistance - + outerRadius - other.outerRadius); + const ElemType innerDistance1 = math::ClampNonNegative(other.innerRadius - + centerDistance - outerRadius); + const ElemType innerDistance2 = math::ClampNonNegative(innerRadius - + centerDistance - other.outerRadius); + + return outerDistance + innerDistance1 + innerDistance2; + } +} + +/** + * Computes maximum distance. + */ +template +template +typename HollowBallBound::ElemType +HollowBallBound::MaxDistance( + const OtherVecType& point, + typename boost::enable_if >* /* junk */) const +{ + if (outerRadius < 0) + return std::numeric_limits::max(); + else + return metric->Evaluate(point, center) + outerRadius; +} + +/** + * Computes maximum distance. + */ +template +typename HollowBallBound::ElemType +HollowBallBound::MaxDistance(const HollowBallBound& other) + const +{ + if (outerRadius < 0) + return std::numeric_limits::max(); + else + return metric->Evaluate(other.center, center) + outerRadius + + other.outerRadius; +} + +/** + * Calculates minimum and maximum bound-to-bound squared distance. + * + * Example: bound1.MinDistanceSq(other) for minimum squared distance. + */ +template +template +math::RangeType::ElemType> +HollowBallBound::RangeDistance( + const OtherVecType& point, + typename boost::enable_if >* /* junk */) const +{ + if (outerRadius < 0) + return math::Range(std::numeric_limits::max(), + std::numeric_limits::max()); + else + { + const ElemType dist = metric->Evaluate(center, point); + return math::Range(math::ClampNonNegative(dist - outerRadius) + + math::ClampNonNegative(innerRadius - dist), + dist + outerRadius); + } +} + +template +math::RangeType::ElemType> +HollowBallBound::RangeDistance( + const HollowBallBound& other) const +{ + if (outerRadius < 0) + return math::Range(std::numeric_limits::max(), + std::numeric_limits::max()); + else + { + const ElemType dist = metric->Evaluate(center, other.center); + const ElemType sumradius = outerRadius + other.outerRadius; + return math::Range(MinDistance(other), dist + sumradius); + } +} + +/** + * Expand the bound to include the given point. Algorithm adapted from + * Jack Ritter, "An Efficient Bounding Sphere" in Graphics Gems (1990). + * The difference lies in the way we initialize the ball bound. The way we + * expand the bound is same. + */ +template +template +const HollowBallBound& +HollowBallBound::operator|=(const MatType& data) +{ + if (outerRadius < 0) + { + center = data.col(0); + outerRadius = 0; + innerRadius = 0; + + // Now iteratively add points. + for (size_t i = 0; i < data.n_cols; ++i) + { + const ElemType dist = metric->Evaluate(center, (VecType) data.col(i)); + + // See if the new point lies outside the bound. + if (dist > outerRadius) + { + // Move towards the new point and increase the radius just enough to + // accommodate the new point. + const VecType diff = data.col(i) - center; + center += ((dist - outerRadius) / (2 * dist)) * diff; + outerRadius = 0.5 * (dist + outerRadius); + } + } + } + else + { + // Now iteratively add points. + for (size_t i = 0; i < data.n_cols; ++i) + { + const ElemType dist = metric->Evaluate(center, data.col(i)); + + // See if the new point lies outside the bound. + if (dist > outerRadius) + outerRadius = dist; + if (dist < innerRadius) + innerRadius = dist; + } + } + + return *this; +} + +/** + * Expand the bound to include the given bound. + */ +template +const HollowBallBound& +HollowBallBound::operator|=(const HollowBallBound& other) +{ + if (outerRadius < 0) + { + center = other.center; + outerRadius = other.outerRadius; + innerRadius = other.innerRadius; + return *this; + } + + const ElemType dist = metric->Evaluate(center, other.center); + + if (outerRadius < dist + other.outerRadius) + outerRadius = dist + other.outerRadius; + + const ElemType innerDist = math::ClampNonNegative(other.innerRadius - dist); + + if (innerRadius > innerDist) + innerRadius = innerDist; + + return *this; +} + + +//! Serialize the BallBound. +template +template +void HollowBallBound::Serialize( + Archive& ar, + const unsigned int /* version */) +{ + ar & data::CreateNVP(innerRadius, "innerRadius"); + ar & data::CreateNVP(outerRadius, "outerRadius"); + ar & data::CreateNVP(center, "center"); + + if (Archive::is_loading::value) + { + // If we're loading, delete the local metric since we'll have a new one. + if (ownsMetric) + delete metric; + } + + ar & data::CreateNVP(metric, "metric"); + ar & data::CreateNVP(ownsMetric, "ownsMetric"); +} + +} // namespace bound +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP diff --git a/src/mlpack/core/tree/vantage_point_tree.hpp b/src/mlpack/core/tree/vantage_point_tree.hpp new file mode 100644 index 00000000000..afa97f6405c --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree.hpp @@ -0,0 +1,21 @@ +/** + * @file vantage_point_tree.hpp + * + * Include all the necessary files to use the BinarySpaceTree class. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_HPP + +#include +#include "bounds.hpp" +#include "vantage_point_tree/single_tree_traverser.hpp" +#include "vantage_point_tree/single_tree_traverser_impl.hpp" +#include "vantage_point_tree/dual_tree_traverser.hpp" +#include "vantage_point_tree/dual_tree_traverser_impl.hpp" +#include "vantage_point_tree/vantage_point_split.hpp" +#include "vantage_point_tree/vantage_point_tree.hpp" +#include "vantage_point_tree/traits.hpp" +#include "vantage_point_tree/typedef.hpp" + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_HPP + diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp new file mode 100644 index 00000000000..1440b337319 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp @@ -0,0 +1,102 @@ +/** + * @file dual_tree_traverser.hpp + * + * Defines the DualTreeTraverser for the VantagePointTree tree type. This is a + * nested class of VantagePointTree which traverses two trees in a depth-first + * manner with a given set of rules which indicate the branches which can be + * pruned and the order in which to recurse. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_HPP + +#include + +#include "vantage_point_tree.hpp" + +namespace mlpack { +namespace tree { + +template class BoundType, + template + class SplitType> +template +class VantagePointTree::DualTreeTraverser +{ + public: + /** + * Instantiate the dual-tree traverser with the given rule set. + */ + DualTreeTraverser(RuleType& rule); + + /** + * Traverse the two trees. This does not reset the number of prunes. + * + * @param queryNode The query node to be traversed. + * @param referenceNode The reference node to be traversed. + */ + void Traverse(VantagePointTree& queryNode, + VantagePointTree& referenceNode); + + //! Get the number of prunes. + size_t NumPrunes() const { return numPrunes; } + //! Modify the number of prunes. + size_t& NumPrunes() { return numPrunes; } + + //! Get the number of visited combinations. + size_t NumVisited() const { return numVisited; } + //! Modify the number of visited combinations. + size_t& NumVisited() { return numVisited; } + + //! Get the number of times a node combination was scored. + size_t NumScores() const { return numScores; } + //! Modify the number of times a node combination was scored. + size_t& NumScores() { return numScores; } + + //! Get the number of times a base case was calculated. + size_t NumBaseCases() const { return numBaseCases; } + //! Modify the number of times a base case was calculated. + size_t& NumBaseCases() { return numBaseCases; } + + private: + //! Reference to the rules with which the trees will be traversed. + RuleType& rule; + + //! The number of prunes. + size_t numPrunes; + + //! The number of node combinations that have been visited during traversal. + size_t numVisited; + + //! The number of times a node combination was scored. + size_t numScores; + + //! The number of times a base case was calculated. + size_t numBaseCases; + + //! Traversal information, held in the class so that it isn't continually + //! being reallocated. + typename RuleType::TraversalInfoType traversalInfo; + + /** + * Traverse the reference tree. + * + * @param queryIndex The index of the point in the query set which is being + * used as the query point. + * @param referenceNode The reference node to be traversed. + */ + void Traverse(const size_t queryIndex, VantagePointTree& referenceNode); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation. +#include "dual_tree_traverser_impl.hpp" + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_HPP + + diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp new file mode 100644 index 00000000000..7b729aeaee9 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp @@ -0,0 +1,538 @@ +/** + * @file dual_tree_traverser_impl.hpp + * + * Implementation of the DualTreeTraverser for VantagePointTree. This is a way + * to perform a dual-tree traversal of two trees. The trees must be the same + * type. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP + +// In case it hasn't been included yet. +#include "dual_tree_traverser.hpp" + +namespace mlpack { +namespace tree { + +template class BoundType, + template + class SplitType> +template +VantagePointTree:: +DualTreeTraverser::DualTreeTraverser(RuleType& rule) : + rule(rule), + numPrunes(0), + numVisited(0), + numScores(0), + numBaseCases(0) +{ /* Nothing to do. */ } + +template class BoundType, + template + class SplitType> +template +void VantagePointTree:: +DualTreeTraverser::Traverse( + VantagePointTree& + queryNode, + VantagePointTree& + referenceNode) +{ + // Increment the visit counter. + ++numVisited; + + // Store the current traversal info. + traversalInfo = rule.TraversalInfo(); + + // If both are leaves, we must evaluate the base case. + if (queryNode.IsLeaf() && referenceNode.IsLeaf()) + { + // Loop through each of the points in each node. + const size_t queryEnd = queryNode.Begin() + queryNode.Count(); + const size_t refEnd = referenceNode.Begin() + referenceNode.Count(); + for (size_t query = queryNode.Begin(); query < queryEnd; ++query) + { + // See if we need to investigate this point (this function should be + // implemented for the single-tree recursion too). Restore the traversal + // information first. + rule.TraversalInfo() = traversalInfo; + const double childScore = rule.Score(query, referenceNode); + + if (childScore == DBL_MAX) + continue; // We can't improve this particular point. + + for (size_t ref = referenceNode.Begin(); ref < refEnd; ++ref) + rule.BaseCase(query, ref); + + numBaseCases += referenceNode.Count(); + } + } + else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) || + (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() && + !queryNode.IsLeaf() && !referenceNode.IsLeaf())) + { + // We have to recurse down the query node. In this case the recursion order + // does not matter. + const double pointScore = rule.Score(queryNode.Point(0), referenceNode); + ++numScores; + + if (pointScore != DBL_MAX) + Traverse(queryNode.Point(0), referenceNode); + else + ++numPrunes; + + // Before recursing, we have to set the traversal information correctly. + rule.TraversalInfo() = traversalInfo; + const double leftScore = rule.Score(*queryNode.Left(), referenceNode); + ++numScores; + + if (leftScore != DBL_MAX) + Traverse(*queryNode.Left(), referenceNode); + else + ++numPrunes; + + // Before recursing, we have to set the traversal information correctly. + rule.TraversalInfo() = traversalInfo; + const double rightScore = rule.Score(*queryNode.Right(), referenceNode); + ++numScores; + + if (rightScore != DBL_MAX) + Traverse(*queryNode.Right(), referenceNode); + else + ++numPrunes; + } + else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf())) + { + const size_t queryEnd = queryNode.Begin() + queryNode.Count(); + for (size_t query = queryNode.Begin(); query < queryEnd; ++query) + rule.BaseCase(query, referenceNode.Point(0)); + numBaseCases += queryNode.Count(); + // We have to recurse down the reference node. In this case the recursion + // order does matter. Before recursing, though, we have to set the + // traversal information correctly. + double leftScore = rule.Score(queryNode, *referenceNode.Left()); + typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + double rightScore = rule.Score(queryNode, *referenceNode.Right()); + numScores += 2; + + if (leftScore < rightScore) + { + // Recurse to the left. Restore the left traversal info. Store the right + // traversal info. + traversalInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryNode, *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = traversalInfo; + Traverse(queryNode, *referenceNode.Right()); + } + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(queryNode, *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(queryNode, *referenceNode.Left()); + } + else + ++numPrunes; + } + else // leftScore is equal to rightScore. + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; + } + else + { + // Choose the left first. Restore the left traversal info. Store the + // right traversal info. + traversalInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryNode, *referenceNode.Left()); + + rightScore = rule.Rescore(queryNode, *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = traversalInfo; + Traverse(queryNode, *referenceNode.Right()); + } + else + ++numPrunes; + } + } + } + else + { + for (size_t i = 0; i < queryNode.NumDescendants(); ++i) + rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0)); + numBaseCases += queryNode.NumDescendants(); + // We have to recurse down both query and reference nodes. Because the + // query descent order does not matter, we will go to the left query child + // first. Before recursing, we have to set the traversal information + // correctly. + double leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left()); + typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + double rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right()); + typename RuleType::TraversalInfoType rightInfo; + numScores += 2; + + if (leftScore < rightScore) + { + // Recurse to the left. Restore the left traversal info. Store the right + // traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryNode.Point(0), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(queryNode.Point(0), *referenceNode.Right()); + } + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(queryNode.Point(0), *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(), + leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(queryNode.Point(0), *referenceNode.Left()); + } + else + ++numPrunes; + } + else + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; + } + else + { + // Choose the left first. Restore the left traversal info and store the + // right traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryNode.Point(0), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal information. + rule.TraversalInfo() = rightInfo; + Traverse(queryNode.Point(0), *referenceNode.Right()); + } + else + ++numPrunes; + } + } + + // Restore the main traversal information. + rule.TraversalInfo() = traversalInfo; + + // Now recurse down the left node. + leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left()); + leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right()); + numScores += 2; + + if (leftScore < rightScore) + { + // Recurse to the left. Restore the left traversal info. Store the right + // traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Left(), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(*queryNode.Left(), *referenceNode.Right()); + } + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(*queryNode.Left(), *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(), + leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Left(), *referenceNode.Left()); + } + else + ++numPrunes; + } + else + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; + } + else + { + // Choose the left first. Restore the left traversal info and store the + // right traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Left(), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal information. + rule.TraversalInfo() = rightInfo; + Traverse(*queryNode.Left(), *referenceNode.Right()); + } + else + ++numPrunes; + } + } + + // Restore the main traversal information. + rule.TraversalInfo() = traversalInfo; + + // Now recurse down the right query node. + leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left()); + leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right()); + numScores += 2; + + if (leftScore < rightScore) + { + // Recurse to the left. Restore the left traversal info. Store the right + // traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Right(), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(*queryNode.Right(), *referenceNode.Right()); + } + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(*queryNode.Right(), *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(), + leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Right(), *referenceNode.Left()); + } + else + ++numPrunes; + } + else + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; + } + else + { + // Choose the left first. Restore the left traversal info. Store the + // right traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(*queryNode.Right(), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(*queryNode.Right(), *referenceNode.Right()); + } + else + ++numPrunes; + } + } + } +} + +template class BoundType, + template + class SplitType> +template +void VantagePointTree:: +DualTreeTraverser::Traverse( + const size_t queryIndex, + VantagePointTree& + referenceNode) +{ + // If we are a leaf, run the base case as necessary. + if (referenceNode.IsLeaf()) + { + const size_t refEnd = referenceNode.Begin() + referenceNode.Count(); + for (size_t i = referenceNode.Begin(); i < refEnd; ++i) + rule.BaseCase(queryIndex, i); + numBaseCases += referenceNode.Count(); + return; + } + + rule.BaseCase(queryIndex, referenceNode.Point(0)); + numBaseCases++; + + // Store the current traversal info. + traversalInfo = rule.TraversalInfo(); + + // If either score is DBL_MAX, we do not recurse into that node. + double leftScore = rule.Score(queryIndex, *referenceNode.Left()); + typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + double rightScore = rule.Score(queryIndex, *referenceNode.Right()); + typename RuleType::TraversalInfoType rightInfo; + + if (leftScore < rightScore) + { + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + // Recurse to the left. + Traverse(queryIndex, *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right. + } + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(queryIndex, *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left. + } + else + ++numPrunes; + } + else // leftScore is equal to rightScore. + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; // Pruned both left and right. + } + else + { + // Choose the left first. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryIndex, *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal info. + rule.TraversalInfo() = rightInfo; + Traverse(queryIndex, *referenceNode.Right()); + } + else + ++numPrunes; + } + } +} + + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP + diff --git a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp new file mode 100644 index 00000000000..2c9f4ff7034 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp @@ -0,0 +1,63 @@ +/** + * @file single_tree_traverser.hpp + * + * A nested class of VantagePointTree which traverses the entire tree with a + * given set of rules which indicate the branches which can be pruned and the + * order in which to recurse. This traverser is a depth-first traverser. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_HPP + +#include + +#include "vantage_point_tree.hpp" + +namespace mlpack { +namespace tree { + +template class BoundType, + template + class SplitType> +template +class VantagePointTree::SingleTreeTraverser +{ + public: + /** + * Instantiate the single tree traverser with the given rule set. + */ + SingleTreeTraverser(RuleType& rule); + + /** + * Traverse the tree with the given point. + * + * @param queryIndex The index of the point in the query set which is being + * used as the query point. + * @param referenceNode The tree node to be traversed. + */ + void Traverse(const size_t queryIndex, VantagePointTree& referenceNode); + + //! Get the number of prunes. + size_t NumPrunes() const { return numPrunes; } + //! Modify the number of prunes. + size_t& NumPrunes() { return numPrunes; } + + private: + //! Reference to the rules with which the tree will be traversed. + RuleType& rule; + + //! The number of nodes which have been pruned during traversal. + size_t numPrunes; +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation. +#include "single_tree_traverser_impl.hpp" + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_HPP + diff --git a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp new file mode 100644 index 00000000000..38f126a0e8d --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp @@ -0,0 +1,113 @@ +/** + * @file single_tree_traverser_impl.hpp + * + * A nested class of VantagePointTree which traverses the entire tree with a + * given set of rules which indicate the branches which can be pruned and the + * order in which to recurse. This traverser is a depth-first traverser. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP + +// In case it hasn't been included yet. +#include "single_tree_traverser.hpp" + +#include + +namespace mlpack { +namespace tree { + +template class BoundType, + template + class SplitType> +template +VantagePointTree:: +SingleTreeTraverser::SingleTreeTraverser(RuleType& rule) : + rule(rule), + numPrunes(0) +{ /* Nothing to do. */ } + +template class BoundType, + template + class SplitType> +template +void VantagePointTree:: +SingleTreeTraverser::Traverse( + const size_t queryIndex, + VantagePointTree& + referenceNode) +{ + // If we are a leaf, run the base case as necessary. + if (referenceNode.IsLeaf()) + { + const size_t refEnd = referenceNode.Begin() + referenceNode.Count(); + for (size_t i = referenceNode.Begin(); i < refEnd; ++i) + rule.BaseCase(queryIndex, i); + return; + } + + rule.BaseCase(queryIndex, referenceNode.Point(0)); + + // If either score is DBL_MAX, we do not recurse into that node. + double leftScore = rule.Score(queryIndex, *referenceNode.Left()); + double rightScore = rule.Score(queryIndex, *referenceNode.Right()); + + if (leftScore < rightScore) + { + // Recurse to the left. + Traverse(queryIndex, *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore); + + if (rightScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right. + else + ++numPrunes; + } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(queryIndex, *referenceNode.Right()); + + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore); + + if (leftScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left. + else + ++numPrunes; + } + else // leftScore is equal to rightScore. + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; // Pruned both left and right. + } + else + { + // Choose the left first. + Traverse(queryIndex, *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Right()); + else + ++numPrunes; + } + } +} + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP + diff --git a/src/mlpack/core/tree/vantage_point_tree/traits.hpp b/src/mlpack/core/tree/vantage_point_tree/traits.hpp new file mode 100644 index 00000000000..5901e715242 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/traits.hpp @@ -0,0 +1,60 @@ +/** + * @file traits.hpp + * + * Specialization of the TreeTraits class for the VantagePointTree type of tree. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TRAITS_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TRAITS_HPP + +#include + +namespace mlpack { +namespace tree { + +/** + * This is a specialization of the TreeType class to the VantagePointTree tree + * type. It defines characteristics of the vantage point tree, and is used to + * help write tree-independent (but still optimized) tree-based algorithms. See + * mlpack/core/tree/tree_traits.hpp for more information. + */ +template class BoundType, + template + class SplitType> +class TreeTraits> +{ + public: + /** + * Children nodes may overlap each other. + */ + static const bool HasOverlappingChildren = true; + + /** + * TODO: The first point of each node is centroid. + */ + static const bool FirstPointIsCentroid = false; + + /** + * Points are not contained at multiple levels of the vantage point tree. + */ + static const bool HasSelfChildren = false; + + /** + * Points are rearranged during building of the tree. + */ + static const bool RearrangesDataset = true; + + /** + * This is always a binary tree. + */ + static const bool BinaryTree = true; +}; + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TRAITS_HPP + diff --git a/src/mlpack/core/tree/vantage_point_tree/typedef.hpp b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp new file mode 100644 index 00000000000..df0454387d1 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp @@ -0,0 +1,26 @@ +/** + * @file typedef.hpp + * + * Template typedefs for the VantagePointTree class that satisfy the + * requirements of the TreeType policy class. + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TYPEDEF_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TYPEDEF_HPP + +// In case it hasn't been included yet. +#include "../vantage_point_tree.hpp" + +namespace mlpack { +namespace tree { + +template +using VPTree = VantagePointTree; + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_TYPEDEF_HPP diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp similarity index 100% rename from src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp rename to src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp similarity index 96% rename from src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp rename to src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp index e762eb9affa..ce0d082129a 100644 --- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp @@ -29,7 +29,9 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, if (mu == 0) return false; - arma::Col vantagePoint = data.col(vantagePointIndex); + data.swap_cols(begin, vantagePointIndex); + + arma::Col vantagePoint = data.col(begin); splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu); assert(splitCol > begin); @@ -52,7 +54,12 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, if (mu == 0) return false; - arma::Col vantagePoint = data.col(vantagePointIndex); + data.swap_cols(begin, vantagePointIndex); + size_t t = oldFromNew[begin]; + oldFromNew[begin] = oldFromNew[vantagePointIndex]; + oldFromNew[vantagePointIndex] = t; + + arma::Col vantagePoint = data.col(begin); splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu, oldFromNew); diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp new file mode 100644 index 00000000000..c6ddbde1c82 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp @@ -0,0 +1,218 @@ +/** + * @file vantage_point_tree.hpp + */ + +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_HPP + +namespace mlpack { +namespace tree /** Trees and tree-building procedures. */ { + +template class BoundType = + bound::HRectBound, + template + class SplitType = MidpointSplit> +class VantagePointTree +{ + public: + typedef MatType Mat; + typedef typename MatType::elem_type ElemType; + + private: + VantagePointTree* left; + VantagePointTree* right; + VantagePointTree* parent; + size_t begin; + size_t count; + BoundType bound; + StatisticType stat; + ElemType parentDistance; + ElemType furthestDescendantDistance; + ElemType minimumBoundDistance; + MatType* dataset; + + public: + template + class SingleTreeTraverser; + + template + class DualTreeTraverser; + + VantagePointTree(const MatType& data, const size_t maxLeafSize = 20); + + VantagePointTree(const MatType& data, + std::vector& oldFromNew, + const size_t maxLeafSize = 20); + + VantagePointTree(const MatType& data, + std::vector& oldFromNew, + std::vector& newFromOld, + const size_t maxLeafSize = 20); + + VantagePointTree(MatType&& data, + const size_t maxLeafSize = 20); + + VantagePointTree(MatType&& data, + std::vector& oldFromNew, + const size_t maxLeafSize = 20); + + VantagePointTree(MatType&& data, + std::vector& oldFromNew, + std::vector& newFromOld, + const size_t maxLeafSize = 20); + + VantagePointTree(VantagePointTree* parent, + const size_t begin, + const size_t count, + SplitType, MatType>& splitter, + const size_t maxLeafSize = 20); + + VantagePointTree(VantagePointTree* parent, + const size_t begin, + const size_t count, + std::vector& oldFromNew, + SplitType, MatType>& splitter, + const size_t maxLeafSize = 20); + + VantagePointTree(VantagePointTree* parent, + const size_t begin, + const size_t count, + std::vector& oldFromNew, + std::vector& newFromOld, + SplitType, MatType>& splitter, + const size_t maxLeafSize = 20); + + VantagePointTree(const VantagePointTree& other); + + VantagePointTree(VantagePointTree&& other); + + template + VantagePointTree( + Archive& ar, + const typename boost::enable_if::type* = 0); + + ~VantagePointTree(); + + const BoundType& Bound() const { return bound; } + BoundType& Bound() { return bound; } + + const StatisticType& Stat() const { return stat; } + StatisticType& Stat() { return stat; } + + bool IsLeaf() const; + + VantagePointTree* Left() const { return left; } + VantagePointTree*& Left() { return left; } + + VantagePointTree* Right() const { return right; } + VantagePointTree*& Right() { return right; } + + VantagePointTree* Parent() const { return parent; } + VantagePointTree*& Parent() { return parent; } + + const MatType& Dataset() const { return *dataset; } + MatType& Dataset() { return *dataset; } + + MetricType Metric() const { return MetricType(); } + + size_t NumChildren() const; + + ElemType FurthestPointDistance() const; + + ElemType FurthestDescendantDistance() const; + + ElemType MinimumBoundDistance() const; + + ElemType ParentDistance() const { return parentDistance; } + ElemType& ParentDistance() { return parentDistance; } + + VantagePointTree& Child(const size_t child) const; + + VantagePointTree*& ChildPtr(const size_t child) + { return (child == 0) ? left : right; } + + size_t NumPoints() const; + + size_t NumDescendants() const; + + size_t Descendant(const size_t index) const; + + size_t Point(const size_t index) const; + + ElemType MinDistance(const VantagePointTree* other) const + { + return bound.MinDistance(other->Bound()); + } + + ElemType MaxDistance(const VantagePointTree* other) const + { + return bound.MaxDistance(other->Bound()); + } + + math::RangeType RangeDistance(const VantagePointTree* other) const + { + return bound.RangeDistance(other->Bound()); + } + + template + ElemType MinDistance(const VecType& point, + typename boost::enable_if >::type* = 0) + const + { + return bound.MinDistance(point); + } + + template + ElemType MaxDistance(const VecType& point, + typename boost::enable_if >::type* = 0) + const + { + return bound.MaxDistance(point); + } + + template + math::RangeType + RangeDistance(const VecType& point, + typename boost::enable_if >::type* = 0) const + { + return bound.RangeDistance(point); + } + + size_t Begin() const { return begin; } + size_t& Begin() { return begin; } + + size_t Count() const { return count; } + size_t& Count() { return count; } + + static bool HasSelfChildren() { return false; } + + void Center(arma::vec& center) { bound.Center(center); } + + private: + void SplitNode(const size_t maxLeafSize, + SplitType, MatType>& splitter); + + void SplitNode(std::vector& oldFromNew, + const size_t maxLeafSize, + SplitType, MatType>& splitter); + + protected: + VantagePointTree(); + + friend class boost::serialization::access; + + public: + template + void Serialize(Archive& ar, const unsigned int version); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation. +#include "vantage_point_tree_impl.hpp" + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_HPP diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp new file mode 100644 index 00000000000..8cc320ae1d5 --- /dev/null +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp @@ -0,0 +1,879 @@ +/** + * @file vantage_point_tree_impl.hpp + */ +#ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_IMPL_HPP +#define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_IMPL_HPP + +// In case it wasn't included already for some reason. +#include "vantage_point_tree.hpp" +#include + +namespace mlpack { +namespace tree { + +// Each of these overloads is kept as a separate function to keep the overhead +// from the two std::vectors out, if possible. +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + const MatType& data, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), /* This root node starts at index 0, */ + count(data.n_cols), /* and spans all of the dataset. */ + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(data)) // Copies the dataset. +{ + // Do the actual splitting of this node. + SplitType, MatType> splitter; + SplitNode(maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + const MatType& data, + std::vector& oldFromNew, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(data.n_cols), + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(data)) // Copies the dataset. +{ + // Initialize oldFromNew correctly. + oldFromNew.resize(data.n_cols); + for (size_t i = 0; i < data.n_cols; i++) + oldFromNew[i] = i; // Fill with unharmed indices. + + // Now do the actual splitting. + SplitType, MatType> splitter; + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + const MatType& data, + std::vector& oldFromNew, + std::vector& newFromOld, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(data.n_cols), + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(data)) // Copies the dataset. +{ + // Initialize the oldFromNew vector correctly. + oldFromNew.resize(data.n_cols); + for (size_t i = 0; i < data.n_cols; i++) + oldFromNew[i] = i; // Fill with unharmed indices. + + // Now do the actual splitting. + SplitType, MatType> splitter; + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); + + // Map the newFromOld indices correctly. + newFromOld.resize(data.n_cols); + for (size_t i = 0; i < data.n_cols; i++) + newFromOld[oldFromNew[i]] = i; +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree(MatType&& data, const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(data.n_cols), + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(std::move(data))) +{ + // Do the actual splitting of this node. + SplitType, MatType> splitter; + SplitNode(maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + MatType&& data, + std::vector& oldFromNew, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(data.n_cols), + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(std::move(data))) +{ + // Initialize oldFromNew correctly. + oldFromNew.resize(dataset->n_cols); + for (size_t i = 0; i < dataset->n_cols; i++) + oldFromNew[i] = i; // Fill with unharmed indices. + + // Now do the actual splitting. + SplitType, MatType> splitter; + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + MatType&& data, + std::vector& oldFromNew, + std::vector& newFromOld, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(data.n_cols), + bound(data.n_rows), + parentDistance(0), // Parent distance for the root is 0: it has no parent. + dataset(new MatType(std::move(data))) +{ + // Initialize the oldFromNew vector correctly. + oldFromNew.resize(dataset->n_cols); + for (size_t i = 0; i < dataset->n_cols; i++) + oldFromNew[i] = i; // Fill with unharmed indices. + + // Now do the actual splitting. + SplitType, MatType> splitter; + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); + + // Map the newFromOld indices correctly. + newFromOld.resize(dataset->n_cols); + for (size_t i = 0; i < dataset->n_cols; i++) + newFromOld[oldFromNew[i]] = i; +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + VantagePointTree* parent, + const size_t begin, + const size_t count, + SplitType, MatType>& splitter, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(parent), + begin(begin), + count(count), + bound(parent->Dataset().n_rows), + dataset(&parent->Dataset()) // Point to the parent's dataset. +{ + // Perform the actual splitting. + SplitNode(maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + VantagePointTree* parent, + const size_t begin, + const size_t count, + std::vector& oldFromNew, + SplitType, MatType>& splitter, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(parent), + begin(begin), + count(count), + bound(parent->Dataset().n_rows), + dataset(&parent->Dataset()) +{ + // Hopefully the vector is initialized correctly! We can't check that + // entirely but we can do a minor sanity check. + assert(oldFromNew.size() == dataset->n_cols); + + // Perform the actual splitting. + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); +} + +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + VantagePointTree* parent, + const size_t begin, + const size_t count, + std::vector& oldFromNew, + std::vector& newFromOld, + SplitType, MatType>& splitter, + const size_t maxLeafSize) : + left(NULL), + right(NULL), + parent(parent), + begin(begin), + count(count), + bound(parent->Dataset()->n_rows), + dataset(&parent->Dataset()) +{ + // Hopefully the vector is initialized correctly! We can't check that + // entirely but we can do a minor sanity check. + Log::Assert(oldFromNew.size() == dataset->n_cols); + + // Perform the actual splitting. + SplitNode(oldFromNew, maxLeafSize, splitter); + + // Create the statistic depending on if we are a leaf or not. + stat = StatisticType(*this); + + // Map the newFromOld indices correctly. + newFromOld.resize(dataset->n_cols); + for (size_t i = 0; i < dataset->n_cols; i++) + newFromOld[oldFromNew[i]] = i; +} + +/** + * Create a binary space tree by copying the other tree. Be careful! This can + * take a long time and use a lot of memory. + */ +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree( + const VantagePointTree& other) : + left(NULL), + right(NULL), + parent(other.parent), + begin(other.begin), + count(other.count), + bound(other.bound), + stat(other.stat), + parentDistance(other.parentDistance), + furthestDescendantDistance(other.furthestDescendantDistance), + // Copy matrix, but only if we are the root. + dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL) +{ + // Create left and right children (if any). + if (other.Left()) + { + left = new VantagePointTree(*other.Left()); + left->Parent() = this; // Set parent to this, not other tree. + } + + if (other.Right()) + { + right = new VantagePointTree(*other.Right()); + right->Parent() = this; // Set parent to this, not other tree. + } + + // Propagate matrix, but only if we are the root. + if (parent == NULL) + { + std::queue queue; + if (left) + queue.push(left); + if (right) + queue.push(right); + while (!queue.empty()) + { + VantagePointTree* node = queue.front(); + queue.pop(); + + node->dataset = dataset; + if (node->left) + queue.push(node->left); + if (node->right) + queue.push(node->right); + } + } +} + +/** + * Move constructor. + */ +template class BoundType, + template + class SplitType> +VantagePointTree:: +VantagePointTree(VantagePointTree&& other) : + left(other.left), + right(other.right), + parent(other.parent), + begin(other.begin), + count(other.count), + bound(std::move(other.bound)), + stat(std::move(other.stat)), + parentDistance(other.parentDistance), + furthestDescendantDistance(other.furthestDescendantDistance), + minimumBoundDistance(other.minimumBoundDistance), + dataset(other.dataset) +{ + // Now we are a clone of the other tree. But we must also clear the other + // tree's contents, so it doesn't delete anything when it is destructed. + other.left = NULL; + other.right = NULL; + other.begin = 0; + other.count = 0; + other.parentDistance = 0.0; + other.furthestDescendantDistance = 0.0; + other.minimumBoundDistance = 0.0; + other.dataset = NULL; +} + +/** + * Initialize the tree from an archive. + */ +template class BoundType, + template + class SplitType> +template +VantagePointTree:: +VantagePointTree( + Archive& ar, + const typename boost::enable_if::type*) : + VantagePointTree() // Create an empty BinarySpaceTree. +{ + // We've delegated to the constructor which gives us an empty tree, and now we + // can serialize from it. + ar >> data::CreateNVP(*this, "tree"); +} + +/** + * Deletes this node, deallocating the memory for the children and calling their + * destructors in turn. This will invalidate any pointers or references to any + * nodes which are children of this one. + */ +template class BoundType, + template + class SplitType> +VantagePointTree:: + ~VantagePointTree() +{ + delete left; + delete right; + + // If we're the root, delete the matrix. + if (!parent) + delete dataset; +} + +template class BoundType, + template + class SplitType> +inline bool VantagePointTree::IsLeaf() const +{ + return !left; +} + +/** + * Returns the number of children in this node. + */ +template class BoundType, + template + class SplitType> +inline size_t VantagePointTree::NumChildren() const +{ + if (left && right) + return 2; + if (left) + return 1; + + return 0; +} + +/** + * Return a bound on the furthest point in the node from the center. This + * returns 0 unless the node is a leaf. + */ +template class BoundType, + template + class SplitType> +inline +typename VantagePointTree::ElemType +VantagePointTree::FurthestPointDistance() const +{ + if (!IsLeaf()) + return 0.0; + + // Otherwise return the distance from the center to a corner of the bound. + return 0.5 * bound.Diameter(); +} + +/** + * Return the furthest possible descendant distance. This returns the maximum + * distance from the center to the edge of the bound and not the empirical + * quantity which is the actual furthest descendant distance. So the actual + * furthest descendant distance may be less than what this method returns (but + * it will never be greater than this). + */ +template class BoundType, + template + class SplitType> +inline +typename VantagePointTree::ElemType +VantagePointTree::FurthestDescendantDistance() const +{ + return furthestDescendantDistance; +} + +//! Return the minimum distance from the center to any bound edge. +template class BoundType, + template + class SplitType> +inline +typename VantagePointTree::ElemType +VantagePointTree::MinimumBoundDistance() const +{ + return bound.MinWidth() / 2.0; +} + +/** + * Return the specified child. + */ +template class BoundType, + template + class SplitType> +inline VantagePointTree& + VantagePointTree::Child(const size_t child) const +{ + if (child == 0) + return *left; + else + return *right; +} + +/** + * Return the number of points contained in this node. + */ +template class BoundType, + template + class SplitType> +inline size_t VantagePointTree::NumPoints() const +{ + // Each intermediate node contains exactly one point. + if (left) + return 1; + + return count; +} + +/** + * Return the number of descendants contained in the node. + */ +template class BoundType, + template + class SplitType> +inline size_t VantagePointTree::NumDescendants() const +{ + return count; +} + +/** + * Return the index of a particular descendant contained in this node. + */ +template class BoundType, + template + class SplitType> +inline size_t VantagePointTree::Descendant(const size_t index) const +{ + return (begin + index); +} + +/** + * Return the index of a particular point contained in this node. + */ +template class BoundType, + template + class SplitType> +inline size_t VantagePointTree::Point(const size_t index) const +{ + return (begin + index); +} + +template class BoundType, + template + class SplitType> +void VantagePointTree:: + SplitNode(const size_t maxLeafSize, + SplitType, MatType>& splitter) +{ + // We need to expand the bounds of this node properly. + if (parent) + { + bound.Center() = dataset->col(parent->begin); + bound.OuterRadius() = 0; + bound.InnerRadius() = std::numeric_limits::max(); + } + + if (count > 0) + bound |= dataset->cols(begin, begin + count - 1); + + VantagePointTree* tree = this; + + while (tree->Parent() != NULL) + { + tree->Parent()->Bound() |= tree->Bound(); + tree = tree->Parent(); + } + // Calculate the furthest descendant distance. + furthestDescendantDistance = 0.5 * bound.Diameter(); + + // Now, check if we need to split at all. + if (count <= maxLeafSize) + return; // We can't split this. + + // splitCol denotes the two partitions of the dataset after the split. The + // points on its left go to the left child and the others go to the right + // child. + size_t splitCol; + + // Split the node. The elements of 'data' are reordered by the splitting + // algorithm. This function call updates splitCol. + const bool split = splitter.SplitNode(bound, *dataset, begin, count, + splitCol); + + // The node may not be always split. For instance, if all the points are the + // same, we can't split them. + if (!split) + return; + + // Now that we know the split column, we will recursively split the children + // by calling their constructors (which perform this splitting process). + left = new VantagePointTree(this, begin + 1, splitCol - begin - 1, splitter, + maxLeafSize); + right = new VantagePointTree(this, splitCol, begin + count - splitCol, + splitter, maxLeafSize); + + // Calculate parent distances for those two nodes. + arma::vec center, leftCenter, rightCenter; + Center(center); + left->Center(leftCenter); + right->Center(rightCenter); + + const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter); + const ElemType rightParentDistance = MetricType::Evaluate(center, + rightCenter); + + left->ParentDistance() = leftParentDistance; + right->ParentDistance() = rightParentDistance; +} + +template class BoundType, + template + class SplitType> +void VantagePointTree:: +SplitNode(std::vector& oldFromNew, + const size_t maxLeafSize, + SplitType, MatType>& splitter) +{ + // This should be a single function for Bound. + // We need to expand the bounds of this node properly. + + if (parent) + { + bound.Center() = dataset->col(parent->begin); + bound.OuterRadius() = 0; + bound.InnerRadius() = std::numeric_limits::max(); + } + + if (count > 0) + bound |= dataset->cols(begin, begin + count - 1); + + VantagePointTree* tree = this; + + while (tree->Parent() != NULL) + { + tree->Parent()->Bound() |= tree->Bound(); + tree = tree->Parent(); + } + + // Calculate the furthest descendant distance. + furthestDescendantDistance = 0.5 * bound.Diameter(); + + // First, check if we need to split at all. + if (count <= maxLeafSize) + return; // We can't split this. + + // splitCol denotes the two partitions of the dataset after the split. The + // points on its left go to the left child and the others go to the right + // child. + size_t splitCol; + + // Split the node. The elements of 'data' are reordered by the splitting + // algorithm. This function call updates splitCol and oldFromNew. + const bool split = splitter.SplitNode(bound, *dataset, begin, count, splitCol, + oldFromNew); + + // The node may not be always split. For instance, if all the points are the + // same, we can't split them. + if (!split) + return; + + // Now that we know the split column, we will recursively split the children + // by calling their constructors (which perform this splitting process). + left = new VantagePointTree(this, begin + 1, splitCol - begin - 1, oldFromNew, + splitter, maxLeafSize); + right = new VantagePointTree(this, splitCol, begin + count - splitCol, + oldFromNew, splitter, maxLeafSize); + + + // Calculate parent distances for those two nodes. + arma::vec center, leftCenter, rightCenter; + Center(center); + left->Center(leftCenter); + right->Center(rightCenter); + + const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter); + const ElemType rightParentDistance = MetricType::Evaluate(center, + rightCenter); + + left->ParentDistance() = leftParentDistance; + right->ParentDistance() = rightParentDistance; +} + +// Default constructor (private), for boost::serialization. +template class BoundType, + template + class SplitType> +VantagePointTree:: + VantagePointTree() : + left(NULL), + right(NULL), + parent(NULL), + begin(0), + count(0), + stat(*this), + parentDistance(0), + furthestDescendantDistance(0), + dataset(NULL) +{ + // Nothing to do. +} + +/** + * Serialize the tree. + */ +template class BoundType, + template + class SplitType> +template +void VantagePointTree:: + Serialize(Archive& ar, const unsigned int /* version */) +{ + using data::CreateNVP; + + // If we're loading, and we have children, they need to be deleted. + if (Archive::is_loading::value) + { + if (left) + delete left; + if (right) + delete right; + if (!parent) + delete dataset; + } + + ar & CreateNVP(parent, "parent"); + ar & CreateNVP(begin, "begin"); + ar & CreateNVP(count, "count"); + ar & CreateNVP(bound, "bound"); + ar & CreateNVP(stat, "statistic"); + ar & CreateNVP(parentDistance, "parentDistance"); + ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance"); + ar & CreateNVP(dataset, "dataset"); + + // Save children last; otherwise boost::serialization gets confused. + ar & CreateNVP(left, "left"); + ar & CreateNVP(right, "right"); + + // Due to quirks of boost::serialization, if a tree is saved as an object and + // not a pointer, the first level of the tree will be duplicated on load. + // Therefore, if we are the root of the tree, then we need to make sure our + // children's parent links are correct, and delete the duplicated node if + // necessary. + if (Archive::is_loading::value) + { + // Get parents of left and right children, or, NULL, if they don't exist. + VantagePointTree* leftParent = left ? left->Parent() : NULL; + VantagePointTree* rightParent = right ? right->Parent() : NULL; + + // Reassign parent links if necessary. + if (left && left->Parent() != this) + left->Parent() = this; + if (right && right->Parent() != this) + right->Parent() = this; + + // Do we need to delete the left parent? + if (leftParent != NULL && leftParent != this) + { + // Sever the duplicate parent's children. Ensure we don't delete the + // dataset, by faking the duplicated parent's parent (that is, we need to + // set the parent to something non-NULL; 'this' works). + leftParent->Parent() = this; + leftParent->Left() = NULL; + leftParent->Right() = NULL; + delete leftParent; + } + + // Do we need to delete the right parent? + if (rightParent != NULL && rightParent != this && rightParent != leftParent) + { + // Sever the duplicate parent's children, in the same way as above. + rightParent->Parent() = this; + rightParent->Left() = NULL; + rightParent->Right() = NULL; + delete rightParent; + } + } +} + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_IMPL_HPP diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 967edeee4e1..d98a67ef08e 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -81,6 +81,7 @@ add_executable(mlpack_test nystroem_method_test.cpp armadillo_svd_test.cpp recurrent_network_test.cpp + vantage_point_tree_test.cpp ) # Link dependencies of test executable. target_link_libraries(mlpack_test diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp index 12f3678c338..81a94463b25 100644 --- a/src/mlpack/tests/tree_test.cpp +++ b/src/mlpack/tests/tree_test.cpp @@ -1430,52 +1430,6 @@ BOOST_AUTO_TEST_CASE(BallTreeTest) } } -BOOST_AUTO_TEST_CASE(VantagePointTreeTest) -{ - typedef VantagePointTree TreeType; - - size_t maxRuns = 10; // Ten total tests. - size_t pointIncrements = 1000; // Range is from 2000 points to 11000. - - // We use the default leaf size of 20. - for (size_t run = 0; run < maxRuns; run++) - { - size_t dimensions = run + 2; - size_t maxPoints = (run + 1) * pointIncrements; - - size_t size = maxPoints; - arma::mat dataset = arma::mat(dimensions, size); - arma::mat datacopy; // Used to test mappings. - - // Mappings for post-sort verification of data. - std::vector newToOld; - std::vector oldToNew; - - // Generate data. - dataset.randu(); - - // Build the tree itself. - TreeType root(dataset, newToOld, oldToNew); - const arma::mat& treeset = root.Dataset(); - - // Ensure the size of the tree is correct. - BOOST_REQUIRE_EQUAL(root.NumDescendants(), size); - - // Check the forward and backward mappings for correctness. - for(size_t i = 0; i < size; i++) - { - for(size_t j = 0; j < dimensions; j++) - { - BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i])); - BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i)); - } - } - - // Now check that each point is contained inside of all bounds above it. - CheckPointBounds(root); - } -} - template bool DoBoundsIntersect(HRectBound& a, HRectBound& b) diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp new file mode 100644 index 00000000000..1570f6c2015 --- /dev/null +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -0,0 +1,291 @@ +/** + * @file tree_test.cpp + * + * Tests for tree-building methods. + */ +#include +#include +#include +#include +#include + +#include +#include "test_tools.hpp" + +using namespace mlpack; +using namespace mlpack::math; +using namespace mlpack::tree; +using namespace mlpack::neighbor; +using namespace mlpack::metric; +using namespace mlpack::bound; + +BOOST_AUTO_TEST_SUITE(VantagePointTreeTest); + +BOOST_AUTO_TEST_CASE(VPTreeTraitsTest) +{ + typedef VPTree TreeType; + + bool b = TreeTraits::HasOverlappingChildren; + BOOST_REQUIRE_EQUAL(b, true); +// b = TreeTraits::FirstPointIsCentroid; +// BOOST_REQUIRE_EQUAL(b, true); + b = TreeTraits::HasSelfChildren; + BOOST_REQUIRE_EQUAL(b, false); + b = TreeTraits::RearrangesDataset; + BOOST_REQUIRE_EQUAL(b, true); + b = TreeTraits::BinaryTree; + BOOST_REQUIRE_EQUAL(b, true); +} + +BOOST_AUTO_TEST_CASE(HollowBallBoundTest) +{ + HollowBallBound b(2, 4, "1.0 2.0 3.0 4.0 5.0"); + + BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 7.0 5.0"), true); + + BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 9.0 5.0"), false); + + BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 5.0 5.0"), false); + + HollowBallBound b2(0.5, 1, "1.0 2.0 3.0 7.0 5.0"); + BOOST_REQUIRE_EQUAL(b.Contains(b2), true); + + b2 = HollowBallBound(2.5, 3.5, "1.0 2.0 3.0 4.5 5.0"); + BOOST_REQUIRE_EQUAL(b.Contains(b2), true); + + b2 = HollowBallBound(2.0, 3.5, "1.0 2.0 3.0 4.5 5.0"); + BOOST_REQUIRE_EQUAL(b.Contains(b2), false); + + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("1.0 2.0 8.0 4.0 5.0")), 1.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("1.0 2.0 4.0 4.0 5.0")), 1.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("1.0 2.0 3.0 4.0 5.0")), 2.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("1.0 2.0 5.0 4.0 5.0")), 0.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("5.0 2.0 3.0 4.0 5.0")), 0.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("3.0 2.0 3.0 4.0 5.0")), 0.0, + 1e-5); + + BOOST_REQUIRE_CLOSE(b.MaxDistance(arma::vec("1.0 2.0 4.0 4.0 5.0")), 5.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MaxDistance(arma::vec("1.0 2.0 8.0 4.0 5.0")), 9.0, + 1e-5); + BOOST_REQUIRE_CLOSE(b.MaxDistance(arma::vec("1.0 2.0 3.0 4.0 5.0")), 4.0, + 1e-5); + + b2 = HollowBallBound(3, 4, "1.0 2.0 3.0 5.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); + + b2 = HollowBallBound(1, 2, "1.0 2.0 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); + + b2 = HollowBallBound(0.5, 1.0, "1.0 2.5 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.5, 1e-5); + + b2 = HollowBallBound(0.5, 1.0, "1.0 8.0 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 1.0, 1e-5); + + b2 = HollowBallBound(0.5, 2.0, "1.0 8.0 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); + + b2 = HollowBallBound(0.5, 2.0, "1.0 8.0 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MaxDistance(b2), 12.0, 1e-5); + + b2 = HollowBallBound(0.5, 2.0, "1.0 3.0 3.0 4.0 5.0"); + BOOST_REQUIRE_CLOSE(b.MaxDistance(b2), 7.0, 1e-5); + + HollowBallBound b1 = b; + b2 = HollowBallBound(1.0, 2.0, "1.0 2.5 3.0 4.0 5.0"); + + b1 |= b2; + BOOST_REQUIRE_CLOSE(b1.InnerRadius(), 0.5, 1e-5); + + b1 = b; + b2 = HollowBallBound(0.5, 2.0, "1.0 3.0 3.0 4.0 5.0"); + b1 |= b2; + BOOST_REQUIRE_CLOSE(b1.InnerRadius(), 0.0, 1e-5); + + b1 = b; + b2 = HollowBallBound(0.5, 4.0, "1.0 3.0 3.0 4.0 5.0"); + b1 |= b2; + BOOST_REQUIRE_CLOSE(b1.OuterRadius(), 5.0, 1e-5); +} + +template +void CheckBound(TreeType& tree) +{ + if (tree.IsLeaf()) + { + for (size_t i = 0; i < tree.NumPoints(); i++) + BOOST_REQUIRE_EQUAL(true, + tree.Bound().Contains(tree.Dataset().col(tree.Point(i)))); + } + else + { + BOOST_REQUIRE_EQUAL(tree.NumPoints(), 1); + BOOST_REQUIRE_EQUAL(true, + tree.Bound().Contains(tree.Dataset().col(tree.Point(0)))); + + BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Left()->Bound()), true); + BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Right()->Bound()), true); + + CheckBound(*tree.Left()); + CheckBound(*tree.Right()); + } +} + +BOOST_AUTO_TEST_CASE(VPTreeBoundTest) +{ + typedef VPTree TreeType; + + arma::mat dataset(8, 1000); + dataset.randu(); + + TreeType tree(dataset); + CheckBound(tree); +} + +template +void CheckSplit(TreeType& tree) +{ + if(tree.IsLeaf()) + return; + + typename TreeType::ElemType maxDist = 0; + + size_t pointsEnd = tree.Left()->Begin() + tree.Left()->Count(); + for (size_t i = tree.Left()->Begin(); i < pointsEnd; i++) + { + typename TreeType::ElemType dist = + tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Begin()), + tree.Dataset().col(i)); + + if (dist > maxDist) + maxDist = dist; + } + + pointsEnd = tree.Right()->Begin() + tree.Right()->Count(); + for (size_t i = tree.Right()->Begin(); i < pointsEnd; i++) + { + typename TreeType::ElemType dist = + tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Begin()), + tree.Dataset().col(i)); + BOOST_REQUIRE_LE(maxDist, dist); + } + + CheckSplit(*tree.Left()); + CheckSplit(*tree.Right()); +} + +BOOST_AUTO_TEST_CASE(VPTreeSplitTest) +{ + typedef VPTree TreeType; + + arma::mat dataset(8, 1000); + dataset.randu(); + + TreeType tree(dataset); + CheckSplit(tree); +} + +BOOST_AUTO_TEST_CASE(VPTreeTest) +{ + typedef VPTree TreeType; + + size_t maxRuns = 10; // Ten total tests. + size_t pointIncrements = 1000; // Range is from 2000 points to 11000. + + // We use the default leaf size of 20. + for (size_t run = 0; run < maxRuns; run++) + { + size_t dimensions = run + 2; + size_t maxPoints = (run + 1) * pointIncrements; + + size_t size = maxPoints; + arma::mat dataset = arma::mat(dimensions, size); + arma::mat datacopy; // Used to test mappings. + + // Mappings for post-sort verification of data. + std::vector newToOld; + std::vector oldToNew; + + // Generate data. + dataset.randu(); + + // Build the tree itself. + TreeType root(dataset, newToOld, oldToNew); + const arma::mat& treeset = root.Dataset(); + + // Ensure the size of the tree is correct. + BOOST_REQUIRE_EQUAL(root.NumDescendants(), size); + + // Check the forward and backward mappings for correctness. + for(size_t i = 0; i < size; i++) + { + for(size_t j = 0; j < dimensions; j++) + { + BOOST_REQUIRE_EQUAL(treeset(j, i), dataset(j, newToOld[i])); + BOOST_REQUIRE_EQUAL(treeset(j, oldToNew[i]), dataset(j, i)); + } + } + } +} + +BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) +{ + arma::mat dataset; + dataset.randu(8, 1000); // 1000 points in 8 dimensions. + arma::Mat neighbors1; + arma::mat distances1; + arma::Mat neighbors2; + arma::mat distances2; + + // Nearest neighbor search with the VP tree. + NeighborSearch, arma::mat, + VPTree> knn1(dataset, false, true); + + knn1.Search(5, neighbors1, distances1); + + // Nearest neighbor search the naive way. + KNN knn2(dataset, true, true); + + knn2.Search(5, neighbors2, distances2); + + for (size_t i = 0; i < neighbors1.size(); i++) + { + BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]); + BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]); + } +} + +BOOST_AUTO_TEST_CASE(DualTreeTraverserTest) +{ + arma::mat dataset; + dataset.randu(8, 1000); // 1000 points in 8 dimensions. + arma::Mat neighbors1; + arma::mat distances1; + arma::Mat neighbors2; + arma::mat distances2; + + // Nearest neighbor search with the VP tree. + NeighborSearch, arma::mat, + VPTree> knn1(dataset, false, false); + + knn1.Search(5, neighbors1, distances1); + + // Nearest neighbor search the naive way. + KNN knn2(dataset, true, true); + + knn2.Search(5, neighbors2, distances2); + + for (size_t i = 0; i < neighbors1.size(); i++) + { + BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]); + BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]); + } +} + +BOOST_AUTO_TEST_SUITE_END(); From dbdf3ca4ebb753616dbfc1db832aa48bce4099b3 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Sat, 16 Jul 2016 19:32:14 +0300 Subject: [PATCH 04/12] Added the vantage point tree to RSModel and NSModel. --- .../methods/neighbor_search/kfn_main.cpp | 12 ++++++----- .../methods/neighbor_search/knn_main.cpp | 12 ++++++----- .../methods/neighbor_search/ns_model.hpp | 7 +++++-- .../methods/neighbor_search/ns_model_impl.hpp | 6 ++++++ .../range_search/range_search_main.cpp | 12 ++++++----- src/mlpack/methods/range_search/rs_model.cpp | 21 ++++++++++++++++++- src/mlpack/methods/range_search/rs_model.hpp | 6 +++++- .../methods/range_search/rs_model_impl.hpp | 14 +++++++++++++ src/mlpack/tests/aknn_test.cpp | 16 ++++++++++---- src/mlpack/tests/knn_test.cpp | 12 +++++++---- src/mlpack/tests/range_search_test.cpp | 12 +++++++---- 11 files changed, 99 insertions(+), 31 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp index 6adbc566f70..08cd95f4dbb 100644 --- a/src/mlpack/methods/neighbor_search/kfn_main.cpp +++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp @@ -62,10 +62,10 @@ PARAM_INT("k", "Number of furthest neighbors to find.", "k", 0); // The user may specify the type of tree to use, and a few pararmeters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); -PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " - "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", - 20); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'vp'.", "t", "kd"); +PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, vp " + "trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ " + "trees).", "l", 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -194,10 +194,12 @@ int main(int argc, char *argv[]) tree = KFNModel::R_PLUS_TREE; else if (treeType == "r-plus-plus") tree = KFNModel::R_PLUS_PLUS_TREE; + else if (treeType == "vp") + tree = KFNModel::VP_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " - << "'r-plus' and 'r-plus-plus'." << endl; + << "'r-plus', 'r-plus-plus' and 'vp'." << endl; kfn.TreeType() = tree; kfn.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp index 87bdb32e798..62fff18ec02 100644 --- a/src/mlpack/methods/neighbor_search/knn_main.cpp +++ b/src/mlpack/methods/neighbor_search/knn_main.cpp @@ -63,10 +63,10 @@ PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0); // The user may specify the type of tree to use, and a few parameters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); -PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " - "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", - 20); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'vp'.", "t", "kd"); +PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, vp " + "trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ " + "trees).", "l", 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -179,10 +179,12 @@ int main(int argc, char *argv[]) tree = KNNModel::R_PLUS_TREE; else if (treeType == "r-plus-plus") tree = KNNModel::R_PLUS_PLUS_TREE; + else if (treeType == "vp") + tree = KNNModel::VP_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " - << "'r-plus' and 'r-plus-plus'." << endl; + << "'r-plus', 'r-plus-plus' and 'vp'." << endl; knn.TreeType() = tree; knn.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 38e474898b9..fc782fed4f9 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "neighbor_search.hpp" @@ -256,7 +257,8 @@ class NSModel X_TREE, HILBERT_R_TREE, R_PLUS_TREE, - R_PLUS_PLUS_TREE + R_PLUS_PLUS_TREE, + VP_TREE }; private: @@ -284,7 +286,8 @@ class NSModel NSType*, NSType*, NSType*, - NSType*> nSearch; + NSType*, + NSType*> nSearch; public: /** diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index acbed6ce2a5..2ca8be5d5ab 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -394,6 +394,10 @@ void NSModel::BuildModel(arma::mat&& referenceSet, nSearch = new NSType(naive, singleMode, epsilon); break; + case VP_TREE: + nSearch = new NSType(naive, singleMode, + epsilon); + break; } TrainVisitor tn(std::move(referenceSet), leafSize); @@ -478,6 +482,8 @@ std::string NSModel::TreeName() const return "R+ tree"; case R_PLUS_PLUS_TREE: return "R++ tree"; + case VP_TREE: + return "Vantage point tree"; default: return "unknown tree"; } diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp index c8ea2a5b83f..4db62f7c4f8 100644 --- a/src/mlpack/methods/range_search/range_search_main.cpp +++ b/src/mlpack/methods/range_search/range_search_main.cpp @@ -70,10 +70,10 @@ PARAM_DOUBLE("min", "Lower bound in range.", "L", 0.0); // The user may specify the type of tree to use, and a few parameters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); -PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " - "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", - 20); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'vp'.", "t", "kd"); +PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, vp " + "trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ " + "trees).", "l", 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -181,10 +181,12 @@ int main(int argc, char *argv[]) tree = RSModel::R_PLUS_TREE; else if (treeType == "r-plus-plus") tree = RSModel::R_PLUS_PLUS_TREE; + else if (treeType == "vp") + tree = RSModel::VP_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are " << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " - << "'r-plus' and 'r-plus-plus'." << endl; + << "'r-plus', 'r-plus-plus' and 'vp'." << endl; rs.TreeType() = tree; rs.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp index 1cf3938523a..10514c6de84 100644 --- a/src/mlpack/methods/range_search/rs_model.cpp +++ b/src/mlpack/methods/range_search/rs_model.cpp @@ -25,7 +25,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) : xTreeRS(NULL), hilbertRTreeRS(NULL), rPlusTreeRS(NULL), - rPlusPlusTreeRS(NULL) + rPlusPlusTreeRS(NULL), + vpTreeRS(NULL) { // Nothing to do. } @@ -140,6 +141,11 @@ void RSModel::BuildModel(arma::mat&& referenceSet, rPlusPlusTreeRS = new RSType(move(referenceSet), naive, singleMode); break; + + case VP_TREE: + vpTreeRS = new RSType(move(referenceSet), naive, + singleMode); + break; } if (!naive) @@ -261,6 +267,10 @@ void RSModel::Search(arma::mat&& querySet, case R_PLUS_PLUS_TREE: rPlusPlusTreeRS->Search(querySet, range, neighbors, distances); break; + + case VP_TREE: + vpTreeRS->Search(querySet, range, neighbors, distances); + break; } } @@ -315,6 +325,10 @@ void RSModel::Search(const math::Range& range, case R_PLUS_PLUS_TREE: rPlusPlusTreeRS->Search(range, neighbors, distances); break; + + case VP_TREE: + vpTreeRS->Search(range, neighbors, distances); + break; } } @@ -341,6 +355,8 @@ std::string RSModel::TreeName() const return "R+ tree"; case R_PLUS_PLUS_TREE: return "R++ tree"; + case VP_TREE: + return "Vantage point tree"; default: return "unknown tree"; } @@ -367,6 +383,8 @@ void RSModel::CleanMemory() delete rPlusTreeRS; if (rPlusPlusTreeRS) delete rPlusPlusTreeRS; + if (vpTreeRS) + delete vpTreeRS; kdTreeRS = NULL; coverTreeRS = NULL; @@ -377,4 +395,5 @@ void RSModel::CleanMemory() hilbertRTreeRS = NULL; rPlusTreeRS = NULL; rPlusPlusTreeRS = NULL; + vpTreeRS = NULL; } diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp index 7903d373c38..8c7af3b4cd0 100644 --- a/src/mlpack/methods/range_search/rs_model.hpp +++ b/src/mlpack/methods/range_search/rs_model.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include "range_search.hpp" @@ -32,7 +33,8 @@ class RSModel X_TREE, HILBERT_R_TREE, R_PLUS_TREE, - R_PLUS_PLUS_TREE + R_PLUS_PLUS_TREE, + VP_TREE }; private: @@ -69,6 +71,8 @@ class RSModel RSType* rPlusTreeRS; //! R++ tree based range search object (NULL if not in use). RSType* rPlusPlusTreeRS; + //! VP tree based range search object (NULL if not in use). + RSType* vpTreeRS; public: /** diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp index 98fa7a8224b..fbc658055ca 100644 --- a/src/mlpack/methods/range_search/rs_model_impl.hpp +++ b/src/mlpack/methods/range_search/rs_model_impl.hpp @@ -65,6 +65,10 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */) case R_PLUS_PLUS_TREE: ar & CreateNVP(rPlusPlusTreeRS, "range_search_model"); break; + + case VP_TREE: + ar & CreateNVP(vpTreeRS, "range_search_model"); + break; } } @@ -88,6 +92,8 @@ inline const arma::mat& RSModel::Dataset() const return rPlusTreeRS->ReferenceSet(); else if (rPlusPlusTreeRS) return rPlusPlusTreeRS->ReferenceSet(); + else if (vpTreeRS) + return vpTreeRS->ReferenceSet(); throw std::runtime_error("no range search model initialized"); } @@ -112,6 +118,8 @@ inline bool RSModel::SingleMode() const return rPlusTreeRS->SingleMode(); else if (rPlusPlusTreeRS) return rPlusPlusTreeRS->SingleMode(); + else if (vpTreeRS) + return vpTreeRS->SingleMode(); throw std::runtime_error("no range search model initialized"); } @@ -136,6 +144,8 @@ inline bool& RSModel::SingleMode() return rPlusTreeRS->SingleMode(); else if (rPlusPlusTreeRS) return rPlusPlusTreeRS->SingleMode(); + else if (vpTreeRS) + return vpTreeRS->SingleMode(); throw std::runtime_error("no range search model initialized"); } @@ -160,6 +170,8 @@ inline bool RSModel::Naive() const return rPlusTreeRS->Naive(); else if (rPlusPlusTreeRS) return rPlusPlusTreeRS->Naive(); + else if (vpTreeRS) + return vpTreeRS->Naive(); throw std::runtime_error("no range search model initialized"); } @@ -184,6 +196,8 @@ inline bool& RSModel::Naive() return rPlusTreeRS->Naive(); else if (rPlusPlusTreeRS) return rPlusPlusTreeRS->Naive(); + else if (vpTreeRS) + return vpTreeRS->Naive(); throw std::runtime_error("no range search model initialized"); } diff --git a/src/mlpack/tests/aknn_test.cpp b/src/mlpack/tests/aknn_test.cpp index f38978dbfeb..98466b4c7fb 100644 --- a/src/mlpack/tests/aknn_test.cpp +++ b/src/mlpack/tests/aknn_test.cpp @@ -287,7 +287,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[14]; + KNNModel models[20]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -302,6 +302,12 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) models[11] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false); models[12] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, true); models[13] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true); + models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); + models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = KNNModel(KNNModel::TreeTypes::VP_TREE, true); + models[19] = KNNModel(KNNModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 3; ++j) { @@ -311,7 +317,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat distancesExact; aknn.Search(queryData, 3, neighborsExact, distancesExact); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have std::move() constructors so make a copy of our data. arma::mat referenceCopy(referenceData); @@ -352,7 +358,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[18]; + KNNModel models[20]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -371,6 +377,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = KNNModel(KNNModel::TreeTypes::VP_TREE, true); + models[19] = KNNModel(KNNModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -380,7 +388,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat distancesExact; exact.Search(3, neighborsExact, distancesExact); - for (size_t i = 0; i < 18; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have a std::move() constructor... so copy the data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp index 0de22b8b959..ffb6412d6b2 100644 --- a/src/mlpack/tests/knn_test.cpp +++ b/src/mlpack/tests/knn_test.cpp @@ -977,7 +977,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[18]; + KNNModel models[20]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -996,6 +996,8 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = KNNModel(KNNModel::TreeTypes::VP_TREE, true); + models[19] = KNNModel(KNNModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1005,7 +1007,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat baselineDistances; knn.Search(queryData, 3, baselineNeighbors, baselineDistances); - for (size_t i = 0; i < 18; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have std::move() constructors so make a copy of our data. arma::mat referenceCopy(referenceData); @@ -1049,7 +1051,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[18]; + KNNModel models[20]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -1068,6 +1070,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = KNNModel(KNNModel::TreeTypes::VP_TREE, true); + models[19] = KNNModel(KNNModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1077,7 +1081,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat baselineDistances; knn.Search(3, baselineNeighbors, baselineDistances); - for (size_t i = 0; i < 18; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have a std::move() constructor... so copy the data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp index 7f28f2e24eb..037f593e926 100644 --- a/src/mlpack/tests/range_search_test.cpp +++ b/src/mlpack/tests/range_search_test.cpp @@ -1249,7 +1249,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - RSModel models[18]; + RSModel models[20]; models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true); models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false); models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true); @@ -1268,6 +1268,8 @@ BOOST_AUTO_TEST_CASE(RSModelTest) models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false); models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true); models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = RSModel(RSModel::TreeTypes::VP_TREE, true); + models[19] = RSModel(RSModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1281,7 +1283,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest) vector>> baselineSorted; SortResults(baselineNeighbors, baselineDistances, baselineSorted); - for (size_t i = 0; i < 18; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have std::move() constructors, so make a copy of our data. arma::mat referenceCopy(referenceData); @@ -1325,7 +1327,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - RSModel models[18]; + RSModel models[20]; models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true); models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false); models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true); @@ -1344,6 +1346,8 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false); models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true); models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[18] = RSModel(RSModel::TreeTypes::VP_TREE, true); + models[19] = RSModel(RSModel::TreeTypes::VP_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1356,7 +1360,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) vector>> baselineSorted; SortResults(baselineNeighbors, baselineDistances, baselineSorted); - for (size_t i = 0; i < 18; ++i) + for (size_t i = 0; i < 20; ++i) { // We only have std::move() cosntructors, so make a copy of our data. arma::mat referenceCopy(referenceData); From 6c68a88b9a907079c740e0a6e30b17c1de4bbaec Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Sat, 16 Jul 2016 19:33:33 +0300 Subject: [PATCH 05/12] Added TreeType::IsFirstPointCentroid() to the TreeType API. Fixed comments for the vantage point tree. --- .../binary_space_tree/binary_space_tree.hpp | 4 + .../core/tree/cover_tree/cover_tree.hpp | 4 + .../tree/rectangle_tree/rectangle_tree.hpp | 4 + .../dual_tree_traverser_impl.hpp | 175 ++++++---- .../single_tree_traverser_impl.hpp | 4 +- .../vantage_point_split_impl.hpp | 2 + .../vantage_point_tree/vantage_point_tree.hpp | 311 +++++++++++++++++- .../vantage_point_tree_impl.hpp | 98 ++++-- .../kmeans/dual_tree_kmeans_rules_impl.hpp | 5 +- .../neighbor_search_rules_impl.hpp | 9 +- .../range_search/range_search_rules_impl.hpp | 6 +- src/mlpack/tests/vantage_point_tree_test.cpp | 22 +- 12 files changed, 520 insertions(+), 124 deletions(-) diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp index d2494683670..81faa7f2df0 100644 --- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp +++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp @@ -456,6 +456,10 @@ class BinarySpaceTree //! Store the center of the bounding region in the given vector. void Center(arma::vec& center) { bound.Center(center); } + //! Returns false: The first point of this node is not the centroid + //! of its bound. + static constexpr bool IsFirstPointCentroid() { return false; } + private: /** * Splits the current node, assigning its left and right children recursively. diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp index 82c6a2cd50a..422389d41cf 100644 --- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp +++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp @@ -374,6 +374,10 @@ class CoverTree //! Get the instantiated metric. MetricType& Metric() const { return *metric; } + //! Returns true: The first point of this node is the centroid + //! of its bound. + static constexpr bool IsFirstPointCentroid() { return true; } + private: //! Reference to the matrix which this tree is built on. const MatType* dataset; diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp index bbdebdadc1f..72b379ce91a 100644 --- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp +++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp @@ -492,6 +492,10 @@ class RectangleTree //! Returns false: this tree type does not have self children. static bool HasSelfChildren() { return false; } + //! Returns false: The first point of this node is not the centroid + //! of its bound. + static constexpr bool IsFirstPointCentroid() { return false; } + private: /** * Splits the current node, recursing up the tree. diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp index 7b729aeaee9..6a8a9b58b03 100644 --- a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp @@ -79,16 +79,23 @@ DualTreeTraverser::Traverse( { // We have to recurse down the query node. In this case the recursion order // does not matter. - const double pointScore = rule.Score(queryNode.Point(0), referenceNode); - ++numScores; - if (pointScore != DBL_MAX) - Traverse(queryNode.Point(0), referenceNode); - else - ++numPrunes; - - // Before recursing, we have to set the traversal information correctly. - rule.TraversalInfo() = traversalInfo; + // If the first point of the query node is the centroid, the query node + // contains a point. In this case we should run the single tree traverser. + if (queryNode.IsFirstPointCentroid()) + { + const double pointScore = rule.Score(queryNode.Point(0), referenceNode); + ++numScores; + + if (pointScore != DBL_MAX) + Traverse(queryNode.Point(0), referenceNode); + else + ++numPrunes; + + // Before recursing, we have to set the traversal information correctly. + rule.TraversalInfo() = traversalInfo; + } + const double leftScore = rule.Score(*queryNode.Left(), referenceNode); ++numScores; @@ -109,10 +116,15 @@ DualTreeTraverser::Traverse( } else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf())) { - const size_t queryEnd = queryNode.Begin() + queryNode.Count(); - for (size_t query = queryNode.Begin(); query < queryEnd; ++query) - rule.BaseCase(query, referenceNode.Point(0)); - numBaseCases += queryNode.Count(); + // If the reference node contains a point we should calculate all + // base cases with this point. + if (referenceNode.IsFirstPointCentroid()) + { + const size_t queryEnd = queryNode.Begin() + queryNode.Count(); + for (size_t query = queryNode.Begin(); query < queryEnd; ++query) + rule.BaseCase(query, referenceNode.Point(0)); + numBaseCases += queryNode.Count(); + } // We have to recurse down the reference node. In this case the recursion // order does matter. Before recursing, though, we have to set the // traversal information correctly. @@ -189,69 +201,36 @@ DualTreeTraverser::Traverse( } else { - for (size_t i = 0; i < queryNode.NumDescendants(); ++i) - rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0)); - numBaseCases += queryNode.NumDescendants(); + // If the reference node contains a point we should calculate all + // base cases with this point. + if (referenceNode.IsFirstPointCentroid()) + { + for (size_t i = 0; i < queryNode.NumDescendants(); ++i) + rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0)); + numBaseCases += queryNode.NumDescendants(); + } // We have to recurse down both query and reference nodes. Because the // query descent order does not matter, we will go to the left query child // first. Before recursing, we have to set the traversal information // correctly. - double leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left()); - typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); - rule.TraversalInfo() = traversalInfo; - double rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right()); - typename RuleType::TraversalInfoType rightInfo; - numScores += 2; - if (leftScore < rightScore) - { - // Recurse to the left. Restore the left traversal info. Store the right - // traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryNode.Point(0), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), - rightScore); + double leftScore; + typename RuleType::TraversalInfoType leftInfo; + double rightScore; + typename RuleType::TraversalInfoType rightInfo; - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(queryNode.Point(0), *referenceNode.Right()); - } - else - ++numPrunes; - } - else if (rightScore < leftScore) + if (queryNode.IsFirstPointCentroid()) { - // Recurse to the right. - Traverse(queryNode.Point(0), *referenceNode.Right()); - - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(), - leftScore); + leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left()); + leftInfo = rule.TraversalInfo(); + rule.TraversalInfo() = traversalInfo; + rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right()); + numScores += 2; - if (leftScore != DBL_MAX) + if (leftScore < rightScore) { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(queryNode.Point(0), *referenceNode.Left()); - } - else - ++numPrunes; - } - else - { - if (leftScore == DBL_MAX) - { - numPrunes += 2; - } - else - { - // Choose the left first. Restore the left traversal info and store the - // right traversal info. + // Recurse to the left. Restore the left traversal info. Store the right + // traversal info. rightInfo = rule.TraversalInfo(); rule.TraversalInfo() = leftInfo; Traverse(queryNode.Point(0), *referenceNode.Left()); @@ -262,17 +241,63 @@ DualTreeTraverser::Traverse( if (rightScore != DBL_MAX) { - // Restore the right traversal information. + // Restore the right traversal info. rule.TraversalInfo() = rightInfo; Traverse(queryNode.Point(0), *referenceNode.Right()); } else ++numPrunes; } - } + else if (rightScore < leftScore) + { + // Recurse to the right. + Traverse(queryNode.Point(0), *referenceNode.Right()); - // Restore the main traversal information. - rule.TraversalInfo() = traversalInfo; + // Is it still valid to recurse to the left? + leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(), + leftScore); + + if (leftScore != DBL_MAX) + { + // Restore the left traversal info. + rule.TraversalInfo() = leftInfo; + Traverse(queryNode.Point(0), *referenceNode.Left()); + } + else + ++numPrunes; + } + else + { + if (leftScore == DBL_MAX) + { + numPrunes += 2; + } + else + { + // Choose the left first. Restore the left traversal info and store the + // right traversal info. + rightInfo = rule.TraversalInfo(); + rule.TraversalInfo() = leftInfo; + Traverse(queryNode.Point(0), *referenceNode.Left()); + + // Is it still valid to recurse to the right? + rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), + rightScore); + + if (rightScore != DBL_MAX) + { + // Restore the right traversal information. + rule.TraversalInfo() = rightInfo; + Traverse(queryNode.Point(0), *referenceNode.Right()); + } + else + ++numPrunes; + } + } + + // Restore the main traversal information. + rule.TraversalInfo() = traversalInfo; + } // Now recurse down the left node. leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left()); @@ -452,8 +477,12 @@ DualTreeTraverser::Traverse( return; } - rule.BaseCase(queryIndex, referenceNode.Point(0)); - numBaseCases++; + // If the reference node contains a point we should calculate the base case. + if (referenceNode.IsFirstPointCentroid()) + { + rule.BaseCase(queryIndex, referenceNode.Point(0)); + numBaseCases++; + } // Store the current traversal info. traversalInfo = rule.TraversalInfo(); diff --git a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp index 38f126a0e8d..427f9714ce8 100644 --- a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp @@ -51,7 +51,9 @@ SingleTreeTraverser::Traverse( return; } - rule.BaseCase(queryIndex, referenceNode.Point(0)); + // If the reference node contains a point we should calculate the base case. + if (referenceNode.IsFirstPointCentroid()) + rule.BaseCase(queryIndex, referenceNode.Point(0)); // If either score is DBL_MAX, we do not recurse into that node. double leftScore = rule.Score(queryIndex, *referenceNode.Left()); diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp index ce0d082129a..03ee1467103 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp @@ -29,6 +29,7 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, if (mu == 0) return false; + // The first point of the left child is centroid. data.swap_cols(begin, vantagePointIndex); arma::Col vantagePoint = data.col(begin); @@ -54,6 +55,7 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, if (mu == 0) return false; + // The first point of the left child is centroid. data.swap_cols(begin, vantagePointIndex); size_t t = oldFromNew[begin]; oldFromNew[begin] = oldFromNew[vantagePointIndex]; diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp index c6ddbde1c82..2091dc9780d 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp @@ -1,162 +1,422 @@ /** * @file vantage_point_tree.hpp + * + * Definition of the vantage point tree. */ #ifndef MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_HPP #define MLPACK_CORE_TREE_VANTAGE_POINT_TREE_VANTAGE_POINT_TREE_HPP +#include "vantage_point_split.hpp" + namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { +/** + * The vantage point tree is a variant of a binary space tree. The difference + * from BinarySpaceTree is a presence of points in intermediate nodes. + * If an intermediate node holds a point, this point is the centroid of the + * bound. + * + * This particular tree does not allow growth, so you cannot add or delete nodes + * from it. If you need to add or delete a node, the better procedure is to + * rebuild the tree entirely. + * + * This tree does take one runtime parameter in the constructor, which is the + * max leaf size to be used. + + * @tparam MetricType The metric used for tree-building. The BoundType may + * place restrictions on the metrics that can be used. + * @tparam StatisticType Extra data contained in the node. See statistic.hpp + * for the necessary skeleton interface. + * @tparam MatType The dataset class. + * @tparam BoundType The bound used for each node. Currently only + * HollowBallBound is supported. + * @tparam SplitType The class that partitions the dataset/points at a + * particular node into two parts. Its definition decides the way this split + * is done. + */ template class BoundType = - bound::HRectBound, + bound::HollowBallBound, template - class SplitType = MidpointSplit> + class SplitType = VantagePointSplit> class VantagePointTree { public: + //! So other classes can use TreeType::Mat. typedef MatType Mat; + //! The type of element held in MatType. typedef typename MatType::elem_type ElemType; private: + //! The left child node. VantagePointTree* left; + //! The right child node. VantagePointTree* right; + //! The parent node (NULL if this is the root of the tree). VantagePointTree* parent; + //! The index of the first point in the dataset contained in this node. size_t begin; + //! The number of points of the dataset contained in this node. size_t count; + //! The bound object for this node. BoundType bound; + //! Any extra data contained in the node. StatisticType stat; + //! The distance from the centroid of this node to the centroid of the parent. ElemType parentDistance; + //! The worst possible distance to the furthest descendant, cached to speed + //! things up. ElemType furthestDescendantDistance; + //! The minimum distance from the center to any edge of the bound. ElemType minimumBoundDistance; + //! The dataset. If we are the root of the tree, we own the dataset and must + //! delete it. MatType* dataset; + //! Indicates that the first point of the node is the centroid of its bound. + bool firstPointIsCentroid; public: + //! A single-tree traverser for the vantage point tree; see + //! single_tree_traverser.hpp for implementation. template class SingleTreeTraverser; + //! A dual-tree traverser for the vantage point tree; + //! see dual_tree_traverser.hpp. template class DualTreeTraverser; + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will copy the input matrix; if you don't want this, consider + * using the constructor that takes an rvalue reference and use std::move(). + * + * @param data Dataset to create tree from. This will be copied! + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(const MatType& data, const size_t maxLeafSize = 20); + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will copy the input matrix and modify its ordering; a + * mapping of the old point indices to the new point indices is filled. If + * you don't want the matrix to be copied, consider using the constructor that + * takes an rvalue reference and use std::move(). + * + * @param data Dataset to create tree from. This will be copied! + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(const MatType& data, std::vector& oldFromNew, const size_t maxLeafSize = 20); + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will copy the input matrix and modify its ordering; a + * mapping of the old point indices to the new point indices is filled, as + * well as a mapping of the new point indices to the old point indices. If + * you don't want the matrix to be copied, consider using the constructor that + * takes an rvalue reference and use std::move(). + * + * @param data Dataset to create tree from. This will be copied! + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param newFromOld Vector which will be filled with the new positions for + * each old point. + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(const MatType& data, std::vector& oldFromNew, std::vector& newFromOld, const size_t maxLeafSize = 20); + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will take ownership of the data matrix; if you don't want + * this, consider using the constructor that takes a const reference to a + * dataset. + * + * @param data Dataset to create tree from. + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(MatType&& data, const size_t maxLeafSize = 20); + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will take ownership of the data matrix; a mapping of the + * old point indices to the new point indices is filled. If you don't want + * the matrix to have its ownership taken, consider using the constructor that + * takes a const reference to a dataset. + * + * @param data Dataset to create tree from. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(MatType&& data, std::vector& oldFromNew, const size_t maxLeafSize = 20); + /** + * Construct this as the root node of a vantage point tree using the given + * dataset. This will take ownership of the data matrix; a mapping of the old + * point indices to the new point indices is filled, as well as a mapping of + * the new point indices to the old point indices. If you don't want the + * matrix to have its ownership taken, consider using the constructor that + * takes a const reference to a dataset. + * + * @param data Dataset to create tree from. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param newFromOld Vector which will be filled with the new positions for + * each old point. + * @param maxLeafSize Size of each leaf in the tree. + */ VantagePointTree(MatType&& data, std::vector& oldFromNew, std::vector& newFromOld, const size_t maxLeafSize = 20); + /** + * Construct this node as a child of the given parent, starting at column + * begin and using count points. The ordering of that subset of points in the + * parent's data matrix will be modified! This is used for recursive + * tree-building by the other constructors which don't specify point indices. + * + * @param parent Parent of this node. Its dataset will be modified! + * @param begin Index of point to start tree construction with. + * @param count Number of points to use to construct tree. + * @param maxLeafSize Size of each leaf in the tree. + * @param firstPointIsCentroid Indicates that the first point of the node is + * the centroid of its bound. + */ VantagePointTree(VantagePointTree* parent, const size_t begin, const size_t count, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20); - + const size_t maxLeafSize = 20, + bool firstPointIsCentroid = false); + + /** + * Construct this node as a child of the given parent, starting at column + * begin and using count points. The ordering of that subset of points in the + * parent's data matrix will be modified! This is used for recursive + * tree-building by the other constructors which don't specify point indices. + * + * A mapping of the old point indices to the new point indices is filled, but + * it is expected that the vector is already allocated with size greater than + * or equal to (begin_in + count_in), and if that is not true, invalid memory + * reads (and writes) will occur. + * + * @param parent Parent of this node. Its dataset will be modified! + * @param begin Index of point to start tree construction with. + * @param count Number of points to use to construct tree. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param maxLeafSize Size of each leaf in the tree. + * @param firstPointIsCentroid Indicates that the first point of the node is + * the centroid of its bound. + */ VantagePointTree(VantagePointTree* parent, const size_t begin, const size_t count, std::vector& oldFromNew, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20); - + const size_t maxLeafSize = 20, + bool firstPointIsCentroid = false); + + /** + * Construct this node as a child of the given parent, starting at column + * begin and using count points. The ordering of that subset of points in the + * parent's data matrix will be modified! This is used for recursive + * tree-building by the other constructors which don't specify point indices. + * + * A mapping of the old point indices to the new point indices is filled, as + * well as a mapping of the new point indices to the old point indices. It is + * expected that the vector is already allocated with size greater than or + * equal to (begin_in + count_in), and if that is not true, invalid memory + * reads (and writes) will occur. + * + * @param parent Parent of this node. Its dataset will be modified! + * @param begin Index of point to start tree construction with. + * @param count Number of points to use to construct tree. + * @param oldFromNew Vector which will be filled with the old positions for + * each new point. + * @param newFromOld Vector which will be filled with the new positions for + * each old point. + * @param maxLeafSize Size of each leaf in the tree. + * @param firstPointIsCentroid Indicates that the first point of the node is + * the centroid of its bound. + */ VantagePointTree(VantagePointTree* parent, const size_t begin, const size_t count, std::vector& oldFromNew, std::vector& newFromOld, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20); - + const size_t maxLeafSize = 20, + bool firstPointIsCentroid = false); + + /** + * Create a vantage point tree by copying the other tree. Be careful! This + * can take a long time and use a lot of memory. + * + * @param other Tree to be replicated. + */ VantagePointTree(const VantagePointTree& other); + /** + * Move constructor for a VantagePointTree; possess all the members of the + * given tree. + */ VantagePointTree(VantagePointTree&& other); + /** + * Initialize the tree from a boost::serialization archive. + * + * @param ar Archive to load tree from. Must be an iarchive, not an oarchive. + */ template VantagePointTree( Archive& ar, const typename boost::enable_if::type* = 0); + /** + * Deletes this node, deallocating the memory for the children and calling + * their destructors in turn. This will invalidate any pointers or references + * to any nodes which are children of this one. + */ ~VantagePointTree(); + //! Return the bound object for this node. const BoundType& Bound() const { return bound; } + //! Return the bound object for this node. BoundType& Bound() { return bound; } + //! Return the statistic object for this node. const StatisticType& Stat() const { return stat; } + //! Modify the statistic object for this node. StatisticType& Stat() { return stat; } + //! Return whether or not this node is a leaf (true if it has no children). bool IsLeaf() const; + //! Gets the left child of this node. VantagePointTree* Left() const { return left; } + //! Modify the left child of this node. VantagePointTree*& Left() { return left; } + //! Gets the right child of this node. VantagePointTree* Right() const { return right; } + //! Modify the right child of this node. VantagePointTree*& Right() { return right; } + //! Gets the parent of this node. VantagePointTree* Parent() const { return parent; } + //! Modify the parent of this node. VantagePointTree*& Parent() { return parent; } + //! Get the dataset which the tree is built on. const MatType& Dataset() const { return *dataset; } + //! Modify the dataset which the tree is built on. Be careful! MatType& Dataset() { return *dataset; } + //! Get the metric that the tree uses. MetricType Metric() const { return MetricType(); } + //! Return the number of children in this node. size_t NumChildren() const; + /** + * Return the furthest distance to a point held in this node. If this is not + * a leaf node, then the distance is 0 because the node holds no points or + * the only point is the centroid. + */ ElemType FurthestPointDistance() const; + /** + * Return the furthest possible descendant distance. This returns the maximum + * distance from the centroid to the edge of the bound and not the empirical + * quantity which is the actual furthest descendant distance. So the actual + * furthest descendant distance may be less than what this method returns (but + * it will never be greater than this). + */ ElemType FurthestDescendantDistance() const; + //! Return the minimum distance from the center of the node to any bound edge. ElemType MinimumBoundDistance() const; + //! Return the distance from the center of this node to the center of the + //! parent node. ElemType ParentDistance() const { return parentDistance; } + //! Modify the distance from the center of this node to the center of the + //! parent node. ElemType& ParentDistance() { return parentDistance; } + /** + * Return the specified child (0 will be left, 1 will be right). If the index + * is greater than 1, this will return the right child. + * + * @param child Index of child to return. + */ VantagePointTree& Child(const size_t child) const; VantagePointTree*& ChildPtr(const size_t child) { return (child == 0) ? left : right; } + //! Return the number of points in this node. size_t NumPoints() const; + /** + * Return the number of descendants of this node. + */ size_t NumDescendants() const; + /** + * Return the index (with reference to the dataset) of a particular descendant + * of this node. The index should be greater than zero but less than the + * number of descendants. + * + * @param index Index of the descendant. + */ size_t Descendant(const size_t index) const; + /** + * Return the index (with reference to the dataset) of a particular point in + * this node. This will happily return invalid indices if the given index is + * greater than the number of points in this node (obtained with NumPoints()) + * -- be careful. + * + * @param index Index of point for which a dataset index is wanted. + */ size_t Point(const size_t index) const; + //! Return the minimum distance to another node. ElemType MinDistance(const VantagePointTree* other) const { return bound.MinDistance(other->Bound()); } + //! Return the maximum distance to another node. ElemType MaxDistance(const VantagePointTree* other) const { return bound.MaxDistance(other->Bound()); } + //! Return the minimum and maximum distance to another node. math::RangeType RangeDistance(const VantagePointTree* other) const { return bound.RangeDistance(other->Bound()); } + //! Return the minimum distance to another point. template ElemType MinDistance(const VecType& point, typename boost::enable_if >::type* = 0) @@ -165,6 +425,7 @@ class VantagePointTree return bound.MinDistance(point); } + //! Return the maximum distance to another point. template ElemType MaxDistance(const VecType& point, typename boost::enable_if >::type* = 0) @@ -173,6 +434,7 @@ class VantagePointTree return bound.MaxDistance(point); } + //! Return the minimum and maximum distance to another point. template math::RangeType RangeDistance(const VecType& point, @@ -181,30 +443,63 @@ class VantagePointTree return bound.RangeDistance(point); } + //! Return the index of the beginning point of this subset. size_t Begin() const { return begin; } + //! Modify the index of the beginning point of this subset. size_t& Begin() { return begin; } + //! Return the number of points in this subset. size_t Count() const { return count; } + //! Modify the number of points in this subset. size_t& Count() { return count; } + //! Returns false: this tree type does not have self children. static bool HasSelfChildren() { return false; } + //! Store the center of the bounding region in the given vector. void Center(arma::vec& center) { bound.Center(center); } + //! Indicates that the first point of this node is the centroid of its bound. + bool IsFirstPointCentroid() const { return firstPointIsCentroid; } + private: + /** + * Splits the current node, assigning its left and right children recursively. + * + * @param maxLeafSize Maximum number of points held in a leaf. + * @param splitter Instantiated SplitType object. + */ void SplitNode(const size_t maxLeafSize, SplitType, MatType>& splitter); + /** + * Splits the current node, assigning its left and right children recursively. + * Also returns a list of the changed indices. + * + * @param oldFromNew Vector holding permuted indices. + * @param maxLeafSize Maximum number of points held in a leaf. + * @param splitter Instantiated SplitType object. + */ void SplitNode(std::vector& oldFromNew, const size_t maxLeafSize, SplitType, MatType>& splitter); protected: + /** + * A default constructor. This is meant to only be used with + * boost::serialization, which is allowed with the friend declaration below. + * This does not return a valid tree! The method must be protected, so that + * the serialization shim can work with the default constructor. + */ VantagePointTree(); + //! Friend access is given for the default constructor. friend class boost::serialization::access; public: + /** + * Serialize the tree. + */ template void Serialize(Archive& ar, const unsigned int version); }; diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp index 8cc320ae1d5..087b5d9645a 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp @@ -30,7 +30,8 @@ VantagePointTree( count(data.n_cols), /* and spans all of the dataset. */ bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(data)) // Copies the dataset. + dataset(new MatType(data)), // Copies the dataset. + firstPointIsCentroid(false) { // Do the actual splitting of this node. SplitType, MatType> splitter; @@ -58,7 +59,8 @@ VantagePointTree( count(data.n_cols), bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(data)) // Copies the dataset. + dataset(new MatType(data)), // Copies the dataset. + firstPointIsCentroid(false) { // Initialize oldFromNew correctly. oldFromNew.resize(data.n_cols); @@ -92,7 +94,8 @@ VantagePointTree( count(data.n_cols), bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(data)) // Copies the dataset. + dataset(new MatType(data)), // Copies the dataset. + firstPointIsCentroid(false) { // Initialize the oldFromNew vector correctly. oldFromNew.resize(data.n_cols); @@ -127,7 +130,8 @@ VantagePointTree(MatType&& data, const size_t maxLeafSize) : count(data.n_cols), bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(std::move(data))) + dataset(new MatType(std::move(data))), + firstPointIsCentroid(false) { // Do the actual splitting of this node. SplitType, MatType> splitter; @@ -155,7 +159,8 @@ VantagePointTree( count(data.n_cols), bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(std::move(data))) + dataset(new MatType(std::move(data))), + firstPointIsCentroid(false) { // Initialize oldFromNew correctly. oldFromNew.resize(dataset->n_cols); @@ -189,7 +194,8 @@ VantagePointTree( count(data.n_cols), bound(data.n_rows), parentDistance(0), // Parent distance for the root is 0: it has no parent. - dataset(new MatType(std::move(data))) + dataset(new MatType(std::move(data))), + firstPointIsCentroid(false) { // Initialize the oldFromNew vector correctly. oldFromNew.resize(dataset->n_cols); @@ -221,14 +227,16 @@ VantagePointTree( const size_t begin, const size_t count, SplitType, MatType>& splitter, - const size_t maxLeafSize) : + const size_t maxLeafSize, + bool firstPointIsCentroid) : left(NULL), right(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset().n_rows), - dataset(&parent->Dataset()) // Point to the parent's dataset. + dataset(&parent->Dataset()), // Point to the parent's dataset. + firstPointIsCentroid(firstPointIsCentroid) { // Perform the actual splitting. SplitNode(maxLeafSize, splitter); @@ -250,14 +258,16 @@ VantagePointTree( const size_t count, std::vector& oldFromNew, SplitType, MatType>& splitter, - const size_t maxLeafSize) : + const size_t maxLeafSize, + bool firstPointIsCentroid) : left(NULL), right(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset().n_rows), - dataset(&parent->Dataset()) + dataset(&parent->Dataset()), + firstPointIsCentroid(firstPointIsCentroid) { // Hopefully the vector is initialized correctly! We can't check that // entirely but we can do a minor sanity check. @@ -284,14 +294,16 @@ VantagePointTree( std::vector& oldFromNew, std::vector& newFromOld, SplitType, MatType>& splitter, - const size_t maxLeafSize) : + const size_t maxLeafSize, + bool firstPointIsCentroid) : left(NULL), right(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset()->n_rows), - dataset(&parent->Dataset()) + dataset(&parent->Dataset()), + firstPointIsCentroid(firstPointIsCentroid) { // Hopefully the vector is initialized correctly! We can't check that // entirely but we can do a minor sanity check. @@ -332,7 +344,8 @@ VantagePointTree( parentDistance(other.parentDistance), furthestDescendantDistance(other.furthestDescendantDistance), // Copy matrix, but only if we are the root. - dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL) + dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL), + firstPointIsCentroid(other.firstPointIsCentroid) { // Create left and right children (if any). if (other.Left()) @@ -390,7 +403,8 @@ VantagePointTree(VantagePointTree&& other) : parentDistance(other.parentDistance), furthestDescendantDistance(other.furthestDescendantDistance), minimumBoundDistance(other.minimumBoundDistance), - dataset(other.dataset) + dataset(other.dataset), + firstPointIsCentroid(other.firstPointIsCentroid) { // Now we are a clone of the other tree. But we must also clear the other // tree's contents, so it doesn't delete anything when it is destructed. @@ -573,8 +587,10 @@ inline size_t VantagePointTree::NumPoints() const { // Each intermediate node contains exactly one point. - if (left) + if (left && parent) return 1; + else if(!parent) + return 0; return count; } @@ -637,7 +653,8 @@ void VantagePointTree: // We need to expand the bounds of this node properly. if (parent) { - bound.Center() = dataset->col(parent->begin); + bound.Center() = parent->firstPointIsCentroid ? + dataset->col(parent->begin + 1) : dataset->col(parent->begin); bound.OuterRadius() = 0; bound.InnerRadius() = std::numeric_limits::max(); } @@ -650,6 +667,8 @@ void VantagePointTree: while (tree->Parent() != NULL) { tree->Parent()->Bound() |= tree->Bound(); + tree->Parent()->furthestDescendantDistance = 0.5 * + tree->Parent()->Bound().Diameter(); tree = tree->Parent(); } // Calculate the furthest descendant distance. @@ -666,7 +685,17 @@ void VantagePointTree: // Split the node. The elements of 'data' are reordered by the splitting // algorithm. This function call updates splitCol. - const bool split = splitter.SplitNode(bound, *dataset, begin, count, + + size_t splitBegin = begin; + size_t splitCount = count; + + if (IsFirstPointCentroid()) + { + splitBegin = begin + 1; + splitCount = count - 1; + } + + const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, splitCol); // The node may not be always split. For instance, if all the points are the @@ -676,10 +705,10 @@ void VantagePointTree: // Now that we know the split column, we will recursively split the children // by calling their constructors (which perform this splitting process). - left = new VantagePointTree(this, begin + 1, splitCol - begin - 1, splitter, - maxLeafSize); - right = new VantagePointTree(this, splitCol, begin + count - splitCol, - splitter, maxLeafSize); + left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, splitter, + maxLeafSize, true); + right = new VantagePointTree(this, splitCol, splitBegin + splitCount - splitCol, + splitter, maxLeafSize, false); // Calculate parent distances for those two nodes. arma::vec center, leftCenter, rightCenter; @@ -706,12 +735,12 @@ SplitNode(std::vector& oldFromNew, const size_t maxLeafSize, SplitType, MatType>& splitter) { - // This should be a single function for Bound. // We need to expand the bounds of this node properly. if (parent) { - bound.Center() = dataset->col(parent->begin); + bound.Center() = parent->firstPointIsCentroid ? + dataset->col(parent->begin + 1) : dataset->col(parent->begin); bound.OuterRadius() = 0; bound.InnerRadius() = std::numeric_limits::max(); } @@ -724,6 +753,8 @@ SplitNode(std::vector& oldFromNew, while (tree->Parent() != NULL) { tree->Parent()->Bound() |= tree->Bound(); + tree->Parent()->furthestDescendantDistance = 0.5 * + tree->Parent()->Bound().Diameter(); tree = tree->Parent(); } @@ -741,7 +772,17 @@ SplitNode(std::vector& oldFromNew, // Split the node. The elements of 'data' are reordered by the splitting // algorithm. This function call updates splitCol and oldFromNew. - const bool split = splitter.SplitNode(bound, *dataset, begin, count, splitCol, + + size_t splitBegin = begin; + size_t splitCount = count; + + if (IsFirstPointCentroid()) + { + splitBegin = begin + 1; + splitCount = count - 1; + } + + const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, splitCol, oldFromNew); // The node may not be always split. For instance, if all the points are the @@ -751,10 +792,10 @@ SplitNode(std::vector& oldFromNew, // Now that we know the split column, we will recursively split the children // by calling their constructors (which perform this splitting process). - left = new VantagePointTree(this, begin + 1, splitCol - begin - 1, oldFromNew, - splitter, maxLeafSize); - right = new VantagePointTree(this, splitCol, begin + count - splitCol, - oldFromNew, splitter, maxLeafSize); + left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, oldFromNew, + splitter, maxLeafSize, true); + right = new VantagePointTree(this, splitCol, splitBegin + splitCount - splitCol, + oldFromNew, splitter, maxLeafSize, false); // Calculate parent distances for those two nodes. @@ -831,6 +872,7 @@ void VantagePointTree: // Save children last; otherwise boost::serialization gets confused. ar & CreateNVP(left, "left"); ar & CreateNVP(right, "right"); + ar & CreateNVP(firstPointIsCentroid, "firstPointIsCentroid"); // Due to quirks of boost::serialization, if a tree is saved as an object and // not a pointer, the first level of the tree will be duplicated on load. diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp index f91b5ef153e..5aa8d2f0aef 100644 --- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp +++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp @@ -131,7 +131,7 @@ inline double DualTreeKMeansRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (tree::TreeTraits::FirstPointIsCentroid) + if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -207,7 +207,8 @@ inline double DualTreeKMeansRules::Score( // Now, check if we can prune. if (adjustedScore > queryNode.Stat().UpperBound()) { - if (!(tree::TreeTraits::FirstPointIsCentroid && score == 0.0)) + if (!(queryNode.IsFirstPointCentroid() && + referenceNode.IsFirstPointCentroid() && score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 24f94856f5a..7ae89de9fbf 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -85,7 +85,7 @@ inline double NeighborSearchRules::Score( { ++scores; // Count number of Score() calls. double distance; - if (tree::TreeTraits::FirstPointIsCentroid) + if (referenceNode.IsFirstPointCentroid()) { // The first point in the tree is the centroid. So we can then calculate // the base case between that and the query point. @@ -160,7 +160,7 @@ inline double NeighborSearchRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (tree::TreeTraits::FirstPointIsCentroid) + if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -237,7 +237,8 @@ inline double NeighborSearchRules::Score( // Can we prune? if (!SortPolicy::IsBetter(adjustedScore, bestDistance)) { - if (!(tree::TreeTraits::FirstPointIsCentroid && score == 0.0)) + if (!(queryNode.IsFirstPointCentroid() && + referenceNode.IsFirstPointCentroid() && score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only @@ -247,7 +248,7 @@ inline double NeighborSearchRules::Score( } double distance; - if (tree::TreeTraits::FirstPointIsCentroid) + if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) { // The first point in the node is the centroid, so we can calculate the // distance between the two points using BaseCase() and then find the diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp index a0629bbb419..8dd103f5054 100644 --- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp +++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp @@ -79,7 +79,7 @@ double RangeSearchRules::Score(const size_t queryIndex, // object. math::Range distances; - if (tree::TreeTraits::FirstPointIsCentroid) + if (referenceNode.IsFirstPointCentroid()) { // In this situation, we calculate the base case. So we should check to be // sure we haven't already done that. @@ -147,7 +147,7 @@ double RangeSearchRules::Score(TreeType& queryNode, TreeType& referenceNode) { math::Range distances; - if (tree::TreeTraits::FirstPointIsCentroid) + if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) { // It is possible that the base case has already been calculated. double baseCase = 0.0; @@ -224,7 +224,7 @@ void RangeSearchRules::AddResult(const size_t queryIndex, // called, so if the base case has already been calculated, then we must avoid // adding that point to the results again. size_t baseCaseMod = 0; - if (tree::TreeTraits::FirstPointIsCentroid && + if (referenceNode.IsFirstPointCentroid() && (queryIndex == lastQueryIndex) && (referenceNode.Point(0) == lastReferenceIndex)) { diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp index 1570f6c2015..04b15553d9e 100644 --- a/src/mlpack/tests/vantage_point_tree_test.cpp +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -125,9 +125,14 @@ void CheckBound(TreeType& tree) } else { - BOOST_REQUIRE_EQUAL(tree.NumPoints(), 1); - BOOST_REQUIRE_EQUAL(true, - tree.Bound().Contains(tree.Dataset().col(tree.Point(0)))); + if (!tree.Parent()) + BOOST_REQUIRE_EQUAL(tree.NumPoints(), 0); + else if (tree.IsFirstPointCentroid()) + { + BOOST_REQUIRE_EQUAL(tree.NumPoints(), 1); + BOOST_REQUIRE_EQUAL(true, + tree.Bound().Contains(tree.Dataset().col(tree.Point(0)))); + } BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Left()->Bound()), true); BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Right()->Bound()), true); @@ -160,7 +165,7 @@ void CheckSplit(TreeType& tree) for (size_t i = tree.Left()->Begin(); i < pointsEnd; i++) { typename TreeType::ElemType dist = - tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Begin()), + tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Left()->Begin()), tree.Dataset().col(i)); if (dist > maxDist) @@ -171,11 +176,18 @@ void CheckSplit(TreeType& tree) for (size_t i = tree.Right()->Begin(); i < pointsEnd; i++) { typename TreeType::ElemType dist = - tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Begin()), + tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Left()->Begin()), tree.Dataset().col(i)); BOOST_REQUIRE_LE(maxDist, dist); } + if (tree.IsFirstPointCentroid()) + { + for (size_t k = 0; k < tree.Bound().Dim(); k++) + BOOST_REQUIRE_EQUAL(tree.Bound().Center()[k], + tree.Dataset().col(tree.Point(0))[k]); + } + CheckSplit(*tree.Left()); CheckSplit(*tree.Right()); } From 300882ac96e7a663e3e303ca0c45c14c6fafe1a6 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Tue, 19 Jul 2016 16:53:45 +0300 Subject: [PATCH 06/12] Very minor fixes of HollowBallBound. --- src/mlpack/core/tree/hollow_ball_bound.hpp | 22 ++- .../core/tree/hollow_ball_bound_impl.hpp | 126 +++++++++--------- 2 files changed, 71 insertions(+), 77 deletions(-) diff --git a/src/mlpack/core/tree/hollow_ball_bound.hpp b/src/mlpack/core/tree/hollow_ball_bound.hpp index 5770acbef28..53be28712b2 100644 --- a/src/mlpack/core/tree/hollow_ball_bound.hpp +++ b/src/mlpack/core/tree/hollow_ball_bound.hpp @@ -33,10 +33,8 @@ class HollowBallBound typedef VecType Vec; private: - //! The radius of the inner ball bound. - ElemType innerRadius; - //! The radius of the outer ball bound. - ElemType outerRadius; + //! The inner and the outer radii of the bound. + math::RangeType radii; //! The center of the ball bound. VecType center; //! The metric used in this bound. @@ -65,8 +63,8 @@ class HollowBallBound /** * Create the ball bound with the specified radius and center. * - * @param innerRradius Inner radius of ball bound. - * @param outerRradius Outer radius of ball bound. + * @param innerRadius Inner radius of ball bound. + * @param outerRadius Outer radius of ball bound. * @param center Center of ball bound. */ HollowBallBound(const ElemType innerRadius, @@ -86,14 +84,14 @@ class HollowBallBound ~HollowBallBound(); //! Get the outer radius of the ball. - ElemType OuterRadius() const { return outerRadius; } + ElemType OuterRadius() const { return radii.Hi(); } //! Modify the outer radius of the ball. - ElemType& OuterRadius() { return outerRadius; } + ElemType& OuterRadius() { return radii.Hi(); } //! Get the innner radius of the ball. - ElemType InnerRadius() const { return innerRadius; } + ElemType InnerRadius() const { return radii.Lo(); } //! Modify the inner radius of the ball. - ElemType& InnerRadius() { return innerRadius; } + ElemType& InnerRadius() { return radii.Lo(); } //! Get the center point of the ball. const VecType& Center() const { return center; } @@ -107,7 +105,7 @@ class HollowBallBound * Get the minimum width of the bound (this is same as the diameter). * For ball bounds, width along all dimensions remain same. */ - ElemType MinWidth() const { return outerRadius * 2.0; } + ElemType MinWidth() const { return radii.Hi() * 2.0; } //! Get the range in a certain dimension. math::RangeType operator[](const size_t i) const; @@ -194,7 +192,7 @@ class HollowBallBound /** * Returns the diameter of the ballbound. */ - ElemType Diameter() const { return 2 * outerRadius; } + ElemType Diameter() const { return 2 * radii.Hi(); } //! Returns the distance metric used in this bound. const MetricType& Metric() const { return *metric; } diff --git a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp index 570d1806f50..ee61daa6a51 100644 --- a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp +++ b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp @@ -20,8 +20,8 @@ namespace bound { //! Empty Constructor. template HollowBallBound::HollowBallBound() : - innerRadius(std::numeric_limits::lowest()), - outerRadius(std::numeric_limits::lowest()), + radii(std::numeric_limits::lowest(), + std::numeric_limits::lowest()), metric(new MetricType()), ownsMetric(true) { /* Nothing to do. */ } @@ -33,8 +33,8 @@ HollowBallBound::HollowBallBound() : */ template HollowBallBound::HollowBallBound(const size_t dimension) : - innerRadius(std::numeric_limits::lowest()), - outerRadius(std::numeric_limits::lowest()), + radii(std::numeric_limits::lowest(), + std::numeric_limits::lowest()), center(dimension), metric(new MetricType()), ownsMetric(true) @@ -52,8 +52,8 @@ HollowBallBound:: HollowBallBound(const ElemType innerRadius, const ElemType outerRadius, const VecType& center) : - innerRadius(innerRadius), - outerRadius(outerRadius), + radii(innerRadius, + outerRadius), center(center), metric(new MetricType()), ownsMetric(true) @@ -63,8 +63,7 @@ HollowBallBound(const ElemType innerRadius, template HollowBallBound::HollowBallBound( const HollowBallBound& other) : - innerRadius(other.innerRadius), - outerRadius(other.outerRadius), + radii(other.radii), center(other.center), metric(other.metric), ownsMetric(false) @@ -75,8 +74,7 @@ template HollowBallBound& HollowBallBound:: operator=(const HollowBallBound& other) { - innerRadius = other.innerRadius; - outerRadius = other.outerRadius; + radii = other.radii; center = other.center; metric = other.metric; ownsMetric = false; @@ -87,15 +85,14 @@ operator=(const HollowBallBound& other) //! Move constructor. template HollowBallBound::HollowBallBound(HollowBallBound&& other) : - innerRadius(other.innerRadius), - outerRadius(other.outerRadius), + radii(other.radii), center(other.center), metric(other.metric), ownsMetric(other.ownsMetric) { // Fix the other bound. - other.innerRadius = 0.0; - other.outerRadius = 0.0; + other.radii.Hi() = 0.0; + other.radii.Lo() = 0.0; other.center = VecType(); other.metric = NULL; other.ownsMetric = false; @@ -114,10 +111,10 @@ template math::RangeType::ElemType> HollowBallBound::operator[](const size_t i) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return math::Range(); else - return math::Range(center[i] - outerRadius, center[i] + outerRadius); + return math::Range(center[i] - radii.Hi(), center[i] + radii.Hi()); } /** @@ -126,12 +123,12 @@ HollowBallBound::operator[](const size_t i) const template bool HollowBallBound::Contains(const VecType& point) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return false; else { const ElemType dist = metric->Evaluate(center, point); - return ((dist <= outerRadius) && (dist >= innerRadius)); + return ((dist <= radii.Hi()) && (dist >= radii.Lo())); } } @@ -142,19 +139,19 @@ template bool HollowBallBound::Contains( const HollowBallBound& other) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return false; else { const ElemType dist = metric->Evaluate(center, other.center); - bool containOnOneSide = (dist - other.outerRadius >= innerRadius) && - (dist + other.outerRadius <= outerRadius); - bool containOnEverySide = (dist + innerRadius <= other.innerRadius) && - (dist + other.outerRadius <= outerRadius); + bool containOnOneSide = (dist - other.radii.Hi() >= radii.Lo()) && + (dist + other.radii.Hi() <= radii.Hi()); + bool containOnEverySide = (dist + radii.Lo() <= other.radii.Lo()) && + (dist + other.radii.Hi() <= radii.Hi()); - bool containAsBall = (innerRadius == 0) && - (dist + other.outerRadius <= outerRadius); + bool containAsBall = (radii.Lo() == 0) && + (dist + other.radii.Hi() <= radii.Hi()); return (containOnOneSide || containOnEverySide || containAsBall); } @@ -171,14 +168,14 @@ HollowBallBound::MinDistance( const OtherVecType& point, typename boost::enable_if>* /* junk */) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return std::numeric_limits::max(); else { const ElemType dist = metric->Evaluate(point, center); - const ElemType outerDistance = math::ClampNonNegative(dist - outerRadius); - const ElemType innerDistance = math::ClampNonNegative(innerRadius - dist); + const ElemType outerDistance = math::ClampNonNegative(dist - radii.Hi()); + const ElemType innerDistance = math::ClampNonNegative(radii.Lo() - dist); return innerDistance + outerDistance; } @@ -192,18 +189,18 @@ typename HollowBallBound::ElemType HollowBallBound::MinDistance(const HollowBallBound& other) const { - if (outerRadius < 0 || other.outerRadius < 0) + if (radii.Hi() < 0 || other.radii.Hi() < 0) return std::numeric_limits::max(); else { const ElemType centerDistance = metric->Evaluate(center, other.center); const ElemType outerDistance = math::ClampNonNegative(centerDistance - - outerRadius - other.outerRadius); - const ElemType innerDistance1 = math::ClampNonNegative(other.innerRadius - - centerDistance - outerRadius); - const ElemType innerDistance2 = math::ClampNonNegative(innerRadius - - centerDistance - other.outerRadius); + radii.Hi() - other.radii.Hi()); + const ElemType innerDistance1 = math::ClampNonNegative(other.radii.Lo() - + centerDistance - radii.Hi()); + const ElemType innerDistance2 = math::ClampNonNegative(radii.Lo() - + centerDistance - other.radii.Hi()); return outerDistance + innerDistance1 + innerDistance2; } @@ -219,10 +216,10 @@ HollowBallBound::MaxDistance( const OtherVecType& point, typename boost::enable_if >* /* junk */) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return std::numeric_limits::max(); else - return metric->Evaluate(point, center) + outerRadius; + return metric->Evaluate(point, center) + radii.Hi(); } /** @@ -233,11 +230,11 @@ typename HollowBallBound::ElemType HollowBallBound::MaxDistance(const HollowBallBound& other) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return std::numeric_limits::max(); else - return metric->Evaluate(other.center, center) + outerRadius + - other.outerRadius; + return metric->Evaluate(other.center, center) + radii.Hi() + + other.radii.Hi(); } /** @@ -252,15 +249,15 @@ HollowBallBound::RangeDistance( const OtherVecType& point, typename boost::enable_if >* /* junk */) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return math::Range(std::numeric_limits::max(), std::numeric_limits::max()); else { const ElemType dist = metric->Evaluate(center, point); - return math::Range(math::ClampNonNegative(dist - outerRadius) + - math::ClampNonNegative(innerRadius - dist), - dist + outerRadius); + return math::Range(math::ClampNonNegative(dist - radii.Hi()) + + math::ClampNonNegative(radii.Lo() - dist), + dist + radii.Hi()); } } @@ -269,13 +266,13 @@ math::RangeType::ElemType> HollowBallBound::RangeDistance( const HollowBallBound& other) const { - if (outerRadius < 0) + if (radii.Hi() < 0) return math::Range(std::numeric_limits::max(), std::numeric_limits::max()); else { const ElemType dist = metric->Evaluate(center, other.center); - const ElemType sumradius = outerRadius + other.outerRadius; + const ElemType sumradius = radii.Hi() + other.radii.Hi(); return math::Range(MinDistance(other), dist + sumradius); } } @@ -291,11 +288,11 @@ template const HollowBallBound& HollowBallBound::operator|=(const MatType& data) { - if (outerRadius < 0) + if (radii.Hi() < 0) { center = data.col(0); - outerRadius = 0; - innerRadius = 0; + radii.Hi() = 0; + radii.Lo() = 0; // Now iteratively add points. for (size_t i = 0; i < data.n_cols; ++i) @@ -303,13 +300,13 @@ HollowBallBound::operator|=(const MatType& data) const ElemType dist = metric->Evaluate(center, (VecType) data.col(i)); // See if the new point lies outside the bound. - if (dist > outerRadius) + if (dist > radii.Hi()) { // Move towards the new point and increase the radius just enough to // accommodate the new point. const VecType diff = data.col(i) - center; - center += ((dist - outerRadius) / (2 * dist)) * diff; - outerRadius = 0.5 * (dist + outerRadius); + center += ((dist - radii.Hi()) / (2 * dist)) * diff; + radii.Hi() = 0.5 * (dist + radii.Hi()); } } } @@ -321,10 +318,10 @@ HollowBallBound::operator|=(const MatType& data) const ElemType dist = metric->Evaluate(center, data.col(i)); // See if the new point lies outside the bound. - if (dist > outerRadius) - outerRadius = dist; - if (dist < innerRadius) - innerRadius = dist; + if (dist > radii.Hi()) + radii.Hi() = dist; + if (dist < radii.Lo()) + radii.Lo() = dist; } } @@ -338,23 +335,23 @@ template const HollowBallBound& HollowBallBound::operator|=(const HollowBallBound& other) { - if (outerRadius < 0) + if (radii.Hi() < 0) { center = other.center; - outerRadius = other.outerRadius; - innerRadius = other.innerRadius; + radii.Hi() = other.radii.Hi(); + radii.Lo() = other.radii.Lo(); return *this; } const ElemType dist = metric->Evaluate(center, other.center); - if (outerRadius < dist + other.outerRadius) - outerRadius = dist + other.outerRadius; + if (radii.Hi() < dist + other.radii.Hi()) + radii.Hi() = dist + other.radii.Hi(); - const ElemType innerDist = math::ClampNonNegative(other.innerRadius - dist); + const ElemType innerDist = math::ClampNonNegative(other.radii.Lo() - dist); - if (innerRadius > innerDist) - innerRadius = innerDist; + if (radii.Lo() > innerDist) + radii.Lo() = innerDist; return *this; } @@ -367,8 +364,7 @@ void HollowBallBound::Serialize( Archive& ar, const unsigned int /* version */) { - ar & data::CreateNVP(innerRadius, "innerRadius"); - ar & data::CreateNVP(outerRadius, "outerRadius"); + ar & data::CreateNVP(radii, "radii"); ar & data::CreateNVP(center, "center"); if (Archive::is_loading::value) From 601de29c9eff6e65e9976fd146bc347c4408e8a6 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Fri, 22 Jul 2016 18:57:47 +0300 Subject: [PATCH 07/12] Added function for obtaining a number of distinct samples. --- src/mlpack/core/math/random.hpp | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/mlpack/core/math/random.hpp b/src/mlpack/core/math/random.hpp index a61be19567b..1df5b01345e 100644 --- a/src/mlpack/core/math/random.hpp +++ b/src/mlpack/core/math/random.hpp @@ -97,6 +97,44 @@ inline double RandNormal(const double mean, const double variance) return variance * randNormalDist(randGen) + mean; } +/** + * Obtains no more than maxNumSamples distinct samples. Each sample belongs to + * [loInclusive, hiExclusive). + * + * @param loInclusive The lower bound (inclusive). + * @param hiExclusive The high bound (exclusive). + * @param maxNumSamples The maximum number of samples to obtain. + * @param distinctSamples The samples that will be obtained. + */ +inline void ObtainDistinctSamples(const size_t loInclusive, + const size_t hiExclusive, + const size_t maxNumSamples, + arma::uvec& distinctSamples) +{ + const size_t samplesRangeSize = hiExclusive - loInclusive; + + if (samplesRangeSize > maxNumSamples) + { + arma::Col samples; + + samples.zeros(samplesRangeSize); + + for (size_t i = 0; i < maxNumSamples; i++) + samples [ (size_t) math::RandInt(samplesRangeSize) ]++; + + distinctSamples = arma::find(samples > 0); + + if (loInclusive > 0) + distinctSamples += loInclusive; + } + else + { + distinctSamples.set_size(samplesRangeSize); + for (size_t i = 0; i < samplesRangeSize; i++) + distinctSamples[i] = loInclusive + i; + } +} + } // namespace math } // namespace mlpack From eea2ea3d767a457b3755070162d862ea53608356 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Fri, 22 Jul 2016 18:59:50 +0300 Subject: [PATCH 08/12] Various vantage point tree fixes. Replace the TreeTraits::FirstPointIsCentroid variable by a static method. --- src/mlpack/core/tree/ballbound_impl.hpp | 2 - .../binary_space_tree/binary_space_tree.hpp | 4 - .../core/tree/binary_space_tree/traits.hpp | 16 +- .../core/tree/cover_tree/cover_tree.hpp | 4 - src/mlpack/core/tree/cover_tree/traits.hpp | 8 +- src/mlpack/core/tree/hollow_ball_bound.hpp | 45 ++--- .../core/tree/hollow_ball_bound_impl.hpp | 122 ++++++------- src/mlpack/core/tree/hrectbound_impl.hpp | 2 - .../tree/rectangle_tree/rectangle_tree.hpp | 4 - .../core/tree/rectangle_tree/traits.hpp | 18 +- src/mlpack/core/tree/tree_traits.hpp | 9 +- .../dual_tree_traverser.hpp | 2 +- .../dual_tree_traverser_impl.hpp | 16 +- .../single_tree_traverser.hpp | 2 +- .../single_tree_traverser_impl.hpp | 6 +- .../core/tree/vantage_point_tree/traits.hpp | 12 +- .../core/tree/vantage_point_tree/typedef.hpp | 45 ++++- .../vantage_point_split.hpp | 63 ++----- .../vantage_point_split_impl.hpp | 171 ++++++------------ .../vantage_point_tree/vantage_point_tree.hpp | 6 +- .../vantage_point_tree_impl.hpp | 116 ++++++------ .../methods/fastmks/fastmks_rules_impl.hpp | 9 +- src/mlpack/methods/fastmks/fastmks_stat.hpp | 2 +- .../kmeans/dual_tree_kmeans_rules_impl.hpp | 8 +- .../neighbor_search_rules_impl.hpp | 13 +- .../range_search/range_search_rules_impl.hpp | 7 +- src/mlpack/methods/rann/ra_search_impl.hpp | 4 +- .../methods/rann/ra_search_rules_impl.hpp | 46 +++-- src/mlpack/methods/rann/ra_util.cpp | 16 -- src/mlpack/tests/tree_traits_test.cpp | 6 +- src/mlpack/tests/vantage_point_tree_test.cpp | 51 ++++-- 31 files changed, 402 insertions(+), 433 deletions(-) diff --git a/src/mlpack/core/tree/ballbound_impl.hpp b/src/mlpack/core/tree/ballbound_impl.hpp index 885acb5a8e0..989658e4a45 100644 --- a/src/mlpack/core/tree/ballbound_impl.hpp +++ b/src/mlpack/core/tree/ballbound_impl.hpp @@ -3,8 +3,6 @@ * * Bounds that are useful for binary space partitioning trees. * Implementation of BallBound ball bound metric policy class. - * - * @experimental */ #ifndef MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP #define MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp index 81faa7f2df0..d2494683670 100644 --- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp +++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp @@ -456,10 +456,6 @@ class BinarySpaceTree //! Store the center of the bounding region in the given vector. void Center(arma::vec& center) { bound.Center(center); } - //! Returns false: The first point of this node is not the centroid - //! of its bound. - static constexpr bool IsFirstPointCentroid() { return false; } - private: /** * Splits the current node, assigning its left and right children recursively. diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp index 9a81673c19b..87078e825fd 100644 --- a/src/mlpack/core/tree/binary_space_tree/traits.hpp +++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp @@ -38,8 +38,14 @@ class TreeTraits* /* node */ = NULL) + { + return false; + } /** * Points are not contained at multiple levels of the binary space tree. @@ -74,7 +80,13 @@ class TreeTraits* /* node */ = NULL) + { + return false; + } + static const bool HasSelfChildren = false; static const bool RearrangesDataset = true; static const bool BinaryTree = true; diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp index 422389d41cf..82c6a2cd50a 100644 --- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp +++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp @@ -374,10 +374,6 @@ class CoverTree //! Get the instantiated metric. MetricType& Metric() const { return *metric; } - //! Returns true: The first point of this node is the centroid - //! of its bound. - static constexpr bool IsFirstPointCentroid() { return true; } - private: //! Reference to the matrix which this tree is built on. const MatType* dataset; diff --git a/src/mlpack/core/tree/cover_tree/traits.hpp b/src/mlpack/core/tree/cover_tree/traits.hpp index a88c08f302e..64eeeba0e94 100644 --- a/src/mlpack/core/tree/cover_tree/traits.hpp +++ b/src/mlpack/core/tree/cover_tree/traits.hpp @@ -35,8 +35,14 @@ class TreeTraits> /** * Each cover tree node contains only one point, and that point is its * centroid. + * + * @param node The node to check. */ - static const bool FirstPointIsCentroid = true; + static constexpr bool FirstPointIsCentroid(const CoverTree* /* node */ = NULL) + { + return true; + } /** * Cover trees do have self-children. diff --git a/src/mlpack/core/tree/hollow_ball_bound.hpp b/src/mlpack/core/tree/hollow_ball_bound.hpp index 53be28712b2..8dbaa29f865 100644 --- a/src/mlpack/core/tree/hollow_ball_bound.hpp +++ b/src/mlpack/core/tree/hollow_ball_bound.hpp @@ -19,24 +19,22 @@ namespace bound { * specific point (center). MetricType is the custom metric type that defaults * to the Euclidean (L2) distance. * - * @tparam MetricType metric type used in the distance measure. - * @tparam VecType Type of vector (arma::vec or arma::sp_vec or similar). + * @tparam TMetricType metric type used in the distance measure. + * @tparam ElemType Type of element (float or double or similar). */ -template, - typename VecType = arma::vec> +template, + typename ElemType = double> class HollowBallBound { public: - //! The underlying data type. - typedef typename VecType::elem_type ElemType; - //! A public version of the vector type. - typedef VecType Vec; + //! A public version of the metric type. + typedef TMetricType MetricType; private: //! The inner and the outer radii of the bound. math::RangeType radii; //! The center of the ball bound. - VecType center; + arma::Col center; //! The metric used in this bound. MetricType* metric; @@ -67,6 +65,7 @@ class HollowBallBound * @param outerRadius Outer radius of ball bound. * @param center Center of ball bound. */ + template HollowBallBound(const ElemType innerRadius, const ElemType outerRadius, const VecType& center); @@ -94,9 +93,9 @@ class HollowBallBound ElemType& InnerRadius() { return radii.Lo(); } //! Get the center point of the ball. - const VecType& Center() const { return center; } + const arma::Col& Center() const { return center; } //! Modify the center point of the ball. - VecType& Center() { return center; } + arma::Col& Center() { return center; } //! Get the dimensionality of the ball. size_t Dim() const { return center.n_elem; } @@ -113,6 +112,7 @@ class HollowBallBound /** * Determines if a point is within this bound. */ + template bool Contains(const VecType& point) const; /** @@ -125,14 +125,15 @@ class HollowBallBound * * @param center Vector which the centroid will be written to. */ + template void Center(VecType& center) const { center = this->center; } /** * Calculates minimum bound-to-point squared distance. */ - template - ElemType MinDistance(const OtherVecType& point, - typename boost::enable_if>* = 0) + template + ElemType MinDistance(const VecType& point, + typename boost::enable_if>* = 0) const; /** @@ -143,9 +144,9 @@ class HollowBallBound /** * Computes maximum distance. */ - template - ElemType MaxDistance(const OtherVecType& point, - typename boost::enable_if>* = 0) + template + ElemType MaxDistance(const VecType& point, + typename boost::enable_if>* = 0) const; /** @@ -156,10 +157,10 @@ class HollowBallBound /** * Calculates minimum and maximum bound-to-point distance. */ - template + template math::RangeType RangeDistance( - const OtherVecType& other, - typename boost::enable_if>* = 0) const; + const VecType& other, + typename boost::enable_if>* = 0) const; /** * Calculates minimum and maximum bound-to-bound distance. @@ -205,8 +206,8 @@ class HollowBallBound }; //! A specialization of BoundTraits for this bound type. -template -struct BoundTraits> +template +struct BoundTraits> { //! These bounds are potentially loose in some dimensions. const static bool HasTightBounds = false; diff --git a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp index ee61daa6a51..759e5d29d4a 100644 --- a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp +++ b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp @@ -3,8 +3,6 @@ * * Bounds that are useful for binary space partitioning trees. * Implementation of HollowBallBound ball bound metric policy class. - * - * @experimental */ #ifndef MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP #define MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP @@ -12,14 +10,12 @@ // In case it hasn't been included already. #include "hollow_ball_bound.hpp" -#include - namespace mlpack { namespace bound { //! Empty Constructor. -template -HollowBallBound::HollowBallBound() : +template +HollowBallBound::HollowBallBound() : radii(std::numeric_limits::lowest(), std::numeric_limits::lowest()), metric(new MetricType()), @@ -31,8 +27,8 @@ HollowBallBound::HollowBallBound() : * * @param dimension Dimensionality of ball bound. */ -template -HollowBallBound::HollowBallBound(const size_t dimension) : +template +HollowBallBound::HollowBallBound(const size_t dimension) : radii(std::numeric_limits::lowest(), std::numeric_limits::lowest()), center(dimension), @@ -47,8 +43,9 @@ HollowBallBound::HollowBallBound(const size_t dimension) : * @param outerRadius Outer radius of hollow ball bound. * @param center Center of hollow ball bound. */ -template -HollowBallBound:: +template +template +HollowBallBound:: HollowBallBound(const ElemType innerRadius, const ElemType outerRadius, const VecType& center) : @@ -60,8 +57,8 @@ HollowBallBound(const ElemType innerRadius, { /* Nothing to do. */ } //! Copy Constructor. To prevent memory leaks. -template -HollowBallBound::HollowBallBound( +template +HollowBallBound::HollowBallBound( const HollowBallBound& other) : radii(other.radii), center(other.center), @@ -70,8 +67,8 @@ HollowBallBound::HollowBallBound( { /* Nothing to do. */ } //! For the same reason as the copy constructor: to prevent memory leaks. -template -HollowBallBound& HollowBallBound:: +template +HollowBallBound& HollowBallBound:: operator=(const HollowBallBound& other) { radii = other.radii; @@ -83,8 +80,9 @@ operator=(const HollowBallBound& other) } //! Move constructor. -template -HollowBallBound::HollowBallBound(HollowBallBound&& other) : +template +HollowBallBound::HollowBallBound( + HollowBallBound&& other) : radii(other.radii), center(other.center), metric(other.metric), @@ -93,23 +91,23 @@ HollowBallBound::HollowBallBound(HollowBallBound&& other) : // Fix the other bound. other.radii.Hi() = 0.0; other.radii.Lo() = 0.0; - other.center = VecType(); + other.center = arma::Col(); other.metric = NULL; other.ownsMetric = false; } //! Destructor to release allocated memory. -template -HollowBallBound::~HollowBallBound() +template +HollowBallBound::~HollowBallBound() { if (ownsMetric) delete metric; } //! Get the range in a certain dimension. -template -math::RangeType::ElemType> -HollowBallBound::operator[](const size_t i) const +template +math::RangeType HollowBallBound::operator[]( + const size_t i) const { if (radii.Hi() < 0) return math::Range(); @@ -120,8 +118,10 @@ HollowBallBound::operator[](const size_t i) const /** * Determines if a point is within the bound. */ -template -bool HollowBallBound::Contains(const VecType& point) const +template +template +bool HollowBallBound::Contains( + const VecType& point) const { if (radii.Hi() < 0) return false; @@ -135,8 +135,8 @@ bool HollowBallBound::Contains(const VecType& point) const /** * Determines if another bound is within this bound. */ -template -bool HollowBallBound::Contains( +template +bool HollowBallBound::Contains( const HollowBallBound& other) const { if (radii.Hi() < 0) @@ -161,12 +161,11 @@ bool HollowBallBound::Contains( /** * Calculates minimum bound-to-point squared distance. */ -template -template -typename HollowBallBound::ElemType -HollowBallBound::MinDistance( - const OtherVecType& point, - typename boost::enable_if>* /* junk */) const +template +template +ElemType HollowBallBound::MinDistance( + const VecType& point, + typename boost::enable_if>* /* junk */) const { if (radii.Hi() < 0) return std::numeric_limits::max(); @@ -184,9 +183,9 @@ HollowBallBound::MinDistance( /** * Calculates minimum bound-to-bound squared distance. */ -template -typename HollowBallBound::ElemType -HollowBallBound::MinDistance(const HollowBallBound& other) +template +ElemType HollowBallBound::MinDistance( + const HollowBallBound& other) const { if (radii.Hi() < 0 || other.radii.Hi() < 0) @@ -209,12 +208,11 @@ HollowBallBound::MinDistance(const HollowBallBound& other) /** * Computes maximum distance. */ -template -template -typename HollowBallBound::ElemType -HollowBallBound::MaxDistance( - const OtherVecType& point, - typename boost::enable_if >* /* junk */) const +template +template +ElemType HollowBallBound::MaxDistance( + const VecType& point, + typename boost::enable_if >* /* junk */) const { if (radii.Hi() < 0) return std::numeric_limits::max(); @@ -225,9 +223,9 @@ HollowBallBound::MaxDistance( /** * Computes maximum distance. */ -template -typename HollowBallBound::ElemType -HollowBallBound::MaxDistance(const HollowBallBound& other) +template +ElemType HollowBallBound::MaxDistance( + const HollowBallBound& other) const { if (radii.Hi() < 0) @@ -242,12 +240,11 @@ HollowBallBound::MaxDistance(const HollowBallBound& other) * * Example: bound1.MinDistanceSq(other) for minimum squared distance. */ -template -template -math::RangeType::ElemType> -HollowBallBound::RangeDistance( - const OtherVecType& point, - typename boost::enable_if >* /* junk */) const +template +template +math::RangeType HollowBallBound::RangeDistance( + const VecType& point, + typename boost::enable_if >* /* junk */) const { if (radii.Hi() < 0) return math::Range(std::numeric_limits::max(), @@ -261,9 +258,8 @@ HollowBallBound::RangeDistance( } } -template -math::RangeType::ElemType> -HollowBallBound::RangeDistance( +template +math::RangeType HollowBallBound::RangeDistance( const HollowBallBound& other) const { if (radii.Hi() < 0) @@ -283,10 +279,10 @@ HollowBallBound::RangeDistance( * The difference lies in the way we initialize the ball bound. The way we * expand the bound is same. */ -template +template template -const HollowBallBound& -HollowBallBound::operator|=(const MatType& data) +const HollowBallBound& +HollowBallBound::operator|=(const MatType& data) { if (radii.Hi() < 0) { @@ -297,14 +293,14 @@ HollowBallBound::operator|=(const MatType& data) // Now iteratively add points. for (size_t i = 0; i < data.n_cols; ++i) { - const ElemType dist = metric->Evaluate(center, (VecType) data.col(i)); + const ElemType dist = metric->Evaluate(center, data.col(i)); // See if the new point lies outside the bound. if (dist > radii.Hi()) { // Move towards the new point and increase the radius just enough to // accommodate the new point. - const VecType diff = data.col(i) - center; + const arma::Col diff = data.col(i) - center; center += ((dist - radii.Hi()) / (2 * dist)) * diff; radii.Hi() = 0.5 * (dist + radii.Hi()); } @@ -331,9 +327,9 @@ HollowBallBound::operator|=(const MatType& data) /** * Expand the bound to include the given bound. */ -template -const HollowBallBound& -HollowBallBound::operator|=(const HollowBallBound& other) +template +const HollowBallBound& +HollowBallBound::operator|=(const HollowBallBound& other) { if (radii.Hi() < 0) { @@ -358,9 +354,9 @@ HollowBallBound::operator|=(const HollowBallBound& other) //! Serialize the BallBound. -template +template template -void HollowBallBound::Serialize( +void HollowBallBound::Serialize( Archive& ar, const unsigned int /* version */) { diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp index ccfa0265b4f..60aa2b87831 100644 --- a/src/mlpack/core/tree/hrectbound_impl.hpp +++ b/src/mlpack/core/tree/hrectbound_impl.hpp @@ -3,8 +3,6 @@ * * Implementation of hyper-rectangle bound policy class. * Template parameter Power is the metric to use; use 2 for Euclidean (L2). - * - * @experimental */ #ifndef MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP #define MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp index 72b379ce91a..bbdebdadc1f 100644 --- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp +++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp @@ -492,10 +492,6 @@ class RectangleTree //! Returns false: this tree type does not have self children. static bool HasSelfChildren() { return false; } - //! Returns false: The first point of this node is not the centroid - //! of its bound. - static constexpr bool IsFirstPointCentroid() { return false; } - private: /** * Splits the current node, recursing up the tree. diff --git a/src/mlpack/core/tree/rectangle_tree/traits.hpp b/src/mlpack/core/tree/rectangle_tree/traits.hpp index e4e16f90629..e94aa68ad86 100644 --- a/src/mlpack/core/tree/rectangle_tree/traits.hpp +++ b/src/mlpack/core/tree/rectangle_tree/traits.hpp @@ -35,8 +35,15 @@ class TreeTraits* /* node */ = NULL) + { + return false; + } /** * Points are not contained at multiple levels of the R-tree. @@ -83,8 +90,15 @@ class TreeTraits, + DescentType, AuxiliaryInformationType>* /* node */ = NULL) + { + return false; + } /** * Points are not contained at multiple levels of the R-tree. diff --git a/src/mlpack/core/tree/tree_traits.hpp b/src/mlpack/core/tree/tree_traits.hpp index cb005e63954..190244d0e5a 100644 --- a/src/mlpack/core/tree/tree_traits.hpp +++ b/src/mlpack/core/tree/tree_traits.hpp @@ -79,9 +79,14 @@ class TreeTraits static const bool HasOverlappingChildren = true; /** - * This is true if Point(0) is the centroid of the node. + * Returns true if Point(0) is the centroid of the node. + * + * @param node The node to check. */ - static const bool FirstPointIsCentroid = false; + static constexpr bool FirstPointIsCentroid(const TreeType* /* node */ = NULL) + { + return false; + } /** * This is true if the points contained in the first child of a node diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp index 1440b337319..e3fc2f0c44e 100644 --- a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp @@ -20,7 +20,7 @@ template class BoundType, - template + template class SplitType> template class VantagePointTree class BoundType, - template + template class SplitType> template VantagePointTree:: @@ -34,7 +34,7 @@ template class BoundType, - template + template class SplitType> template void VantagePointTree:: @@ -82,7 +82,7 @@ DualTreeTraverser::Traverse( // If the first point of the query node is the centroid, the query node // contains a point. In this case we should run the single tree traverser. - if (queryNode.IsFirstPointCentroid()) + if (queryNode.FirstPointIsCentroid()) { const double pointScore = rule.Score(queryNode.Point(0), referenceNode); ++numScores; @@ -118,7 +118,7 @@ DualTreeTraverser::Traverse( { // If the reference node contains a point we should calculate all // base cases with this point. - if (referenceNode.IsFirstPointCentroid()) + if (referenceNode.FirstPointIsCentroid()) { const size_t queryEnd = queryNode.Begin() + queryNode.Count(); for (size_t query = queryNode.Begin(); query < queryEnd; ++query) @@ -203,7 +203,7 @@ DualTreeTraverser::Traverse( { // If the reference node contains a point we should calculate all // base cases with this point. - if (referenceNode.IsFirstPointCentroid()) + if (referenceNode.FirstPointIsCentroid()) { for (size_t i = 0; i < queryNode.NumDescendants(); ++i) rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0)); @@ -219,7 +219,7 @@ DualTreeTraverser::Traverse( double rightScore; typename RuleType::TraversalInfoType rightInfo; - if (queryNode.IsFirstPointCentroid()) + if (queryNode.FirstPointIsCentroid()) { leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left()); leftInfo = rule.TraversalInfo(); @@ -458,7 +458,7 @@ template class BoundType, - template + template class SplitType> template void VantagePointTree:: @@ -478,7 +478,7 @@ DualTreeTraverser::Traverse( } // If the reference node contains a point we should calculate the base case. - if (referenceNode.IsFirstPointCentroid()) + if (referenceNode.FirstPointIsCentroid()) { rule.BaseCase(queryIndex, referenceNode.Point(0)); numBaseCases++; diff --git a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp index 2c9f4ff7034..bc161657696 100644 --- a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser.hpp @@ -19,7 +19,7 @@ template class BoundType, - template + template class SplitType> template class VantagePointTree class BoundType, - template + template class SplitType> template VantagePointTree:: @@ -33,7 +33,7 @@ template class BoundType, - template + template class SplitType> template void VantagePointTree:: @@ -52,7 +52,7 @@ SingleTreeTraverser::Traverse( } // If the reference node contains a point we should calculate the base case. - if (referenceNode.IsFirstPointCentroid()) + if (referenceNode.FirstPointIsCentroid()) rule.BaseCase(queryIndex, referenceNode.Point(0)); // If either score is DBL_MAX, we do not recurse into that node. diff --git a/src/mlpack/core/tree/vantage_point_tree/traits.hpp b/src/mlpack/core/tree/vantage_point_tree/traits.hpp index 5901e715242..ab844d7ec12 100644 --- a/src/mlpack/core/tree/vantage_point_tree/traits.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/traits.hpp @@ -21,7 +21,7 @@ template class BoundType, - template + template class SplitType> class TreeTraits> @@ -33,9 +33,15 @@ class TreeTraits* node) + { + return node->FirstPointIsCentroid(); + } /** * Points are not contained at multiple levels of the vantage point tree. diff --git a/src/mlpack/core/tree/vantage_point_tree/typedef.hpp b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp index df0454387d1..cadde7fa1eb 100644 --- a/src/mlpack/core/tree/vantage_point_tree/typedef.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp @@ -12,13 +12,48 @@ namespace mlpack { namespace tree { - + +/** + * The vantage point tree is a kind of the binary space tree. In contrast to + * BinarySpaceTree, each left intermediate node of the vantage point tree + * contains a point which is the centroid of the node. When recursively + * splitting nodes, the VPTree class selects a vantage point and splits the node + * according to the distance to this point. Thus, points that are closer to the + * vantage point form the left subtree (and the vantage point is the only point + * that the left node contains). Other points form the right subtree. + * In such a way, the bound of each left node is a ball and the vantage point is + * the centroid of the bound. The bound of each right node is a hollow ball + * centered at the vantage point. + * + * For more information, see the following paper. + * + * @code + * @inproceedings{yianilos1993vptrees, + * author = {Yianilos, Peter N.}, + * title = {Data Structures and Algorithms for Nearest Neighbor Search in + * General Metric Spaces}, + * booktitle = {Proceedings of the Fourth Annual ACM-SIAM Symposium on + * Discrete Algorithms}, + * series = {SODA '93}, + * year = {1993}, + * isbn = {0-89871-313-7}, + * pages = {311--321}, + * numpages = {11}, + * publisher = {Society for Industrial and Applied Mathematics}, + * address = {Philadelphia, PA, USA} + * } + * @endcode + * + * This template typedef satisfies the TreeType policy API. + * + * @see @ref trees, BinarySpaceTree, VantagePointTree, VPTree + */ template using VPTree = VantagePointTree; + StatisticType, + MatType, + bound::HollowBallBound, + VantagePointSplit>; } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp index f587e6592ce..66456b37fdb 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp @@ -13,11 +13,14 @@ namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { -template +template class VantagePointSplit { public: typedef typename MatType::elem_type ElemType; + typedef typename BoundType::MetricType MetricType; /** * Split the node according to the distance to a vantage point. * @@ -55,26 +58,6 @@ class VantagePointSplit size_t& splitCol, std::vector& oldFromNew); private: - /** - * The maximum number of samples used for vantage point estimation and for - * estimation of the median. - */ - static const size_t maxNumSamples = 100; - - template - struct SortStruct - { - size_t point; - ElemType dist; - }; - - template - static bool StructComp(const SortStruct& s1, - const SortStruct& s2) - { - return (s1.dist < s2.dist); - }; - /** * Select the best vantage point i.e. the point with the largest second moment * of the distance from a number of random node points to the vantage point. @@ -92,7 +75,7 @@ class VantagePointSplit * @param mu The median value of distance form the vantage point to * a number of random points. */ - static void SelectVantagePoint(const BoundType& bound, const MatType& data, + static void SelectVantagePoint(const MetricType& metric, const MatType& data, const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu); /** @@ -108,31 +91,6 @@ class VantagePointSplit static void GetDistinctSamples(arma::uvec& distinctSamples, const size_t numSamples, const size_t begin, const size_t upperBound); - /** - * Get the median value of the distance from a certain vantage point to a - * number of samples. - * - * @param bound The bound used for this node. - * @param data The dataset used by the binary space tree. - * @param samples The indices of random samples. - * @param vantagePoint The vantage point. - * @param mu The median value. - */ - static void GetMedian(const BoundType& bound, const MatType& data, - const arma::uvec& samples, const size_t vantagePoint, ElemType& mu); - - /** - * Calculate the second moment of the distance from a certain vantage point to - * a number of random samples. - * - * @param bound The bound used for this node. - * @param data The dataset used by the binary space tree. - * @param samples The indices of random samples. - * @param vantagePoint The vantage point. - */ - static ElemType GetSecondMoment(const BoundType& bound, const MatType& data, - const arma::uvec& samples, const size_t vantagePoint); - /** * This method returns true if a point should be assigned to the left subtree * i.e. the distance from the point to the vantage point is less then @@ -145,8 +103,11 @@ class VantagePointSplit * @param mu The median value. */ template - static bool AssignToLeftSubtree(const BoundType& bound, const MatType& mat, - const VecType& vantagePoint, const size_t point, const ElemType mu); + static bool AssignToLeftSubtree(const MetricType& metric, const MatType& mat, + const VecType& vantagePoint, const size_t point, const ElemType mu) + { + return (metric.Evaluate(vantagePoint, mat.col(point)) < mu); + } /** * Perform split according to the median value and the vantage point. @@ -159,7 +120,7 @@ class VantagePointSplit * @param mu The median value. */ template - static size_t PerformSplit(const BoundType& bound, + static size_t PerformSplit(const MetricType& metric, MatType& data, const size_t begin, const size_t count, @@ -179,7 +140,7 @@ class VantagePointSplit * each new point. */ template - static size_t PerformSplit(const BoundType& bound, + static size_t PerformSplit(const MetricType& metric, MatType& data, const size_t begin, const size_t count, diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp index 03ee1467103..93ec45605e1 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp @@ -14,16 +14,16 @@ namespace mlpack { namespace tree { -template -bool VantagePointSplit:: +template +bool VantagePointSplit:: SplitNode(const BoundType& bound, MatType& data, const size_t begin, const size_t count, size_t& splitCol) { - typename BoundType::ElemType mu; + ElemType mu = 0; size_t vantagePointIndex; // Find the best vantage point - SelectVantagePoint(bound, data, begin, count, vantagePointIndex, mu); + SelectVantagePoint(bound.Metric(), data, begin, count, vantagePointIndex, mu); // All points are equal if (mu == 0) @@ -33,23 +33,23 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, data.swap_cols(begin, vantagePointIndex); arma::Col vantagePoint = data.col(begin); - splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu); + splitCol = PerformSplit(bound.Metric(), data, begin, count, vantagePoint, mu); assert(splitCol > begin); assert(splitCol < begin + count); return true; } -template -bool VantagePointSplit:: +template +bool VantagePointSplit:: SplitNode(const BoundType& bound, MatType& data, const size_t begin, const size_t count, size_t& splitCol, std::vector& oldFromNew) { - ElemType mu; + ElemType mu = 0; size_t vantagePointIndex; // Find the best vantage point - SelectVantagePoint(bound, data, begin, count, vantagePointIndex, mu); + SelectVantagePoint(bound.Metric(), data, begin, count, vantagePointIndex, mu); // All points are equal if (mu == 0) @@ -63,16 +63,18 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin, arma::Col vantagePoint = data.col(begin); - splitCol = PerformSplit(bound, data, begin, count, vantagePoint, mu, oldFromNew); + splitCol = PerformSplit(bound.Metric(), data, begin, count, vantagePoint, mu, + oldFromNew); assert(splitCol > begin); assert(splitCol < begin + count); return true; } -template +template template -size_t VantagePointSplit::PerformSplit(const BoundType& bound, +size_t VantagePointSplit::PerformSplit( + const MetricType& metric, MatType& data, const size_t begin, const size_t count, @@ -88,10 +90,12 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // First half-iteration of the loop is out here because the termination // condition is in the middle. - while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) + while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) && + (left <= right)) left++; - while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) + while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) && + (left <= right) && (right > 0)) right--; while (left <= right) @@ -102,14 +106,16 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // See how many points on the left are correct. When they are correct, // increase the left counter accordingly. When we encounter one that isn't // correct, stop. We will switch it later. - while ((AssignToLeftSubtree(bound, data, vantagePoint, left, mu)) && (left <= right)) + while ((AssignToLeftSubtree(metric, data, vantagePoint, left, mu)) && + (left <= right)) left++; // Now see how many points on the right are correct. When they are correct, // decrease the right counter accordingly. When we encounter one that isn't // correct, stop. We will switch it with the wrong point we found in the // previous loop. - while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right)) + while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) && + (left <= right)) right--; } @@ -118,9 +124,10 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun return left; } -template +template template -size_t VantagePointSplit::PerformSplit(const BoundType& bound, +size_t VantagePointSplit::PerformSplit( + const MetricType& metric, MatType& data, const size_t begin, const size_t count, @@ -138,10 +145,12 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // First half-iteration of the loop is out here because the termination // condition is in the middle. - while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) + while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) && + (left <= right)) left++; - while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right) && (right > 0)) + while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) && + (left <= right) && (right > 0)) right--; while (left <= right) @@ -157,14 +166,16 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun // See how many points on the left are correct. When they are correct, // increase the left counter accordingly. When we encounter one that isn't // correct, stop. We will switch it later. - while (AssignToLeftSubtree(bound, data, vantagePoint, left, mu) && (left <= right)) + while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) && + (left <= right)) left++; // Now see how many points on the right are correct. When they are correct, // decrease the right counter accordingly. When we encounter one that isn't // correct, stop. We will switch it with the wrong point we found in the // previous loop. - while ((!AssignToLeftSubtree(bound, data, vantagePoint, right, mu)) && (left <= right)) + while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) && + (left <= right)) right--; } @@ -173,30 +184,36 @@ size_t VantagePointSplit::PerformSplit(const BoundType& boun return left; } -template -void VantagePointSplit:: -SelectVantagePoint(const BoundType& bound, const MatType& data, +template +void VantagePointSplit:: +SelectVantagePoint(const MetricType& metric, const MatType& data, const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu) { arma::uvec vantagePointCandidates; + arma::Col distances(maxNumSamples); // Get no more than max(maxNumSamples, count) vantage point candidates - GetDistinctSamples(vantagePointCandidates, maxNumSamples, begin, count); + math::ObtainDistinctSamples(begin, begin + count, maxNumSamples, + vantagePointCandidates); ElemType bestSpread = 0; - // Evaluate eache candidate - for (size_t i = 0; i < vantagePointCandidates.n_rows; i++) + arma::uvec samples; + // Evaluate each candidate + for (size_t i = 0; i < vantagePointCandidates.n_elem; i++) { - arma::uvec samples; - - // Get no more than max(maxNumSamples, count) random samples - GetDistinctSamples(samples, maxNumSamples, begin, count); + // Get no more than min(maxNumSamples, count) random samples + math::ObtainDistinctSamples(begin, begin + count, maxNumSamples, samples); // Calculate the second moment of the distance to the vantage point candidate // using these random samples - const ElemType spread = GetSecondMoment(bound, data, samples, - vantagePointCandidates[i]); + distances.set_size(samples.n_elem); + + for (size_t j = 0; j < samples.n_elem; j++) + distances[j] = metric.Evaluate(data.col(vantagePointCandidates[i]), + data.col(samples[j])); + + const ElemType spread = arma::sum(distances % distances) / samples.n_elem; if (spread > bestSpread) { @@ -204,94 +221,12 @@ SelectVantagePoint(const BoundType& bound, const MatType& data, vantagePoint = vantagePointCandidates[i]; // Calculate the median value of the distance from the vantage point candidate // to these samples - GetMedian(bound, data, samples, vantagePoint, mu); - } + mu = arma::median(distances); + } } assert(bestSpread > 0); } -template -void VantagePointSplit:: -GetDistinctSamples(arma::uvec& distinctSamples, const size_t numSamples, - const size_t begin, const size_t upperBound) -{ - if (upperBound > numSamples) - { - arma::Col samples; - - samples.zeros(upperBound); - - for (size_t i = 0; i < numSamples; i++) - samples [ (size_t) math::RandInt(upperBound) ]++; - - distinctSamples = arma::find(samples > 0); - - distinctSamples += begin; - } - else - { - // The node contains less points than requested - distinctSamples.set_size(upperBound); - for (size_t i = 0; i < upperBound; i++) - distinctSamples[i] = begin + i; - } -} - -template -void VantagePointSplit:: -GetMedian(const BoundType& bound, const MatType& data, - const arma::uvec& samples, const size_t vantagePoint, ElemType& mu) -{ - std::vector> sorted(samples.n_rows); - - for (size_t i = 0; i < samples.n_rows; i++) - { - sorted[i].point = samples[i]; - sorted[i].dist = bound.Metric().Evaluate(data.col(vantagePoint), - data.col(samples[i])); - } - - // Sort samples according to the distance to the vantage point - std::sort(sorted.begin(), sorted.end(), StructComp); - - // Get the midian value - mu = bound.Metric().Evaluate(data.col(vantagePoint), - data.col(sorted[sorted.size() / 2].point)); -} - -template -typename MatType::elem_type VantagePointSplit:: -GetSecondMoment(const BoundType& bound, const MatType& data, - const arma::uvec& samples, const size_t vantagePoint) -{ - ElemType moment = 0; - - for (size_t i = 0; i < samples.size(); i++) - { - const ElemType dist = - bound.Metric().Evaluate(data.col(vantagePoint), data.col(samples[i])); - - moment += dist * dist; - } - - moment /= samples.size(); - - return moment; -} - -template -template -bool VantagePointSplit:: -AssignToLeftSubtree(const BoundType& bound, const MatType& mat, - const VecType& vantagePoint, const size_t point, const ElemType mu) -{ - // Return true if the point is close to the vantage point - if (bound.Metric().Evaluate(vantagePoint, mat.col(point)) < mu) - return true; - - return false; -} - } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp index 2091dc9780d..5737e6eaaa9 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp @@ -41,7 +41,7 @@ template class BoundType = bound::HollowBallBound, - template + template class SplitType = VantagePointSplit> class VantagePointTree { @@ -63,7 +63,7 @@ class VantagePointTree //! The number of points of the dataset contained in this node. size_t count; //! The bound object for this node. - BoundType bound; + BoundType bound; //! Any extra data contained in the node. StatisticType stat; //! The distance from the centroid of this node to the centroid of the parent. @@ -460,7 +460,7 @@ class VantagePointTree void Center(arma::vec& center) { bound.Center(center); } //! Indicates that the first point of this node is the centroid of its bound. - bool IsFirstPointCentroid() const { return firstPointIsCentroid; } + bool FirstPointIsCentroid() const { return firstPointIsCentroid; } private: /** diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp index 087b5d9645a..587305d9b20 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp @@ -17,7 +17,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -45,7 +45,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -79,7 +79,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -119,7 +119,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree(MatType&& data, const size_t maxLeafSize) : @@ -145,7 +145,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -179,7 +179,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -219,7 +219,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -249,7 +249,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -284,7 +284,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -329,7 +329,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree( @@ -389,7 +389,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree(VantagePointTree&& other) : @@ -425,7 +425,7 @@ template class BoundType, - template + template class SplitType> template VantagePointTree:: @@ -448,7 +448,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: ~VantagePointTree() @@ -465,7 +465,7 @@ template class BoundType, - template + template class SplitType> inline bool VantagePointTree::IsLeaf() const @@ -480,7 +480,7 @@ template class BoundType, - template + template class SplitType> inline size_t VantagePointTree::NumChildren() const @@ -501,7 +501,7 @@ template class BoundType, - template + template class SplitType> inline typename VantagePointTree class BoundType, - template + template class SplitType> inline typename VantagePointTree class BoundType, - template + template class SplitType> inline typename VantagePointTree::MinimumBoundDistance() const { - return bound.MinWidth() / 2.0; + return bound.OuterRadius(); } /** @@ -561,7 +561,7 @@ template class BoundType, - template + template class SplitType> inline VantagePointTree& @@ -581,17 +581,19 @@ template class BoundType, - template + template class SplitType> inline size_t VantagePointTree::NumPoints() const { - // Each intermediate node contains exactly one point. - if (left && parent) + // Each left intermediate node contains exactly one point. + // Each right intermediate node contains no points. + if (firstPointIsCentroid && left) return 1; - else if(!parent) - return 0; + else if(left) + return 0; + // This is a leaf node. return count; } @@ -602,7 +604,7 @@ template class BoundType, - template + template class SplitType> inline size_t VantagePointTree::NumDescendants() const @@ -617,7 +619,7 @@ template class BoundType, - template + template class SplitType> inline size_t VantagePointTree::Descendant(const size_t index) const @@ -632,7 +634,7 @@ template class BoundType, - template + template class SplitType> inline size_t VantagePointTree::Point(const size_t index) const @@ -644,7 +646,7 @@ template class BoundType, - template + template class SplitType> void VantagePointTree:: SplitNode(const size_t maxLeafSize, @@ -689,7 +691,7 @@ void VantagePointTree: size_t splitBegin = begin; size_t splitCount = count; - if (IsFirstPointCentroid()) + if (FirstPointIsCentroid()) { splitBegin = begin + 1; splitCount = count - 1; @@ -707,8 +709,8 @@ void VantagePointTree: // by calling their constructors (which perform this splitting process). left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, splitter, maxLeafSize, true); - right = new VantagePointTree(this, splitCol, splitBegin + splitCount - splitCol, - splitter, maxLeafSize, false); + right = new VantagePointTree(this, splitCol, + splitBegin + splitCount - splitCol, splitter, maxLeafSize, false); // Calculate parent distances for those two nodes. arma::vec center, leftCenter, rightCenter; @@ -728,7 +730,7 @@ template class BoundType, - template + template class SplitType> void VantagePointTree:: SplitNode(std::vector& oldFromNew, @@ -776,14 +778,14 @@ SplitNode(std::vector& oldFromNew, size_t splitBegin = begin; size_t splitCount = count; - if (IsFirstPointCentroid()) + if (FirstPointIsCentroid()) { splitBegin = begin + 1; splitCount = count - 1; } - const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, splitCol, - oldFromNew); + const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, + splitCol, oldFromNew); // The node may not be always split. For instance, if all the points are the // same, we can't split them. @@ -792,24 +794,34 @@ SplitNode(std::vector& oldFromNew, // Now that we know the split column, we will recursively split the children // by calling their constructors (which perform this splitting process). - left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, oldFromNew, - splitter, maxLeafSize, true); - right = new VantagePointTree(this, splitCol, splitBegin + splitCount - splitCol, - oldFromNew, splitter, maxLeafSize, false); + left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, + oldFromNew, splitter, maxLeafSize, true); + right = new VantagePointTree(this, splitCol, + splitBegin + splitCount - splitCol, oldFromNew, splitter, maxLeafSize, + false); // Calculate parent distances for those two nodes. - arma::vec center, leftCenter, rightCenter; - Center(center); - left->Center(leftCenter); - right->Center(rightCenter); + ElemType parentDistance; + if (firstPointIsCentroid) + { + assert(left->firstPointIsCentroid == true); - const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter); - const ElemType rightParentDistance = MetricType::Evaluate(center, - rightCenter); + parentDistance = MetricType::Evaluate(dataset->col(begin), + dataset->col(left->begin)); + } + else + { + arma::vec center; + Center(center); - left->ParentDistance() = leftParentDistance; - right->ParentDistance() = rightParentDistance; + assert(left->firstPointIsCentroid == true); + + parentDistance = MetricType::Evaluate(center, dataset->col(left->begin)); + } + + left->ParentDistance() = parentDistance; + right->ParentDistance() = parentDistance; } // Default constructor (private), for boost::serialization. @@ -817,7 +829,7 @@ template class BoundType, - template + template class SplitType> VantagePointTree:: VantagePointTree() : @@ -841,7 +853,7 @@ template class BoundType, - template + template class SplitType> template void VantagePointTree:: diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp index 27abacf971d..eaeff5fac96 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp @@ -58,7 +58,7 @@ double FastMKSRules::BaseCase( // cover trees, the kernel evaluation between the two centroid points already // happened. So we don't need to do it. Note that this optimizes out if the // first conditional is false (its result is known at compile time). - if (tree::TreeTraits::FirstPointIsCentroid) + if (tree::TreeTraits::FirstPointIsCentroid()) { if ((queryIndex == lastQueryIndex) && (referenceIndex == lastReferenceIndex)) @@ -74,7 +74,7 @@ double FastMKSRules::BaseCase( referenceSet.col(referenceIndex)); // Update the last kernel value, if we need to. - if (tree::TreeTraits::FirstPointIsCentroid) + if (tree::TreeTraits::FirstPointIsCentroid()) lastKernel = kernelEval; // If the reference and query sets are identical, we still need to compute the @@ -141,7 +141,7 @@ double FastMKSRules::Score(const size_t queryIndex, // centroid or, if the centroid is a point, use that. ++scores; double kernelEval; - if (tree::TreeTraits::FirstPointIsCentroid) + if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // Could it be that this kernel evaluation has already been calculated? if (tree::TreeTraits::HasSelfChildren && @@ -295,7 +295,8 @@ double FastMKSRules::Score(TreeType& queryNode, // We were unable to perform a parent-child or parent-parent prune, so now we // must calculate kernel evaluation, if necessary. double kernelEval = 0.0; - if (tree::TreeTraits::FirstPointIsCentroid) + if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // For this type of tree, we may have already calculated the base case in // the parents. diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp index 318e5c9b621..2086d2d5f82 100644 --- a/src/mlpack/methods/fastmks/fastmks_stat.hpp +++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp @@ -44,7 +44,7 @@ class FastMKSStat lastKernelNode(NULL) { // Do we have to calculate the centroid? - if (tree::TreeTraits::FirstPointIsCentroid) + if (tree::TreeTraits::FirstPointIsCentroid(&node)) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp index 5aa8d2f0aef..5622e4474bd 100644 --- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp +++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp @@ -131,7 +131,8 @@ inline double DualTreeKMeansRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -207,8 +208,9 @@ inline double DualTreeKMeansRules::Score( // Now, check if we can prune. if (adjustedScore > queryNode.Stat().UpperBound()) { - if (!(queryNode.IsFirstPointCentroid() && - referenceNode.IsFirstPointCentroid() && score == 0.0)) + if (!(tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && + score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 7ae89de9fbf..5f3d2386878 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -85,7 +85,7 @@ inline double NeighborSearchRules::Score( { ++scores; // Count number of Score() calls. double distance; - if (referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // The first point in the tree is the centroid. So we can then calculate // the base case between that and the query point. @@ -160,7 +160,8 @@ inline double NeighborSearchRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -237,8 +238,9 @@ inline double NeighborSearchRules::Score( // Can we prune? if (!SortPolicy::IsBetter(adjustedScore, bestDistance)) { - if (!(queryNode.IsFirstPointCentroid() && - referenceNode.IsFirstPointCentroid() && score == 0.0)) + if (!(tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && + score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only @@ -248,7 +250,8 @@ inline double NeighborSearchRules::Score( } double distance; - if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // The first point in the node is the centroid, so we can calculate the // distance between the two points using BaseCase() and then find the diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp index 8dd103f5054..2b24a70dca4 100644 --- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp +++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp @@ -79,7 +79,7 @@ double RangeSearchRules::Score(const size_t queryIndex, // object. math::Range distances; - if (referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // In this situation, we calculate the base case. So we should check to be // sure we haven't already done that. @@ -147,7 +147,8 @@ double RangeSearchRules::Score(TreeType& queryNode, TreeType& referenceNode) { math::Range distances; - if (queryNode.IsFirstPointCentroid() && referenceNode.IsFirstPointCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && + tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) { // It is possible that the base case has already been calculated. double baseCase = 0.0; @@ -224,7 +225,7 @@ void RangeSearchRules::AddResult(const size_t queryIndex, // called, so if the base case has already been calculated, then we must avoid // adding that point to the results again. size_t baseCaseMod = 0; - if (referenceNode.IsFirstPointCentroid() && + if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && (queryIndex == lastQueryIndex) && (referenceNode.Point(0) == lastReferenceIndex)) { diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp index aa8daa5010d..88bccca0e9e 100644 --- a/src/mlpack/methods/rann/ra_search_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_impl.hpp @@ -369,7 +369,7 @@ Search(const MatType& querySet, const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols, k, tau, alpha); arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols, + math::ObtainDistinctSamples(0, referenceSet->n_cols, numSamples, distinctSamples); // Run the base case on each combination of query point and sampled @@ -597,7 +597,7 @@ void RASearch::Search( const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols, k, tau, alpha); arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols, + math::ObtainDistinctSamples(0, referenceSet->n_cols, numSamples, distinctSamples); // The naive brute-force solution. diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp index 2071de1b98e..22739405d3a 100644 --- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp @@ -71,10 +71,10 @@ RASearchRules(const arma::mat& referenceSet, if (naive) // No tree traversal; just do naive sampling here. { // Sample enough points. + arma::uvec distinctSamples; for (size_t i = 0; i < querySet.n_cols; ++i) { - arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(numSamplesReqd, n, distinctSamples); + math::ObtainDistinctSamples(0, n, numSamplesReqd, distinctSamples); for (size_t j = 0; j < distinctSamples.n_elem; j++) BaseCase(i, (size_t) distinctSamples[j]); } @@ -178,9 +178,8 @@ inline double RASearchRules::Score( // Then samplesReqd <= singleSampleLimit. // Hence, approximate the node by sampling enough number of points. arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), - distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t i = 0; i < distinctSamples.n_elem; i++) // The counting of the samples are done in the 'BaseCase' function // so no book-keeping is required here. @@ -195,9 +194,8 @@ inline double RASearchRules::Score( { // Approximate node by sampling enough number of points. arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), - distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t i = 0; i < distinctSamples.n_elem; i++) // The counting of the samples are done in the 'BaseCase' function // so no book-keeping is required here. @@ -284,8 +282,8 @@ Rescore(const size_t queryIndex, // Then, samplesReqd <= singleSampleLimit. Hence, approximate the node // by sampling enough number of points. arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t i = 0; i < distinctSamples.n_elem; i++) // The counting of the samples are done in the 'BaseCase' function so // no book-keeping is required here. @@ -300,8 +298,8 @@ Rescore(const size_t queryIndex, { // Approximate node by sampling enough points. arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t i = 0; i < distinctSamples.n_elem; i++) // The counting of the samples are done in the 'BaseCase' function // so no book-keeping is required here. @@ -483,12 +481,12 @@ inline double RASearchRules::Score( { // Then samplesReqd <= singleSampleLimit. Hence, approximate node by // sampling enough number of points for every query in the query node. + arma::uvec distinctSamples; for (size_t i = 0; i < queryNode.NumDescendants(); ++i) { const size_t queryIndex = queryNode.Descendant(i); - arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t j = 0; j < distinctSamples.n_elem; j++) // The counting of the samples are done in the 'BaseCase' function // so no book-keeping is required here. @@ -513,12 +511,12 @@ inline double RASearchRules::Score( { // Approximate node by sampling enough number of points for every // query in the query node. + arma::uvec distinctSamples; for (size_t i = 0; i < queryNode.NumDescendants(); ++i) { const size_t queryIndex = queryNode.Descendant(i); - arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t j = 0; j < distinctSamples.n_elem; j++) // The counting of the samples are done in the 'BaseCase' // function so no book-keeping is required here. @@ -688,12 +686,12 @@ Rescore(TreeType& queryNode, { // then samplesReqd <= singleSampleLimit. Hence, approximate the node // by sampling enough points for every query in the query node. + arma::uvec distinctSamples; for (size_t i = 0; i < queryNode.NumDescendants(); ++i) { const size_t queryIndex = queryNode.Descendant(i); - arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t j = 0; j < distinctSamples.n_elem; j++) // The counting of the samples are done in the 'BaseCase' // function so no book-keeping is required here. @@ -717,12 +715,12 @@ Rescore(TreeType& queryNode, { // Approximate node by sampling enough points for every query in the // query node. + arma::uvec distinctSamples; for (size_t i = 0; i < queryNode.NumDescendants(); ++i) { const size_t queryIndex = queryNode.Descendant(i); - arma::uvec distinctSamples; - RAUtil::ObtainDistinctSamples(samplesReqd, - referenceNode.NumDescendants(), distinctSamples); + math::ObtainDistinctSamples(0, referenceNode.NumDescendants(), + samplesReqd, distinctSamples); for (size_t j = 0; j < distinctSamples.n_elem; j++) // The counting of the samples are done in BaseCase() so no // book-keeping is required here. diff --git a/src/mlpack/methods/rann/ra_util.cpp b/src/mlpack/methods/rann/ra_util.cpp index 579e03c2b1c..c85edee0299 100644 --- a/src/mlpack/methods/rann/ra_util.cpp +++ b/src/mlpack/methods/rann/ra_util.cpp @@ -164,19 +164,3 @@ double mlpack::neighbor::RAUtil::SuccessProbability(const size_t n, return sum; } // For k > 1. } - -void mlpack::neighbor::RAUtil::ObtainDistinctSamples( - const size_t numSamples, - const size_t rangeUpperBound, - arma::uvec& distinctSamples) -{ - // Keep track of the points that are sampled. - arma::Col sampledPoints; - sampledPoints.zeros(rangeUpperBound); - - for (size_t i = 0; i < numSamples; i++) - sampledPoints[(size_t) math::RandInt(rangeUpperBound)]++; - - distinctSamples = arma::find(sampledPoints > 0); - return; -} diff --git a/src/mlpack/tests/tree_traits_test.cpp b/src/mlpack/tests/tree_traits_test.cpp index cf0395d19f5..16ecf70e662 100644 --- a/src/mlpack/tests/tree_traits_test.cpp +++ b/src/mlpack/tests/tree_traits_test.cpp @@ -36,7 +36,7 @@ BOOST_AUTO_TEST_CASE(DefaultsTraitsTest) BOOST_REQUIRE_EQUAL(b, true); b = TreeTraits::HasSelfChildren; BOOST_REQUIRE_EQUAL(b, false); - b = TreeTraits::FirstPointIsCentroid; + b = TreeTraits::FirstPointIsCentroid(); BOOST_REQUIRE_EQUAL(b, false); b = TreeTraits::RearrangesDataset; BOOST_REQUIRE_EQUAL(b, false); @@ -58,7 +58,7 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeTraitsTest) BOOST_REQUIRE_EQUAL(b, false); // The first point is not the centroid. - b = TreeTraits::FirstPointIsCentroid; + b = TreeTraits::FirstPointIsCentroid(); BOOST_REQUIRE_EQUAL(b, false); // The dataset gets rearranged at build time. @@ -82,7 +82,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTraitsTest) BOOST_REQUIRE_EQUAL(b, true); // The first point is the center of the node. - b = TreeTraits>::FirstPointIsCentroid; + b = TreeTraits>::FirstPointIsCentroid(); BOOST_REQUIRE_EQUAL(b, true); b = TreeTraits>::RearrangesDataset; diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp index 04b15553d9e..9b9da8c6244 100644 --- a/src/mlpack/tests/vantage_point_tree_test.cpp +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -39,21 +39,24 @@ BOOST_AUTO_TEST_CASE(VPTreeTraitsTest) BOOST_AUTO_TEST_CASE(HollowBallBoundTest) { - HollowBallBound b(2, 4, "1.0 2.0 3.0 4.0 5.0"); + HollowBallBound b(2, 4, arma::vec("1.0 2.0 3.0 4.0 5.0")); - BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 7.0 5.0"), true); + BOOST_REQUIRE_EQUAL(b.Contains(arma::vec("1.0 2.0 3.0 7.0 5.0")), true); - BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 9.0 5.0"), false); + BOOST_REQUIRE_EQUAL(b.Contains(arma::vec("1.0 2.0 3.0 9.0 5.0")), false); - BOOST_REQUIRE_EQUAL(b.Contains("1.0 2.0 3.0 5.0 5.0"), false); + BOOST_REQUIRE_EQUAL(b.Contains(arma::vec("1.0 2.0 3.0 5.0 5.0")), false); - HollowBallBound b2(0.5, 1, "1.0 2.0 3.0 7.0 5.0"); + HollowBallBound b2(0.5, 1, + arma::vec("1.0 2.0 3.0 7.0 5.0")); BOOST_REQUIRE_EQUAL(b.Contains(b2), true); - b2 = HollowBallBound(2.5, 3.5, "1.0 2.0 3.0 4.5 5.0"); + b2 = HollowBallBound(2.5, 3.5, + arma::vec("1.0 2.0 3.0 4.5 5.0")); BOOST_REQUIRE_EQUAL(b.Contains(b2), true); - b2 = HollowBallBound(2.0, 3.5, "1.0 2.0 3.0 4.5 5.0"); + b2 = HollowBallBound(2.0, 3.5, + arma::vec("1.0 2.0 3.0 4.5 5.0")); BOOST_REQUIRE_EQUAL(b.Contains(b2), false); BOOST_REQUIRE_CLOSE(b.MinDistance(arma::vec("1.0 2.0 8.0 4.0 5.0")), 1.0, @@ -76,40 +79,50 @@ BOOST_AUTO_TEST_CASE(HollowBallBoundTest) BOOST_REQUIRE_CLOSE(b.MaxDistance(arma::vec("1.0 2.0 3.0 4.0 5.0")), 4.0, 1e-5); - b2 = HollowBallBound(3, 4, "1.0 2.0 3.0 5.0 5.0"); + b2 = HollowBallBound(3, 4, + arma::vec("1.0 2.0 3.0 5.0 5.0")); BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); - b2 = HollowBallBound(1, 2, "1.0 2.0 3.0 4.0 5.0"); + b2 = HollowBallBound(1, 2, + arma::vec("1.0 2.0 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); - b2 = HollowBallBound(0.5, 1.0, "1.0 2.5 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 1.0, + arma::vec("1.0 2.5 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.5, 1e-5); - b2 = HollowBallBound(0.5, 1.0, "1.0 8.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 1.0, + arma::vec("1.0 8.0 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 1.0, 1e-5); - b2 = HollowBallBound(0.5, 2.0, "1.0 8.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 2.0, + arma::vec("1.0 8.0 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MinDistance(b2), 0.0, 1e-5); - b2 = HollowBallBound(0.5, 2.0, "1.0 8.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 2.0, + arma::vec("1.0 8.0 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MaxDistance(b2), 12.0, 1e-5); - b2 = HollowBallBound(0.5, 2.0, "1.0 3.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 2.0, + arma::vec("1.0 3.0 3.0 4.0 5.0")); BOOST_REQUIRE_CLOSE(b.MaxDistance(b2), 7.0, 1e-5); HollowBallBound b1 = b; - b2 = HollowBallBound(1.0, 2.0, "1.0 2.5 3.0 4.0 5.0"); + b2 = HollowBallBound(1.0, 2.0, + arma::vec("1.0 2.5 3.0 4.0 5.0")); b1 |= b2; BOOST_REQUIRE_CLOSE(b1.InnerRadius(), 0.5, 1e-5); b1 = b; - b2 = HollowBallBound(0.5, 2.0, "1.0 3.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 2.0, + arma::vec("1.0 3.0 3.0 4.0 5.0")); b1 |= b2; BOOST_REQUIRE_CLOSE(b1.InnerRadius(), 0.0, 1e-5); b1 = b; - b2 = HollowBallBound(0.5, 4.0, "1.0 3.0 3.0 4.0 5.0"); + b2 = HollowBallBound(0.5, 4.0, + arma::vec("1.0 3.0 3.0 4.0 5.0")); b1 |= b2; BOOST_REQUIRE_CLOSE(b1.OuterRadius(), 5.0, 1e-5); } @@ -127,7 +140,7 @@ void CheckBound(TreeType& tree) { if (!tree.Parent()) BOOST_REQUIRE_EQUAL(tree.NumPoints(), 0); - else if (tree.IsFirstPointCentroid()) + else if (tree.FirstPointIsCentroid()) { BOOST_REQUIRE_EQUAL(tree.NumPoints(), 1); BOOST_REQUIRE_EQUAL(true, @@ -181,7 +194,7 @@ void CheckSplit(TreeType& tree) BOOST_REQUIRE_LE(maxDist, dist); } - if (tree.IsFirstPointCentroid()) + if (tree.FirstPointIsCentroid()) { for (size_t k = 0; k < tree.Bound().Dim(); k++) BOOST_REQUIRE_EQUAL(tree.Bound().Center()[k], From cc326f35d6ed41968ee1948bcaad38c364bab84d Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Mon, 25 Jul 2016 01:26:33 +0300 Subject: [PATCH 09/12] Added TreeTraits::FirstSiblingFirstPointIsCentroid. Vantage points moved to a saparate tree node. Thus, each intermediate node of the vantage point tree has three children. --- .../core/tree/binary_space_tree/traits.hpp | 23 +- src/mlpack/core/tree/cover_tree/traits.hpp | 14 +- .../core/tree/rectangle_tree/traits.hpp | 30 +- src/mlpack/core/tree/tree_traits.hpp | 15 +- .../dual_tree_traverser.hpp | 10 +- .../dual_tree_traverser_impl.hpp | 518 ++++-------------- .../single_tree_traverser_impl.hpp | 56 +- .../core/tree/vantage_point_tree/traits.hpp | 16 +- .../core/tree/vantage_point_tree/typedef.hpp | 43 +- .../vantage_point_split.hpp | 4 +- .../vantage_point_split_impl.hpp | 2 +- .../vantage_point_tree/vantage_point_tree.hpp | 66 ++- .../vantage_point_tree_impl.hpp | 320 ++++++----- .../methods/fastmks/fastmks_rules_impl.hpp | 9 +- src/mlpack/methods/fastmks/fastmks_stat.hpp | 2 +- .../kmeans/dual_tree_kmeans_rules_impl.hpp | 7 +- .../neighbor_search_rules_impl.hpp | 59 +- .../range_search/range_search_rules_impl.hpp | 45 +- src/mlpack/tests/tree_traits_test.cpp | 6 +- src/mlpack/tests/vantage_point_tree_test.cpp | 62 ++- 20 files changed, 535 insertions(+), 772 deletions(-) diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp index 87078e825fd..611447f090f 100644 --- a/src/mlpack/core/tree/binary_space_tree/traits.hpp +++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp @@ -38,14 +38,14 @@ class TreeTraits* /* node */ = NULL) - { - return false; - } + static const bool FirstPointIsCentroid = false; + + /** + * There is no guarantee that the first point of the first sibling is the + * centroid of other siblings. + */ + static const bool FirstSiblingFirstPointIsCentroid = false; /** * Points are not contained at multiple levels of the binary space tree. @@ -80,13 +80,8 @@ class TreeTraits* /* node */ = NULL) - { - return false; - } - + static const bool FirstPointIsCentroid = false; + static const bool FirstSiblingFirstPointIsCentroid = false; static const bool HasSelfChildren = false; static const bool RearrangesDataset = true; static const bool BinaryTree = true; diff --git a/src/mlpack/core/tree/cover_tree/traits.hpp b/src/mlpack/core/tree/cover_tree/traits.hpp index 64eeeba0e94..b28b3b3a667 100644 --- a/src/mlpack/core/tree/cover_tree/traits.hpp +++ b/src/mlpack/core/tree/cover_tree/traits.hpp @@ -35,14 +35,14 @@ class TreeTraits> /** * Each cover tree node contains only one point, and that point is its * centroid. - * - * @param node The node to check. */ - static constexpr bool FirstPointIsCentroid(const CoverTree* /* node */ = NULL) - { - return true; - } + static const bool FirstPointIsCentroid = true; + + /** + * There is no guarantee that the first point of the first sibling is the + * centroid of other siblings. + */ + static const bool FirstSiblingFirstPointIsCentroid = false; /** * Cover trees do have self-children. diff --git a/src/mlpack/core/tree/rectangle_tree/traits.hpp b/src/mlpack/core/tree/rectangle_tree/traits.hpp index e94aa68ad86..6f49a0cc85e 100644 --- a/src/mlpack/core/tree/rectangle_tree/traits.hpp +++ b/src/mlpack/core/tree/rectangle_tree/traits.hpp @@ -35,15 +35,14 @@ class TreeTraits* /* node */ = NULL) - { - return false; - } + static const bool FirstPointIsCentroid = false; + + /** + * There is no guarantee that the first point of the first sibling is the + * centroid of a node. + */ + static const bool FirstSiblingFirstPointIsCentroid = false; /** * Points are not contained at multiple levels of the R-tree. @@ -90,15 +89,14 @@ class TreeTraits, - DescentType, AuxiliaryInformationType>* /* node */ = NULL) - { - return false; - } + static const bool FirstPointIsCentroid = false; + + /** + * There is no guarantee that the first point of the first sibling is the + * centroid of other siblings. + */ + static const bool FirstSiblingFirstPointIsCentroid = false; /** * Points are not contained at multiple levels of the R-tree. diff --git a/src/mlpack/core/tree/tree_traits.hpp b/src/mlpack/core/tree/tree_traits.hpp index 190244d0e5a..4f6a0b00a4a 100644 --- a/src/mlpack/core/tree/tree_traits.hpp +++ b/src/mlpack/core/tree/tree_traits.hpp @@ -79,14 +79,15 @@ class TreeTraits static const bool HasOverlappingChildren = true; /** - * Returns true if Point(0) is the centroid of the node. - * - * @param node The node to check. + * This is true if the first point of each node is the centroid of its bound. */ - static constexpr bool FirstPointIsCentroid(const TreeType* /* node */ = NULL) - { - return false; - } + static const bool FirstPointIsCentroid = false; + + /** + * This is true if the first point of the first sibling is the centroid of + * other siblings. + */ + static const bool FirstSiblingFirstPointIsCentroid = false; /** * This is true if the points contained in the first child of a node diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp index e3fc2f0c44e..6c462494f2d 100644 --- a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser.hpp @@ -77,18 +77,14 @@ class VantagePointTree::Traverse( VantagePointTree& referenceNode) { + typedef typename RuleType::TraversalInfoType TravInfo; + // This tree traverser use TraversalInfoType slightly different than other + // traversers. All children at one level use the same traversal information + // since their centroids are equal. + // Increment the visit counter. ++numVisited; - // Store the current traversal info. - traversalInfo = rule.TraversalInfo(); - // If both are leaves, we must evaluate the base case. if (queryNode.IsLeaf() && referenceNode.IsLeaf()) { + TravInfo traversalInfo = rule.TraversalInfo(); + + if (traversalInfo.LastQueryNode() == &queryNode && + traversalInfo.LastReferenceNode() == &referenceNode) + return; // We have already calculated this base case. + // Loop through each of the points in each node. const size_t queryEnd = queryNode.Begin() + queryNode.Count(); const size_t refEnd = referenceNode.Begin() + referenceNode.Count(); for (size_t query = queryNode.Begin(); query < queryEnd; ++query) { // See if we need to investigate this point (this function should be - // implemented for the single-tree recursion too). Restore the traversal - // information first. - rule.TraversalInfo() = traversalInfo; + // implemented for the single-tree recursion too). const double childScore = rule.Score(query, referenceNode); if (childScore == DBL_MAX) @@ -80,377 +86,79 @@ DualTreeTraverser::Traverse( // We have to recurse down the query node. In this case the recursion order // does not matter. - // If the first point of the query node is the centroid, the query node - // contains a point. In this case we should run the single tree traverser. - if (queryNode.FirstPointIsCentroid()) - { - const double pointScore = rule.Score(queryNode.Point(0), referenceNode); - ++numScores; + const double pointScore = rule.Score(*queryNode.Central(), referenceNode); + ++numScores; - if (pointScore != DBL_MAX) - Traverse(queryNode.Point(0), referenceNode); - else - ++numPrunes; + // The traversal information is the same for all children. + TravInfo traversalInfo = rule.TraversalInfo(); - // Before recursing, we have to set the traversal information correctly. - rule.TraversalInfo() = traversalInfo; - } + if (pointScore != DBL_MAX) + Traverse(*queryNode.Central(), referenceNode); + else + ++numPrunes; - const double leftScore = rule.Score(*queryNode.Left(), referenceNode); + // Before recursing, we have to set the traversal information correctly. + rule.TraversalInfo() = traversalInfo; + + const double innerScore = rule.Score(*queryNode.Inner(), referenceNode); ++numScores; - if (leftScore != DBL_MAX) - Traverse(*queryNode.Left(), referenceNode); + if (innerScore != DBL_MAX) + Traverse(*queryNode.Inner(), referenceNode); else ++numPrunes; // Before recursing, we have to set the traversal information correctly. rule.TraversalInfo() = traversalInfo; - const double rightScore = rule.Score(*queryNode.Right(), referenceNode); + const double outerScore = rule.Score(*queryNode.Outer(), referenceNode); ++numScores; - if (rightScore != DBL_MAX) - Traverse(*queryNode.Right(), referenceNode); + if (outerScore != DBL_MAX) + Traverse(*queryNode.Outer(), referenceNode); else ++numPrunes; } else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf())) { - // If the reference node contains a point we should calculate all - // base cases with this point. - if (referenceNode.FirstPointIsCentroid()) - { - const size_t queryEnd = queryNode.Begin() + queryNode.Count(); - for (size_t query = queryNode.Begin(); query < queryEnd; ++query) - rule.BaseCase(query, referenceNode.Point(0)); - numBaseCases += queryNode.Count(); - } // We have to recurse down the reference node. In this case the recursion - // order does matter. Before recursing, though, we have to set the - // traversal information correctly. - double leftScore = rule.Score(queryNode, *referenceNode.Left()); - typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); - rule.TraversalInfo() = traversalInfo; - double rightScore = rule.Score(queryNode, *referenceNode.Right()); - numScores += 2; - - if (leftScore < rightScore) - { - // Recurse to the left. Restore the left traversal info. Store the right - // traversal info. - traversalInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryNode, *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = traversalInfo; - Traverse(queryNode, *referenceNode.Right()); - } - else - ++numPrunes; - } - else if (rightScore < leftScore) - { - // Recurse to the right. - Traverse(queryNode, *referenceNode.Right()); - - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore); - - if (leftScore != DBL_MAX) - { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(queryNode, *referenceNode.Left()); - } - else - ++numPrunes; - } - else // leftScore is equal to rightScore. - { - if (leftScore == DBL_MAX) - { - numPrunes += 2; - } - else - { - // Choose the left first. Restore the left traversal info. Store the - // right traversal info. - traversalInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryNode, *referenceNode.Left()); - - rightScore = rule.Rescore(queryNode, *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = traversalInfo; - Traverse(queryNode, *referenceNode.Right()); - } - else - ++numPrunes; - } - } + // order does matter. + TraverseReferenceNode(queryNode, referenceNode); } else { - // If the reference node contains a point we should calculate all - // base cases with this point. - if (referenceNode.FirstPointIsCentroid()) - { - for (size_t i = 0; i < queryNode.NumDescendants(); ++i) - rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0)); - numBaseCases += queryNode.NumDescendants(); - } // We have to recurse down both query and reference nodes. Because the - // query descent order does not matter, we will go to the left query child - // first. Before recursing, we have to set the traversal information + // query descent order does not matter, we will go to the central query + // child first. Before recursing, we have to set the traversal information // correctly. - double leftScore; - typename RuleType::TraversalInfoType leftInfo; - double rightScore; - typename RuleType::TraversalInfoType rightInfo; - - if (queryNode.FirstPointIsCentroid()) - { - leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left()); - leftInfo = rule.TraversalInfo(); - rule.TraversalInfo() = traversalInfo; - rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right()); - numScores += 2; - - if (leftScore < rightScore) - { - // Recurse to the left. Restore the left traversal info. Store the right - // traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryNode.Point(0), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(queryNode.Point(0), *referenceNode.Right()); - } - else - ++numPrunes; - } - else if (rightScore < leftScore) - { - // Recurse to the right. - Traverse(queryNode.Point(0), *referenceNode.Right()); - - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(), - leftScore); - - if (leftScore != DBL_MAX) - { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(queryNode.Point(0), *referenceNode.Left()); - } - else - ++numPrunes; - } - else - { - if (leftScore == DBL_MAX) - { - numPrunes += 2; - } - else - { - // Choose the left first. Restore the left traversal info and store the - // right traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryNode.Point(0), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal information. - rule.TraversalInfo() = rightInfo; - Traverse(queryNode.Point(0), *referenceNode.Right()); - } - else - ++numPrunes; - } - } - - // Restore the main traversal information. - rule.TraversalInfo() = traversalInfo; - } + typename RuleType::TraversalInfoType traversalInfo; + + // We have to calculate the base case with the central reference node. + // All children of a vantage point tree node use the same traversal + // information. + traversalInfo.LastQueryNode() = queryNode.Central(); + traversalInfo.LastReferenceNode() = referenceNode.Central(); + traversalInfo.LastBaseCase() = rule.BaseCase(queryNode.Central()->Point(0), + referenceNode.Central()->Point(0)); + traversalInfo.LastScore() = traversalInfo.LastBaseCase(); + numBaseCases++; - // Now recurse down the left node. - leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left()); - leftInfo = rule.TraversalInfo(); rule.TraversalInfo() = traversalInfo; - rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right()); - numScores += 2; - if (leftScore < rightScore) - { - // Recurse to the left. Restore the left traversal info. Store the right - // traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Left(), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(*queryNode.Left(), *referenceNode.Right()); - } - else - ++numPrunes; - } - else if (rightScore < leftScore) - { - // Recurse to the right. - Traverse(*queryNode.Left(), *referenceNode.Right()); - - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(), - leftScore); - - if (leftScore != DBL_MAX) - { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Left(), *referenceNode.Left()); - } - else - ++numPrunes; - } - else - { - if (leftScore == DBL_MAX) - { - numPrunes += 2; - } - else - { - // Choose the left first. Restore the left traversal info and store the - // right traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Left(), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal information. - rule.TraversalInfo() = rightInfo; - Traverse(*queryNode.Left(), *referenceNode.Right()); - } - else - ++numPrunes; - } - } + // Now recurse down the central node. + TraverseReferenceNode(*queryNode.Central(), referenceNode); // Restore the main traversal information. rule.TraversalInfo() = traversalInfo; - // Now recurse down the right query node. - leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left()); - leftInfo = rule.TraversalInfo(); + // Now recurse down the inner node. + TraverseReferenceNode(*queryNode.Inner(), referenceNode); + + // Restore the main traversal information. rule.TraversalInfo() = traversalInfo; - rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right()); - numScores += 2; - if (leftScore < rightScore) - { - // Recurse to the left. Restore the left traversal info. Store the right - // traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Right(), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(*queryNode.Right(), *referenceNode.Right()); - } - else - ++numPrunes; - } - else if (rightScore < leftScore) - { - // Recurse to the right. - Traverse(*queryNode.Right(), *referenceNode.Right()); - - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(), - leftScore); - - if (leftScore != DBL_MAX) - { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Right(), *referenceNode.Left()); - } - else - ++numPrunes; - } - else - { - if (leftScore == DBL_MAX) - { - numPrunes += 2; - } - else - { - // Choose the left first. Restore the left traversal info. Store the - // right traversal info. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(*queryNode.Right(), *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(*queryNode.Right(), *referenceNode.Right()); - } - else - ++numPrunes; - } - } + // Now recurse down the outer query node. + TraverseReferenceNode(*queryNode.Outer(), referenceNode); } } @@ -462,104 +170,68 @@ template template void VantagePointTree:: -DualTreeTraverser::Traverse( - const size_t queryIndex, +DualTreeTraverser::TraverseReferenceNode( + VantagePointTree& + queryNode, VantagePointTree& referenceNode) { - // If we are a leaf, run the base case as necessary. - if (referenceNode.IsLeaf()) - { - const size_t refEnd = referenceNode.Begin() + referenceNode.Count(); - for (size_t i = referenceNode.Begin(); i < refEnd; ++i) - rule.BaseCase(queryIndex, i); - numBaseCases += referenceNode.Count(); - return; - } + typedef VantagePointTree TreeType; + typedef typename RuleType::TraversalInfoType TravInfo; - // If the reference node contains a point we should calculate the base case. - if (referenceNode.FirstPointIsCentroid()) - { - rule.BaseCase(queryIndex, referenceNode.Point(0)); - numBaseCases++; - } + // We have to recurse down the reference node. In this case the recursion + // order does matter. Before recursing, though, we have to set the + // traversal information correctly. - // Store the current traversal info. - traversalInfo = rule.TraversalInfo(); + std::array, 3> scores; - // If either score is DBL_MAX, we do not recurse into that node. - double leftScore = rule.Score(queryIndex, *referenceNode.Left()); - typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo(); - rule.TraversalInfo() = traversalInfo; - double rightScore = rule.Score(queryIndex, *referenceNode.Right()); - typename RuleType::TraversalInfoType rightInfo; + double score = rule.Score(queryNode, *referenceNode.Central()); + scores[0] = std::make_tuple(score, referenceNode.Central()); - if (leftScore < rightScore) - { - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - // Recurse to the left. - Traverse(queryIndex, *referenceNode.Left()); + // All children of a vantage point tree use the same traversal info. + TravInfo traversalInfo = rule.TraversalInfo(); - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore); + score = rule.Score(queryNode, *referenceNode.Inner()); + scores[1] = std::make_tuple(score, referenceNode.Inner()); - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right. - } - else - ++numPrunes; - } - else if (rightScore < leftScore) - { - // Recurse to the right. - Traverse(queryIndex, *referenceNode.Right()); + score = rule.Score(queryNode, *referenceNode.Outer()); + scores[2] = std::make_tuple(score, referenceNode.Outer()); + numScores += 3; - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore); + // Sort the array according to the score. + if (std::get<0>(scores[0]) > std::get<0>(scores[1])) + std::swap(scores[0], scores[1]); + if (std::get<0>(scores[1]) > std::get<0>(scores[2])) + std::swap(scores[1], scores[2]); + if (std::get<0>(scores[0]) > std::get<0>(scores[1])) + std::swap(scores[0], scores[1]); - if (leftScore != DBL_MAX) - { - // Restore the left traversal info. - rule.TraversalInfo() = leftInfo; - Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left. - } - else - ++numPrunes; - } - else // leftScore is equal to rightScore. + for (size_t i = 0; i < 3; i++) { - if (leftScore == DBL_MAX) + if (std::get<0>(scores[i]) == DBL_MAX) { - numPrunes += 2; // Pruned both left and right. + numPrunes += 3 - i; + break; } - else + + // Is it still valid to recurse to the node? + double rescore = 0; + if (i > 0) + rescore = rule.Rescore(queryNode, *std::get<1>(scores[i]), + std::get<0>(scores[i])); + + if (rescore != DBL_MAX) { - // Choose the left first. - rightInfo = rule.TraversalInfo(); - rule.TraversalInfo() = leftInfo; - Traverse(queryIndex, *referenceNode.Left()); - - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), - rightScore); - - if (rightScore != DBL_MAX) - { - // Restore the right traversal info. - rule.TraversalInfo() = rightInfo; - Traverse(queryIndex, *referenceNode.Right()); - } - else - ++numPrunes; + // Restore the traversal info. + rule.TraversalInfo() = traversalInfo; + Traverse(queryNode, *std::get<1>(scores[i])); } + else + numPrunes++; } } - } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp index 149715fe887..b1dca54f816 100644 --- a/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/single_tree_traverser_impl.hpp @@ -51,57 +51,55 @@ SingleTreeTraverser::Traverse( return; } - // If the reference node contains a point we should calculate the base case. - if (referenceNode.FirstPointIsCentroid()) - rule.BaseCase(queryIndex, referenceNode.Point(0)); + rule.BaseCase(queryIndex, referenceNode.Central()->Point(0)); // If either score is DBL_MAX, we do not recurse into that node. - double leftScore = rule.Score(queryIndex, *referenceNode.Left()); - double rightScore = rule.Score(queryIndex, *referenceNode.Right()); + double innerScore = rule.Score(queryIndex, *referenceNode.Inner()); + double outerScore = rule.Score(queryIndex, *referenceNode.Outer()); - if (leftScore < rightScore) + if (innerScore < outerScore) { - // Recurse to the left. - Traverse(queryIndex, *referenceNode.Left()); + // Recurse to the inner node. + Traverse(queryIndex, *referenceNode.Inner()); - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), rightScore); + // Is it still valid to recurse to the outer node? + outerScore = rule.Rescore(queryIndex, *referenceNode.Outer(), outerScore); - if (rightScore != DBL_MAX) - Traverse(queryIndex, *referenceNode.Right()); // Recurse to the right. + if (outerScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Outer()); // Recurse to the outer. else ++numPrunes; } - else if (rightScore < leftScore) + else if (outerScore < innerScore) { - // Recurse to the right. - Traverse(queryIndex, *referenceNode.Right()); + // Recurse to the outer node. + Traverse(queryIndex, *referenceNode.Outer()); - // Is it still valid to recurse to the left? - leftScore = rule.Rescore(queryIndex, *referenceNode.Left(), leftScore); + // Is it still valid to recurse to the inner node? + innerScore = rule.Rescore(queryIndex, *referenceNode.Inner(), innerScore); - if (leftScore != DBL_MAX) - Traverse(queryIndex, *referenceNode.Left()); // Recurse to the left. + if (innerScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Inner()); // Recurse to the inner. else ++numPrunes; } - else // leftScore is equal to rightScore. + else // innerScore is equal to outerScore. { - if (leftScore == DBL_MAX) + if (innerScore == DBL_MAX) { - numPrunes += 2; // Pruned both left and right. + numPrunes += 2; // Pruned both inner and outer nodes. } else { - // Choose the left first. - Traverse(queryIndex, *referenceNode.Left()); + // Choose the inner node first. + Traverse(queryIndex, *referenceNode.Inner()); - // Is it still valid to recurse to the right? - rightScore = rule.Rescore(queryIndex, *referenceNode.Right(), - rightScore); + // Is it still valid to recurse to the outer node? + outerScore = rule.Rescore(queryIndex, *referenceNode.Outer(), + outerScore); - if (rightScore != DBL_MAX) - Traverse(queryIndex, *referenceNode.Right()); + if (outerScore != DBL_MAX) + Traverse(queryIndex, *referenceNode.Outer()); else ++numPrunes; } diff --git a/src/mlpack/core/tree/vantage_point_tree/traits.hpp b/src/mlpack/core/tree/vantage_point_tree/traits.hpp index ab844d7ec12..1b5da5c6549 100644 --- a/src/mlpack/core/tree/vantage_point_tree/traits.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/traits.hpp @@ -33,15 +33,15 @@ class TreeTraits* node) - { - return node->FirstPointIsCentroid(); - } + static const bool FirstPointIsCentroid = false; + + /** + * The first point of the central node (vantage point) is the centroid of + * its siblings. + */ + static const bool FirstSiblingFirstPointIsCentroid = true; /** * Points are not contained at multiple levels of the vantage point tree. diff --git a/src/mlpack/core/tree/vantage_point_tree/typedef.hpp b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp index cadde7fa1eb..4af223766d5 100644 --- a/src/mlpack/core/tree/vantage_point_tree/typedef.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/typedef.hpp @@ -14,26 +14,30 @@ namespace mlpack { namespace tree { /** - * The vantage point tree is a kind of the binary space tree. In contrast to - * BinarySpaceTree, each left intermediate node of the vantage point tree - * contains a point which is the centroid of the node. When recursively - * splitting nodes, the VPTree class selects a vantage point and splits the node - * according to the distance to this point. Thus, points that are closer to the - * vantage point form the left subtree (and the vantage point is the only point - * that the left node contains). Other points form the right subtree. - * In such a way, the bound of each left node is a ball and the vantage point is - * the centroid of the bound. The bound of each right node is a hollow ball - * centered at the vantage point. + * The vantage point tree (which is also called the metric tree. Vantage point + * trees and metric trees were invented independently by Yianilos an Uhlmann) is + * a kind of the binary space tree. In contrast to BinarySpaceTree, each + * intermediate node of the vantage point tree contains three children. + * The first child contains exactly one point (the vantage point). When + * recursively splitting nodes, the VPTree class selects the vantage point and + * splits the node according to the distance to this point. Thus, points that + * are closer to the vantage point form the inner subtree. Other points form the + * outer subtree. The vantage point is contained in the first (central) node. + * In such a way, the bound of each inner and outer nodes is a hollow ball and + * the vantage point is the centroid of the bound. + * + * This implementation differs from the original algorithms. Namely, the central + * node was introduced in order to simplify dual-tree traversers. * - * For more information, see the following paper. + * For more information, see the following papers. * * @code * @inproceedings{yianilos1993vptrees, * author = {Yianilos, Peter N.}, * title = {Data Structures and Algorithms for Nearest Neighbor Search in - * General Metric Spaces}, + * General Metric Spaces}, * booktitle = {Proceedings of the Fourth Annual ACM-SIAM Symposium on - * Discrete Algorithms}, + * Discrete Algorithms}, * series = {SODA '93}, * year = {1993}, * isbn = {0-89871-313-7}, @@ -41,7 +45,18 @@ namespace tree { * numpages = {11}, * publisher = {Society for Industrial and Applied Mathematics}, * address = {Philadelphia, PA, USA} - * } + * } + * + * @article{uhlmann1991metrictrees, + * author = {Jeffrey K. Uhlmann}, + * title = {Satisfying general proximity / similarity queries with metric + * trees}, + * journal = {Information Processing Letters}, + * volume = {40}, + * number = {4}, + * pages = {175 - 179}, + * year = {1991}, + * } * @endcode * * This template typedef satisfies the TreeType policy API. diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp index 66456b37fdb..4f9a9871eb2 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split.hpp @@ -2,8 +2,8 @@ * @file vantage_point_split.hpp * @author Mikhail Lozhnikov * - * Definition of class VantagePointSplit, a class that splits a binary space - * partitioning into two parts using the distance to a certain vantage point. + * Definition of class VantagePointSplit, a class that splits a vantage point + * tree into two parts using the distance to a certain vantage point. */ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp index 93ec45605e1..6e5b1a24100 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_split_impl.hpp @@ -2,7 +2,7 @@ * @file vantage_point_split_impl.hpp * @author Mikhail Lozhnikov * - * Implementation of class (VantagePointSplit) to split a binary space partition + * Implementation of class (VantagePointSplit) to split a vantage point * tree according to the median value of the distance to a certain vantage point. */ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp index 5737e6eaaa9..16aef9ea4d6 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree.hpp @@ -13,10 +13,12 @@ namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { /** - * The vantage point tree is a variant of a binary space tree. The difference - * from BinarySpaceTree is a presence of points in intermediate nodes. - * If an intermediate node holds a point, this point is the centroid of the - * bound. + * The vantage point tree (vantage point trees are also called metric trees) + * is a variant of a binary space tree. In contrast to BinarySpaceTree, + * each intermediate node of the vantage point tree contains exactly three + * children. The first child contains only one point (the vantage point). + * Two other children (inner and outer) may be leaf or intermediate nodes. + * Thus, the point of the first (central) child is the centroid of its siblings. * * This particular tree does not allow growth, so you cannot add or delete nodes * from it. If you need to add or delete a node, the better procedure is to @@ -24,8 +26,8 @@ namespace tree /** Trees and tree-building procedures. */ { * * This tree does take one runtime parameter in the constructor, which is the * max leaf size to be used. - - * @tparam MetricType The metric used for tree-building. The BoundType may + * + * @tparam MetricType The metric used for tree-building. The BoundType may * place restrictions on the metrics that can be used. * @tparam StatisticType Extra data contained in the node. See statistic.hpp * for the necessary skeleton interface. @@ -52,10 +54,12 @@ class VantagePointTree typedef typename MatType::elem_type ElemType; private: - //! The left child node. - VantagePointTree* left; - //! The right child node. - VantagePointTree* right; + //! The child node that contains only the vantage point. + VantagePointTree* central; + //! The inner child node. + VantagePointTree* inner; + //! The outer child node. + VantagePointTree* outer; //! The parent node (NULL if this is the root of the tree). VantagePointTree* parent; //! The index of the first point in the dataset contained in this node. @@ -194,15 +198,12 @@ class VantagePointTree * @param begin Index of point to start tree construction with. * @param count Number of points to use to construct tree. * @param maxLeafSize Size of each leaf in the tree. - * @param firstPointIsCentroid Indicates that the first point of the node is - * the centroid of its bound. */ VantagePointTree(VantagePointTree* parent, const size_t begin, const size_t count, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20, - bool firstPointIsCentroid = false); + const size_t maxLeafSize = 20); /** * Construct this node as a child of the given parent, starting at column @@ -221,16 +222,13 @@ class VantagePointTree * @param oldFromNew Vector which will be filled with the old positions for * each new point. * @param maxLeafSize Size of each leaf in the tree. - * @param firstPointIsCentroid Indicates that the first point of the node is - * the centroid of its bound. */ VantagePointTree(VantagePointTree* parent, const size_t begin, const size_t count, std::vector& oldFromNew, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20, - bool firstPointIsCentroid = false); + const size_t maxLeafSize = 20); /** * Construct this node as a child of the given parent, starting at column @@ -252,8 +250,6 @@ class VantagePointTree * @param newFromOld Vector which will be filled with the new positions for * each old point. * @param maxLeafSize Size of each leaf in the tree. - * @param firstPointIsCentroid Indicates that the first point of the node is - * the centroid of its bound. */ VantagePointTree(VantagePointTree* parent, const size_t begin, @@ -261,8 +257,7 @@ class VantagePointTree std::vector& oldFromNew, std::vector& newFromOld, SplitType, MatType>& splitter, - const size_t maxLeafSize = 20, - bool firstPointIsCentroid = false); + const size_t maxLeafSize = 20); /** * Create a vantage point tree by copying the other tree. Be careful! This @@ -308,15 +303,21 @@ class VantagePointTree //! Return whether or not this node is a leaf (true if it has no children). bool IsLeaf() const; - //! Gets the left child of this node. - VantagePointTree* Left() const { return left; } - //! Modify the left child of this node. - VantagePointTree*& Left() { return left; } + //! Gets the cental child of this node. + VantagePointTree* Central() const { return central; } + //! Modify the cental child of this node. + VantagePointTree*& Central() { return central; } + - //! Gets the right child of this node. - VantagePointTree* Right() const { return right; } - //! Modify the right child of this node. - VantagePointTree*& Right() { return right; } + //! Gets the inner child of this node. + VantagePointTree* Inner() const { return inner; } + //! Modify the inner child of this node. + VantagePointTree*& Inner() { return inner; } + + //! Gets the outer child of this node. + VantagePointTree* Outer() const { return outer; } + //! Modify the outer child of this node. + VantagePointTree*& Outer() { return outer; } //! Gets the parent of this node. VantagePointTree* Parent() const { return parent; } @@ -369,7 +370,7 @@ class VantagePointTree VantagePointTree& Child(const size_t child) const; VantagePointTree*& ChildPtr(const size_t child) - { return (child == 0) ? left : right; } + { return (child == 0) ? central : (child == 1 ? inner : outer); } //! Return the number of points in this node. size_t NumPoints() const; @@ -459,9 +460,6 @@ class VantagePointTree //! Store the center of the bounding region in the given vector. void Center(arma::vec& center) { bound.Center(center); } - //! Indicates that the first point of this node is the centroid of its bound. - bool FirstPointIsCentroid() const { return firstPointIsCentroid; } - private: /** * Splits the current node, assigning its left and right children recursively. diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp index 587305d9b20..ffba3ec2a64 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp @@ -23,8 +23,9 @@ VantagePointTree:: VantagePointTree( const MatType& data, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), /* This root node starts at index 0, */ count(data.n_cols), /* and spans all of the dataset. */ @@ -52,8 +53,9 @@ VantagePointTree( const MatType& data, std::vector& oldFromNew, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(data.n_cols), @@ -87,8 +89,9 @@ VantagePointTree( std::vector& oldFromNew, std::vector& newFromOld, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(data.n_cols), @@ -123,8 +126,9 @@ template VantagePointTree:: VantagePointTree(MatType&& data, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(data.n_cols), @@ -152,8 +156,9 @@ VantagePointTree( MatType&& data, std::vector& oldFromNew, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(data.n_cols), @@ -187,8 +192,9 @@ VantagePointTree( std::vector& oldFromNew, std::vector& newFromOld, const size_t maxLeafSize) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(data.n_cols), @@ -227,16 +233,15 @@ VantagePointTree( const size_t begin, const size_t count, SplitType, MatType>& splitter, - const size_t maxLeafSize, - bool firstPointIsCentroid) : - left(NULL), - right(NULL), + const size_t maxLeafSize) : + central(NULL), + inner(NULL), + outer(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset().n_rows), - dataset(&parent->Dataset()), // Point to the parent's dataset. - firstPointIsCentroid(firstPointIsCentroid) + dataset(&parent->Dataset()) // Point to the parent's dataset. { // Perform the actual splitting. SplitNode(maxLeafSize, splitter); @@ -258,16 +263,15 @@ VantagePointTree( const size_t count, std::vector& oldFromNew, SplitType, MatType>& splitter, - const size_t maxLeafSize, - bool firstPointIsCentroid) : - left(NULL), - right(NULL), + const size_t maxLeafSize) : + central(NULL), + inner(NULL), + outer(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset().n_rows), - dataset(&parent->Dataset()), - firstPointIsCentroid(firstPointIsCentroid) + dataset(&parent->Dataset()) { // Hopefully the vector is initialized correctly! We can't check that // entirely but we can do a minor sanity check. @@ -294,16 +298,15 @@ VantagePointTree( std::vector& oldFromNew, std::vector& newFromOld, SplitType, MatType>& splitter, - const size_t maxLeafSize, - bool firstPointIsCentroid) : - left(NULL), - right(NULL), + const size_t maxLeafSize) : + central(NULL), + inner(NULL), + outer(NULL), parent(parent), begin(begin), count(count), bound(parent->Dataset()->n_rows), - dataset(&parent->Dataset()), - firstPointIsCentroid(firstPointIsCentroid) + dataset(&parent->Dataset()) { // Hopefully the vector is initialized correctly! We can't check that // entirely but we can do a minor sanity check. @@ -334,8 +337,9 @@ template:: VantagePointTree( const VantagePointTree& other) : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(other.parent), begin(other.begin), count(other.count), @@ -347,37 +351,47 @@ VantagePointTree( dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL), firstPointIsCentroid(other.firstPointIsCentroid) { - // Create left and right children (if any). - if (other.Left()) + // Create central, inner and outer children (if any). + if (other.Central()) { - left = new VantagePointTree(*other.Left()); - left->Parent() = this; // Set parent to this, not other tree. + central = new VantagePointTree(*other.Central()); + central->Parent() = this; // Set parent to this, not other tree. } - if (other.Right()) + if (other.Inner()) { - right = new VantagePointTree(*other.Right()); - right->Parent() = this; // Set parent to this, not other tree. + inner = new VantagePointTree(*other.Inner()); + inner->Parent() = this; // Set parent to this, not other tree. + } + + if (other.Outer()) + { + outer = new VantagePointTree(*other.Outer()); + outer->Parent() = this; // Set parent to this, not other tree. } // Propagate matrix, but only if we are the root. if (parent == NULL) { std::queue queue; - if (left) - queue.push(left); - if (right) - queue.push(right); + if (central) + queue.push(central); + if (inner) + queue.push(inner); + if (outer) + queue.push(outer); while (!queue.empty()) { VantagePointTree* node = queue.front(); queue.pop(); node->dataset = dataset; - if (node->left) - queue.push(node->left); - if (node->right) - queue.push(node->right); + if (node->central) + queue.push(node->central); + if (node->inner) + queue.push(node->inner); + if (node->outer) + queue.push(node->outer); } } } @@ -393,8 +407,9 @@ template VantagePointTree:: VantagePointTree(VantagePointTree&& other) : - left(other.left), - right(other.right), + central(other.central), + inner(other.inner), + outer(other.outer), parent(other.parent), begin(other.begin), count(other.count), @@ -408,8 +423,9 @@ VantagePointTree(VantagePointTree&& other) : { // Now we are a clone of the other tree. But we must also clear the other // tree's contents, so it doesn't delete anything when it is destructed. - other.left = NULL; - other.right = NULL; + other.central = NULL; + other.inner = NULL; + other.outer = NULL; other.begin = 0; other.count = 0; other.parentDistance = 0.0; @@ -453,8 +469,9 @@ template:: ~VantagePointTree() { - delete left; - delete right; + delete central; + delete inner; + delete outer; // If we're the root, delete the matrix. if (!parent) @@ -470,7 +487,7 @@ template::IsLeaf() const { - return !left; + return !central; } /** @@ -485,9 +502,11 @@ template::NumChildren() const { - if (left && right) + if (central && inner && outer) + return 3; + if (central && inner) return 2; - if (left) + if (central) return 1; return 0; @@ -569,9 +588,11 @@ inline VantagePointTree::Child(const size_t child) const { if (child == 0) - return *left; + return *central; + else if(child == 1) + return *inner; else - return *right; + return *outer; } /** @@ -586,11 +607,7 @@ template::NumPoints() const { - // Each left intermediate node contains exactly one point. - // Each right intermediate node contains no points. - if (firstPointIsCentroid && left) - return 1; - else if(left) + if (!IsLeaf()) return 0; // This is a leaf node. @@ -655,8 +672,7 @@ void VantagePointTree: // We need to expand the bounds of this node properly. if (parent) { - bound.Center() = parent->firstPointIsCentroid ? - dataset->col(parent->begin + 1) : dataset->col(parent->begin); + bound.Center() = dataset->col(parent->begin); bound.OuterRadius() = 0; bound.InnerRadius() = std::numeric_limits::max(); } @@ -669,7 +685,7 @@ void VantagePointTree: while (tree->Parent() != NULL) { tree->Parent()->Bound() |= tree->Bound(); - tree->Parent()->furthestDescendantDistance = 0.5 * + tree->Parent()->furthestDescendantDistance = 0.5 * tree->Parent()->Bound().Diameter(); tree = tree->Parent(); } @@ -681,23 +697,14 @@ void VantagePointTree: return; // We can't split this. // splitCol denotes the two partitions of the dataset after the split. The - // points on its left go to the left child and the others go to the right + // points on its left go to the inner child and the others go to the outer // child. size_t splitCol; // Split the node. The elements of 'data' are reordered by the splitting // algorithm. This function call updates splitCol. - size_t splitBegin = begin; - size_t splitCount = count; - - if (FirstPointIsCentroid()) - { - splitBegin = begin + 1; - splitCount = count - 1; - } - - const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, + const bool split = splitter.SplitNode(bound, *dataset, begin, count, splitCol); // The node may not be always split. For instance, if all the points are the @@ -707,23 +714,30 @@ void VantagePointTree: // Now that we know the split column, we will recursively split the children // by calling their constructors (which perform this splitting process). - left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, splitter, - maxLeafSize, true); - right = new VantagePointTree(this, splitCol, - splitBegin + splitCount - splitCol, splitter, maxLeafSize, false); + central = new VantagePointTree(this, begin, 1, splitter, maxLeafSize); + inner = new VantagePointTree(this, begin + 1, splitCol - begin - 1, + splitter, maxLeafSize); + outer = new VantagePointTree(this, splitCol, begin + count - splitCol, + splitter, maxLeafSize); - // Calculate parent distances for those two nodes. - arma::vec center, leftCenter, rightCenter; - Center(center); - left->Center(leftCenter); - right->Center(rightCenter); + // Calculate parent distances for those three nodes. - const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter); - const ElemType rightParentDistance = MetricType::Evaluate(center, - rightCenter); + ElemType parentDistance; - left->ParentDistance() = leftParentDistance; - right->ParentDistance() = rightParentDistance; + if (parent) + parentDistance = MetricType::Evaluate(dataset->col(parent->begin), + dataset->col(begin)); + else + { + arma::vec center; + Center(center); + + parentDistance = MetricType::Evaluate(center, dataset->col(begin)); + } + + central->ParentDistance() = parentDistance; + inner->ParentDistance() = parentDistance; + outer->ParentDistance() = parentDistance; } template& oldFromNew, if (parent) { - bound.Center() = parent->firstPointIsCentroid ? - dataset->col(parent->begin + 1) : dataset->col(parent->begin); + bound.Center() = dataset->col(parent->begin); bound.OuterRadius() = 0; bound.InnerRadius() = std::numeric_limits::max(); } @@ -755,7 +768,7 @@ SplitNode(std::vector& oldFromNew, while (tree->Parent() != NULL) { tree->Parent()->Bound() |= tree->Bound(); - tree->Parent()->furthestDescendantDistance = 0.5 * + tree->Parent()->furthestDescendantDistance = 0.5 * tree->Parent()->Bound().Diameter(); tree = tree->Parent(); } @@ -768,23 +781,14 @@ SplitNode(std::vector& oldFromNew, return; // We can't split this. // splitCol denotes the two partitions of the dataset after the split. The - // points on its left go to the left child and the others go to the right + // points on its left go to the inner child and the others go to the outer // child. size_t splitCol; // Split the node. The elements of 'data' are reordered by the splitting // algorithm. This function call updates splitCol and oldFromNew. - size_t splitBegin = begin; - size_t splitCount = count; - - if (FirstPointIsCentroid()) - { - splitBegin = begin + 1; - splitCount = count - 1; - } - - const bool split = splitter.SplitNode(bound, *dataset, splitBegin, splitCount, + const bool split = splitter.SplitNode(bound, *dataset, begin, count, splitCol, oldFromNew); // The node may not be always split. For instance, if all the points are the @@ -794,34 +798,31 @@ SplitNode(std::vector& oldFromNew, // Now that we know the split column, we will recursively split the children // by calling their constructors (which perform this splitting process). - left = new VantagePointTree(this, splitBegin, splitCol - splitBegin, - oldFromNew, splitter, maxLeafSize, true); - right = new VantagePointTree(this, splitCol, - splitBegin + splitCount - splitCol, oldFromNew, splitter, maxLeafSize, - false); + central = new VantagePointTree(this, begin, 1, oldFromNew, splitter, + maxLeafSize); + inner = new VantagePointTree(this, begin + 1, splitCol - begin - 1, + oldFromNew, splitter, maxLeafSize); + outer = new VantagePointTree(this, splitCol, begin + count - splitCol, + oldFromNew, splitter, maxLeafSize); // Calculate parent distances for those two nodes. ElemType parentDistance; - if (firstPointIsCentroid) - { - assert(left->firstPointIsCentroid == true); - parentDistance = MetricType::Evaluate(dataset->col(begin), - dataset->col(left->begin)); - } + if (parent) + parentDistance = MetricType::Evaluate(dataset->col(parent->begin), + dataset->col(begin)); else { arma::vec center; Center(center); - assert(left->firstPointIsCentroid == true); - - parentDistance = MetricType::Evaluate(center, dataset->col(left->begin)); + parentDistance = MetricType::Evaluate(center, dataset->col(begin)); } - left->ParentDistance() = parentDistance; - right->ParentDistance() = parentDistance; + central->ParentDistance() = parentDistance; + inner->ParentDistance() = parentDistance; + outer->ParentDistance() = parentDistance; } // Default constructor (private), for boost::serialization. @@ -833,8 +834,9 @@ template VantagePointTree:: VantagePointTree() : - left(NULL), - right(NULL), + central(NULL), + inner(NULL), + outer(NULL), parent(NULL), begin(0), count(0), @@ -864,10 +866,12 @@ void VantagePointTree: // If we're loading, and we have children, they need to be deleted. if (Archive::is_loading::value) { - if (left) - delete left; - if (right) - delete right; + if (central) + delete central; + if (inner) + delete inner; + if (outer) + delete outer; if (!parent) delete dataset; } @@ -882,9 +886,9 @@ void VantagePointTree: ar & CreateNVP(dataset, "dataset"); // Save children last; otherwise boost::serialization gets confused. - ar & CreateNVP(left, "left"); - ar & CreateNVP(right, "right"); - ar & CreateNVP(firstPointIsCentroid, "firstPointIsCentroid"); + ar & CreateNVP(central, "central"); + ar & CreateNVP(inner, "inner"); + ar & CreateNVP(outer, "outer"); // Due to quirks of boost::serialization, if a tree is saved as an object and // not a pointer, the first level of the tree will be duplicated on load. @@ -893,36 +897,54 @@ void VantagePointTree: // necessary. if (Archive::is_loading::value) { - // Get parents of left and right children, or, NULL, if they don't exist. - VantagePointTree* leftParent = left ? left->Parent() : NULL; - VantagePointTree* rightParent = right ? right->Parent() : NULL; + // Get parents of central, inner and outer children, or, NULL, + // if they don't exist. + VantagePointTree* centralParent = central ? central->Parent() : NULL; + VantagePointTree* innerParent = inner ? inner->Parent() : NULL; + VantagePointTree* outerParent = outer ? outer->Parent() : NULL; // Reassign parent links if necessary. - if (left && left->Parent() != this) - left->Parent() = this; - if (right && right->Parent() != this) - right->Parent() = this; + if (central && central->Parent() != this) + central->Parent() = this; + if (inner && inner->Parent() != this) + inner->Parent() = this; + if (outer && outer->Parent() != this) + outer->Parent() = this; + + // Do we need to delete the central parent? + if (centralParent != NULL && centralParent != this) + { + // Sever the duplicate parent's children. Ensure we don't delete the + // dataset, by faking the duplicated parent's parent (that is, we need to + // set the parent to something non-NULL; 'this' works). + centralParent->Parent() = this; + centralParent->Inner() = NULL; + centralParent->Outer() = NULL; + delete centralParent; + } - // Do we need to delete the left parent? - if (leftParent != NULL && leftParent != this) + // Do we need to delete the inner parent? + if (innerParent != NULL && innerParent != this && + innerParent != centralParent) { // Sever the duplicate parent's children. Ensure we don't delete the // dataset, by faking the duplicated parent's parent (that is, we need to // set the parent to something non-NULL; 'this' works). - leftParent->Parent() = this; - leftParent->Left() = NULL; - leftParent->Right() = NULL; - delete leftParent; + innerParent->Parent() = this; + innerParent->Inner() = NULL; + innerParent->Outer() = NULL; + delete innerParent; } - // Do we need to delete the right parent? - if (rightParent != NULL && rightParent != this && rightParent != leftParent) + // Do we need to delete the outer parent? + if (outerParent != NULL && outerParent != this && + outerParent != centralParent) { // Sever the duplicate parent's children, in the same way as above. - rightParent->Parent() = this; - rightParent->Left() = NULL; - rightParent->Right() = NULL; - delete rightParent; + outerParent->Parent() = this; + outerParent->Inner() = NULL; + outerParent->Outer() = NULL; + delete outerParent; } } } diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp index eaeff5fac96..27abacf971d 100644 --- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp +++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp @@ -58,7 +58,7 @@ double FastMKSRules::BaseCase( // cover trees, the kernel evaluation between the two centroid points already // happened. So we don't need to do it. Note that this optimizes out if the // first conditional is false (its result is known at compile time). - if (tree::TreeTraits::FirstPointIsCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid) { if ((queryIndex == lastQueryIndex) && (referenceIndex == lastReferenceIndex)) @@ -74,7 +74,7 @@ double FastMKSRules::BaseCase( referenceSet.col(referenceIndex)); // Update the last kernel value, if we need to. - if (tree::TreeTraits::FirstPointIsCentroid()) + if (tree::TreeTraits::FirstPointIsCentroid) lastKernel = kernelEval; // If the reference and query sets are identical, we still need to compute the @@ -141,7 +141,7 @@ double FastMKSRules::Score(const size_t queryIndex, // centroid or, if the centroid is a point, use that. ++scores; double kernelEval; - if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // Could it be that this kernel evaluation has already been calculated? if (tree::TreeTraits::HasSelfChildren && @@ -295,8 +295,7 @@ double FastMKSRules::Score(TreeType& queryNode, // We were unable to perform a parent-child or parent-parent prune, so now we // must calculate kernel evaluation, if necessary. double kernelEval = 0.0; - if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // For this type of tree, we may have already calculated the base case in // the parents. diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp index 2086d2d5f82..318e5c9b621 100644 --- a/src/mlpack/methods/fastmks/fastmks_stat.hpp +++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp @@ -44,7 +44,7 @@ class FastMKSStat lastKernelNode(NULL) { // Do we have to calculate the centroid? - if (tree::TreeTraits::FirstPointIsCentroid(&node)) + if (tree::TreeTraits::FirstPointIsCentroid) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp index 5622e4474bd..f91b5ef153e 100644 --- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp +++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp @@ -131,8 +131,7 @@ inline double DualTreeKMeansRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -208,9 +207,7 @@ inline double DualTreeKMeansRules::Score( // Now, check if we can prune. if (adjustedScore > queryNode.Stat().UpperBound()) { - if (!(tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && - score == 0.0)) + if (!(tree::TreeTraits::FirstPointIsCentroid && score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 5f3d2386878..09483223578 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -85,7 +85,7 @@ inline double NeighborSearchRules::Score( { ++scores; // Count number of Score() calls. double distance; - if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // The first point in the tree is the centroid. So we can then calculate // the base case between that and the query point. @@ -160,8 +160,11 @@ inline double NeighborSearchRules::Score( // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. - if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) + { + adjustedScore = traversalInfo.LastBaseCase(); + } + else if (tree::TreeTraits::FirstSiblingFirstPointIsCentroid) { adjustedScore = traversalInfo.LastBaseCase(); } @@ -238,9 +241,9 @@ inline double NeighborSearchRules::Score( // Can we prune? if (!SortPolicy::IsBetter(adjustedScore, bestDistance)) { - if (!(tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && - score == 0.0)) + if (!((tree::TreeTraits::FirstPointIsCentroid && score == 0.0) || + (tree::TreeTraits::FirstSiblingFirstPointIsCentroid && + score == 0.0))) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only @@ -250,8 +253,7 @@ inline double NeighborSearchRules::Score( } double distance; - if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // The first point in the node is the centroid, so we can calculate the // distance between the two points using BaseCase() and then find the @@ -279,6 +281,38 @@ inline double NeighborSearchRules::Score( traversalInfo.LastBaseCase() = baseCase; } + else if (tree::TreeTraits::FirstSiblingFirstPointIsCentroid && + queryNode.Parent() && referenceNode.Parent()) + { + // The first point of the first sibling is the centroid, so we have to + // calculate the distance between the centroids if we have not calculated + // that yet. + double baseCase; + + TreeType* firstQuerySibling = &queryNode.Parent()->Child(0); + TreeType* firstReferenceSibling = &referenceNode.Parent()->Child(0); + + if (firstQuerySibling != traversalInfo.LastQueryNode() || + firstReferenceSibling != traversalInfo.LastReferenceNode()) + { + baseCase = BaseCase(firstQuerySibling->Point(0), + firstReferenceSibling->Point(0)); + + // We update the traversal information only if we come across new + // centroids. + traversalInfo.LastQueryNode() = firstQuerySibling; + traversalInfo.LastReferenceNode() = firstReferenceSibling; + traversalInfo.LastBaseCase() = baseCase; + } + else + baseCase = traversalInfo.LastBaseCase(); + + distance = SortPolicy::CombineBest(baseCase, + queryNode.FurthestDescendantDistance() + + referenceNode.FurthestDescendantDistance()); + + traversalInfo.LastScore() = distance; + } else { distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode); @@ -287,9 +321,12 @@ inline double NeighborSearchRules::Score( if (SortPolicy::IsBetter(distance, bestDistance)) { // Set traversal information. - traversalInfo.LastQueryNode() = &queryNode; - traversalInfo.LastReferenceNode() = &referenceNode; - traversalInfo.LastScore() = distance; + if (!tree::TreeTraits::FirstSiblingFirstPointIsCentroid) + { + traversalInfo.LastQueryNode() = &queryNode; + traversalInfo.LastReferenceNode() = &referenceNode; + traversalInfo.LastScore() = distance; + } return distance; } diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp index 2b24a70dca4..b0b56a7a5c2 100644 --- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp +++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp @@ -79,7 +79,7 @@ double RangeSearchRules::Score(const size_t queryIndex, // object. math::Range distances; - if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // In this situation, we calculate the base case. So we should check to be // sure we haven't already done that. @@ -147,8 +147,7 @@ double RangeSearchRules::Score(TreeType& queryNode, TreeType& referenceNode) { math::Range distances; - if (tree::TreeTraits::FirstPointIsCentroid(&queryNode) && - tree::TreeTraits::FirstPointIsCentroid(&referenceNode)) + if (tree::TreeTraits::FirstPointIsCentroid) { // It is possible that the base case has already been calculated. double baseCase = 0.0; @@ -177,6 +176,37 @@ double RangeSearchRules::Score(TreeType& queryNode, // Update the last distances performed for the query and reference node. traversalInfo.LastBaseCase() = baseCase; } + else if (tree::TreeTraits::FirstSiblingFirstPointIsCentroid && + queryNode.Parent() && referenceNode.Parent()) + { + // The first point of the first sibling is the centroid, so we have to + // calculate the distance between the centroids if we have not calculated + // that yet. + double baseCase; + + TreeType* firstQuerySibling = &queryNode.Parent()->Child(0); + TreeType* firstReferenceSibling = &referenceNode.Parent()->Child(0); + + if (firstQuerySibling != traversalInfo.LastQueryNode() || + firstReferenceSibling != traversalInfo.LastReferenceNode()) + { + baseCase = BaseCase(firstQuerySibling->Point(0), + firstReferenceSibling->Point(0)); + + // We update the traversal information only if we come across new + // centroids. + traversalInfo.LastQueryNode() = firstQuerySibling; + traversalInfo.LastReferenceNode() = firstReferenceSibling; + traversalInfo.LastBaseCase() = baseCase; + } + else + baseCase = traversalInfo.LastBaseCase(); + + distances.Lo() = baseCase - queryNode.FurthestDescendantDistance() + - referenceNode.FurthestDescendantDistance(); + distances.Hi() = baseCase + queryNode.FurthestDescendantDistance() + + referenceNode.FurthestDescendantDistance(); + } else { // Just perform the calculation. @@ -199,8 +229,11 @@ double RangeSearchRules::Score(TreeType& queryNode, // Otherwise the score doesn't matter. Recursion order is irrelevant in range // search. - traversalInfo.LastQueryNode() = &queryNode; - traversalInfo.LastReferenceNode() = &referenceNode; + if (!tree::TreeTraits::FirstSiblingFirstPointIsCentroid) + { + traversalInfo.LastQueryNode() = &queryNode; + traversalInfo.LastReferenceNode() = &referenceNode; + } return 0.0; } @@ -225,7 +258,7 @@ void RangeSearchRules::AddResult(const size_t queryIndex, // called, so if the base case has already been calculated, then we must avoid // adding that point to the results again. size_t baseCaseMod = 0; - if (tree::TreeTraits::FirstPointIsCentroid(&referenceNode) && + if (tree::TreeTraits::FirstPointIsCentroid && (queryIndex == lastQueryIndex) && (referenceNode.Point(0) == lastReferenceIndex)) { diff --git a/src/mlpack/tests/tree_traits_test.cpp b/src/mlpack/tests/tree_traits_test.cpp index 16ecf70e662..cf0395d19f5 100644 --- a/src/mlpack/tests/tree_traits_test.cpp +++ b/src/mlpack/tests/tree_traits_test.cpp @@ -36,7 +36,7 @@ BOOST_AUTO_TEST_CASE(DefaultsTraitsTest) BOOST_REQUIRE_EQUAL(b, true); b = TreeTraits::HasSelfChildren; BOOST_REQUIRE_EQUAL(b, false); - b = TreeTraits::FirstPointIsCentroid(); + b = TreeTraits::FirstPointIsCentroid; BOOST_REQUIRE_EQUAL(b, false); b = TreeTraits::RearrangesDataset; BOOST_REQUIRE_EQUAL(b, false); @@ -58,7 +58,7 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeTraitsTest) BOOST_REQUIRE_EQUAL(b, false); // The first point is not the centroid. - b = TreeTraits::FirstPointIsCentroid(); + b = TreeTraits::FirstPointIsCentroid; BOOST_REQUIRE_EQUAL(b, false); // The dataset gets rearranged at build time. @@ -82,7 +82,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTraitsTest) BOOST_REQUIRE_EQUAL(b, true); // The first point is the center of the node. - b = TreeTraits>::FirstPointIsCentroid(); + b = TreeTraits>::FirstPointIsCentroid; BOOST_REQUIRE_EQUAL(b, true); b = TreeTraits>::RearrangesDataset; diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp index 9b9da8c6244..3097af274b8 100644 --- a/src/mlpack/tests/vantage_point_tree_test.cpp +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -27,8 +27,8 @@ BOOST_AUTO_TEST_CASE(VPTreeTraitsTest) bool b = TreeTraits::HasOverlappingChildren; BOOST_REQUIRE_EQUAL(b, true); -// b = TreeTraits::FirstPointIsCentroid; -// BOOST_REQUIRE_EQUAL(b, true); + b = TreeTraits::FirstPointIsCentroid; + BOOST_REQUIRE_EQUAL(b, false); b = TreeTraits::HasSelfChildren; BOOST_REQUIRE_EQUAL(b, false); b = TreeTraits::RearrangesDataset; @@ -138,20 +138,19 @@ void CheckBound(TreeType& tree) } else { - if (!tree.Parent()) - BOOST_REQUIRE_EQUAL(tree.NumPoints(), 0); - else if (tree.FirstPointIsCentroid()) - { - BOOST_REQUIRE_EQUAL(tree.NumPoints(), 1); - BOOST_REQUIRE_EQUAL(true, - tree.Bound().Contains(tree.Dataset().col(tree.Point(0)))); - } - - BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Left()->Bound()), true); - BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Right()->Bound()), true); - - CheckBound(*tree.Left()); - CheckBound(*tree.Right()); + TreeType* central = tree.Central(); + BOOST_REQUIRE_EQUAL(central->NumPoints(), 1); + BOOST_REQUIRE_EQUAL(true, + central->Bound().Contains(tree.Dataset().col(central->Point(0)))); + BOOST_REQUIRE_EQUAL(central->Bound().InnerRadius(), 0.0); + BOOST_REQUIRE_EQUAL(central->Bound().OuterRadius(), 0.0); + + BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Central()->Bound()), true); + BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Inner()->Bound()), true); + BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Outer()->Bound()), true); + + CheckBound(*tree.Inner()); + CheckBound(*tree.Outer()); } } @@ -174,35 +173,38 @@ void CheckSplit(TreeType& tree) typename TreeType::ElemType maxDist = 0; - size_t pointsEnd = tree.Left()->Begin() + tree.Left()->Count(); - for (size_t i = tree.Left()->Begin(); i < pointsEnd; i++) + size_t pointsEnd = tree.Inner()->Begin() + tree.Inner()->Count(); + for (size_t i = tree.Inner()->Begin(); i < pointsEnd; i++) { typename TreeType::ElemType dist = - tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Left()->Begin()), - tree.Dataset().col(i)); + tree.Bound().Metric().Evaluate(tree.Dataset().col(i), + tree.Dataset().col(tree.Central()->Begin())); if (dist > maxDist) maxDist = dist; } - pointsEnd = tree.Right()->Begin() + tree.Right()->Count(); - for (size_t i = tree.Right()->Begin(); i < pointsEnd; i++) + pointsEnd = tree.Outer()->Begin() + tree.Outer()->Count(); + for (size_t i = tree.Outer()->Begin(); i < pointsEnd; i++) { typename TreeType::ElemType dist = - tree.Bound().Metric().Evaluate(tree.Dataset().col(tree.Left()->Begin()), - tree.Dataset().col(i)); + tree.Bound().Metric().Evaluate(tree.Dataset().col(i), + tree.Dataset().col(tree.Central()->Begin())); BOOST_REQUIRE_LE(maxDist, dist); } - if (tree.FirstPointIsCentroid()) + for (size_t k = 0; k < tree.Bound().Dim(); k++) { - for (size_t k = 0; k < tree.Bound().Dim(); k++) - BOOST_REQUIRE_EQUAL(tree.Bound().Center()[k], - tree.Dataset().col(tree.Point(0))[k]); + BOOST_REQUIRE_EQUAL(tree.Inner()->Bound().Center()[k], + tree.Dataset().col(tree.Central()->Point(0))[k]); + BOOST_REQUIRE_EQUAL(tree.Outer()->Bound().Center()[k], + tree.Dataset().col(tree.Central()->Point(0))[k]); + BOOST_REQUIRE_EQUAL(tree.Central()->Bound().Center()[k], + tree.Dataset().col(tree.Central()->Point(0))[k]); } - CheckSplit(*tree.Left()); - CheckSplit(*tree.Right()); + CheckSplit(*tree.Inner()); + CheckSplit(*tree.Outer()); } BOOST_AUTO_TEST_CASE(VPTreeSplitTest) From d5fb40fd9bd4beee24f849f1b8c433d46842cebc Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Tue, 26 Jul 2016 10:33:02 +0300 Subject: [PATCH 10/12] Very minor fixes. --- src/mlpack/core/tree/vantage_point_tree/traits.hpp | 4 ++-- .../tree/vantage_point_tree/vantage_point_split.hpp | 13 ------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/mlpack/core/tree/vantage_point_tree/traits.hpp b/src/mlpack/core/tree/vantage_point_tree/traits.hpp index 1b5da5c6549..99803a8f519 100644 --- a/src/mlpack/core/tree/vantage_point_tree/traits.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/traits.hpp @@ -54,9 +54,9 @@ class TreeTraits Date: Sat, 6 Aug 2016 20:45:11 +0300 Subject: [PATCH 11/12] Fixed an error with duplicated base cases. Fixed tests. --- .../tree/vantage_point_tree/dual_tree_traverser_impl.hpp | 4 +--- .../methods/neighbor_search/neighbor_search_rules_impl.hpp | 6 +++++- src/mlpack/methods/range_search/range_search_rules_impl.hpp | 6 +++++- src/mlpack/tests/vantage_point_tree_test.cpp | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp index 3c1c503f1de..4a9e39e94bb 100644 --- a/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp @@ -79,9 +79,7 @@ DualTreeTraverser::Traverse( numBaseCases += referenceNode.Count(); } } - else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) || - (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() && - !queryNode.IsLeaf() && !referenceNode.IsLeaf())) + else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) { // We have to recurse down the query node. In this case the recursion order // does not matter. diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp index 09483223578..88fbdf69955 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp @@ -282,11 +282,15 @@ inline double NeighborSearchRules::Score( traversalInfo.LastBaseCase() = baseCase; } else if (tree::TreeTraits::FirstSiblingFirstPointIsCentroid && - queryNode.Parent() && referenceNode.Parent()) + queryNode.Parent() && referenceNode.Parent() && + !queryNode.IsLeaf() && !referenceNode.IsLeaf()) { // The first point of the first sibling is the centroid, so we have to // calculate the distance between the centroids if we have not calculated // that yet. + // We can not use this property if the traverser does not recurse down + // the query or the reference node since two siblings may be traversed + // in two different branches of the recursion. double baseCase; TreeType* firstQuerySibling = &queryNode.Parent()->Child(0); diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp index b0b56a7a5c2..fddcad3c45b 100644 --- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp +++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp @@ -177,11 +177,15 @@ double RangeSearchRules::Score(TreeType& queryNode, traversalInfo.LastBaseCase() = baseCase; } else if (tree::TreeTraits::FirstSiblingFirstPointIsCentroid && - queryNode.Parent() && referenceNode.Parent()) + queryNode.Parent() && referenceNode.Parent() && + !queryNode.IsLeaf() && !referenceNode.IsLeaf()) { // The first point of the first sibling is the centroid, so we have to // calculate the distance between the centroids if we have not calculated // that yet. + // We can not use this property if the traverser does not recurse down + // the query or the reference node since two siblings may be traversed + // in two different branches of the recursion. double baseCase; TreeType* firstQuerySibling = &queryNode.Parent()->Child(0); diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp index 3097af274b8..696313ade3c 100644 --- a/src/mlpack/tests/vantage_point_tree_test.cpp +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -34,7 +34,7 @@ BOOST_AUTO_TEST_CASE(VPTreeTraitsTest) b = TreeTraits::RearrangesDataset; BOOST_REQUIRE_EQUAL(b, true); b = TreeTraits::BinaryTree; - BOOST_REQUIRE_EQUAL(b, true); + BOOST_REQUIRE_EQUAL(b, false); } BOOST_AUTO_TEST_CASE(HollowBallBoundTest) From add74d0ea472287a58472cd8cb2573b3bf7a7af8 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Sat, 6 Aug 2016 21:12:23 +0300 Subject: [PATCH 12/12] Removed the property that each child bound is contained entirely in the parent bound. --- .../vantage_point_tree_impl.hpp | 19 ------------ src/mlpack/tests/vantage_point_tree_test.cpp | 29 +++++++++++++++---- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp index ffba3ec2a64..fc5c5711605 100644 --- a/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp +++ b/src/mlpack/core/tree/vantage_point_tree/vantage_point_tree_impl.hpp @@ -680,15 +680,6 @@ void VantagePointTree: if (count > 0) bound |= dataset->cols(begin, begin + count - 1); - VantagePointTree* tree = this; - - while (tree->Parent() != NULL) - { - tree->Parent()->Bound() |= tree->Bound(); - tree->Parent()->furthestDescendantDistance = 0.5 * - tree->Parent()->Bound().Diameter(); - tree = tree->Parent(); - } // Calculate the furthest descendant distance. furthestDescendantDistance = 0.5 * bound.Diameter(); @@ -763,16 +754,6 @@ SplitNode(std::vector& oldFromNew, if (count > 0) bound |= dataset->cols(begin, begin + count - 1); - VantagePointTree* tree = this; - - while (tree->Parent() != NULL) - { - tree->Parent()->Bound() |= tree->Bound(); - tree->Parent()->furthestDescendantDistance = 0.5 * - tree->Parent()->Bound().Diameter(); - tree = tree->Parent(); - } - // Calculate the furthest descendant distance. furthestDescendantDistance = 0.5 * bound.Diameter(); diff --git a/src/mlpack/tests/vantage_point_tree_test.cpp b/src/mlpack/tests/vantage_point_tree_test.cpp index 696313ade3c..fe7cb424c26 100644 --- a/src/mlpack/tests/vantage_point_tree_test.cpp +++ b/src/mlpack/tests/vantage_point_tree_test.cpp @@ -130,11 +130,21 @@ BOOST_AUTO_TEST_CASE(HollowBallBoundTest) template void CheckBound(TreeType& tree) { + typedef typename TreeType::ElemType ElemType; if (tree.IsLeaf()) { + // Ensure that the bound contains all descendant points. for (size_t i = 0; i < tree.NumPoints(); i++) - BOOST_REQUIRE_EQUAL(true, - tree.Bound().Contains(tree.Dataset().col(tree.Point(i)))); + { + ElemType dist = tree.Bound().Metric().Evaluate(tree.Bound().Center(), + tree.Dataset().col(tree.Point(i))); + + BOOST_REQUIRE_LE(tree.Bound().InnerRadius(), dist * + (1.0 + 10.0 * std::numeric_limits::epsilon())); + + BOOST_REQUIRE_LE(dist, tree.Bound().OuterRadius() * + (1.0 + 10.0 * std::numeric_limits::epsilon())); + } } else { @@ -145,9 +155,18 @@ void CheckBound(TreeType& tree) BOOST_REQUIRE_EQUAL(central->Bound().InnerRadius(), 0.0); BOOST_REQUIRE_EQUAL(central->Bound().OuterRadius(), 0.0); - BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Central()->Bound()), true); - BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Inner()->Bound()), true); - BOOST_REQUIRE_EQUAL(tree.Bound().Contains(tree.Outer()->Bound()), true); + // Ensure that the bound contains all descendant points. + for (size_t i = 0; i < tree.NumDescendants(); i++) + { + ElemType dist = tree.Bound().Metric().Evaluate(tree.Bound().Center(), + tree.Dataset().col(tree.Descendant(i))); + + BOOST_REQUIRE_LE(tree.Bound().InnerRadius(), dist * + (1.0 + 10.0 * std::numeric_limits::epsilon())); + + BOOST_REQUIRE_LE(dist, tree.Bound().OuterRadius() * + (1.0 + 10.0 * std::numeric_limits::epsilon())); + } CheckBound(*tree.Inner()); CheckBound(*tree.Outer());