From b8da9a9c01630b5455ac31347f69865b85bce370 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Thu, 16 Jun 2016 05:45:59 +0300 Subject: [PATCH 1/8] R+ tree implementation --- src/mlpack/core/tree/CMakeLists.txt | 4 + src/mlpack/core/tree/rectangle_tree.hpp | 2 + .../hilbert_r_tree_split_impl.hpp | 2 + .../r_plus_tree_descent_heuristic.hpp | 49 ++ .../r_plus_tree_descent_heuristic_impl.hpp | 96 ++++ .../tree/rectangle_tree/r_plus_tree_split.hpp | 95 ++++ .../rectangle_tree/r_plus_tree_split_impl.hpp | 441 ++++++++++++++++++ .../rectangle_tree/r_star_tree_split_impl.hpp | 3 + .../tree/rectangle_tree/r_tree_split_impl.hpp | 2 + .../rectangle_tree/rectangle_tree_impl.hpp | 4 +- .../core/tree/rectangle_tree/typedef.hpp | 7 + .../tree/rectangle_tree/x_tree_split_impl.hpp | 3 + src/mlpack/tests/rectangle_tree_test.cpp | 90 ++++ 13 files changed, 796 insertions(+), 2 deletions(-) create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt index 28415d528dd..527dd59b2d5 100644 --- a/src/mlpack/core/tree/CMakeLists.txt +++ b/src/mlpack/core/tree/CMakeLists.txt @@ -63,6 +63,10 @@ set(SOURCES rectangle_tree/recursive_hilbert_value_impl.hpp rectangle_tree/discrete_hilbert_value.hpp rectangle_tree/discrete_hilbert_value_impl.hpp + rectangle_tree/r_plus_tree_descent_heuristic.hpp + rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp + rectangle_tree/r_plus_tree_split.hpp + rectangle_tree/r_plus_tree_split_impl.hpp statistic.hpp traversal_info.hpp tree_traits.hpp diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp index de236ad40ed..a28cd9f759d 100644 --- a/src/mlpack/core/tree/rectangle_tree.hpp +++ b/src/mlpack/core/tree/rectangle_tree.hpp @@ -30,6 +30,8 @@ #include "rectangle_tree/hilbert_r_tree_auxiliary_information.hpp" #include "rectangle_tree/recursive_hilbert_value.hpp" #include "rectangle_tree/discrete_hilbert_value.hpp" +#include "rectangle_tree/r_plus_tree_descent_heuristic.hpp" +#include "rectangle_tree/r_plus_tree_split.hpp" #include "rectangle_tree/typedef.hpp" #endif diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp index fd399610942..0d4ed5f7a18 100644 --- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp @@ -18,6 +18,8 @@ template void HilbertRTreeSplit:: SplitLeafNode(TreeType* tree, std::vector& relevels) { + if (tree->Count() <= tree->MaxLeafSize()) + return; // If we are splitting the root node, we need will do things differently so // that the constructor and other methods don't confuse the end user by giving // an address of another node. diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp new file mode 100644 index 00000000000..dfe8e0acc9e --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp @@ -0,0 +1,49 @@ +/** + * @file r_plus_tree_descent_heuristic.hpp + * @author Mikhail Lozhnikov + * + * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a + * node in an R tree when inserting a new point. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP + +#include + +namespace mlpack { +namespace tree { + +class RPlusTreeDescentHeuristic +{ + public: + /** + * Evaluate the node using a heuristic. Returns the number of the node + * with minimum largest Hilbert value is greater than the Hilbert value of + * the point being inserted. + * + * @param node The node that is being evaluated. + * @param point The number of the point that is being inserted. + */ + template + static size_t ChooseDescentNode(TreeType* node, const size_t point); + + /** + * Evaluate the node using a heuristic. Returns the number of the node + * with minimum largest Hilbert value is greater than the largest + * Hilbert value of the point being inserted. + * + * @param node The node that is being evaluated. + * @param insertedNode The node that is being inserted. + */ + template + static size_t ChooseDescentNode(const TreeType* node, + const TreeType* insertedNode); + +}; + +} // namespace tree +} // namespace mlpack + +#include "r_plus_tree_descent_heuristic_impl.hpp" + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp new file mode 100644 index 00000000000..265b739a0d5 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp @@ -0,0 +1,96 @@ +/** + * @file hilbert_r_tree_descent_heuristic_impl.hpp + * @author Mikhail Lozhnikov + * + * Implementation of HilbertRTreeDescentHeuristic, a class that chooses the best child + * of a node in an R tree when inserting a new point. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP + +#include "r_plus_tree_descent_heuristic.hpp" +#include "../hrectbound.hpp" + +namespace mlpack { +namespace tree { + +template +size_t RPlusTreeDescentHeuristic:: +ChooseDescentNode(TreeType* node, const size_t point) +{ + typedef typename TreeType::ElemType ElemType; + size_t bestIndex = 0; + bool success; + + for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) + { + if (node->Children()[bestIndex]->Bound().Contains(node->Dataset().col(point))) + return bestIndex; + } + + for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) + { + bound::HRectBound bound = + node->Children()[bestIndex]->Bound(); + bound |= node->Dataset().col(point); + + success = true; + + for (size_t j = 0; j < node->NumChildren(); j++) + { + if (j == bestIndex) + continue; + success = false; + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + if (bound[k].Lo() >= node->Children()[j]->Bound()[k].Hi() || + node->Children()[j]->Bound()[k].Lo() >= bound[k].Hi()) + { + success = true; + break; + } + } + if (!success) + break; + } + if (success) + break; + } + + if (!success) + { + size_t depth = node->TreeDepth(); + + TreeType* tree = node; + while (depth > 1) + { + TreeType* child = new TreeType(node); + + tree->Children()[tree->NumChildren()++] = child; + tree = child; + depth--; + } + return node->NumChildren()-1; + } + + assert(bestIndex < node->NumChildren()); + + return bestIndex; +} + +template +size_t RPlusTreeDescentHeuristic:: +ChooseDescentNode(const TreeType* node, const TreeType* insertedNode) +{ + size_t bestIndex = 0; + + assert(false); + + return bestIndex; +} + + +} // namespace tree +} // namespace mlpack + +#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp new file mode 100644 index 00000000000..f06b813a66a --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp @@ -0,0 +1,95 @@ +/** + * @file r_plus_tree_split.hpp + * @author Mikhail Lozhnikov + * + * Defintion of the RPlusTreeSplit class, a class that splits the nodes of an R + * tree, starting at a leaf node and moving upwards if necessary. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP + +#include + +const double fillFactorFraction = 0.5; + +namespace mlpack { +namespace tree /** Trees and tree-building procedures. */ { + +class RPlusTreeSplit +{ + public: + /** + * Split a leaf node using the "default" algorithm. If necessary, this split + * will propagate upwards through the tree. + * @param node. The node that is being split. + * @param relevels Not used. + */ + template + static void SplitLeafNode(TreeType *tree,std::vector& relevels); + + /** + * Split a non-leaf node using the "default" algorithm. If this is a root + * node, the tree increases in depth. + * @param node. The node that is being split. + * @param relevels Not used. + */ + template + static bool SplitNonLeafNode(TreeType *tree,std::vector& relevels); + + + + private: + + template + struct SortStruct + { + ElemType d; + int n; + }; + + template + static bool StructComp(const SortStruct& s1, + const SortStruct& s2) + { + return s1.d < s2.d; + } + + template + static void SplitLeafNodeAlongPartition(TreeType* tree, + TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut); + + template + static void SplitNonLeafNodeAlongPartition(TreeType* tree, + TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut); + + template + static bool PartitionNode(const TreeType* node, size_t fillFactor, + size_t& minCutAxis, double& minCut); + + template + static double SweepLeafNode(size_t axis, const TreeType* node, + size_t fillFactor, double& axisCut); + + template + static double SweepNonLeafNode(size_t axis, const TreeType* node, + size_t fillFactor, double& axisCut); + + template + static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode); + + template + static bool CheckNonLeafSweep(const TreeType* node, + size_t cutAxis, double cut); + + template + static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation +#include "r_plus_tree_split_impl.hpp" + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP + diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp new file mode 100644 index 00000000000..fca51fc0eb8 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp @@ -0,0 +1,441 @@ +/** + * @file r_plus_tree_split_impl.hpp + * @author Mikhail Lozhnikov + * + * Implementation of class (RPlusTreeSplit) to split a RectangleTree. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP + +#include "r_plus_tree_split.hpp" +#include "rectangle_tree.hpp" + +namespace mlpack { +namespace tree { + +template +void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector& relevels) +{ + if (tree->Count() == 1) + { + TreeType* node = tree->Parent(); + + while (node != NULL) + { + if (node->NumChildren() == node->MaxNumChildren() + 1) + { + RPlusTreeSplit::SplitNonLeafNode(node,relevels); + return; + } + node = node->Parent(); + } + return; + } + else if (tree->Count() <= tree->MaxLeafSize()) + return; + // If we are splitting the root node, we need will do things differently so + // that the constructor and other methods don't confuse the end user by giving + // an address of another node. + if (tree->Parent() == NULL) + { + // We actually want to copy this way. Pointers and everything. + TreeType* copy = new TreeType(*tree, false); + copy->Parent() = tree; + tree->Count() = 0; + tree->NullifyData(); + // Because this was a leaf node, numChildren must be 0. + tree->Children()[(tree->NumChildren())++] = copy; + assert(tree->NumChildren() == 1); + + RPlusTreeSplit::SplitLeafNode(copy,relevels); + return; + } + + const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction; + size_t cutAxis; + double cut; + + if ( !PartitionNode(tree, fillFactor, cutAxis, cut)) + return; + + assert(cutAxis < tree->Bound().Dim()); + + TreeType* treeOne = new TreeType(tree->Parent()); + TreeType* treeTwo = new TreeType(tree->Parent()); + treeOne->MinLeafSize() = 0; + treeOne->MinNumChildren() = 0; + treeTwo->MinLeafSize() = 0; + treeTwo->MinNumChildren() = 0; + + SplitLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut); + + TreeType* parent = tree->Parent(); + size_t i = 0; + while (parent->Children()[i] != tree) + i++; + + assert(i < parent->NumChildren()); + + parent->Children()[i] = parent->Children()[--parent->NumChildren()]; + + InsertNodeIntoTree(parent, treeOne); + InsertNodeIntoTree(parent, treeTwo); + + assert(parent->NumChildren() <= parent->MaxNumChildren() + 1); + if (parent->NumChildren() == parent->MaxNumChildren() + 1) + RPlusTreeSplit::SplitNonLeafNode(parent, relevels); + + tree->SoftDelete(); +} + +template +bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, + std::vector& relevels) +{ + // If we are splitting the root node, we need will do things differently so + // that the constructor and other methods don't confuse the end user by giving + // an address of another node. + if (tree->Parent() == NULL) + { + // We actually want to copy this way. Pointers and everything. + TreeType* copy = new TreeType(*tree, false); + + copy->Parent() = tree; + tree->NumChildren() = 0; + tree->NullifyData(); + tree->Children()[(tree->NumChildren())++] = copy; + + RPlusTreeSplit::SplitNonLeafNode(copy,relevels); + return true; + } + const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction; + size_t cutAxis; + double cut; + + if ( !PartitionNode(tree, fillFactor, cutAxis, cut)) + return false; + + assert(cutAxis < tree->Bound().Dim()); + + TreeType* treeOne = new TreeType(tree->Parent()); + TreeType* treeTwo = new TreeType(tree->Parent()); + treeOne->MinLeafSize() = 0; + treeOne->MinNumChildren() = 0; + treeTwo->MinLeafSize() = 0; + treeTwo->MinNumChildren() = 0; + + SplitNonLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut); + + TreeType* parent = tree->Parent(); + size_t i = 0; + while (parent->Children()[i] != tree) + i++; + + assert(i < parent->NumChildren()); + + parent->Children()[i] = parent->Children()[--parent->NumChildren()]; + + InsertNodeIntoTree(parent, treeOne); + InsertNodeIntoTree(parent, treeTwo); + + tree->SoftDelete(); + + assert(parent->NumChildren() <= parent->MaxNumChildren() + 1); + + if (parent->NumChildren() == parent->MaxNumChildren() + 1) + RPlusTreeSplit::SplitNonLeafNode(parent, relevels); + + return false; +} + +template +void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree, + TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) +{ + for (size_t i = 0; i < tree->NumPoints(); i++) + { + if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut) + { + treeOne->Points()[treeOne->Count()++] = tree->Point(i); + treeOne->Bound() |= tree->Dataset().col(tree->Point(i)); + } + else + { + treeTwo->Points()[treeTwo->Count()++] = tree->Point(i); + treeTwo->Bound() |= tree->Dataset().col(tree->Point(i)); + } + } + assert(treeOne->Count() <= treeOne->MaxLeafSize()); + assert(treeTwo->Count() <= treeTwo->MaxLeafSize()); + + assert(tree->Count() == treeOne->Count() + treeTwo->Count()); + assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo()); +} + +template +void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree, + TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) +{ + for (size_t i = 0; i < tree->NumChildren(); i++) + { + TreeType* child = tree->Children()[i]; + if (child->Bound()[cutAxis].Hi() <= cut) + { + InsertNodeIntoTree(treeOne, child); + child->Parent() = treeOne; + } + else if (child->Bound()[cutAxis].Lo() >= cut) + { + InsertNodeIntoTree(treeTwo, child); + child->Parent() = treeTwo; + } + else + { + TreeType* childOne = new TreeType(treeOne); + TreeType* childTwo = new TreeType(treeTwo); + treeOne->MinLeafSize() = 0; + treeOne->MinNumChildren() = 0; + treeTwo->MinLeafSize() = 0; + treeTwo->MinNumChildren() = 0; + + if (child->IsLeaf()) + SplitLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut); + else + SplitNonLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut); + + InsertNodeIntoTree(treeOne, childOne); + InsertNodeIntoTree(treeTwo, childTwo); + + child->SoftDelete(); + } + } + assert(treeOne->NumChildren() <= treeOne->MaxNumChildren()); + assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren()); +} + +template +bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, + size_t cutAxis, double cut) +{ + size_t numTreeOneChildren = 0; + size_t numTreeTwoChildren = 0; + + for (size_t i = 0; i < node->NumChildren(); i++) + { + TreeType* child = node->Children()[i]; + if (child->Bound()[cutAxis].Hi() <= cut) + numTreeOneChildren++; + else if (child->Bound()[cutAxis].Lo() >= cut) + numTreeTwoChildren++; + else + { + numTreeOneChildren++; + numTreeTwoChildren++; + } + } + + if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 && + numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) + return true; + return false; +} + +template +bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node, + size_t cutAxis, double cut) +{ + size_t numTreeOnePoints = 0; + size_t numTreeTwoPoints = 0; + + for (size_t i = 0; i < node->NumPoints(); i++) + { + if (node->Dataset().col(node->Point(i))[cutAxis] <= cut) + numTreeOnePoints++; + else + numTreeTwoPoints++; + } + + if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 && + numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0) + return true; + return false; +} + +template +bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor, + size_t& minCutAxis, double& minCut) +{ + if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) || + (node->Count() <= fillFactor && node->IsLeaf())) + return false; + + double minCost = std::numeric_limits::max(); + minCutAxis = node->Bound().Dim(); + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + double cut; + double cost; + + if (node->IsLeaf()) + cost = SweepLeafNode(k, node, fillFactor, cut); + else + cost = SweepNonLeafNode(k, node, fillFactor, cut); + + + if (cost < minCost) + { + minCost = cost; + minCutAxis = k; + minCut = cut; + } + } + return true; +} + +template +double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, + size_t fillFactor, double& axisCut) +{ + typedef typename TreeType::ElemType ElemType; + + std::vector> sorted(node->NumChildren()); + + for (size_t i = 0; i < node->NumChildren(); i++) + { + sorted[i].d = node->Children()[i]->Bound()[axis].Hi(); + sorted[i].n = i; + } + std::sort(sorted.begin(), sorted.end(), StructComp); + + axisCut = sorted[fillFactor - 1].d; + + if (!CheckNonLeafSweep(node, axis, axisCut)) + return std::numeric_limits::max(); + + std::vector lowerBound1(node->Bound().Dim()); + std::vector highBound1(node->Bound().Dim()); + std::vector lowerBound2(node->Bound().Dim()); + std::vector highBound2(node->Bound().Dim()); + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo(); + highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi(); + + for (size_t i = 1; i < fillFactor; i++) + { + if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k]) + lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); + if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k]) + highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); + } + + lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo(); + highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi(); + + for (size_t i = fillFactor + 1; i < node->NumChildren(); i++) + { + if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k]) + lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); + if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k]) + highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); + } + } + + ElemType area1 = 1.0, area2 = 1.0; + ElemType overlappedArea = 1.0; + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + area1 *= highBound1[k] - lowerBound1[k]; + area2 *= highBound2[k] - lowerBound2[k]; + + if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) + overlappedArea *= 0; + else + overlappedArea *= std::min(highBound1[k], highBound2[k]) - + std::max(lowerBound1[k], lowerBound2[k]); + } + + return area1 + area2 - overlappedArea; +} + +template +double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node, + size_t fillFactor, double& axisCut) +{ + typedef typename TreeType::ElemType ElemType; + + std::vector> sorted(node->Count()); + + sorted.resize(node->Count()); + + for (size_t i = 0; i < node->NumPoints(); i++) + { + sorted[i].d = node->Dataset().col(node->Point(i))[axis]; + sorted[i].n = i; + } + + std::sort(sorted.begin(), sorted.end(), StructComp); + + axisCut = sorted[fillFactor - 1].d; + + if (!CheckLeafSweep(node, axis, axisCut)) + return std::numeric_limits::max(); + + std::vector lowerBound1(node->Bound().Dim()); + std::vector highBound1(node->Bound().Dim()); + std::vector lowerBound2(node->Bound().Dim()); + std::vector highBound2(node->Bound().Dim()); + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; + highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; + + for (size_t i = 1; i < fillFactor; i++) + { + if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k]) + lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k]) + highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + } + + lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k]; + highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k]; + + for (size_t i = fillFactor + 1; i < node->NumChildren(); i++) + { + if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k]) + lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k]) + highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + } + } + + ElemType area1 = 1.0, area2 = 1.0; + ElemType overlappedArea = 1.0; + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + area1 *= highBound1[k] - lowerBound1[k]; + area2 *= highBound2[k] - lowerBound2[k]; + } + + return area1 + area2 - overlappedArea; +} + +template +void RPlusTreeSplit:: +InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode) +{ + destTree->Bound() |= srcNode->Bound(); + destTree->Children()[destTree->NumChildren()++] = srcNode; +} + + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_SPLIT_IMPL_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp index 0ec5c51e4e3..6b8e73d66ac 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp @@ -27,6 +27,9 @@ void RStarTreeSplit::SplitLeafNode(TreeType *tree,std::vector& relevels) // Convenience typedef. typedef typename TreeType::ElemType ElemType; + if (tree->Count() <= tree->MaxLeafSize()) + return; + // If we are splitting the root node, we need will do things differently so // that the constructor and other methods don't confuse the end user by giving // an address of another node. diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp index db67fbedb95..3ad36292fc0 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp @@ -23,6 +23,8 @@ namespace tree { template void RTreeSplit::SplitLeafNode(TreeType *tree,std::vector& relevels) { + if (tree->Count() <= tree->MaxLeafSize()) + return; // If we are splitting the root node, we need will do things differently so // that the constructor and other methods don't confuse the end user by giving // an address of another node. diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp index b087f794e2b..c6d8ac2a540 100644 --- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp @@ -694,8 +694,8 @@ void RectangleTree; +template +using RPlusTree = RectangleTree; } // namespace tree } // namespace mlpack diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp index 89372a450e9..58591d76c0c 100644 --- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp @@ -26,6 +26,9 @@ void XTreeSplit::SplitLeafNode(TreeType *tree,std::vector& relevels) // Convenience typedef. typedef typename TreeType::ElemType ElemType; + if (tree->Count() <= tree->MaxLeafSize()) + return; + // If we are splitting the root node, we need will do things differently so // that the constructor and other methods don't confuse the end user by giving // an address of another node. diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp index 2af8ec33d56..084514707a6 100644 --- a/src/mlpack/tests/rectangle_tree_test.cpp +++ b/src/mlpack/tests/rectangle_tree_test.cpp @@ -763,6 +763,96 @@ BOOST_AUTO_TEST_CASE(DiscreteHilbertValueTest) BOOST_REQUIRE_EQUAL(DiscreteHilbertValue::ComparePoints(point1,point2), 1); } +template +void CheckOverlap(TreeType* tree) +{ + bool success = true; + + for (size_t i = 0; i < tree->NumChildren(); i++) + { + success = true; + + for (size_t j = 0; j < tree->NumChildren(); j++) + { + if (j == i) + continue; + success = false; + for (size_t k = 0; k < tree->Bound().Dim(); k++) + { + if (tree->Children()[i]->Bound()[k].Lo() >= tree->Children()[j]->Bound()[k].Hi() || + tree->Children()[j]->Bound()[k].Lo() >= tree->Children()[i]->Bound()[k].Hi()) + { + success = true; + break; + } + } + if (!success) + break; + } + if (success) + break; + } + assert(success == true); + + for (size_t i = 0; i < tree->NumChildren(); i++) + CheckOverlap(tree->Children()[i]); +} + +BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest) +{ + arma::mat dataset; + dataset.randu(8, 1000); // 1000 points in 8 dimensions. + + typedef RPlusTree,arma::mat> TreeType; + TreeType rPlusTree(dataset, 20, 6, 5, 2, 0); + + CheckOverlap(&rPlusTree); +} + + +BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest) +{ + arma::mat dataset; + + const int numP = 1000; + + dataset.randu(8, numP); // 1000 points in 8 dimensions. + arma::Mat neighbors1; + arma::mat distances1; + arma::Mat neighbors2; + arma::mat distances2; + + typedef RPlusTree, + arma::mat> TreeType; + TreeType rPlusTree(dataset, 20, 6, 5, 2, 0); + + // Nearest neighbor search with the X tree. + + NeighborSearch, arma::mat, RPlusTree > + knn1(&rPlusTree, true); + + BOOST_REQUIRE_EQUAL(rPlusTree.NumDescendants(), numP); + + CheckContainment(rPlusTree); + CheckExactContainment(rPlusTree); + CheckHierarchy(rPlusTree); + CheckOverlap(&rPlusTree); + + knn1.Search(5, neighbors1, distances1); + + // Nearest neighbor search the naive way. + KNN knn2(dataset, true, true); + + knn2.Search(5, neighbors2, distances2); + + for (size_t i = 0; i < neighbors1.size(); i++) + { + BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]); + BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]); + } +} + // Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low // to allow us to test by hand without adding hundreds of points. BOOST_AUTO_TEST_CASE(RTreeSplitTest) From 147617add993a5b3d037aa883412c6ec3f672bbb Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Sun, 19 Jun 2016 16:34:28 +0300 Subject: [PATCH 2/8] R++ tree implementation --- src/mlpack/core/tree/CMakeLists.txt | 6 + src/mlpack/core/tree/rectangle_tree.hpp | 4 + .../no_auxiliary_information.hpp | 7 + ...r_plus_plus_tree_auxiliary_information.hpp | 88 ++++++++++++ ...s_plus_tree_auxiliary_information_impl.hpp | 131 +++++++++++++++++ .../r_plus_plus_tree_descent_heuristic.hpp | 49 +++++++ ..._plus_plus_tree_descent_heuristic_impl.hpp | 47 +++++++ .../r_plus_plus_tree_split_policy.hpp | 46 ++++++ .../tree/rectangle_tree/r_plus_tree_split.hpp | 4 + .../rectangle_tree/r_plus_tree_split_impl.hpp | 133 ++++++++++++++---- .../r_plus_tree_split_policy.hpp | 46 ++++++ .../core/tree/rectangle_tree/typedef.hpp | 9 +- src/mlpack/tests/rectangle_tree_test.cpp | 117 ++++++++++++++- 13 files changed, 658 insertions(+), 29 deletions(-) create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt index 527dd59b2d5..1cefc805a49 100644 --- a/src/mlpack/core/tree/CMakeLists.txt +++ b/src/mlpack/core/tree/CMakeLists.txt @@ -67,6 +67,12 @@ set(SOURCES rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp rectangle_tree/r_plus_tree_split.hpp rectangle_tree/r_plus_tree_split_impl.hpp + rectangle_tree/r_plus_tree_split_policy.hpp + rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp + rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp + rectangle_tree/r_plus_plus_tree_split_policy.hpp + rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp + rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp statistic.hpp traversal_info.hpp tree_traits.hpp diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp index a28cd9f759d..d4183145f89 100644 --- a/src/mlpack/core/tree/rectangle_tree.hpp +++ b/src/mlpack/core/tree/rectangle_tree.hpp @@ -31,7 +31,11 @@ #include "rectangle_tree/recursive_hilbert_value.hpp" #include "rectangle_tree/discrete_hilbert_value.hpp" #include "rectangle_tree/r_plus_tree_descent_heuristic.hpp" +#include "rectangle_tree/r_plus_tree_split_policy.hpp" #include "rectangle_tree/r_plus_tree_split.hpp" +#include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp" +#include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp" +#include "rectangle_tree/r_plus_plus_tree_split_policy.hpp" #include "rectangle_tree/typedef.hpp" #endif diff --git a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp index ac37908b2fb..046075d1261 100644 --- a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp +++ b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp @@ -64,6 +64,13 @@ class NoAuxiliaryInformation return false; } + /** + * Nothing to split. + */ + void SplitAuxiliaryInfo(TreeType* , TreeType* , size_t , + typename TreeType::ElemType) + { } + /** * Nothing to copy. */ diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp new file mode 100644 index 00000000000..926c300e050 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp @@ -0,0 +1,88 @@ +/** + * @file r_plus_plus_tree_auxiliary_information.hpp + * @author Mikhail Lozhnikov + * + * Definition of the RPlusPlusTreeAuxiliaryInformation class, + * a class that provides some r++-tree specific information + * about the nodes. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP + +#include +#include "../hrectbound.hpp" + +namespace mlpack { +namespace tree { + +template +class RPlusPlusTreeAuxiliaryInformation +{ + public: + typedef typename TreeType::ElemType ElemType; + + RPlusPlusTreeAuxiliaryInformation(); + RPlusPlusTreeAuxiliaryInformation(const TreeType* ); + RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& ); + + /** + * Some tree types require to save some properties at the insertion process. + * This method should return false if it does not handle the process. + */ + bool HandlePointInsertion(TreeType* , const size_t); + + /** + * Some tree types require to save some properties at the insertion process. + * This method should return false if it does not handle the process. + */ + bool HandleNodeInsertion(TreeType* , TreeType* ,bool); + + /** + * Some tree types require to save some properties at the deletion process. + * This method should return false if it does not handle the process. + */ + bool HandlePointDeletion(TreeType* , const size_t); + + /** + * Some tree types require to save some properties at the deletion process. + * This method should return false if it does not handle the process. + */ + bool HandleNodeRemoval(TreeType* , const size_t); + + + /** + * Some tree types require to propagate the information downward. + * This method should return false if this is not the case. + */ + bool UpdateAuxiliaryInfo(TreeType* ); + + void SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, + size_t axis, ElemType cut); + + static void Copy(TreeType* ,const TreeType* ); + + void NullifyData(); + + + bound::HRectBound& OuterBound() + { return outerBound; } + + const bound::HRectBound& OuterBound() const + { return outerBound; } + private: + + bound::HRectBound outerBound; + public: + /** + * Serialize the information. + */ + template + void Serialize(Archive &, const unsigned int /* version */); +}; + +} // namespace tree +} // namespace mlpack + +#include "r_plus_plus_tree_auxiliary_information_impl.hpp" + +#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp new file mode 100644 index 00000000000..b6b2aea0e20 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp @@ -0,0 +1,131 @@ +/** + * @file r_plus_plus_tree_auxiliary_information.hpp + * @author Mikhail Lozhnikov + * + * Implementation of the RPlusPlusTreeAuxiliaryInformation class, + * a class that provides some r++-tree specific information + * about the nodes. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP + +#include "r_plus_plus_tree_auxiliary_information.hpp" + +namespace mlpack { +namespace tree { + +template +RPlusPlusTreeAuxiliaryInformation:: +RPlusPlusTreeAuxiliaryInformation() : + outerBound(0) +{ + +} + +template +RPlusPlusTreeAuxiliaryInformation:: +RPlusPlusTreeAuxiliaryInformation(const TreeType* tree) : + outerBound(tree->Parent() ? + tree->Parent()->AuxiliaryInfo().OuterBound() : + tree->Bound().Dim()) +{ + if (!tree->Parent()) + for (size_t k = 0; k < outerBound.Dim(); k++) + { + outerBound[k].Lo() = std::numeric_limits::lowest(); + outerBound[k].Hi() = std::numeric_limits::max(); + } +} + +template +RPlusPlusTreeAuxiliaryInformation:: +RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& other) : + outerBound(other.OuterBound()) +{ + +} + +template +bool RPlusPlusTreeAuxiliaryInformation:: +HandlePointInsertion(TreeType* , const size_t ) +{ + return false; +} + +template +bool RPlusPlusTreeAuxiliaryInformation:: +HandleNodeInsertion(TreeType* , TreeType* ,bool) +{ + assert(false); + return false; +} + +template +bool RPlusPlusTreeAuxiliaryInformation:: +HandlePointDeletion(TreeType* , const size_t) +{ + return false; +} + +template +bool RPlusPlusTreeAuxiliaryInformation:: +HandleNodeRemoval(TreeType* , const size_t) +{ + return false; +} + +template +bool RPlusPlusTreeAuxiliaryInformation:: +UpdateAuxiliaryInfo(TreeType* ) +{ + return false; +} + +template +void RPlusPlusTreeAuxiliaryInformation:: +SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, size_t axis, + typename TreeType::ElemType cut) +{ + typedef bound::HRectBound Bound; + Bound& treeOneBound = treeOne->AuxiliaryInfo().OuterBound(); + Bound& treeTwoBound = treeTwo->AuxiliaryInfo().OuterBound(); + + treeOneBound = outerBound; + treeTwoBound = outerBound; + + treeOneBound[axis].Hi() = cut; + treeTwoBound[axis].Lo() = cut; +} + + +template +void RPlusPlusTreeAuxiliaryInformation:: +Copy(TreeType* dst, const TreeType* src) +{ + dst.OuterBound() = src.OuterBound(); +} + +template +void RPlusPlusTreeAuxiliaryInformation:: +NullifyData() +{ + +} + +/** + * Serialize the information. + */ +template +template +void RPlusPlusTreeAuxiliaryInformation:: +Serialize(Archive& ar, const unsigned int /* version */) +{ + using data::CreateNVP; + + ar & CreateNVP(outerBound, "outerBound"); +} + +} // namespace tree +} // namespace mlpack + +#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp new file mode 100644 index 00000000000..18166f4a4ff --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp @@ -0,0 +1,49 @@ +/** + * @file r_plus_plus_tree_descent_heuristic.hpp + * @author Mikhail Lozhnikov + * + * Definition of RPlusPlusTreeDescentHeuristic, a class that chooses the best child of a + * node in an R++ tree when inserting a new point. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP + +#include + +namespace mlpack { +namespace tree { + +class RPlusPlusTreeDescentHeuristic +{ + public: + /** + * Evaluate the node using a heuristic. Returns the number of the node + * with minimum largest Hilbert value is greater than the Hilbert value of + * the point being inserted. + * + * @param node The node that is being evaluated. + * @param point The number of the point that is being inserted. + */ + template + static size_t ChooseDescentNode(TreeType* node, const size_t point); + + /** + * Evaluate the node using a heuristic. Returns the number of the node + * with minimum largest Hilbert value is greater than the largest + * Hilbert value of the point being inserted. + * + * @param node The node that is being evaluated. + * @param insertedNode The node that is being inserted. + */ + template + static size_t ChooseDescentNode(const TreeType* node, + const TreeType* insertedNode); + +}; + +} // namespace tree +} // namespace mlpack + +#include "r_plus_plus_tree_descent_heuristic_impl.hpp" + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp new file mode 100644 index 00000000000..566f686efcd --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp @@ -0,0 +1,47 @@ +/** + * @file r_plus_plus_tree_descent_heuristic_impl.hpp + * @author Mikhail Lozhnikov + * + * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the best child + * of a node in an R++ tree when inserting a new point. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP + +#include "r_plus_plus_tree_descent_heuristic.hpp" +#include "../hrectbound.hpp" + +namespace mlpack { +namespace tree { + +template +size_t RPlusPlusTreeDescentHeuristic:: +ChooseDescentNode(TreeType* node, const size_t point) +{ + for (size_t bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) + { + if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains(node->Dataset().col(point))) + return bestIndex; + } + + assert(false); + + return 0; +} + +template +size_t RPlusPlusTreeDescentHeuristic:: +ChooseDescentNode(const TreeType* , const TreeType* ) +{ + size_t bestIndex = 0; + + assert(false); + + return bestIndex; +} + + +} // namespace tree +} // namespace mlpack + +#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp new file mode 100644 index 00000000000..e0484afe483 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp @@ -0,0 +1,46 @@ +/** + * @file r_plus_plus_tree_split_policy.hpp + * @author Mikhail Lozhnikov + * + * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class that + * helps to determine the node into which we should insert an intermediate node. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP + +namespace mlpack { +namespace tree { + +class RPlusPlusTreeSplitPolicy +{ + public: + static const int SplitRequired = 0; + static const int AssignToFirstTree = 1; + static const int AssignToSecondTree = 2; + + template + static int GetSplitPolicy(const TreeType* child, size_t axis, + typename TreeType::ElemType cut) + { + if (child->AuxiliaryInfo().OuterBound()[axis].Hi() <= cut) + return AssignToFirstTree; + else if (child->AuxiliaryInfo().OuterBound()[axis].Lo() >= cut) + return AssignToSecondTree; + + return SplitRequired; + } + + template + static const + bound::HRectBound& + Bound(const TreeType* node) + { + return node->AuxiliaryInfo().OuterBound(); + } +}; + +} // namespace tree +} // namespace mlpack +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP + + diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp index f06b813a66a..8c2d0566728 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp @@ -15,6 +15,7 @@ const double fillFactorFraction = 0.5; namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { +template class RPlusTreeSplit { public: @@ -62,6 +63,9 @@ class RPlusTreeSplit static void SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut); + template + static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree); + template static bool PartitionNode(const TreeType* node, size_t fillFactor, size_t& minCutAxis, double& minCut); diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp index fca51fc0eb8..e48487288b6 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp @@ -9,12 +9,16 @@ #include "r_plus_tree_split.hpp" #include "rectangle_tree.hpp" +#include "r_plus_plus_tree_auxiliary_information.hpp" +#include "r_plus_tree_split_policy.hpp" +#include "r_plus_plus_tree_split_policy.hpp" namespace mlpack { namespace tree { +template template -void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector& relevels) +void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector& relevels) { if (tree->Count() == 1) { @@ -88,8 +92,9 @@ void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector& relevels) tree->SoftDelete(); } +template template -bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, +bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector& relevels) { // If we are splitting the root node, we need will do things differently so @@ -148,10 +153,13 @@ bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, return false; } +template template -void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree, +void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) { + tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); + for (size_t i = 0; i < tree->NumPoints(); i++) { if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut) @@ -172,19 +180,24 @@ void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree, assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo()); } +template template -void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree, +void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) { + tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); + for (size_t i = 0; i < tree->NumChildren(); i++) { TreeType* child = tree->Children()[i]; - if (child->Bound()[cutAxis].Hi() <= cut) + int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut); + + if (policy == SplitPolicyType::AssignToFirstTree) { InsertNodeIntoTree(treeOne, child); child->Parent() = treeOne; } - else if (child->Bound()[cutAxis].Lo() >= cut) + else if (policy == SplitPolicyType::AssignToSecondTree) { InsertNodeIntoTree(treeTwo, child); child->Parent() = treeTwo; @@ -209,12 +222,46 @@ void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree, child->SoftDelete(); } } + + assert(treeOne->NumChildren() + treeTwo->NumChildren() != 0); + + if (treeOne->NumChildren() == 0) + AddFakeNodes(treeTwo, treeOne); + else if (treeTwo->NumChildren() == 0) + AddFakeNodes(treeOne, treeTwo); + assert(treeOne->NumChildren() <= treeOne->MaxNumChildren()); assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren()); } +template template -bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, +void RPlusTreeSplit:: +AddFakeNodes(const TreeType* tree, TreeType* emptyTree) +{ + size_t numDescendantNodes = 1; + + TreeType* node = tree->Children()[0]; + + while (!node->IsLeaf()) + { + numDescendantNodes++; + node = node->Children()[0]; + } + + node = emptyTree; + for (size_t i = 0; i < numDescendantNodes; i++) + { + TreeType* child = new TreeType(node); + + node = child; + } +} + + +template +template +bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, size_t cutAxis, double cut) { size_t numTreeOneChildren = 0; @@ -223,9 +270,10 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, for (size_t i = 0; i < node->NumChildren(); i++) { TreeType* child = node->Children()[i]; - if (child->Bound()[cutAxis].Hi() <= cut) + int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut); + if (policy == SplitPolicyType::AssignToFirstTree) numTreeOneChildren++; - else if (child->Bound()[cutAxis].Lo() >= cut) + else if (policy == SplitPolicyType::AssignToSecondTree) numTreeTwoChildren++; else { @@ -240,8 +288,9 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, return false; } +template template -bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node, +bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut) { size_t numTreeOnePoints = 0; @@ -261,8 +310,9 @@ bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node, return false; } +template template -bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor, +bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor, size_t& minCutAxis, double& minCut) { if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) || @@ -293,8 +343,9 @@ bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor, return true; } +template template -double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, +double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, size_t fillFactor, double& axisCut) { typedef typename TreeType::ElemType ElemType; @@ -303,15 +354,27 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, for (size_t i = 0; i < node->NumChildren(); i++) { - sorted[i].d = node->Children()[i]->Bound()[axis].Hi(); + sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi(); sorted[i].n = i; } std::sort(sorted.begin(), sorted.end(), StructComp); - axisCut = sorted[fillFactor - 1].d; + size_t splitPointer = fillFactor; + + axisCut = sorted[splitPointer - 1].d; if (!CheckNonLeafSweep(node, axis, axisCut)) - return std::numeric_limits::max(); + { + for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++) + { + axisCut = sorted[splitPointer - 1].d; + if (CheckNonLeafSweep(node, axis, axisCut)) + break; + } + + if (splitPointer == node->NumChildren()) + return std::numeric_limits::max(); + } std::vector lowerBound1(node->Bound().Dim()); std::vector highBound1(node->Bound().Dim()); @@ -323,7 +386,7 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo(); highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi(); - for (size_t i = 1; i < fillFactor; i++) + for (size_t i = 1; i < splitPointer; i++) { if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k]) lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); @@ -331,10 +394,10 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); } - lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo(); - highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi(); + lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo(); + highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi(); - for (size_t i = fillFactor + 1; i < node->NumChildren(); i++) + for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) { if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k]) lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); @@ -348,21 +411,38 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, for (size_t k = 0; k < node->Bound().Dim(); k++) { - area1 *= highBound1[k] - lowerBound1[k]; - area2 *= highBound2[k] - lowerBound2[k]; + if (lowerBound1[k] >= highBound1[k]) + { + overlappedArea *= 0; + area1 *= 0; + } + else + area1 *= highBound1[k] - lowerBound1[k]; - if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) + if (lowerBound2[k] >= highBound2[k]) + { overlappedArea *= 0; + area1 *= 0; + } else - overlappedArea *= std::min(highBound1[k], highBound2[k]) - - std::max(lowerBound1[k], lowerBound2[k]); + area2 *= highBound2[k] - lowerBound2[k]; + + if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k]) + { + if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) + overlappedArea *= 0; + else + overlappedArea *= std::min(highBound1[k], highBound2[k]) - + std::max(lowerBound1[k], lowerBound2[k]); + } } return area1 + area2 - overlappedArea; } +template template -double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node, +double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node, size_t fillFactor, double& axisCut) { typedef typename TreeType::ElemType ElemType; @@ -426,8 +506,9 @@ double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node, return area1 + area2 - overlappedArea; } +template template -void RPlusTreeSplit:: +void RPlusTreeSplit:: InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode) { destTree->Bound() |= srcNode->Bound(); diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp new file mode 100644 index 00000000000..219eb1a390b --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp @@ -0,0 +1,46 @@ +/** + * @file r_plus_tree_split_policy.hpp + * @author Mikhail Lozhnikov + * + * Defintion and implementation of the RPlusTreeSplitPolicy class, a class that + * helps to determine the node into which we should insert an intermediate node. + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP + +namespace mlpack { +namespace tree { + +class RPlusTreeSplitPolicy +{ + public: + static const int SplitRequired = 0; + static const int AssignToFirstTree = 1; + static const int AssignToSecondTree = 2; + + template + static int GetSplitPolicy(const TreeType* child, size_t axis, + typename TreeType::ElemType cut) + { + if (child->Bound()[axis].Hi() <= cut) + return AssignToFirstTree; + else if (child->Bound()[axis].Lo() >= cut) + return AssignToSecondTree; + + return SplitRequired; + } + + template + static const + bound::HRectBound& + Bound(const TreeType* node) + { + return node->Bound(); + } +}; + +} // namespace tree +} // namespace mlpack +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP + + diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp index 4737d5a9a75..b22c3b105e9 100644 --- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp +++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp @@ -129,10 +129,17 @@ template using RPlusTree = RectangleTree, RPlusTreeDescentHeuristic, NoAuxiliaryInformation>; +template +using RPlusPlusTree = RectangleTree, + RPlusPlusTreeDescentHeuristic, + RPlusPlusTreeAuxiliaryInformation>; } // namespace tree } // namespace mlpack diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp index 084514707a6..3568a78ac21 100644 --- a/src/mlpack/tests/rectangle_tree_test.cpp +++ b/src/mlpack/tests/rectangle_tree_test.cpp @@ -789,10 +789,10 @@ void CheckOverlap(TreeType* tree) if (!success) break; } - if (success) + if (!success) break; } - assert(success == true); + BOOST_REQUIRE_EQUAL(success, true); for (size_t i = 0; i < tree->NumChildren(); i++) CheckOverlap(tree->Children()[i]); @@ -853,6 +853,119 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest) } } +template +void CheckRPlusPlusTreeBound(const TreeType* tree) +{ + typedef bound::HRectBound Bound; + + bool success = true; + + for (size_t k = 0; k < tree->Bound().Dim(); k++) + { + BOOST_REQUIRE_LE(tree->Bound()[k].Hi(), + tree->AuxiliaryInfo().OuterBound()[k].Hi()); + BOOST_REQUIRE_LE(tree->AuxiliaryInfo().OuterBound()[k].Lo(), + tree->Bound()[k].Lo()); + } + + if (tree->IsLeaf()) + { + for (size_t i = 0; i < tree->Count(); i++) + BOOST_REQUIRE_EQUAL(true, + tree->Bound().Contains(tree->Dataset().col(tree->Points()[i]))); + + return; + } + + for (size_t i = 0; i < tree->NumChildren(); i++) + { + const Bound& bound1 = tree->Children()[i]->AuxiliaryInfo().OuterBound(); + success = true; + + for (size_t j = 0; j < tree->NumChildren(); j++) + { + if (j == i) + continue; + const Bound& bound2 = tree->Children()[j]->AuxiliaryInfo().OuterBound(); + + success = false; + for (size_t k = 0; k < tree->Bound().Dim(); k++) + { + if (bound1[k].Lo() >= bound2[k].Hi() || + bound2[k].Lo() >= bound1[k].Hi()) + { + success = true; + break; + } + } + if (!success) + break; + } + if (!success) + break; + } + BOOST_REQUIRE_EQUAL(success, true); + + for (size_t i = 0; i < tree->NumChildren(); i++) + CheckRPlusPlusTreeBound(tree->Children()[i]); +} + +BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) +{ + arma::mat dataset; + dataset.randu(8, 1000); // 1000 points in 8 dimensions. + + typedef RPlusPlusTree,arma::mat> TreeType; + TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); + + CheckRPlusPlusTreeBound(&rPlusPlusTree); +} + +BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) +{ + arma::mat dataset; + + const int numP = 1000; + + dataset.randu(8, numP); // 1000 points in 8 dimensions. + arma::Mat neighbors1; + arma::mat distances1; + arma::Mat neighbors2; + arma::mat distances2; + + typedef RPlusPlusTree, + arma::mat> TreeType; + TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); + + // Nearest neighbor search with the X tree. + + NeighborSearch, + arma::mat, RPlusPlusTree > knn1(&rPlusPlusTree, true); + + BOOST_REQUIRE_EQUAL(rPlusPlusTree.NumDescendants(), numP); + + CheckContainment(rPlusPlusTree); + CheckExactContainment(rPlusPlusTree); + CheckHierarchy(rPlusPlusTree); + CheckRPlusPlusTreeBound(&rPlusPlusTree); + + knn1.Search(5, neighbors1, distances1); + + // Nearest neighbor search the naive way. + KNN knn2(dataset, true, true); + + knn2.Search(5, neighbors2, distances2); + + for (size_t i = 0; i < neighbors1.size(); i++) + { + BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]); + BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]); + } +} + + // Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low // to allow us to test by hand without adding hundreds of points. BOOST_AUTO_TEST_CASE(RTreeSplitTest) From da4b598cc0eacec8807611824887feea1fd397ad Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Thu, 23 Jun 2016 22:59:29 +0300 Subject: [PATCH 3/8] R+ and R++ trees refactoring. Added a template parameter SweepType. Implemented MinimalSplitsNumberSweep. --- src/mlpack/core/tree/CMakeLists.txt | 4 + src/mlpack/core/tree/rectangle_tree.hpp | 2 + .../rectangle_tree/minimal_coverage_sweep.hpp | 64 ++++ .../minimal_coverage_sweep_impl.hpp | 236 +++++++++++++++ .../minimal_splits_number_sweep.hpp | 53 ++++ .../minimal_splits_number_sweep_impl.hpp | 89 ++++++ .../r_plus_tree_descent_heuristic_impl.hpp | 2 +- .../tree/rectangle_tree/r_plus_tree_split.hpp | 39 +-- .../rectangle_tree/r_plus_tree_split_impl.hpp | 285 +++--------------- .../core/tree/rectangle_tree/typedef.hpp | 6 +- src/mlpack/tests/rectangle_tree_test.cpp | 23 +- 11 files changed, 525 insertions(+), 278 deletions(-) create mode 100644 src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp create mode 100644 src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt index 1cefc805a49..0a46a1eaae4 100644 --- a/src/mlpack/core/tree/CMakeLists.txt +++ b/src/mlpack/core/tree/CMakeLists.txt @@ -65,6 +65,10 @@ set(SOURCES rectangle_tree/discrete_hilbert_value_impl.hpp rectangle_tree/r_plus_tree_descent_heuristic.hpp rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp + rectangle_tree/minimal_coverage_sweep.hpp + rectangle_tree/minimal_coverage_sweep_impl.hpp + rectangle_tree/minimal_splits_number_sweep.hpp + rectangle_tree/minimal_splits_number_sweep_impl.hpp rectangle_tree/r_plus_tree_split.hpp rectangle_tree/r_plus_tree_split_impl.hpp rectangle_tree/r_plus_tree_split_policy.hpp diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp index d4183145f89..a31b75b558e 100644 --- a/src/mlpack/core/tree/rectangle_tree.hpp +++ b/src/mlpack/core/tree/rectangle_tree.hpp @@ -32,6 +32,8 @@ #include "rectangle_tree/discrete_hilbert_value.hpp" #include "rectangle_tree/r_plus_tree_descent_heuristic.hpp" #include "rectangle_tree/r_plus_tree_split_policy.hpp" +#include "rectangle_tree/minimal_coverage_sweep.hpp" +#include "rectangle_tree/minimal_splits_number_sweep.hpp" #include "rectangle_tree/r_plus_tree_split.hpp" #include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp" #include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp" diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp new file mode 100644 index 00000000000..4e9ce5a60e2 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp @@ -0,0 +1,64 @@ +/** + * @file minimal_coverage_sweep.hpp + * @author Mikhail Lozhnikov + * + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP + +namespace mlpack { +namespace tree { + +constexpr double fillFactor = 0.5; + +template +class MinimalCoverageSweep +{ + private: + template + struct SortStruct + { + ElemType d; + int n; + }; + + template + static bool StructComp(const SortStruct& s1, + const SortStruct& s2) + { + return s1.d < s2.d; + } + + public: + + template + struct SweepCost + { + typedef typename TreeType::ElemType type; + }; + + template + static typename TreeType::ElemType SweepNonLeafNode(size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut); + + template + static typename TreeType::ElemType SweepLeafNode(size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut); + + template + static bool CheckNonLeafSweep(const TreeType* node, size_t cutAxis, + ElemType cut); + + template + static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, + ElemType cut); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation +#include "minimal_coverage_sweep_impl.hpp" + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP + diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp new file mode 100644 index 00000000000..b300cfa3044 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp @@ -0,0 +1,236 @@ +/** + * @file minimal_coverage_sweep_impl.hpp + * @author Mikhail Lozhnikov + * + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP + +#include "minimal_coverage_sweep.hpp" + +namespace mlpack { +namespace tree { + +template +template +typename TreeType::ElemType MinimalCoverageSweep:: +SweepNonLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut) +{ + typedef typename TreeType::ElemType ElemType; + + std::vector> sorted(node->NumChildren()); + + for (size_t i = 0; i < node->NumChildren(); i++) + { + sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi(); + sorted[i].n = i; + } + std::sort(sorted.begin(), sorted.end(), StructComp); + + size_t splitPointer = fillFactor * node->NumChildren(); + + axisCut = sorted[splitPointer - 1].d; + + if (!CheckNonLeafSweep(node, axis, axisCut)) + { + for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++) + { + axisCut = sorted[splitPointer - 1].d; + if (CheckNonLeafSweep(node, axis, axisCut)) + break; + } + + if (splitPointer == node->NumChildren()) + return std::numeric_limits::max(); + } + + std::vector lowerBound1(node->Bound().Dim()); + std::vector highBound1(node->Bound().Dim()); + std::vector lowerBound2(node->Bound().Dim()); + std::vector highBound2(node->Bound().Dim()); + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo(); + highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi(); + + for (size_t i = 1; i < splitPointer; i++) + { + if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k]) + lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); + if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k]) + highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); + } + + lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo(); + highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi(); + + for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) + { + if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k]) + lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); + if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k]) + highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); + } + } + + ElemType area1 = 1.0, area2 = 1.0; + ElemType overlappedArea = 1.0; + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + if (lowerBound1[k] >= highBound1[k]) + { + overlappedArea *= 0; + area1 *= 0; + } + else + area1 *= highBound1[k] - lowerBound1[k]; + + if (lowerBound2[k] >= highBound2[k]) + { + overlappedArea *= 0; + area1 *= 0; + } + else + area2 *= highBound2[k] - lowerBound2[k]; + + if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k]) + { + if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) + overlappedArea *= 0; + else + overlappedArea *= std::min(highBound1[k], highBound2[k]) - + std::max(lowerBound1[k], lowerBound2[k]); + } + } + + return area1 + area2 - overlappedArea; +} + +template +template +typename TreeType::ElemType MinimalCoverageSweep:: +SweepLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut) +{ + typedef typename TreeType::ElemType ElemType; + + std::vector> sorted(node->Count()); + + sorted.resize(node->Count()); + + for (size_t i = 0; i < node->NumPoints(); i++) + { + sorted[i].d = node->Dataset().col(node->Point(i))[axis]; + sorted[i].n = i; + } + + std::sort(sorted.begin(), sorted.end(), StructComp); + + size_t splitPointer = fillFactor * node->Count(); + + axisCut = sorted[splitPointer - 1].d; + + if (!CheckLeafSweep(node, axis, axisCut)) + return std::numeric_limits::max(); + + std::vector lowerBound1(node->Bound().Dim()); + std::vector highBound1(node->Bound().Dim()); + std::vector lowerBound2(node->Bound().Dim()); + std::vector highBound2(node->Bound().Dim()); + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; + highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; + + for (size_t i = 1; i < splitPointer; i++) + { + if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k]) + lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k]) + highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + } + + lowerBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k]; + highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k]; + + for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) + { + if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k]) + lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k]) + highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; + } + } + + ElemType area1 = 1.0, area2 = 1.0; + ElemType overlappedArea = 1.0; + + for (size_t k = 0; k < node->Bound().Dim(); k++) + { + area1 *= highBound1[k] - lowerBound1[k]; + area2 *= highBound2[k] - lowerBound2[k]; + } + + return area1 + area2 - overlappedArea; +} + +template +template +bool MinimalCoverageSweep:: +CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) +{ + size_t numTreeOneChildren = 0; + size_t numTreeTwoChildren = 0; + + for (size_t i = 0; i < node->NumChildren(); i++) + { + TreeType* child = node->Children()[i]; + int policy = SplitPolicy::GetSplitPolicy(child, cutAxis, cut); + if (policy == SplitPolicy::AssignToFirstTree) + numTreeOneChildren++; + else if (policy == SplitPolicy::AssignToSecondTree) + numTreeTwoChildren++; + else + { + numTreeOneChildren++; + numTreeTwoChildren++; + } + } + + if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 && + numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) + return true; + return false; +} + +template +template +bool MinimalCoverageSweep:: +CheckLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) +{ + size_t numTreeOnePoints = 0; + size_t numTreeTwoPoints = 0; + + for (size_t i = 0; i < node->NumPoints(); i++) + { + if (node->Dataset().col(node->Point(i))[cutAxis] <= cut) + numTreeOnePoints++; + else + numTreeTwoPoints++; + } + + if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 && + numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0) + return true; + return false; +} + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP + diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp new file mode 100644 index 00000000000..60cbe2dd7f2 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp @@ -0,0 +1,53 @@ +/** + * @file minimal_splits_number_sweep.hpp + * @author Mikhail Lozhnikov + * + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP + +namespace mlpack { +namespace tree { + +template +class MinimalSplitsNumberSweep +{ + private: + template + struct SortStruct + { + ElemType d; + int n; + }; + + template + static bool StructComp(const SortStruct& s1, + const SortStruct& s2) + { + return s1.d < s2.d; + } + public: + template + struct SweepCost + { + typedef size_t type; + }; + + template + static size_t SweepNonLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut); + + template + static size_t SweepLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut); +}; + +} // namespace tree +} // namespace mlpack + +// Include implementation +#include "minimal_splits_number_sweep_impl.hpp" + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP + + diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp new file mode 100644 index 00000000000..21c0dcb79b3 --- /dev/null +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp @@ -0,0 +1,89 @@ +/** + * @file minimal_splits_number_sweep_impl.hpp + * @author Mikhail Lozhnikov + * + */ +#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP +#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP + +#include "minimal_splits_number_sweep.hpp" + +namespace mlpack { +namespace tree { + +template +template +size_t MinimalSplitsNumberSweep:: +SweepNonLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut) +{ + typedef typename TreeType::ElemType ElemType; + + std::vector> sorted(node->NumChildren()); + + for (size_t i = 0; i < node->NumChildren(); i++) + { + sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi(); + sorted[i].n = i; + } + std::sort(sorted.begin(), sorted.end(), StructComp); + + size_t minCost = SIZE_MAX; + + for (size_t i = 0; i < sorted.size(); i++) + { + size_t numTreeOneChildren = 0; + size_t numTreeTwoChildren = 0; + size_t numSplits = 0; + + for (size_t j = 0; j < node->NumChildren(); j++) + { + TreeType* child = node->Children()[j]; + int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].d); + if (policy == SplitPolicy::AssignToFirstTree) + numTreeOneChildren++; + else if (policy == SplitPolicy::AssignToSecondTree) + numTreeTwoChildren++; + else + { + numTreeOneChildren++; + numTreeTwoChildren++; + numSplits++; + } + } + + if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 && + numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) + { + size_t cost = numSplits * (std::abs(sorted.size() / 2 - i)); + if (cost < minCost) + { + minCost = cost; + axisCut = sorted[i].d; + } + } + } + return minCost; +} + +template +template +size_t MinimalSplitsNumberSweep:: +SweepLeafNode(size_t axis, const TreeType* node, + typename TreeType::ElemType& axisCut) +{ + axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5; + + if (node->Bound()[axis].Lo() == axisCut) + return SIZE_MAX; + + return 0; +} + + +} // namespace tree +} // namespace mlpack + +#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP + + diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp index 265b739a0d5..3d495f993d6 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp @@ -64,7 +64,7 @@ ChooseDescentNode(TreeType* node, const size_t point) TreeType* tree = node; while (depth > 1) { - TreeType* child = new TreeType(node); + TreeType* child = new TreeType(tree); tree->Children()[tree->NumChildren()++] = child; tree = child; diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp index 8c2d0566728..5d5e2737cf9 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp @@ -10,15 +10,15 @@ #include -const double fillFactorFraction = 0.5; - namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { -template +template class SweepType> class RPlusTreeSplit { public: + typedef SplitPolicyType SplitPolicy; /** * Split a leaf node using the "default" algorithm. If necessary, this split * will propagate upwards through the tree. @@ -48,45 +48,24 @@ class RPlusTreeSplit int n; }; - template - static bool StructComp(const SortStruct& s1, - const SortStruct& s2) - { - return s1.d < s2.d; - } - template - static void SplitLeafNodeAlongPartition(TreeType* tree, - TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut); + static void SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, + TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut); template - static void SplitNonLeafNodeAlongPartition(TreeType* tree, - TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut); + static void SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, + TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut); template static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree); template - static bool PartitionNode(const TreeType* node, size_t fillFactor, - size_t& minCutAxis, double& minCut); - - template - static double SweepLeafNode(size_t axis, const TreeType* node, - size_t fillFactor, double& axisCut); - - template - static double SweepNonLeafNode(size_t axis, const TreeType* node, - size_t fillFactor, double& axisCut); + static bool PartitionNode(const TreeType* node, size_t& minCutAxis, + typename TreeType::ElemType& minCut); template static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode); - template - static bool CheckNonLeafSweep(const TreeType* node, - size_t cutAxis, double cut); - - template - static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut); }; } // namespace tree diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp index e48487288b6..901059578bb 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp @@ -16,9 +16,11 @@ namespace mlpack { namespace tree { -template +template class SweepType> template -void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector& relevels) +void RPlusTreeSplit:: +SplitLeafNode(TreeType* tree, std::vector& relevels) { if (tree->Count() == 1) { @@ -55,11 +57,10 @@ void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector< return; } - const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction; size_t cutAxis; - double cut; + typename TreeType::ElemType cut; - if ( !PartitionNode(tree, fillFactor, cutAxis, cut)) + if ( !PartitionNode(tree, cutAxis, cut)) return; assert(cutAxis < tree->Bound().Dim()); @@ -92,10 +93,11 @@ void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector< tree->SoftDelete(); } -template +template class SweepType> template -bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, - std::vector& relevels) +bool RPlusTreeSplit:: +SplitNonLeafNode(TreeType* tree, std::vector& relevels) { // If we are splitting the root node, we need will do things differently so // that the constructor and other methods don't confuse the end user by giving @@ -113,11 +115,10 @@ bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, RPlusTreeSplit::SplitNonLeafNode(copy,relevels); return true; } - const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction; size_t cutAxis; - double cut; + typename TreeType::ElemType cut; - if ( !PartitionNode(tree, fillFactor, cutAxis, cut)) + if ( !PartitionNode(tree, cutAxis, cut)) return false; assert(cutAxis < tree->Bound().Dim()); @@ -153,10 +154,12 @@ bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree, return false; } -template +template class SweepType> template -void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree, - TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) +void RPlusTreeSplit:: +SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, + TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut) { tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); @@ -180,10 +183,12 @@ void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo()); } -template +template class SweepType> template -void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree, - TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut) +void RPlusTreeSplit:: +SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, + TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut) { tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); @@ -234,9 +239,10 @@ void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* t assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren()); } -template +template class SweepType> template -void RPlusTreeSplit:: +void RPlusTreeSplit:: AddFakeNodes(const TreeType* tree, TreeType* emptyTree) { size_t numDescendantNodes = 1; @@ -258,79 +264,32 @@ AddFakeNodes(const TreeType* tree, TreeType* emptyTree) } } - -template -template -bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node, - size_t cutAxis, double cut) -{ - size_t numTreeOneChildren = 0; - size_t numTreeTwoChildren = 0; - - for (size_t i = 0; i < node->NumChildren(); i++) - { - TreeType* child = node->Children()[i]; - int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut); - if (policy == SplitPolicyType::AssignToFirstTree) - numTreeOneChildren++; - else if (policy == SplitPolicyType::AssignToSecondTree) - numTreeTwoChildren++; - else - { - numTreeOneChildren++; - numTreeTwoChildren++; - } - } - - if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 && - numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) - return true; - return false; -} - -template -template -bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node, - size_t cutAxis, double cut) -{ - size_t numTreeOnePoints = 0; - size_t numTreeTwoPoints = 0; - - for (size_t i = 0; i < node->NumPoints(); i++) - { - if (node->Dataset().col(node->Point(i))[cutAxis] <= cut) - numTreeOnePoints++; - else - numTreeTwoPoints++; - } - - if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 && - numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0) - return true; - return false; -} - -template +template class SweepType> template -bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor, - size_t& minCutAxis, double& minCut) +bool RPlusTreeSplit:: +PartitionNode(const TreeType* node, size_t& minCutAxis, + typename TreeType::ElemType& minCut) { if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) || (node->Count() <= fillFactor && node->IsLeaf())) return false; - double minCost = std::numeric_limits::max(); + typedef typename SweepType::template SweepCost::type + SweepCostType; + + SweepCostType minCost = std::numeric_limits::max(); minCutAxis = node->Bound().Dim(); for (size_t k = 0; k < node->Bound().Dim(); k++) { - double cut; - double cost; + typename TreeType::ElemType cut; + SweepCostType cost; if (node->IsLeaf()) - cost = SweepLeafNode(k, node, fillFactor, cut); + cost = SweepType::SweepLeafNode(k, node, cut); else - cost = SweepNonLeafNode(k, node, fillFactor, cut); + cost = SweepType::SweepNonLeafNode(k, node, cut); if (cost < minCost) @@ -343,172 +302,10 @@ bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t return true; } -template -template -double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node, - size_t fillFactor, double& axisCut) -{ - typedef typename TreeType::ElemType ElemType; - - std::vector> sorted(node->NumChildren()); - - for (size_t i = 0; i < node->NumChildren(); i++) - { - sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi(); - sorted[i].n = i; - } - std::sort(sorted.begin(), sorted.end(), StructComp); - - size_t splitPointer = fillFactor; - - axisCut = sorted[splitPointer - 1].d; - - if (!CheckNonLeafSweep(node, axis, axisCut)) - { - for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++) - { - axisCut = sorted[splitPointer - 1].d; - if (CheckNonLeafSweep(node, axis, axisCut)) - break; - } - - if (splitPointer == node->NumChildren()) - return std::numeric_limits::max(); - } - - std::vector lowerBound1(node->Bound().Dim()); - std::vector highBound1(node->Bound().Dim()); - std::vector lowerBound2(node->Bound().Dim()); - std::vector highBound2(node->Bound().Dim()); - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo(); - highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi(); - - for (size_t i = 1; i < splitPointer; i++) - { - if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k]) - lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); - if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k]) - highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); - } - - lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo(); - highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi(); - - for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) - { - if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k]) - lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo(); - if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k]) - highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi(); - } - } - - ElemType area1 = 1.0, area2 = 1.0; - ElemType overlappedArea = 1.0; - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - if (lowerBound1[k] >= highBound1[k]) - { - overlappedArea *= 0; - area1 *= 0; - } - else - area1 *= highBound1[k] - lowerBound1[k]; - - if (lowerBound2[k] >= highBound2[k]) - { - overlappedArea *= 0; - area1 *= 0; - } - else - area2 *= highBound2[k] - lowerBound2[k]; - - if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k]) - { - if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) - overlappedArea *= 0; - else - overlappedArea *= std::min(highBound1[k], highBound2[k]) - - std::max(lowerBound1[k], lowerBound2[k]); - } - } - - return area1 + area2 - overlappedArea; -} - -template -template -double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node, - size_t fillFactor, double& axisCut) -{ - typedef typename TreeType::ElemType ElemType; - - std::vector> sorted(node->Count()); - - sorted.resize(node->Count()); - - for (size_t i = 0; i < node->NumPoints(); i++) - { - sorted[i].d = node->Dataset().col(node->Point(i))[axis]; - sorted[i].n = i; - } - - std::sort(sorted.begin(), sorted.end(), StructComp); - - axisCut = sorted[fillFactor - 1].d; - - if (!CheckLeafSweep(node, axis, axisCut)) - return std::numeric_limits::max(); - - std::vector lowerBound1(node->Bound().Dim()); - std::vector highBound1(node->Bound().Dim()); - std::vector lowerBound2(node->Bound().Dim()); - std::vector highBound2(node->Bound().Dim()); - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; - highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; - - for (size_t i = 1; i < fillFactor; i++) - { - if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k]) - lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k]) - highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - } - - lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k]; - highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k]; - - for (size_t i = fillFactor + 1; i < node->NumChildren(); i++) - { - if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k]) - lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k]) - highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - } - } - - ElemType area1 = 1.0, area2 = 1.0; - ElemType overlappedArea = 1.0; - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - area1 *= highBound1[k] - lowerBound1[k]; - area2 *= highBound2[k] - lowerBound2[k]; - } - - return area1 + area2 - overlappedArea; -} - -template +template class SweepType> template -void RPlusTreeSplit:: +void RPlusTreeSplit:: InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode) { destTree->Bound() |= srcNode->Bound(); diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp index b22c3b105e9..6371d4d678b 100644 --- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp +++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp @@ -129,7 +129,8 @@ template using RPlusTree = RectangleTree, + RPlusTreeSplit, RPlusTreeDescentHeuristic, NoAuxiliaryInformation>; @@ -137,7 +138,8 @@ template using RPlusPlusTree = RectangleTree, + RPlusTreeSplit, RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation>; } // namespace tree diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp index 3568a78ac21..0aedb5c081c 100644 --- a/src/mlpack/tests/rectangle_tree_test.cpp +++ b/src/mlpack/tests/rectangle_tree_test.cpp @@ -137,7 +137,17 @@ void CheckContainment(const TreeType& tree) for (size_t i = 0; i < tree.NumChildren(); i++) { for (size_t j = 0; j < tree.Bound().Dim(); j++) - BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j])); + { + // All children should be covered by the parent node. + // Some children can be empty (only in case of the R++ tree) + bool success = (tree.Children()[i]->Bound()[j].Hi() == + std::numeric_limits::lowest() && + tree.Children()[i]->Bound()[j].Lo() == + std::numeric_limits::max()) || + tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]); + + BOOST_REQUIRE(success); + } CheckContainment(*(tree.Children()[i])); } @@ -921,6 +931,17 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); CheckRPlusPlusTreeBound(&rPlusPlusTree); + + typedef RectangleTree, arma::mat, + RPlusTreeSplit, + RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation> + RPlusPlusTreeMinimalSplits; + + RPlusPlusTreeMinimalSplits rPlusPlusTree2(dataset, 20, 6, 5, 2, 0); + + CheckRPlusPlusTreeBound(&rPlusPlusTree2); + } BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) From 0f16e54611c229e39757aeb675b42c5a17bdde80 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Wed, 29 Jun 2016 00:45:03 +0300 Subject: [PATCH 4/8] Added some comments and style fixes for the R+/R++ tree. --- .../rectangle_tree/minimal_coverage_sweep.hpp | 72 +++++++++++-- .../minimal_coverage_sweep_impl.hpp | 38 +++++-- .../minimal_splits_number_sweep.hpp | 41 ++++++- .../minimal_splits_number_sweep_impl.hpp | 19 +++- .../no_auxiliary_information.hpp | 18 +++- ...r_plus_plus_tree_auxiliary_information.hpp | 100 ++++++++++++++---- ...s_plus_tree_auxiliary_information_impl.hpp | 47 ++++---- ..._plus_plus_tree_descent_heuristic_impl.hpp | 22 ++-- .../r_plus_plus_tree_split_policy.hpp | 37 ++++++- .../r_plus_tree_descent_heuristic.hpp | 8 +- .../r_plus_tree_descent_heuristic_impl.hpp | 26 +++-- .../tree/rectangle_tree/r_plus_tree_split.hpp | 82 +++++++++++--- .../rectangle_tree/r_plus_tree_split_impl.hpp | 64 +++++++---- .../r_plus_tree_split_policy.hpp | 35 +++++- .../rectangle_tree/rectangle_tree_impl.hpp | 9 +- src/mlpack/tests/rectangle_tree_test.cpp | 88 +++++++++------ 16 files changed, 536 insertions(+), 170 deletions(-) diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp index 4e9ce5a60e2..488d5d8f011 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp @@ -2,6 +2,8 @@ * @file minimal_coverage_sweep.hpp * @author Mikhail Lozhnikov * + * Definition of the MinimalCoverageSweep class, a class that finds a partition + * of a node along an axis. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP @@ -11,10 +13,21 @@ namespace tree { constexpr double fillFactor = 0.5; +/** + * The MinimalCoverageSweep class finds a partition along which we + * can split a node according to the coverage of two resulting nodes. + * Moreover, the class evaluates the cost of each split. + * + * @tparam SplitPolicy The class that provides rules for inserting children of + * a node that is being split into two new subtrees. + */ template class MinimalCoverageSweep { private: + /** + * Class to allow for faster sorting. + */ template struct SortStruct { @@ -22,6 +35,9 @@ class MinimalCoverageSweep int n; }; + /** + * Comparator for sorting with SortStruct. + */ template static bool StructComp(const SortStruct& s1, const SortStruct& s2) @@ -30,28 +46,66 @@ class MinimalCoverageSweep } public: - + //! A struct that provides the type of the sweep cost. template struct SweepCost { typedef typename TreeType::ElemType type; }; + /** + * Find a suitable partition of a non-leaf node along the provided axis. + * The method returns the cost of the split. + * + * @param axis The axis along which we are finding a partition. + * @param node The node that is being split. + * @param axisCut The coordinate at which the node may be split. + */ template - static typename TreeType::ElemType SweepNonLeafNode(size_t axis, - const TreeType* node, typename TreeType::ElemType& axisCut); + static typename TreeType::ElemType SweepNonLeafNode( + const size_t axis, + const TreeType* node, + typename TreeType::ElemType& axisCut); + /** + * Find a suitable partition of a leaf node along the provided axis. + * The method returns the cost of the split. + * + * @param axis The axis along which we are finding a partition. + * @param node The node that is being split. + * @param axisCut The coordinate at which the node may be split. + */ template - static typename TreeType::ElemType SweepLeafNode(size_t axis, - const TreeType* node, typename TreeType::ElemType& axisCut); + static typename TreeType::ElemType SweepLeafNode( + const size_t axis, + const TreeType* node, + typename TreeType::ElemType& axisCut); + /** + * Check if an intermediate node can be split along the axis at the provided + * coordinate. + * + * @param node The node that is being split. + * @param cutAxis The axis that we want to check. + * @param cut The coordinate that we want to check. + */ template - static bool CheckNonLeafSweep(const TreeType* node, size_t cutAxis, - ElemType cut); + static bool CheckNonLeafSweep(const TreeType* node, + const size_t cutAxis, + const ElemType cut); + /** + * Check if a leaf node can be split along the axis at the provided + * coordinate. + * + * @param node The node that is being split. + * @param cutAxis The axis that we want to check. + * @param cut The coordinate that we want to check. + */ template - static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, - ElemType cut); + static bool CheckLeafSweep(const TreeType* node, + const size_t cutAxis, + const ElemType cut); }; } // namespace tree diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp index b300cfa3044..50b04b9a0ef 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp @@ -2,6 +2,8 @@ * @file minimal_coverage_sweep_impl.hpp * @author Mikhail Lozhnikov * + * Implementation of the MinimalCoverageSweep class, a class that finds a + * partition of a node along an axis. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP @@ -14,8 +16,9 @@ namespace tree { template template typename TreeType::ElemType MinimalCoverageSweep:: -SweepNonLeafNode(size_t axis, const TreeType* node, - typename TreeType::ElemType& axisCut) +SweepNonLeafNode(const size_t axis, + const TreeType* node, + typename TreeType::ElemType& axisCut) { typedef typename TreeType::ElemType ElemType; @@ -26,12 +29,14 @@ SweepNonLeafNode(size_t axis, const TreeType* node, sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi(); sorted[i].n = i; } + // Sort high bounds of children. std::sort(sorted.begin(), sorted.end(), StructComp); size_t splitPointer = fillFactor * node->NumChildren(); axisCut = sorted[splitPointer - 1].d; + // Check if the partition is suitable. if (!CheckNonLeafSweep(node, axis, axisCut)) { for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++) @@ -50,6 +55,7 @@ SweepNonLeafNode(size_t axis, const TreeType* node, std::vector lowerBound2(node->Bound().Dim()); std::vector highBound2(node->Bound().Dim()); + // Find lower and high bounds of two resulting nodes. for (size_t k = 0; k < node->Bound().Dim(); k++) { lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo(); @@ -75,6 +81,9 @@ SweepNonLeafNode(size_t axis, const TreeType* node, } } + // Evaluate the cost of the split i.e. calculate the total coverage + // of two resulting nodes. + ElemType area1 = 1.0, area2 = 1.0; ElemType overlappedArea = 1.0; @@ -112,8 +121,9 @@ SweepNonLeafNode(size_t axis, const TreeType* node, template template typename TreeType::ElemType MinimalCoverageSweep:: -SweepLeafNode(size_t axis, const TreeType* node, - typename TreeType::ElemType& axisCut) +SweepLeafNode(const size_t axis, + const TreeType* node, + typename TreeType::ElemType& axisCut) { typedef typename TreeType::ElemType ElemType; @@ -127,12 +137,14 @@ SweepLeafNode(size_t axis, const TreeType* node, sorted[i].n = i; } + // Sort high bounds of children. std::sort(sorted.begin(), sorted.end(), StructComp); size_t splitPointer = fillFactor * node->Count(); axisCut = sorted[splitPointer - 1].d; + // Check if the partition is suitable. if (!CheckLeafSweep(node, axis, axisCut)) return std::numeric_limits::max(); @@ -141,6 +153,7 @@ SweepLeafNode(size_t axis, const TreeType* node, std::vector lowerBound2(node->Bound().Dim()); std::vector highBound2(node->Bound().Dim()); + // Find lower and high bounds of two resulting nodes. for (size_t k = 0; k < node->Bound().Dim(); k++) { lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; @@ -154,7 +167,8 @@ SweepLeafNode(size_t axis, const TreeType* node, highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; } - lowerBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k]; + lowerBound2[k] = node->Dataset().col( + node->Point(sorted[splitPointer].n))[k]; highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k]; for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) @@ -166,6 +180,9 @@ SweepLeafNode(size_t axis, const TreeType* node, } } + // Evaluate the cost of the split i.e. calculate the total coverage + // of two resulting nodes. + ElemType area1 = 1.0, area2 = 1.0; ElemType overlappedArea = 1.0; @@ -181,11 +198,14 @@ SweepLeafNode(size_t axis, const TreeType* node, template template bool MinimalCoverageSweep:: -CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) +CheckNonLeafSweep(const TreeType* node, + const size_t cutAxis, + const ElemType cut) { size_t numTreeOneChildren = 0; size_t numTreeTwoChildren = 0; + // Calculate the number of children in the resulting nodes. for (size_t i = 0; i < node->NumChildren(); i++) { TreeType* child = node->Children()[i]; @@ -196,6 +216,7 @@ CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) numTreeTwoChildren++; else { + // The split is required. numTreeOneChildren++; numTreeTwoChildren++; } @@ -210,11 +231,14 @@ CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) template template bool MinimalCoverageSweep:: -CheckLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut) +CheckLeafSweep(const TreeType* node, + const size_t cutAxis, + const ElemType cut) { size_t numTreeOnePoints = 0; size_t numTreeTwoPoints = 0; + // Calculate the number of points in the resulting nodes. for (size_t i = 0; i < node->NumPoints(); i++) { if (node->Dataset().col(node->Point(i))[cutAxis] <= cut) diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp index 60cbe2dd7f2..0470f522445 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp @@ -2,6 +2,8 @@ * @file minimal_splits_number_sweep.hpp * @author Mikhail Lozhnikov * + * Definition of the MinimalSplitsNumberSweep class, a class that finds a + * partition of a node along an axis. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP @@ -9,10 +11,21 @@ namespace mlpack { namespace tree { +/** + * The MinimalSplitsNumberSweep class finds a partition along which we + * can split a node according to the number of required splits of the node. + * Moreover, the class evaluates the cost of each split. + * + * @tparam SplitPolicy The class that provides rules for inserting children of + * a node that is being split into two new subtrees. + */ template class MinimalSplitsNumberSweep { private: + /** + * Class to allow for faster sorting. + */ template struct SortStruct { @@ -20,6 +33,9 @@ class MinimalSplitsNumberSweep int n; }; + /** + * Comparator for sorting with SortStruct. + */ template static bool StructComp(const SortStruct& s1, const SortStruct& s2) @@ -27,18 +43,39 @@ class MinimalSplitsNumberSweep return s1.d < s2.d; } public: + //! A struct that provides the type of the sweep cost. template struct SweepCost { typedef size_t type; }; + /** + * Find a suitable partition of a non-leaf node along the provided axis. + * The method returns the cost of the split. + * + * @param axis The axis along which we are finding a partition. + * @param node The node that is being split. + * @param axisCut The coordinate at which the node may be split. + */ template - static size_t SweepNonLeafNode(size_t axis, const TreeType* node, + static size_t SweepNonLeafNode( + const size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut); + /** + * Find a suitable partition of a leaf node along the provided axis. + * The method returns the cost of the split. + * + * @param axis The axis along which we are finding a partition. + * @param node The node that is being split. + * @param axisCut The coordinate at which the node may be split. + */ template - static size_t SweepLeafNode(size_t axis, const TreeType* node, + static size_t SweepLeafNode( + const size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut); }; diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp index 21c0dcb79b3..6328abba680 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp @@ -2,6 +2,8 @@ * @file minimal_splits_number_sweep_impl.hpp * @author Mikhail Lozhnikov * + * Implementation of the MinimalSplitsNumberSweep class, a class that finds a + * partition of a node along an axis. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP @@ -13,8 +15,9 @@ namespace tree { template template -size_t MinimalSplitsNumberSweep:: -SweepNonLeafNode(size_t axis, const TreeType* node, +size_t MinimalSplitsNumberSweep::SweepNonLeafNode( + const size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut) { typedef typename TreeType::ElemType ElemType; @@ -26,16 +29,20 @@ SweepNonLeafNode(size_t axis, const TreeType* node, sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi(); sorted[i].n = i; } + + // Sort candidates in order to check balancing. std::sort(sorted.begin(), sorted.end(), StructComp); size_t minCost = SIZE_MAX; + // Find a split with the minimal cost. for (size_t i = 0; i < sorted.size(); i++) { size_t numTreeOneChildren = 0; size_t numTreeTwoChildren = 0; size_t numSplits = 0; + // Calculate the number of splits. for (size_t j = 0; j < node->NumChildren(); j++) { TreeType* child = node->Children()[j]; @@ -52,9 +59,11 @@ SweepNonLeafNode(size_t axis, const TreeType* node, } } + // Check if the split is possible. if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 && numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) { + // Evaluate the cost using the number of splits and balancing. size_t cost = numSplits * (std::abs(sorted.size() / 2 - i)); if (cost < minCost) { @@ -68,10 +77,12 @@ SweepNonLeafNode(size_t axis, const TreeType* node, template template -size_t MinimalSplitsNumberSweep:: -SweepLeafNode(size_t axis, const TreeType* node, +size_t MinimalSplitsNumberSweep::SweepLeafNode( + const size_t axis, + const TreeType* node, typename TreeType::ElemType& axisCut) { + // Split along the median. axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5; if (node->Bound()[axis].Lo() == axisCut) diff --git a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp index e3b5fd455b0..c07a4145ca9 100644 --- a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp +++ b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp @@ -32,7 +32,7 @@ class NoAuxiliaryInformation * @param node The node in which the point is being inserted. * @param point The global number of the point being inserted. */ - bool HandlePointInsertion(TreeType* , const size_t) + bool HandlePointInsertion(TreeType* /* node */, const size_t /* point */) { return false; } @@ -98,6 +98,22 @@ class NoAuxiliaryInformation return false; } + /** + * The R++ tree requires to split the maximum bounding rectangle of a node + * that is being split. This method is intended for that. + * + * @param treeOne The first subtree. + * @param treeTwo The second subtree. + * @param axis The axis along which the split is performed. + * @param cut The coordinate at which the node is split. + */ + void SplitAuxiliaryInfo(TreeType* /* treeOne */, + TreeType* /* treeTwo */, + size_t /* axis */, + typename TreeType::ElemType /* cut */) + { } + + /** * Nullify the auxiliary information in order to prevent an invalid free. */ diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp index 926c300e050..d791fa89e95 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp @@ -19,58 +19,120 @@ template class RPlusPlusTreeAuxiliaryInformation { public: + //! The element type held by the tree. typedef typename TreeType::ElemType ElemType; + //! Construct the auxiliary information object. RPlusPlusTreeAuxiliaryInformation(); - RPlusPlusTreeAuxiliaryInformation(const TreeType* ); - RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& ); + + /** + * Construct this as an auxiliary information for the given node. + * + * @param node The node that stores this auxiliary information. + */ + RPlusPlusTreeAuxiliaryInformation(const TreeType* /* node */); + + /** + * Create an auxiliary information object by copying from another node. + * + * @param other The auxiliary information object from which the information + * will be copied. + */ + RPlusPlusTreeAuxiliaryInformation( + const RPlusPlusTreeAuxiliaryInformation& other); /** * Some tree types require to save some properties at the insertion process. - * This method should return false if it does not handle the process. + * This method allows the auxiliary information the option of manipulating + * the tree in order to perform the insertion process. If the auxiliary + * information does that, then the method should return true; if the method + * returns false the RectangleTree performs its default behavior. + * + * @param node The node in which the point is being inserted. + * @param point The global number of the point being inserted. */ - bool HandlePointInsertion(TreeType* , const size_t); + bool HandlePointInsertion(TreeType* /* node */, const size_t /* point */); /** * Some tree types require to save some properties at the insertion process. - * This method should return false if it does not handle the process. + * This method allows the auxiliary information the option of manipulating + * the tree in order to perform the insertion process. If the auxiliary + * information does that, then the method should return true; if the method + * returns false the RectangleTree performs its default behavior. + * + * @param node The node in which the nodeToInsert is being inserted. + * @param nodeToInsert The node being inserted. + * @param insertionLevel The level of the tree at which the nodeToInsert + * should be inserted. */ - bool HandleNodeInsertion(TreeType* , TreeType* ,bool); + bool HandleNodeInsertion(TreeType* /* node */, + TreeType* /* nodeToInsert */, + bool /* insertionLevel */); /** * Some tree types require to save some properties at the deletion process. - * This method should return false if it does not handle the process. + * This method allows the auxiliary information the option of manipulating + * the tree in order to perform the deletion process. If the auxiliary + * information does that, then the method should return true; if the method + * returns false the RectangleTree performs its default behavior. + * + * @param node The node from which the point is being deleted. + * @param localIndex The local index of the point being deleted. */ - bool HandlePointDeletion(TreeType* , const size_t); + bool HandlePointDeletion(TreeType* /* node */, const size_t /* localIndex */); /** * Some tree types require to save some properties at the deletion process. - * This method should return false if it does not handle the process. + * This method allows the auxiliary information the option of manipulating + * the tree in order to perform the deletion process. If the auxiliary + * information does that, then the method should return true; if the method + * returns false the RectangleTree performs its default behavior. + * + * @param node The node from which the node is being deleted. + * @param nodeIndex The local index of the node being deleted. */ - bool HandleNodeRemoval(TreeType* , const size_t); + bool HandleNodeRemoval(TreeType* /* node */, const size_t /* nodeIndex */); /** - * Some tree types require to propagate the information downward. - * This method should return false if this is not the case. + * Some tree types require to propagate the information upward. + * This method should return false if this is not the case. If true is + * returned, the update will be propogated upward. + * + * @param node The node in which the auxiliary information being update. */ - bool UpdateAuxiliaryInfo(TreeType* ); + bool UpdateAuxiliaryInfo(TreeType* /* node */); - void SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, - size_t axis, ElemType cut); + /** + * The R++ tree requires to split the maximum bounding rectangle of a node + * that is being split. This method is intended for that. + * + * @param treeOne The first subtree. + * @param treeTwo The second subtree. + * @param axis The axis along which the split is performed. + * @param cut The coordinate at which the node is split. + */ + void SplitAuxiliaryInfo(TreeType* treeOne, + TreeType* treeTwo, + const size_t axis, + const ElemType cut); - static void Copy(TreeType* ,const TreeType* ); + /** + * Nullify the auxiliary information in order to prevent an invalid free. + */ void NullifyData(); - + //! Return the maximum bounding rectangle. bound::HRectBound& OuterBound() { return outerBound; } - const bound::HRectBound& OuterBound() const + //! Modify the maximum bounding rectangle. + const bound::HRectBound& + OuterBound() const { return outerBound; } private: - + //! The maximum bounding rectangle. bound::HRectBound outerBound; public: /** diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp index b6b2aea0e20..683645daf4b 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp @@ -29,6 +29,7 @@ RPlusPlusTreeAuxiliaryInformation(const TreeType* tree) : tree->Parent()->AuxiliaryInfo().OuterBound() : tree->Bound().Dim()) { + // Initialize the maximum bounding rectangle if the node is the root if (!tree->Parent()) for (size_t k = 0; k < outerBound.Dim(); k++) { @@ -39,75 +40,73 @@ RPlusPlusTreeAuxiliaryInformation(const TreeType* tree) : template RPlusPlusTreeAuxiliaryInformation:: -RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& other) : +RPlusPlusTreeAuxiliaryInformation( + const RPlusPlusTreeAuxiliaryInformation& other) : outerBound(other.OuterBound()) { } template -bool RPlusPlusTreeAuxiliaryInformation:: -HandlePointInsertion(TreeType* , const size_t ) +bool RPlusPlusTreeAuxiliaryInformation::HandlePointInsertion( + TreeType* /* node */, const size_t /* point */) { return false; } template -bool RPlusPlusTreeAuxiliaryInformation:: -HandleNodeInsertion(TreeType* , TreeType* ,bool) +bool RPlusPlusTreeAuxiliaryInformation::HandleNodeInsertion( + TreeType* /* node */, + TreeType* /* nodeToInsert */, + bool /* insertionLevel */) { assert(false); return false; } template -bool RPlusPlusTreeAuxiliaryInformation:: -HandlePointDeletion(TreeType* , const size_t) +bool RPlusPlusTreeAuxiliaryInformation::HandlePointDeletion( + TreeType* /* node */, const size_t /* localIndex */) { return false; } template -bool RPlusPlusTreeAuxiliaryInformation:: -HandleNodeRemoval(TreeType* , const size_t) +bool RPlusPlusTreeAuxiliaryInformation::HandleNodeRemoval( + TreeType* /* node */, const size_t /* nodeIndex */) { return false; } template -bool RPlusPlusTreeAuxiliaryInformation:: -UpdateAuxiliaryInfo(TreeType* ) +bool RPlusPlusTreeAuxiliaryInformation::UpdateAuxiliaryInfo( + TreeType* /* node */) { return false; } template -void RPlusPlusTreeAuxiliaryInformation:: -SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, size_t axis, - typename TreeType::ElemType cut) +void RPlusPlusTreeAuxiliaryInformation::SplitAuxiliaryInfo( + TreeType* treeOne, + TreeType* treeTwo, + const size_t axis, + const typename TreeType::ElemType cut) { typedef bound::HRectBound Bound; Bound& treeOneBound = treeOne->AuxiliaryInfo().OuterBound(); Bound& treeTwoBound = treeTwo->AuxiliaryInfo().OuterBound(); + // Copy the maximum bounding rectangle treeOneBound = outerBound; treeTwoBound = outerBound; + // Set proper limits treeOneBound[axis].Hi() = cut; treeTwoBound[axis].Lo() = cut; } - -template -void RPlusPlusTreeAuxiliaryInformation:: -Copy(TreeType* dst, const TreeType* src) -{ - dst.OuterBound() = src.OuterBound(); -} - template -void RPlusPlusTreeAuxiliaryInformation:: -NullifyData() +void RPlusPlusTreeAuxiliaryInformation::NullifyData() { } diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp index 566f686efcd..c2a5456b5ae 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp @@ -2,8 +2,8 @@ * @file r_plus_plus_tree_descent_heuristic_impl.hpp * @author Mikhail Lozhnikov * - * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the best child - * of a node in an R++ tree when inserting a new point. + * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the + * best child of a node in an R++ tree when inserting a new point. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP @@ -15,29 +15,31 @@ namespace mlpack { namespace tree { template -size_t RPlusPlusTreeDescentHeuristic:: -ChooseDescentNode(TreeType* node, const size_t point) +size_t RPlusPlusTreeDescentHeuristic::ChooseDescentNode( + TreeType* node, const size_t point) { + // Find the node whose maximum bounding rectangle contains the point. for (size_t bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) { - if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains(node->Dataset().col(point))) + if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains( + node->Dataset().col(point))) return bestIndex; } + // We should never reach this point. assert(false); return 0; } template -size_t RPlusPlusTreeDescentHeuristic:: -ChooseDescentNode(const TreeType* , const TreeType* ) +size_t RPlusPlusTreeDescentHeuristic::ChooseDescentNode( + const TreeType* /* node */, const TreeType* /* insertedNode */) { - size_t bestIndex = 0; - + // Should never be used. assert(false); - return bestIndex; + return 0; } diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp index e0484afe483..f4c6e3ef961 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp @@ -2,8 +2,9 @@ * @file r_plus_plus_tree_split_policy.hpp * @author Mikhail Lozhnikov * - * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class that - * helps to determine the node into which we should insert an intermediate node. + * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class + * that helps to determine the subtree into which we should insert an + * intermediate node. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP @@ -11,16 +12,37 @@ namespace mlpack { namespace tree { +/** + * The RPlusPlusTreeSplitPolicy helps to determine the subtree into which + * we should insert a child of an intermediate node that is being split. + * This class is designed for the R++ tree. + */ class RPlusPlusTreeSplitPolicy { public: + //! Indicate that the child should be split. static const int SplitRequired = 0; + //! Indicate that the child should be inserted to the first subtree. static const int AssignToFirstTree = 1; + //! Indicate that the child should be inserted to the second subtree. static const int AssignToSecondTree = 2; + /** + * This method returns SplitRequired if a child of an intermediate node should + * be split, AssignToFirstTree if the child should be inserted to the first + * subtree, AssignToSecondTree if the child should be inserted to the second + * subtree. The method makes desicion according to the maximum bounding + * rectangle of the child, the axis along which the intermediate node is being + * split and the coordinate at which the node is being split. + * + * @param child A child of the node that is being split. + * @param axis The axis along which the node is being split. + * @param cut The coordinate at which the node is being split. + */ template - static int GetSplitPolicy(const TreeType* child, size_t axis, - typename TreeType::ElemType cut) + static int GetSplitPolicy(const TreeType* child, + const size_t axis, + const typename TreeType::ElemType cut) { if (child->AuxiliaryInfo().OuterBound()[axis].Hi() <= cut) return AssignToFirstTree; @@ -30,6 +52,13 @@ class RPlusPlusTreeSplitPolicy return SplitRequired; } + /** + * Return the maximum bounding rectangle of the node. + * This method should always return the bound that is used for the + * desicion-making in GetSplitPolicy(). + * + * @param node The node whose bound is requested. + */ template static const bound::HRectBound& diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp index dfe8e0acc9e..219c85b7b19 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp @@ -2,8 +2,8 @@ * @file r_plus_tree_descent_heuristic.hpp * @author Mikhail Lozhnikov * - * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a - * node in an R tree when inserting a new point. + * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child + * of a node in an R+ tree when inserting a new point. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP @@ -36,8 +36,8 @@ class RPlusTreeDescentHeuristic * @param insertedNode The node that is being inserted. */ template - static size_t ChooseDescentNode(const TreeType* node, - const TreeType* insertedNode); + static size_t ChooseDescentNode(const TreeType* /* node */, + const TreeType* /*insertedNode */); }; diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp index 3d495f993d6..afe3879bc89 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp @@ -22,12 +22,16 @@ ChooseDescentNode(TreeType* node, const size_t point) size_t bestIndex = 0; bool success; + // Try to find a node that contains the point. for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) { - if (node->Children()[bestIndex]->Bound().Contains(node->Dataset().col(point))) + if (node->Children()[bestIndex]->Bound().Contains( + node->Dataset().col(point))) return bestIndex; } + // No one node contains the point. Try to enlarge a node in such a way, that + // the resulting node do not overlap other nodes. for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++) { bound::HRectBound bound = @@ -41,26 +45,31 @@ ChooseDescentNode(TreeType* node, const size_t point) if (j == bestIndex) continue; success = false; + // Two nodes overlap if and only if there are no dimension in which + // they do not overlap each other. for (size_t k = 0; k < node->Bound().Dim(); k++) { if (bound[k].Lo() >= node->Children()[j]->Bound()[k].Hi() || node->Children()[j]->Bound()[k].Lo() >= bound[k].Hi()) { + // We found the dimension in which these nodes do not overlap + // each other. success = true; break; } } - if (!success) + if (!success) // These two nodes overlap each other. break; } - if (success) + if (success) // We found two nodes that do no overlap each other. break; } - if (!success) + if (!success) // We could not find two nodes that do no overlap each other. { size_t depth = node->TreeDepth(); + // Create a new node into which we will insert the point. TreeType* tree = node; while (depth > 1) { @@ -79,14 +88,13 @@ ChooseDescentNode(TreeType* node, const size_t point) } template -size_t RPlusTreeDescentHeuristic:: -ChooseDescentNode(const TreeType* node, const TreeType* insertedNode) +size_t RPlusTreeDescentHeuristic::ChooseDescentNode( + const TreeType* /* node */, const TreeType* /*insertedNode */) { - size_t bestIndex = 0; - + // Should never be used. assert(false); - return bestIndex; + return 0; } diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp index 5d5e2737cf9..0d0c672fb4e 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp @@ -13,6 +13,13 @@ namespace mlpack { namespace tree /** Trees and tree-building procedures. */ { +/** + * The RPlusTreeSplit class performs the split process of a node on overflow. + * + * @tparam SplitPolicyType The class that helps to determine the subtree into + * which we should insert a child node. + * @tparam SweepType The class that sweeps a node along an axis. + */ template class SweepType> class RPlusTreeSplit @@ -37,32 +44,75 @@ class RPlusTreeSplit template static bool SplitNonLeafNode(TreeType *tree,std::vector& relevels); - - private: - - template - struct SortStruct - { - ElemType d; - int n; - }; - + /** + * Split a leaf node along an axis. + * + * @param tree The node that is being split into two new nodes. + * @param treeOne The first subtree of two resulting subtrees. + * @param treeOne The second subtree of two resulting subtrees. + * @param cutAxis The axis along which the node is being split. + * @param cut The coordinate at which the node is being split. + */ template - static void SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, - TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut); + static void SplitLeafNodeAlongPartition( + TreeType* tree, + TreeType* treeOne, + TreeType* treeTwo, + const size_t cutAxis, + const typename TreeType::ElemType cut); + /** + * Split a non-leaf node along an axis. This method propagates the split + * downward up to a leaf node if necessary. + * + * @param tree The node that is being split into two new nodes. + * @param treeOne The first subtree of two resulting subtrees. + * @param treeOne The second subtree of two resulting subtrees. + * @param cutAxis The axis along which the node is being split. + * @param cut The coordinate at which the node is being split. + */ template - static void SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, - TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut); + static void SplitNonLeafNodeAlongPartition( + TreeType* tree, + TreeType* treeOne, + TreeType* treeTwo, + const size_t cutAxis, + const typename TreeType::ElemType cut); + /** + * This method should be invoked in order to make the tree balanced if + * one of two resulting subtrees is empty after the split process + * (i.e. the subtree contains no children). + * The method convert the empty node into an empty subtree (increase the node + * in depth). + * + * @param tree One of two subtrees that is not empty. + * @param emptyTree The empty subtree. + */ template static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree); + /** + * Partition a node using SweepType. This method invokes + * SweepType::Sweep(Non)LeafNode() for each dimension and chooses the + * best one. The method returns false if the node needn't partitioning. + * Overwise, the method returns true. If the method failed in finding + * an acceptable partition, the minCutAxis will be equal to the number of + * dimensions. + * + * @param node The node that is being split. + * @param minCutAxis The axis along which the node will be split. + * @param minCut The coordinate at which the node will be split. + */ template - static bool PartitionNode(const TreeType* node, size_t& minCutAxis, - typename TreeType::ElemType& minCut); + static bool PartitionNode(const TreeType* node, + size_t& minCutAxis, + typename TreeType::ElemType& minCut); + /** + * Insert a node into another node. + */ template static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode); diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp index 901059578bb..2cca91c576e 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp @@ -24,12 +24,16 @@ SplitLeafNode(TreeType* tree, std::vector& relevels) { if (tree->Count() == 1) { + // Check if an intermediate node was added during the insertion process. + // i.e. we couldn't enlarge a node of the R+ tree. So, one of intermediate + // nodes may be overflowed. TreeType* node = tree->Parent(); while (node != NULL) { if (node->NumChildren() == node->MaxNumChildren() + 1) { + // Split the overflowed node. RPlusTreeSplit::SplitNonLeafNode(node,relevels); return; } @@ -60,6 +64,7 @@ SplitLeafNode(TreeType* tree, std::vector& relevels) size_t cutAxis; typename TreeType::ElemType cut; + // Try to find a partiotion of the node. if ( !PartitionNode(tree, cutAxis, cut)) return; @@ -72,6 +77,7 @@ SplitLeafNode(TreeType* tree, std::vector& relevels) treeTwo->MinLeafSize() = 0; treeTwo->MinNumChildren() = 0; + // Split the node into two new nodes. SplitLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut); TreeType* parent = tree->Parent(); @@ -81,12 +87,16 @@ SplitLeafNode(TreeType* tree, std::vector& relevels) assert(i < parent->NumChildren()); + // Remove the node from the tree. parent->Children()[i] = parent->Children()[--parent->NumChildren()]; + // Insert two new nodes to the tree. InsertNodeIntoTree(parent, treeOne); InsertNodeIntoTree(parent, treeTwo); assert(parent->NumChildren() <= parent->MaxNumChildren() + 1); + + // Propagate the split upward if necessary. if (parent->NumChildren() == parent->MaxNumChildren() + 1) RPlusTreeSplit::SplitNonLeafNode(parent, relevels); @@ -118,6 +128,7 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) size_t cutAxis; typename TreeType::ElemType cut; + // Try to find a partiotion of the node. if ( !PartitionNode(tree, cutAxis, cut)) return false; @@ -130,6 +141,7 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) treeTwo->MinLeafSize() = 0; treeTwo->MinNumChildren() = 0; + // Split the node into two new nodes. SplitNonLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut); TreeType* parent = tree->Parent(); @@ -139,8 +151,10 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) assert(i < parent->NumChildren()); + // Remove the node from the tree. parent->Children()[i] = parent->Children()[--parent->NumChildren()]; + // Insert two new nodes to the tree. InsertNodeIntoTree(parent, treeOne); InsertNodeIntoTree(parent, treeTwo); @@ -148,6 +162,7 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) assert(parent->NumChildren() <= parent->MaxNumChildren() + 1); + // Propagate the split upward if necessary. if (parent->NumChildren() == parent->MaxNumChildren() + 1) RPlusTreeSplit::SplitNonLeafNode(parent, relevels); @@ -157,22 +172,27 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) template class SweepType> template -void RPlusTreeSplit:: -SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, - TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut) +void RPlusTreeSplit::SplitLeafNodeAlongPartition( + TreeType* tree, + TreeType* treeOne, + TreeType* treeTwo, + const size_t cutAxis, + const typename TreeType::ElemType cut) { + // Split the auxiliary information. tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); + // Insert points into the corresponding subtree. for (size_t i = 0; i < tree->NumPoints(); i++) { if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut) { - treeOne->Points()[treeOne->Count()++] = tree->Point(i); + treeOne->Point(treeOne->Count()++) = tree->Point(i); treeOne->Bound() |= tree->Dataset().col(tree->Point(i)); } else { - treeTwo->Points()[treeTwo->Count()++] = tree->Point(i); + treeTwo->Point(treeTwo->Count()++) = tree->Point(i); treeTwo->Bound() |= tree->Dataset().col(tree->Point(i)); } } @@ -186,12 +206,17 @@ SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, template class SweepType> template -void RPlusTreeSplit:: -SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, - TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut) +void RPlusTreeSplit::SplitNonLeafNodeAlongPartition( + TreeType* tree, + TreeType* treeOne, + TreeType* treeTwo, + const size_t cutAxis, + const typename TreeType::ElemType cut) { + // Split the auxiliary information. tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut); + // Insert children into the corresponding subtree. for (size_t i = 0; i < tree->NumChildren(); i++) { TreeType* child = tree->Children()[i]; @@ -209,6 +234,7 @@ SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, } else { + // The child should be split (i.e. the partition divides its bound). TreeType* childOne = new TreeType(treeOne); TreeType* childTwo = new TreeType(treeTwo); treeOne->MinLeafSize() = 0; @@ -216,6 +242,7 @@ SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, treeTwo->MinLeafSize() = 0; treeTwo->MinNumChildren() = 0; + // Propagate the split downward. if (child->IsLeaf()) SplitLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut); else @@ -230,6 +257,7 @@ SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne, assert(treeOne->NumChildren() + treeTwo->NumChildren() != 0); + // Add a fake subtree if one of the subtrees is empty. if (treeOne->NumChildren() == 0) AddFakeNodes(treeTwo, treeOne); else if (treeTwo->NumChildren() == 0) @@ -245,20 +273,13 @@ template void RPlusTreeSplit:: AddFakeNodes(const TreeType* tree, TreeType* emptyTree) { - size_t numDescendantNodes = 1; - - TreeType* node = tree->Children()[0]; + size_t numDescendantNodes = tree->TreeDepth() - 1; - while (!node->IsLeaf()) - { - numDescendantNodes++; - node = node->Children()[0]; - } - - node = emptyTree; + TreeType* node = emptyTree; for (size_t i = 0; i < numDescendantNodes; i++) { TreeType* child = new TreeType(node); + node->Children()[node->NumChildren()++] = child; node = child; } @@ -273,14 +294,17 @@ PartitionNode(const TreeType* node, size_t& minCutAxis, { if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) || (node->Count() <= fillFactor && node->IsLeaf())) - return false; + return false; // No partition required. - typedef typename SweepType::template SweepCost::type + // Define the type of the sweep cost. + typedef typename + SweepType::template SweepCost::type SweepCostType; SweepCostType minCost = std::numeric_limits::max(); minCutAxis = node->Bound().Dim(); + // Find the sweep with a minimal cost. for (size_t k = 0; k < node->Bound().Dim(); k++) { typename TreeType::ElemType cut; diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp index 219eb1a390b..6d17338d6fc 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp @@ -3,7 +3,8 @@ * @author Mikhail Lozhnikov * * Defintion and implementation of the RPlusTreeSplitPolicy class, a class that - * helps to determine the node into which we should insert an intermediate node. + * helps to determine the subtree into which we should insert an intermediate + * node. */ #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP #define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP @@ -11,16 +12,37 @@ namespace mlpack { namespace tree { +/** + * The RPlusPlusTreeSplitPolicy helps to determine the subtree into which + * we should insert a child of an intermediate node that is being split. + * This class is designed for the R+ tree. + */ class RPlusTreeSplitPolicy { public: + //! Indicate that the child should be split. static const int SplitRequired = 0; + //! Indicate that the child should be inserted to the first subtree. static const int AssignToFirstTree = 1; + //! Indicate that the child should be inserted to the second subtree. static const int AssignToSecondTree = 2; + /** + * This method returns SplitRequired if a child of an intermediate node should + * be split, AssignToFirstTree if the child should be inserted to the first + * subtree, AssignToSecondTree if the child should be inserted to the second + * subtree. The method makes desicion according to the minimum bounding + * rectangle of the child, the axis along which the intermediate node is being + * split and the coordinate at which the node is being split. + * + * @param child A child of the node that is being split. + * @param axis The axis along which the node is being split. + * @param cut The coordinate at which the node is being split. + */ template - static int GetSplitPolicy(const TreeType* child, size_t axis, - typename TreeType::ElemType cut) + static int GetSplitPolicy(const TreeType* child, + const size_t axis, + const typename TreeType::ElemType cut) { if (child->Bound()[axis].Hi() <= cut) return AssignToFirstTree; @@ -30,6 +52,13 @@ class RPlusTreeSplitPolicy return SplitRequired; } + /** + * Return the minimum bounding rectangle of the node. + * This method should always return the bound that is used for the + * desicion-making in GetSplitPolicy(). + * + * @param node The node whose bound is requested. + */ template static const bound::HRectBound& diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp index 73887d7d8c9..9b67f6a0a6d 100644 --- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp @@ -675,12 +675,11 @@ void RectangleTree -void CheckOverlap(TreeType* tree) +void CheckOverlap(const TreeType& tree) { bool success = true; - for (size_t i = 0; i < tree->NumChildren(); i++) + // Check if two nodes overlap each other. + for (size_t i = 0; i < tree.NumChildren(); i++) { success = true; - for (size_t j = 0; j < tree->NumChildren(); j++) + for (size_t j = 0; j < tree.NumChildren(); j++) { if (j == i) continue; success = false; - for (size_t k = 0; k < tree->Bound().Dim(); k++) + // Two nodes overlap each other if and only if there are no dimension + // in which they do not overlap each other. + for (size_t k = 0; k < tree.Bound().Dim(); k++) { - if (tree->Children()[i]->Bound()[k].Lo() >= tree->Children()[j]->Bound()[k].Hi() || - tree->Children()[j]->Bound()[k].Lo() >= tree->Children()[i]->Bound()[k].Hi()) + if ((tree.Children()[i]->Bound()[k].Lo() >= + tree.Children()[j]->Bound()[k].Hi()) || + (tree.Children()[j]->Bound()[k].Lo() >= + tree.Children()[i]->Bound()[k].Hi())) { success = true; break; @@ -868,8 +873,8 @@ void CheckOverlap(TreeType* tree) } BOOST_REQUIRE_EQUAL(success, true); - for (size_t i = 0; i < tree->NumChildren(); i++) - CheckOverlap(tree->Children()[i]); + for (size_t i = 0; i < tree.NumChildren(); i++) + CheckOverlap(tree.Child(i)); } BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest) @@ -881,7 +886,11 @@ BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest) NeighborSearchStat,arma::mat> TreeType; TreeType rPlusTree(dataset, 20, 6, 5, 2, 0); - CheckOverlap(&rPlusTree); + CheckOverlap(rPlusTree); + + // Ensure that all leaf nodes are at the same level. + BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusTree), GetMaxLevel(rPlusTree)); + BOOST_REQUIRE_EQUAL(rPlusTree.TreeDepth(), GetMinLevel(rPlusTree)); } @@ -903,15 +912,15 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest) // Nearest neighbor search with the X tree. - NeighborSearch, arma::mat, RPlusTree > - knn1(&rPlusTree, true); + NeighborSearch, arma::mat, + RPlusTree > knn1(&rPlusTree, true); BOOST_REQUIRE_EQUAL(rPlusTree.NumDescendants(), numP); CheckContainment(rPlusTree); CheckExactContainment(rPlusTree); CheckHierarchy(rPlusTree); - CheckOverlap(&rPlusTree); + CheckOverlap(rPlusTree); knn1.Search(5, neighbors1, distances1); @@ -928,43 +937,49 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest) } template -void CheckRPlusPlusTreeBound(const TreeType* tree) +void CheckRPlusPlusTreeBound(const TreeType& tree) { typedef bound::HRectBound Bound; bool success = true; - for (size_t k = 0; k < tree->Bound().Dim(); k++) + // Ensure that the maximum bounding rectangle contains all children. + for (size_t k = 0; k < tree.Bound().Dim(); k++) { - BOOST_REQUIRE_LE(tree->Bound()[k].Hi(), - tree->AuxiliaryInfo().OuterBound()[k].Hi()); - BOOST_REQUIRE_LE(tree->AuxiliaryInfo().OuterBound()[k].Lo(), - tree->Bound()[k].Lo()); + BOOST_REQUIRE_LE(tree.Bound()[k].Hi(), + tree.AuxiliaryInfo().OuterBound()[k].Hi()); + BOOST_REQUIRE_LE(tree.AuxiliaryInfo().OuterBound()[k].Lo(), + tree.Bound()[k].Lo()); } - if (tree->IsLeaf()) + if (tree.IsLeaf()) { - for (size_t i = 0; i < tree->Count(); i++) + // Ensure that the maximum bounding rectangle contains all points. + for (size_t i = 0; i < tree.Count(); i++) BOOST_REQUIRE_EQUAL(true, - tree->Bound().Contains(tree->Dataset().col(tree->Points()[i]))); + tree.Bound().Contains(tree.Dataset().col(tree.Point(i)))); return; } - for (size_t i = 0; i < tree->NumChildren(); i++) + // Ensure that two children's maximum bounding rectangles do not overlap + // each other. + for (size_t i = 0; i < tree.NumChildren(); i++) { - const Bound& bound1 = tree->Children()[i]->AuxiliaryInfo().OuterBound(); + const Bound& bound1 = tree.Children()[i]->AuxiliaryInfo().OuterBound(); success = true; - for (size_t j = 0; j < tree->NumChildren(); j++) + for (size_t j = 0; j < tree.NumChildren(); j++) { if (j == i) continue; - const Bound& bound2 = tree->Children()[j]->AuxiliaryInfo().OuterBound(); + const Bound& bound2 = tree.Children()[j]->AuxiliaryInfo().OuterBound(); + // Two bounds overlap each other if and only if there are no dimension + // in which they do not overlap each other. success = false; - for (size_t k = 0; k < tree->Bound().Dim(); k++) + for (size_t k = 0; k < tree.Bound().Dim(); k++) { if (bound1[k].Lo() >= bound2[k].Hi() || bound2[k].Lo() >= bound1[k].Hi()) @@ -981,8 +996,8 @@ void CheckRPlusPlusTreeBound(const TreeType* tree) } BOOST_REQUIRE_EQUAL(success, true); - for (size_t i = 0; i < tree->NumChildren(); i++) - CheckRPlusPlusTreeBound(tree->Children()[i]); + for (size_t i = 0; i < tree.NumChildren(); i++) + CheckRPlusPlusTreeBound(tree.Child(i)); } BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) @@ -990,12 +1005,17 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) arma::mat dataset; dataset.randu(8, 1000); // 1000 points in 8 dimensions. + // Check the MinimalCoverageSweep. typedef RPlusPlusTree,arma::mat> TreeType; TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); - CheckRPlusPlusTreeBound(&rPlusPlusTree); + CheckRPlusPlusTreeBound(rPlusPlusTree); + + BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusPlusTree), GetMaxLevel(rPlusPlusTree)); + BOOST_REQUIRE_EQUAL(rPlusPlusTree.TreeDepth(), GetMinLevel(rPlusPlusTree)); + // Check the MinimalSplitsNumberSweep. typedef RectangleTree, arma::mat, RPlusTreeSplit, @@ -1004,8 +1024,10 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) RPlusPlusTreeMinimalSplits rPlusPlusTree2(dataset, 20, 6, 5, 2, 0); - CheckRPlusPlusTreeBound(&rPlusPlusTree2); + CheckRPlusPlusTreeBound(rPlusPlusTree2); + BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusPlusTree2), GetMaxLevel(rPlusPlusTree2)); + BOOST_REQUIRE_EQUAL(rPlusPlusTree2.TreeDepth(), GetMinLevel(rPlusPlusTree2)); } BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) @@ -1020,8 +1042,8 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) arma::Mat neighbors2; arma::mat distances2; - typedef RPlusPlusTree, - arma::mat> TreeType; + typedef RPlusPlusTree, arma::mat> TreeType; TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); // Nearest neighbor search with the X tree. @@ -1034,7 +1056,7 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) CheckContainment(rPlusPlusTree); CheckExactContainment(rPlusPlusTree); CheckHierarchy(rPlusPlusTree); - CheckRPlusPlusTreeBound(&rPlusPlusTree); + CheckRPlusPlusTreeBound(rPlusPlusTree); knn1.Search(5, neighbors1, distances1); From c2d7a5543a848a4db8473964d0d0b3b05036fb0e Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Wed, 29 Jun 2016 14:35:33 +0300 Subject: [PATCH 5/8] Fix error in MinimalSplitsNumberSweep with casting negative numbers to size_t. --- .../rectangle_tree/minimal_splits_number_sweep_impl.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp index 6328abba680..7f635681b2f 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp @@ -64,7 +64,14 @@ size_t MinimalSplitsNumberSweep::SweepNonLeafNode( numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0) { // Evaluate the cost using the number of splits and balancing. - size_t cost = numSplits * (std::abs(sorted.size() / 2 - i)); + size_t balance; + + if (sorted.size() / 2 > i ) + balance = sorted.size() / 2 - i; + else + balance = i - sorted.size() / 2; + + size_t cost = numSplits * balance; if (cost < minCost) { minCost = cost; From 44b3b2035fa6d9fe0a71c6339b4cb26329a32d88 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Mon, 4 Jul 2016 18:40:29 +0300 Subject: [PATCH 6/8] Added some R+/R++ tree fixes. Updated NS/RS/RA models. --- src/mlpack/core/tree/hrectbound.hpp | 15 ++ src/mlpack/core/tree/hrectbound_impl.hpp | 54 ++++++ .../rectangle_tree/minimal_coverage_sweep.hpp | 29 +--- .../minimal_coverage_sweep_impl.hpp | 156 +++++------------- .../minimal_splits_number_sweep.hpp | 25 +-- .../minimal_splits_number_sweep_impl.hpp | 17 +- .../no_auxiliary_information.hpp | 4 +- ...r_plus_plus_tree_auxiliary_information.hpp | 11 +- .../tree/rectangle_tree/r_plus_tree_split.hpp | 9 +- .../rectangle_tree/r_plus_tree_split_impl.hpp | 24 ++- .../core/tree/rectangle_tree/typedef.hpp | 46 +++++- .../methods/neighbor_search/kfn_main.cpp | 13 +- .../methods/neighbor_search/knn_main.cpp | 12 +- .../methods/neighbor_search/ns_model.hpp | 8 +- .../methods/neighbor_search/ns_model_impl.hpp | 12 ++ .../range_search/range_search_main.cpp | 13 +- src/mlpack/methods/range_search/rs_model.cpp | 40 ++++- src/mlpack/methods/range_search/rs_model.hpp | 8 +- .../methods/range_search/rs_model_impl.hpp | 28 ++++ src/mlpack/methods/rann/krann_main.cpp | 12 +- src/mlpack/methods/rann/ra_model.hpp | 8 +- src/mlpack/methods/rann/ra_model_impl.hpp | 110 +++++++++++- src/mlpack/tests/aknn_test.cpp | 8 +- src/mlpack/tests/knn_test.cpp | 16 +- src/mlpack/tests/krann_search_test.cpp | 8 +- src/mlpack/tests/range_search_test.cpp | 16 +- src/mlpack/tests/rectangle_tree_test.cpp | 41 ++--- src/mlpack/tests/tree_test.cpp | 28 +--- 28 files changed, 505 insertions(+), 266 deletions(-) diff --git a/src/mlpack/core/tree/hrectbound.hpp b/src/mlpack/core/tree/hrectbound.hpp index fb189c182a7..948174a153e 100644 --- a/src/mlpack/core/tree/hrectbound.hpp +++ b/src/mlpack/core/tree/hrectbound.hpp @@ -182,6 +182,21 @@ class HRectBound template bool Contains(const VecType& point) const; + /** + * Determines if this bound partially contains a bound. + */ + bool Contains(const HRectBound& bound) const; + + /** + * Returns the intersection of this bound and another. + */ + HRectBound Intersect(const HRectBound& bound) const; + + /** + * Returns the volume of overlap of this bound and another. + */ + ElemType Overlap(const HRectBound& bound) const; + /** * Returns the diameter of the hyperrectangle (that is, the longest diagonal). */ diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp index 822877c9b8c..9b3ee1ed835 100644 --- a/src/mlpack/core/tree/hrectbound_impl.hpp +++ b/src/mlpack/core/tree/hrectbound_impl.hpp @@ -143,7 +143,12 @@ inline ElemType HRectBound::Volume() const { ElemType volume = 1.0; for (size_t i = 0; i < dim; ++i) + { + if (bounds[i].Lo() >= bounds[i].Hi()) + return 0; + volume *= (bounds[i].Hi() - bounds[i].Lo()); + } return volume; } @@ -430,6 +435,55 @@ inline bool HRectBound::Contains(const VecType& point) con return true; } +template +inline bool HRectBound::Contains( + const HRectBound& bound) const +{ + for (size_t i = 0; i < dim; i++) + { + const math::RangeType& r_a = bounds[i]; + const math::RangeType& r_b = bound.bounds[i]; + + if (r_a.Hi() <= r_b.Lo() || r_a.Lo() >= r_b.Hi()) // If a does not overlap b at all. + return false; + } + + return true; +} + +template +inline HRectBound HRectBound:: +Intersect(const HRectBound& bound) const +{ + HRectBound result(dim); + + for (size_t k = 0; k < dim; k++) + { + result[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].Lo()); + result[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].Hi()); + } + return result; +} + +template +inline ElemType HRectBound::Overlap( + const HRectBound& bound) const +{ + ElemType volume = 1.0; + + for (size_t k = 0; k < dim; k++) + { + ElemType lo = std::max(bounds[k].Lo(), bound.bounds[k].Lo()); + ElemType hi = std::min(bounds[k].Hi(), bound.bounds[k].Hi()); + + if ( hi <= lo) + return 0; + + volume *= hi - lo; + } + return volume; +} + /** * Returns the diameter of the hyperrectangle (that is, the longest diagonal). */ diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp index 488d5d8f011..ab2f664dc55 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp @@ -11,12 +11,12 @@ namespace mlpack { namespace tree { -constexpr double fillFactor = 0.5; - /** * The MinimalCoverageSweep class finds a partition along which we - * can split a node according to the coverage of two resulting nodes. - * Moreover, the class evaluates the cost of each split. + * can split a node according to the coverage of two resulting nodes. The class + * finds a partition along a given axis. Moreover, the class evaluates the cost + * of each split. The cost is proportional to the total coverage of resulting + * nodes. If the resulting nodes are overflowed the maximum cost is returned. * * @tparam SplitPolicy The class that provides rules for inserting children of * a node that is being split into two new subtrees. @@ -24,27 +24,6 @@ constexpr double fillFactor = 0.5; template class MinimalCoverageSweep { - private: - /** - * Class to allow for faster sorting. - */ - template - struct SortStruct - { - ElemType d; - int n; - }; - - /** - * Comparator for sorting with SortStruct. - */ - template - static bool StructComp(const SortStruct& s1, - const SortStruct& s2) - { - return s1.d < s2.d; - } - public: //! A struct that provides the type of the sweep cost. template diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp index e0723d5f1c0..5c899097a3a 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp @@ -21,27 +21,34 @@ SweepNonLeafNode(const size_t axis, typename TreeType::ElemType& axisCut) { typedef typename TreeType::ElemType ElemType; + typedef bound::HRectBound BoundType; - std::vector> sorted(node->NumChildren()); + std::vector> sorted(node->NumChildren()); for (size_t i = 0; i < node->NumChildren(); i++) { - sorted[i].d = SplitPolicy::Bound(node->Child(i))[axis].Hi(); - sorted[i].n = i; + sorted[i].first = SplitPolicy::Bound(node->Child(i))[axis].Hi(); + sorted[i].second = i; } // Sort high bounds of children. - std::sort(sorted.begin(), sorted.end(), StructComp); + std::sort(sorted.begin(), sorted.end(), + [] (std::pair& s1, + std::pair& s2) + { + return s1.first < s2.first; + }); - size_t splitPointer = fillFactor * node->NumChildren(); + size_t splitPointer = node->NumChildren() / 2; - axisCut = sorted[splitPointer - 1].d; + axisCut = sorted[splitPointer - 1].first; - // Check if the partition is suitable. + // Check if the midpoint split is suitable. if (!CheckNonLeafSweep(node, axis, axisCut)) { + // Find any suitable partition if the default partition is not acceptable. for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++) { - axisCut = sorted[splitPointer - 1].d; + axisCut = sorted[splitPointer - 1].first; if (CheckNonLeafSweep(node, axis, axisCut)) break; } @@ -50,72 +57,24 @@ SweepNonLeafNode(const size_t axis, return std::numeric_limits::max(); } - std::vector lowerBound1(node->Bound().Dim()); - std::vector highBound1(node->Bound().Dim()); - std::vector lowerBound2(node->Bound().Dim()); - std::vector highBound2(node->Bound().Dim()); + BoundType bound1(node->Bound().Dim()); + BoundType bound2(node->Bound().Dim()); - // Find lower and high bounds of two resulting nodes. - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - lowerBound1[k] = node->Child(sorted[0].n).Bound()[k].Lo(); - highBound1[k] = node->Child(sorted[0].n).Bound()[k].Hi(); + // Find bounds of two resulting nodes. + for (size_t i = 0; i < splitPointer; i++) + bound1 |= node->Child(sorted[i].second).Bound(); - for (size_t i = 1; i < splitPointer; i++) - { - if (node->Child(sorted[i].n).Bound()[k].Lo() < lowerBound1[k]) - lowerBound1[k] = node->Child(sorted[i].n).Bound()[k].Lo(); - if (node->Child(sorted[i].n).Bound()[k].Hi() > highBound1[k]) - highBound1[k] = node->Child(sorted[i].n).Bound()[k].Hi(); - } + for (size_t i = splitPointer; i < node->NumChildren(); i++) + bound2 |= node->Child(sorted[i].second).Bound(); - lowerBound2[k] = node->Child(sorted[splitPointer].n).Bound()[k].Lo(); - highBound2[k] = node->Child(sorted[splitPointer].n).Bound()[k].Hi(); - - for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) - { - if (node->Child(sorted[i].n).Bound()[k].Lo() < lowerBound2[k]) - lowerBound2[k] = node->Child(sorted[i].n).Bound()[k].Lo(); - if (node->Child(sorted[i].n).Bound()[k].Hi() > highBound2[k]) - highBound2[k] = node->Child(sorted[i].n).Bound()[k].Hi(); - } - } // Evaluate the cost of the split i.e. calculate the total coverage // of two resulting nodes. - ElemType area1 = 1.0, area2 = 1.0; - ElemType overlappedArea = 1.0; - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - if (lowerBound1[k] >= highBound1[k]) - { - overlappedArea *= 0; - area1 *= 0; - } - else - area1 *= highBound1[k] - lowerBound1[k]; - - if (lowerBound2[k] >= highBound2[k]) - { - overlappedArea *= 0; - area1 *= 0; - } - else - area2 *= highBound2[k] - lowerBound2[k]; + ElemType area1 = bound1.Volume(); + ElemType area2 = bound2.Volume(); - if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k]) - { - if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k]) - overlappedArea *= 0; - else - overlappedArea *= std::min(highBound1[k], highBound2[k]) - - std::max(lowerBound1[k], lowerBound2[k]); - } - } - - return area1 + area2 - overlappedArea; + return area1 + area2; } template @@ -126,73 +85,48 @@ SweepLeafNode(const size_t axis, typename TreeType::ElemType& axisCut) { typedef typename TreeType::ElemType ElemType; + typedef bound::HRectBound BoundType; - std::vector> sorted(node->Count()); + std::vector> sorted(node->Count()); sorted.resize(node->Count()); for (size_t i = 0; i < node->NumPoints(); i++) { - sorted[i].d = node->Dataset().col(node->Point(i))[axis]; - sorted[i].n = i; + sorted[i].first = node->Dataset().col(node->Point(i))[axis]; + sorted[i].second = i; } // Sort high bounds of children. - std::sort(sorted.begin(), sorted.end(), StructComp); + std::sort(sorted.begin(), sorted.end(), + [] (std::pair& s1, + std::pair& s2) + { + return s1.first < s2.first; + }); - size_t splitPointer = fillFactor * node->Count(); + size_t splitPointer = node->Count() / 2; - axisCut = sorted[splitPointer - 1].d; + axisCut = sorted[splitPointer - 1].first; // Check if the partition is suitable. if (!CheckLeafSweep(node, axis, axisCut)) return std::numeric_limits::max(); - std::vector lowerBound1(node->Bound().Dim()); - std::vector highBound1(node->Bound().Dim()); - std::vector lowerBound2(node->Bound().Dim()); - std::vector highBound2(node->Bound().Dim()); - - // Find lower and high bounds of two resulting nodes. - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; - highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k]; - - for (size_t i = 1; i < splitPointer; i++) - { - if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k]) - lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k]) - highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - } + BoundType bound1(node->Bound().Dim()); + BoundType bound2(node->Bound().Dim()); - lowerBound2[k] = node->Dataset().col( - node->Point(sorted[splitPointer].n))[k]; - highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k]; + // Find bounds of two resulting nodes. + for (size_t i = 0; i < splitPointer; i++) + bound1 |= node->Dataset().col(node->Point(sorted[i].second)); - for (size_t i = splitPointer + 1; i < node->NumChildren(); i++) - { - if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k]) - lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k]) - highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k]; - } - } + for (size_t i = splitPointer; i < node->NumChildren(); i++) + bound2 |= node->Dataset().col(node->Point(sorted[i].second)); // Evaluate the cost of the split i.e. calculate the total coverage // of two resulting nodes. - ElemType area1 = 1.0, area2 = 1.0; - ElemType overlappedArea = 1.0; - - for (size_t k = 0; k < node->Bound().Dim(); k++) - { - area1 *= highBound1[k] - lowerBound1[k]; - area2 *= highBound2[k] - lowerBound2[k]; - } - - return area1 + area2 - overlappedArea; + return bound1.Volume() + bound2.Volume(); } template diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp index 0470f522445..7134db8d3de 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp @@ -14,7 +14,10 @@ namespace tree { /** * The MinimalSplitsNumberSweep class finds a partition along which we * can split a node according to the number of required splits of the node. - * Moreover, the class evaluates the cost of each split. + * The class finds a partition along a given axis. Moreover, the class evaluates + * the cost of each split. The cost is proportional to the number of required + * splits and the difference of sizes of resulting nodes. If the resulting nodes + * are overflowed the maximum cost is returned. * * @tparam SplitPolicy The class that provides rules for inserting children of * a node that is being split into two new subtrees. @@ -22,26 +25,6 @@ namespace tree { template class MinimalSplitsNumberSweep { - private: - /** - * Class to allow for faster sorting. - */ - template - struct SortStruct - { - ElemType d; - int n; - }; - - /** - * Comparator for sorting with SortStruct. - */ - template - static bool StructComp(const SortStruct& s1, - const SortStruct& s2) - { - return s1.d < s2.d; - } public: //! A struct that provides the type of the sweep cost. template diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp index 8bcc0a64a8c..b1b3f0f880d 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp @@ -22,16 +22,21 @@ size_t MinimalSplitsNumberSweep::SweepNonLeafNode( { typedef typename TreeType::ElemType ElemType; - std::vector> sorted(node->NumChildren()); + std::vector> sorted(node->NumChildren()); for (size_t i = 0; i < node->NumChildren(); i++) { - sorted[i].d = SplitPolicy::Bound(node->Child(i))[axis].Hi(); - sorted[i].n = i; + sorted[i].first = SplitPolicy::Bound(node->Child(i))[axis].Hi(); + sorted[i].second = i; } // Sort candidates in order to check balancing. - std::sort(sorted.begin(), sorted.end(), StructComp); + std::sort(sorted.begin(), sorted.end(), + [] (std::pair& s1, + std::pair& s2) + { + return s1.first < s2.first; + }); size_t minCost = SIZE_MAX; @@ -46,7 +51,7 @@ size_t MinimalSplitsNumberSweep::SweepNonLeafNode( for (size_t j = 0; j < node->NumChildren(); j++) { const TreeType& child = node->Child(j); - int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].d); + int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].first); if (policy == SplitPolicy::AssignToFirstTree) numTreeOneChildren++; else if (policy == SplitPolicy::AssignToSecondTree) @@ -75,7 +80,7 @@ size_t MinimalSplitsNumberSweep::SweepNonLeafNode( if (cost < minCost) { minCost = cost; - axisCut = sorted[i].d; + axisCut = sorted[i].first; } } } diff --git a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp index c07a4145ca9..e282a4cd113 100644 --- a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp +++ b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp @@ -100,7 +100,9 @@ class NoAuxiliaryInformation /** * The R++ tree requires to split the maximum bounding rectangle of a node - * that is being split. This method is intended for that. + * that is being split. This method is intended for that. This method is only + * necessary for an AuxiliaryInformationType that is being used in conjunction + * with RPlusTreeSplit. * * @param treeOne The first subtree. * @param treeTwo The second subtree. diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp index d791fa89e95..fbb82a66ddb 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp @@ -21,6 +21,8 @@ class RPlusPlusTreeAuxiliaryInformation public: //! The element type held by the tree. typedef typename TreeType::ElemType ElemType; + //! The bound type held by the auxiliary information. + typedef bound::HRectBound BoundType; //! Construct the auxiliary information object. RPlusPlusTreeAuxiliaryInformation(); @@ -124,16 +126,13 @@ class RPlusPlusTreeAuxiliaryInformation void NullifyData(); //! Return the maximum bounding rectangle. - bound::HRectBound& OuterBound() - { return outerBound; } + BoundType& OuterBound() { return outerBound; } //! Modify the maximum bounding rectangle. - const bound::HRectBound& - OuterBound() const - { return outerBound; } + const BoundType& OuterBound() const { return outerBound; } private: //! The maximum bounding rectangle. - bound::HRectBound outerBound; + BoundType outerBound; public: /** * Serialize the information. diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp index 0d0c672fb4e..8e5b8ce731e 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp @@ -18,7 +18,9 @@ namespace tree /** Trees and tree-building procedures. */ { * * @tparam SplitPolicyType The class that helps to determine the subtree into * which we should insert a child node. - * @tparam SweepType The class that sweeps a node along an axis. + * @tparam SweepType The class that finds the partition of a node along a + * given axis. The partition algorithm tries to find a partition along each + * axis, evaluates each partition and chooses the best one. */ template class SweepType> @@ -81,8 +83,9 @@ class RPlusTreeSplit const typename TreeType::ElemType cut); /** - * This method should be invoked in order to make the tree balanced if - * one of two resulting subtrees is empty after the split process + * This method is used to make sure that the tree has equivalent maximum depth + * in every branch. The method should be invoked if one of two resulting + * subtrees is empty after the split process * (i.e. the subtree contains no children). * The method convert the empty node into an empty subtree (increase the node * in depth). diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp index 9fefbbc8f51..6cb4d5c806e 100644 --- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp @@ -68,7 +68,15 @@ SplitLeafNode(TreeType* tree, std::vector& relevels) if ( !PartitionNode(tree, cutAxis, cut)) return; - assert(cutAxis < tree->Bound().Dim()); + // If we could not find a suitable partition. + if (cutAxis == tree->Bound().Dim()) + { + tree->MaxLeafSize()++; + tree->points.resize(tree->MaxLeafSize() + 1); + Log::Warn << "Could not find an acceptable partition." + "The size of the node will be increased."; + return; + } TreeType* treeOne = new TreeType(tree->Parent()); TreeType* treeTwo = new TreeType(tree->Parent()); @@ -129,7 +137,15 @@ SplitNonLeafNode(TreeType* tree, std::vector& relevels) if ( !PartitionNode(tree, cutAxis, cut)) return false; - assert(cutAxis < tree->Bound().Dim()); + // If we could not find a suitable partition. + if (cutAxis == tree->Bound().Dim()) + { + tree->MaxNumChildren()++; + tree->children.resize(tree->MaxNumChildren() + 1); + Log::Warn << "Could not find an acceptable partition." + "The size of the node will be increased."; + return false; + } TreeType* treeOne = new TreeType(tree->Parent()); TreeType* treeTwo = new TreeType(tree->Parent()); @@ -290,8 +306,8 @@ bool RPlusTreeSplit:: PartitionNode(const TreeType* node, size_t& minCutAxis, typename TreeType::ElemType& minCut) { - if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) || - (node->Count() <= fillFactor && node->IsLeaf())) + if ((node->NumChildren() <= node->MaxNumChildren() && !node->IsLeaf()) || + (node->Count() <= node->MaxLeafSize() && node->IsLeaf())) return false; // No partition required. // Define the type of the sweep cost. diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp index 27f624e3138..1994afd65b4 100644 --- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp +++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp @@ -130,6 +130,29 @@ using HilbertRTree = RectangleTree; +/** + * The R+ tree, a variant of the R tree that avoids overlapping rectangles. + * The implementation is modified from the original paper implementation. + * This template typedef satisfies the TreeType policy API. + * + * @code + * @inproceedings{sellis1987r, + * author = {Sellis, Timos K. and Roussopoulos, Nick and Faloutsos, Christos}, + * title = {The R+-Tree: A Dynamic Index for Multi-Dimensional Objects}, + * booktitle = {Proceedings of the 13th International Conference on Very + * Large Data Bases}, + * series = {VLDB '87}, + * year = {1987}, + * isbn = {0-934613-46-X}, + * pages = {507--518}, + * numpages = {12}, + * publisher = {Morgan Kaufmann Publishers Inc.}, + * address = {San Francisco, CA, USA}, + * } + * @endcode + * + * @see @ref trees, RTree, RTree, RPlusTree + */ template using RPlusTree = RectangleTree; -template +/** + * The R++ tree, a variant of the R+ tree with maximum buonding rectangles. + * This template typedef satisfies the TreeType policy API. + * + * @code + * @inproceedings{sumak2014r, + * author = {{\v{S}}um{\'a}k, Martin and Gursk{\'y}, Peter}, + * title = {R++-Tree: An Efficient Spatial Access Method for Highly Redundant + * Point Data}, + * booktitle = {New Trends in Databases and Information Systems: 17th East + * European Conference on Advances in Databases and Information Systems}, + * year = {2014}, + * isbn = {978-3-319-01863-8}, + * pages = {37--44}, + * publisher = {Springer International Publishing}, + * } + * @endcode + * + * @see @ref trees, RTree, RTree, RPlusTree, RPlusPlusTree + */template using RPlusPlusTree = RectangleTree, + MinimalSplitsNumberSweep>, RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation>; } // namespace tree diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp index 254dc52b505..6adbc566f70 100644 --- a/src/mlpack/methods/neighbor_search/kfn_main.cpp +++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp @@ -62,8 +62,10 @@ PARAM_INT("k", "Number of furthest neighbors to find.", "k", 0); // The user may specify the type of tree to use, and a few pararmeters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r'.", "t", "kd"); -PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); +PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " + "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", + 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -188,9 +190,14 @@ int main(int argc, char *argv[]) tree = KFNModel::X_TREE; else if (treeType == "hilbert-r") tree = KFNModel::HILBERT_R_TREE; + else if (treeType == "r-plus") + tree = KFNModel::R_PLUS_TREE; + else if (treeType == "r-plus-plus") + tree = KFNModel::R_PLUS_PLUS_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " - << "'kd', 'cover', 'r', 'r-star', 'x', 'ball' and 'hilbert-r'." << endl; + << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " + << "'r-plus' and 'r-plus-plus'." << endl; kfn.TreeType() = tree; kfn.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp index aad19911554..87bdb32e798 100644 --- a/src/mlpack/methods/neighbor_search/knn_main.cpp +++ b/src/mlpack/methods/neighbor_search/knn_main.cpp @@ -63,9 +63,10 @@ PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0); // The user may specify the type of tree to use, and a few parameters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r'.", "t", "kd"); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " - "trees, and R* trees).", "l", 20); + "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", + 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -174,9 +175,14 @@ int main(int argc, char *argv[]) tree = KNNModel::X_TREE; else if (treeType == "hilbert-r") tree = KNNModel::HILBERT_R_TREE; + else if (treeType == "r-plus") + tree = KNNModel::R_PLUS_TREE; + else if (treeType == "r-plus-plus") + tree = KNNModel::R_PLUS_PLUS_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " - << "'kd', 'cover', 'r', 'r-star', 'x', 'ball' and 'hilbert-r'." << endl; + << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " + << "'r-plus' and 'r-plus-plus'." << endl; knn.TreeType() = tree; knn.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 55e9e91dbf4..5f0fbf20017 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -230,7 +230,9 @@ class NSModel R_STAR_TREE, BALL_TREE, X_TREE, - HILBERT_R_TREE + HILBERT_R_TREE, + R_PLUS_TREE, + R_PLUS_PLUS_TREE }; private: @@ -255,7 +257,9 @@ class NSModel NSType*, NSType*, NSType*, - NSType*> nSearch; + NSType*, + NSType*, + NSType*> nSearch; public: /** diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index ae34feba759..acbed6ce2a5 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -386,6 +386,14 @@ void NSModel::BuildModel(arma::mat&& referenceSet, nSearch = new NSType(naive, singleMode, epsilon); break; + case R_PLUS_TREE: + nSearch = new NSType(naive, singleMode, + epsilon); + break; + case R_PLUS_PLUS_TREE: + nSearch = new NSType(naive, singleMode, + epsilon); + break; } TrainVisitor tn(std::move(referenceSet), leafSize); @@ -466,6 +474,10 @@ std::string NSModel::TreeName() const return "X tree"; case HILBERT_R_TREE: return "Hilbert R tree"; + case R_PLUS_TREE: + return "R+ tree"; + case R_PLUS_PLUS_TREE: + return "R++ tree"; default: return "unknown tree"; } diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp index ca3f4330de4..c8ea2a5b83f 100644 --- a/src/mlpack/methods/range_search/range_search_main.cpp +++ b/src/mlpack/methods/range_search/range_search_main.cpp @@ -70,8 +70,10 @@ PARAM_DOUBLE("min", "Lower bound in range.", "L", 0.0); // The user may specify the type of tree to use, and a few parameters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', " - "'x', 'ball', 'hilbert-r'.", "t", "kd"); -PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20); + "'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); +PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " + "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", + 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -175,9 +177,14 @@ int main(int argc, char *argv[]) tree = RSModel::X_TREE; else if (treeType == "hilbert-r") tree = RSModel::HILBERT_R_TREE; + else if (treeType == "r-plus") + tree = RSModel::R_PLUS_TREE; + else if (treeType == "r-plus-plus") + tree = RSModel::R_PLUS_PLUS_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are " - << "'kd', 'cover', 'r', 'r-star', 'x', 'ball' and 'hilbert-r'." << endl; + << "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', " + << "'r-plus' and 'r-plus-plus'." << endl; rs.TreeType() = tree; rs.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp index 1bf565ab672..1cf3938523a 100644 --- a/src/mlpack/methods/range_search/rs_model.cpp +++ b/src/mlpack/methods/range_search/rs_model.cpp @@ -23,7 +23,9 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) : rStarTreeRS(NULL), ballTreeRS(NULL), xTreeRS(NULL), - hilbertRTreeRS(NULL) + hilbertRTreeRS(NULL), + rPlusTreeRS(NULL), + rPlusPlusTreeRS(NULL) { // Nothing to do. } @@ -128,6 +130,16 @@ void RSModel::BuildModel(arma::mat&& referenceSet, hilbertRTreeRS = new RSType(move(referenceSet), naive, singleMode); break; + + case R_PLUS_TREE: + rPlusTreeRS = new RSType(move(referenceSet), naive, + singleMode); + break; + + case R_PLUS_PLUS_TREE: + rPlusPlusTreeRS = new RSType(move(referenceSet), naive, + singleMode); + break; } if (!naive) @@ -241,6 +253,14 @@ void RSModel::Search(arma::mat&& querySet, case HILBERT_R_TREE: hilbertRTreeRS->Search(querySet, range, neighbors, distances); break; + + case R_PLUS_TREE: + rPlusTreeRS->Search(querySet, range, neighbors, distances); + break; + + case R_PLUS_PLUS_TREE: + rPlusPlusTreeRS->Search(querySet, range, neighbors, distances); + break; } } @@ -287,6 +307,14 @@ void RSModel::Search(const math::Range& range, case HILBERT_R_TREE: hilbertRTreeRS->Search(range, neighbors, distances); break; + + case R_PLUS_TREE: + rPlusTreeRS->Search(range, neighbors, distances); + break; + + case R_PLUS_PLUS_TREE: + rPlusPlusTreeRS->Search(range, neighbors, distances); + break; } } @@ -309,6 +337,10 @@ std::string RSModel::TreeName() const return "X tree"; case HILBERT_R_TREE: return "Hilbert R tree"; + case R_PLUS_TREE: + return "R+ tree"; + case R_PLUS_PLUS_TREE: + return "R++ tree"; default: return "unknown tree"; } @@ -331,6 +363,10 @@ void RSModel::CleanMemory() delete xTreeRS; if (hilbertRTreeRS) delete hilbertRTreeRS; + if (rPlusTreeRS) + delete rPlusTreeRS; + if (rPlusPlusTreeRS) + delete rPlusPlusTreeRS; kdTreeRS = NULL; coverTreeRS = NULL; @@ -339,4 +375,6 @@ void RSModel::CleanMemory() ballTreeRS = NULL; xTreeRS = NULL; hilbertRTreeRS = NULL; + rPlusTreeRS = NULL; + rPlusPlusTreeRS = NULL; } diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp index d256c319840..7903d373c38 100644 --- a/src/mlpack/methods/range_search/rs_model.hpp +++ b/src/mlpack/methods/range_search/rs_model.hpp @@ -30,7 +30,9 @@ class RSModel R_STAR_TREE, BALL_TREE, X_TREE, - HILBERT_R_TREE + HILBERT_R_TREE, + R_PLUS_TREE, + R_PLUS_PLUS_TREE }; private: @@ -63,6 +65,10 @@ class RSModel RSType* xTreeRS; //! Hilbert R tree based range search object (NULL if not in use). RSType* hilbertRTreeRS; + //! R+ tree based range search object (NULL if not in use). + RSType* rPlusTreeRS; + //! R++ tree based range search object (NULL if not in use). + RSType* rPlusPlusTreeRS; public: /** diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp index 0f308d47121..98fa7a8224b 100644 --- a/src/mlpack/methods/range_search/rs_model_impl.hpp +++ b/src/mlpack/methods/range_search/rs_model_impl.hpp @@ -57,6 +57,14 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */) case HILBERT_R_TREE: ar & CreateNVP(hilbertRTreeRS, "range_search_model"); break; + + case R_PLUS_TREE: + ar & CreateNVP(rPlusTreeRS, "range_search_model"); + break; + + case R_PLUS_PLUS_TREE: + ar & CreateNVP(rPlusPlusTreeRS, "range_search_model"); + break; } } @@ -76,6 +84,10 @@ inline const arma::mat& RSModel::Dataset() const return xTreeRS->ReferenceSet(); else if (hilbertRTreeRS) return hilbertRTreeRS->ReferenceSet(); + else if (rPlusTreeRS) + return rPlusTreeRS->ReferenceSet(); + else if (rPlusPlusTreeRS) + return rPlusPlusTreeRS->ReferenceSet(); throw std::runtime_error("no range search model initialized"); } @@ -96,6 +108,10 @@ inline bool RSModel::SingleMode() const return xTreeRS->SingleMode(); else if (hilbertRTreeRS) return hilbertRTreeRS->SingleMode(); + else if (rPlusTreeRS) + return rPlusTreeRS->SingleMode(); + else if (rPlusPlusTreeRS) + return rPlusPlusTreeRS->SingleMode(); throw std::runtime_error("no range search model initialized"); } @@ -116,6 +132,10 @@ inline bool& RSModel::SingleMode() return xTreeRS->SingleMode(); else if (hilbertRTreeRS) return hilbertRTreeRS->SingleMode(); + else if (rPlusTreeRS) + return rPlusTreeRS->SingleMode(); + else if (rPlusPlusTreeRS) + return rPlusPlusTreeRS->SingleMode(); throw std::runtime_error("no range search model initialized"); } @@ -136,6 +156,10 @@ inline bool RSModel::Naive() const return xTreeRS->Naive(); else if (hilbertRTreeRS) return hilbertRTreeRS->Naive(); + else if (rPlusTreeRS) + return rPlusTreeRS->Naive(); + else if (rPlusPlusTreeRS) + return rPlusPlusTreeRS->Naive(); throw std::runtime_error("no range search model initialized"); } @@ -156,6 +180,10 @@ inline bool& RSModel::Naive() return xTreeRS->Naive(); else if (hilbertRTreeRS) return hilbertRTreeRS->Naive(); + else if (rPlusTreeRS) + return rPlusTreeRS->Naive(); + else if (rPlusPlusTreeRS) + return rPlusPlusTreeRS->Naive(); throw std::runtime_error("no range search model initialized"); } diff --git a/src/mlpack/methods/rann/krann_main.cpp b/src/mlpack/methods/rann/krann_main.cpp index bcd57e188b8..f9a6c675477 100644 --- a/src/mlpack/methods/rann/krann_main.cpp +++ b/src/mlpack/methods/rann/krann_main.cpp @@ -64,9 +64,10 @@ PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0); // The user may specify the type of tree to use, and a few parameters for tree // building. PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', or " - "'x', 'r-star', 'hilbert-r'.", "t", "kd"); + "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd"); PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R " - "trees, and R* trees).", "l", 20); + "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l", + 20); PARAM_FLAG("random_basis", "Before tree-building, project the data onto a " "random orthogonal basis.", "R"); PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0); @@ -174,9 +175,14 @@ int main(int argc, char *argv[]) tree = RANNModel::X_TREE; else if (treeType == "hilbert-r") tree = RANNModel::HILBERT_R_TREE; + else if (treeType == "r-plus") + tree = RANNModel::R_PLUS_TREE; + else if (treeType == "r-plus-plus") + tree = RANNModel::R_PLUS_PLUS_TREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " - << "'kd', 'cover', 'r', 'r-star', 'x' and 'hilbert-r'." << endl; + << "'kd', 'cover', 'r', 'r-star', 'x', 'hilbert-r', " + << "'r-plus' and 'r-plus-plus'." << endl; rann.TreeType() = tree; rann.RandomBasis() = randomBasis; diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp index 2c929796d99..1e755d3db49 100644 --- a/src/mlpack/methods/rann/ra_model.hpp +++ b/src/mlpack/methods/rann/ra_model.hpp @@ -41,7 +41,9 @@ class RAModel R_TREE, R_STAR_TREE, X_TREE, - HILBERT_R_TREE + HILBERT_R_TREE, + R_PLUS_TREE, + R_PLUS_PLUS_TREE }; private: @@ -76,6 +78,10 @@ class RAModel RAType* xTreeRA; //! Non-NULL if the Hilbert R tree is used. RAType* hilbertRTreeRA; + //! Non-NULL if the R+ tree is used. + RAType* rPlusTreeRA; + //! Non-NULL if the R++ tree is used. + RAType* rPlusPlusTreeRA; public: /** diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp index edaf03866bc..f096540614e 100644 --- a/src/mlpack/methods/rann/ra_model_impl.hpp +++ b/src/mlpack/methods/rann/ra_model_impl.hpp @@ -23,7 +23,9 @@ RAModel::RAModel(const TreeTypes treeType, const bool randomBasis) : rTreeRA(NULL), rStarTreeRA(NULL), xTreeRA(NULL), - hilbertRTreeRA(NULL) + hilbertRTreeRA(NULL), + rPlusTreeRA(NULL), + rPlusPlusTreeRA(NULL) { // Nothing to do. } @@ -43,6 +45,10 @@ RAModel::~RAModel() delete xTreeRA; if (hilbertRTreeRA) delete hilbertRTreeRA; + if (rPlusTreeRA) + delete rPlusTreeRA; + if (rPlusPlusTreeRA) + delete rPlusPlusTreeRA; } template @@ -69,6 +75,10 @@ void RAModel::Serialize(Archive& ar, delete xTreeRA; if (hilbertRTreeRA) delete hilbertRTreeRA; + if (rPlusTreeRA) + delete rPlusTreeRA; + if (rPlusPlusTreeRA) + delete rPlusPlusTreeRA; // Set all the pointers to NULL. kdTreeRA = NULL; @@ -77,6 +87,8 @@ void RAModel::Serialize(Archive& ar, rStarTreeRA = NULL; xTreeRA = NULL; hilbertRTreeRA = NULL; + rPlusPlusTreeRA = NULL; + rPlusTreeRA = NULL; } // We only need to serialize one of the kRANN objects. @@ -100,6 +112,12 @@ void RAModel::Serialize(Archive& ar, case HILBERT_R_TREE: ar & data::CreateNVP(hilbertRTreeRA, "ra_model"); break; + case R_PLUS_TREE: + ar & data::CreateNVP(rPlusTreeRA, "ra_model"); + break; + case R_PLUS_PLUS_TREE: + ar & data::CreateNVP(rPlusPlusTreeRA, "ra_model"); + break; } } @@ -118,6 +136,10 @@ const arma::mat& RAModel::Dataset() const return xTreeRA->ReferenceSet(); else if (hilbertRTreeRA) return hilbertRTreeRA->ReferenceSet(); + else if (rPlusTreeRA) + return rPlusTreeRA->ReferenceSet(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->ReferenceSet(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -138,6 +160,10 @@ bool RAModel::Naive() const return xTreeRA->Naive(); else if (hilbertRTreeRA) return hilbertRTreeRA->Naive(); + else if (rPlusTreeRA) + return rPlusTreeRA->Naive(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Naive(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -158,6 +184,10 @@ bool& RAModel::Naive() return xTreeRA->Naive(); else if (hilbertRTreeRA) return hilbertRTreeRA->Naive(); + else if (rPlusTreeRA) + return rPlusTreeRA->Naive(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Naive(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -178,6 +208,10 @@ bool RAModel::SingleMode() const return xTreeRA->SingleMode(); else if (hilbertRTreeRA) return hilbertRTreeRA->SingleMode(); + else if (rPlusTreeRA) + return rPlusTreeRA->SingleMode(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SingleMode(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -198,6 +232,10 @@ bool& RAModel::SingleMode() return xTreeRA->SingleMode(); else if (hilbertRTreeRA) return hilbertRTreeRA->SingleMode(); + else if (rPlusTreeRA) + return rPlusTreeRA->SingleMode(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SingleMode(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -218,6 +256,10 @@ double RAModel::Tau() const return xTreeRA->Tau(); else if (hilbertRTreeRA) return hilbertRTreeRA->Tau(); + else if (rPlusTreeRA) + return rPlusTreeRA->Tau(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Tau(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -238,6 +280,10 @@ double& RAModel::Tau() return xTreeRA->Tau(); else if (hilbertRTreeRA) return hilbertRTreeRA->Tau(); + else if (rPlusTreeRA) + return rPlusTreeRA->Tau(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Tau(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -258,6 +304,10 @@ double RAModel::Alpha() const return xTreeRA->Alpha(); else if (hilbertRTreeRA) return hilbertRTreeRA->Alpha(); + else if (rPlusTreeRA) + return rPlusTreeRA->Alpha(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Alpha(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -278,6 +328,10 @@ double& RAModel::Alpha() return xTreeRA->Alpha(); else if (hilbertRTreeRA) return hilbertRTreeRA->Alpha(); + else if (rPlusTreeRA) + return rPlusTreeRA->Alpha(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->Alpha(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -298,6 +352,10 @@ bool RAModel::SampleAtLeaves() const return xTreeRA->SampleAtLeaves(); else if (hilbertRTreeRA) return hilbertRTreeRA->SampleAtLeaves(); + else if (rPlusTreeRA) + return rPlusTreeRA->SampleAtLeaves(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SampleAtLeaves(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -318,6 +376,10 @@ bool& RAModel::SampleAtLeaves() return xTreeRA->SampleAtLeaves(); else if (hilbertRTreeRA) return hilbertRTreeRA->SampleAtLeaves(); + else if (rPlusTreeRA) + return rPlusTreeRA->SampleAtLeaves(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SampleAtLeaves(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -338,6 +400,10 @@ bool RAModel::FirstLeafExact() const return xTreeRA->FirstLeafExact(); else if (hilbertRTreeRA) return hilbertRTreeRA->FirstLeafExact(); + else if (rPlusTreeRA) + return rPlusTreeRA->FirstLeafExact(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->FirstLeafExact(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -358,6 +424,10 @@ bool& RAModel::FirstLeafExact() return xTreeRA->FirstLeafExact(); else if (hilbertRTreeRA) return hilbertRTreeRA->FirstLeafExact(); + else if (rPlusTreeRA) + return rPlusTreeRA->FirstLeafExact(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->FirstLeafExact(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -378,6 +448,10 @@ size_t RAModel::SingleSampleLimit() const return xTreeRA->SingleSampleLimit(); else if (hilbertRTreeRA) return hilbertRTreeRA->SingleSampleLimit(); + else if (rPlusTreeRA) + return rPlusTreeRA->SingleSampleLimit(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SingleSampleLimit(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -398,6 +472,10 @@ size_t& RAModel::SingleSampleLimit() return xTreeRA->SingleSampleLimit(); else if (hilbertRTreeRA) return hilbertRTreeRA->SingleSampleLimit(); + else if (rPlusTreeRA) + return rPlusTreeRA->SingleSampleLimit(); + else if (rPlusPlusTreeRA) + return rPlusPlusTreeRA->SingleSampleLimit(); throw std::runtime_error("no rank-approximate nearest neighbor search model " "initialized"); @@ -465,6 +543,10 @@ void RAModel::BuildModel(arma::mat&& referenceSet, delete xTreeRA; if (hilbertRTreeRA) delete hilbertRTreeRA; + if (rPlusTreeRA) + delete rPlusTreeRA; + if (rPlusPlusTreeRA) + delete rPlusPlusTreeRA; if (randomBasis) referenceSet = q * referenceSet; @@ -517,6 +599,14 @@ void RAModel::BuildModel(arma::mat&& referenceSet, hilbertRTreeRA = new RAType(std::move(referenceSet), naive, singleMode); break; + case R_PLUS_TREE: + rPlusTreeRA = new RAType(std::move(referenceSet), + naive, singleMode); + break; + case R_PLUS_PLUS_TREE: + rPlusPlusTreeRA = new RAType(std::move(referenceSet), + naive, singleMode); + break; } if (!naive) @@ -598,6 +688,14 @@ void RAModel::Search(arma::mat&& querySet, // No mapping necessary. hilbertRTreeRA->Search(querySet, k, neighbors, distances); break; + case R_PLUS_TREE: + // No mapping necessary. + rPlusTreeRA->Search(querySet, k, neighbors, distances); + break; + case R_PLUS_PLUS_TREE: + // No mapping necessary. + rPlusPlusTreeRA->Search(querySet, k, neighbors, distances); + break; } } @@ -635,6 +733,12 @@ void RAModel::Search(const size_t k, case HILBERT_R_TREE: hilbertRTreeRA->Search(k, neighbors, distances); break; + case R_PLUS_TREE: + rPlusTreeRA->Search(k, neighbors, distances); + break; + case R_PLUS_PLUS_TREE: + rPlusPlusTreeRA->Search(k, neighbors, distances); + break; } } @@ -655,6 +759,10 @@ std::string RAModel::TreeName() const return "X tree"; case HILBERT_R_TREE: return "Hilbert R tree"; + case R_PLUS_TREE: + return "R+ tree"; + case R_PLUS_PLUS_TREE: + return "R++ tree"; default: return "unknown tree"; } diff --git a/src/mlpack/tests/aknn_test.cpp b/src/mlpack/tests/aknn_test.cpp index 4af732b707b..6e0911635bc 100644 --- a/src/mlpack/tests/aknn_test.cpp +++ b/src/mlpack/tests/aknn_test.cpp @@ -351,7 +351,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[14]; + KNNModel models[18]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -366,6 +366,10 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) models[11] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false); models[12] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, true); models[13] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true); + models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); + models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -375,7 +379,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat distancesExact; exact.Search(3, neighborsExact, distancesExact); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 18; ++i) { // We only have a std::move() constructor... so copy the data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp index 398aee508bc..0de22b8b959 100644 --- a/src/mlpack/tests/knn_test.cpp +++ b/src/mlpack/tests/knn_test.cpp @@ -977,7 +977,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[14]; + KNNModel models[18]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -992,6 +992,10 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) models[11] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false); models[12] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, true); models[13] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true); + models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); + models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1001,7 +1005,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest) arma::mat baselineDistances; knn.Search(queryData, 3, baselineNeighbors, baselineDistances); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 18; ++i) { // We only have std::move() constructors so make a copy of our data. arma::mat referenceCopy(referenceData); @@ -1045,7 +1049,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - KNNModel models[14]; + KNNModel models[18]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true); @@ -1060,6 +1064,10 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) models[11] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false); models[12] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, true); models[13] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true); + models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); + models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1069,7 +1077,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest) arma::mat baselineDistances; knn.Search(3, baselineNeighbors, baselineDistances); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 18; ++i) { // We only have a std::move() constructor... so copy the data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/krann_search_test.cpp b/src/mlpack/tests/krann_search_test.cpp index fa95c543f10..2a7f263c590 100644 --- a/src/mlpack/tests/krann_search_test.cpp +++ b/src/mlpack/tests/krann_search_test.cpp @@ -625,7 +625,7 @@ BOOST_AUTO_TEST_CASE(RAModelTest) data::Load("rann_test_q_3_100.csv", queryData, true); // Build all the possible models. - KNNModel models[12]; + KNNModel models[16]; models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, false); models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, true); models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false); @@ -638,13 +638,17 @@ BOOST_AUTO_TEST_CASE(RAModelTest) models[9] = KNNModel(KNNModel::TreeTypes::X_TREE, true); models[10] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, false); models[11] = KNNModel(KNNModel::TreeTypes::HILBERT_R_TREE, true); + models[12] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false); + models[13] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true); + models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false); + models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true); arma::Mat qrRanks; data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose. for (size_t j = 0; j < 3; ++j) { - for (size_t i = 0; i < 12; ++i) + for (size_t i = 0; i < 16; ++i) { // We only have std::move() constructors so make a copy of our data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp index 1c9f73b3cca..a36d34dadf8 100644 --- a/src/mlpack/tests/range_search_test.cpp +++ b/src/mlpack/tests/range_search_test.cpp @@ -1251,7 +1251,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - RSModel models[14]; + RSModel models[18]; models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true); models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false); models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true); @@ -1266,6 +1266,10 @@ BOOST_AUTO_TEST_CASE(RSModelTest) models[11] = RSModel(RSModel::TreeTypes::BALL_TREE, false); models[12] = RSModel(RSModel::TreeTypes::HILBERT_R_TREE, true); models[13] = RSModel(RSModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, true); + models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false); + models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1279,7 +1283,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest) vector>> baselineSorted; SortResults(baselineNeighbors, baselineDistances, baselineSorted); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 18; ++i) { // We only have std::move() constructors, so make a copy of our data. arma::mat referenceCopy(referenceData); @@ -1323,7 +1327,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) arma::mat referenceData = arma::randu(10, 200); // Build all the possible models. - RSModel models[14]; + RSModel models[18]; models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true); models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false); models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true); @@ -1338,6 +1342,10 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) models[11] = RSModel(RSModel::TreeTypes::BALL_TREE, false); models[12] = RSModel(RSModel::TreeTypes::HILBERT_R_TREE, true); models[13] = RSModel(RSModel::TreeTypes::HILBERT_R_TREE, false); + models[14] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, true); + models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false); + models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true); + models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false); for (size_t j = 0; j < 2; ++j) { @@ -1350,7 +1358,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest) vector>> baselineSorted; SortResults(baselineNeighbors, baselineDistances, baselineSorted); - for (size_t i = 0; i < 14; ++i) + for (size_t i = 0; i < 18; ++i) { // We only have std::move() cosntructors, so make a copy of our data. arma::mat referenceCopy(referenceData); diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp index d331bb317c4..5ba1ec8aec1 100644 --- a/src/mlpack/tests/rectangle_tree_test.cpp +++ b/src/mlpack/tests/rectangle_tree_test.cpp @@ -878,20 +878,9 @@ void CheckOverlap(const TreeType& tree) { if (j == i) continue; - success = false; - // Two nodes overlap each other if and only if there are no dimension - // in which they do not overlap each other. - for (size_t k = 0; k < tree.Bound().Dim(); k++) - { - if ((tree.Child(i).Bound()[k].Lo() >= - tree.Child(j).Bound()[k].Hi()) || - (tree.Child(j).Bound()[k].Lo() >= - tree.Child(i).Bound()[k].Hi())) - { - success = true; - break; - } - } + + success = !tree.Child(i).Bound().Contains(tree.Child(j).Bound()); + if (!success) break; } @@ -934,10 +923,10 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest) arma::mat distances2; typedef RPlusTree, - arma::mat> TreeType; + arma::mat > TreeType; TreeType rPlusTree(dataset, 20, 6, 5, 2, 0); - // Nearest neighbor search with the X tree. + // Nearest neighbor search with the R+ tree. NeighborSearch, arma::mat, RPlusTree > knn1(&rPlusTree, true); @@ -1004,18 +993,8 @@ void CheckRPlusPlusTreeBound(const TreeType& tree) continue; const Bound& bound2 = tree.Child(j).AuxiliaryInfo().OuterBound(); - // Two bounds overlap each other if and only if there are no dimension - // in which they do not overlap each other. - success = false; - for (size_t k = 0; k < tree.Bound().Dim(); k++) - { - if (bound1[k].Lo() >= bound2[k].Hi() || - bound2[k].Lo() >= bound1[k].Hi()) - { - success = true; - break; - } - } + success = !bound1.Contains(bound2); + if (!success) break; } @@ -1046,7 +1025,7 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest) // Check the MinimalSplitsNumberSweep. typedef RectangleTree, arma::mat, - RPlusTreeSplit, + RPlusTreeSplit, RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation> RPlusPlusTreeMinimalSplits; @@ -1071,10 +1050,10 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest) arma::mat distances2; typedef RPlusPlusTree, arma::mat> TreeType; + NeighborSearchStat, arma::mat > TreeType; TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0); - // Nearest neighbor search with the X tree. + // Nearest neighbor search with the R++ tree. NeighborSearch, arma::mat, RPlusPlusTree > knn1(&rPlusPlusTree, true); diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp index 81a94463b25..4c68bb93d18 100644 --- a/src/mlpack/tests/tree_test.cpp +++ b/src/mlpack/tests/tree_test.cpp @@ -1272,10 +1272,6 @@ void GenerateVectorOfTree(TreeType* node, size_t depth, std::vector& v); -template -bool DoBoundsIntersect(HRectBound& a, - HRectBound& b); - /** * Exhaustive kd-tree test based on #125. * @@ -1344,7 +1340,7 @@ BOOST_AUTO_TEST_CASE(KdTreeTest) for (size_t i = depth; i < 2 * depth && i < v.size(); i++) for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++) if (v[i] != NULL && v[j] != NULL) - BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound())); + BOOST_REQUIRE(!v[i]->Bound().Contains(v[j]->Bound())); depth *= 2; } @@ -1430,26 +1426,6 @@ BOOST_AUTO_TEST_CASE(BallTreeTest) } } -template -bool DoBoundsIntersect(HRectBound& a, - HRectBound& b) -{ - size_t dimensionality = a.Dim(); - - Range r_a; - Range r_b; - - for (size_t i = 0; i < dimensionality; i++) - { - r_a = a[i]; - r_b = b[i]; - if (r_a < r_b || r_a > r_b) // If a does not overlap b at all. - return false; - } - - return true; -} - template void GenerateVectorOfTree(TreeType* node, size_t depth, @@ -1541,7 +1517,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest) for (size_t i = depth; i < 2 * depth && i < v.size(); i++) for (size_t j = i + 1; j < 2 * depth && j < v.size(); j++) if (v[i] != NULL && v[j] != NULL) - BOOST_REQUIRE(!DoBoundsIntersect(v[i]->Bound(), v[j]->Bound())); + BOOST_REQUIRE(!v[i]->Bound().Contains(v[j]->Bound())); depth *= 2; } From e75a0879a12b6ddabc1046d2a0fabde080b69543 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Tue, 5 Jul 2016 12:21:12 +0300 Subject: [PATCH 7/8] Replace HRectBound::Intersect() by operator&() --- src/mlpack/core/tree/hrectbound.hpp | 7 ++++++- src/mlpack/core/tree/hrectbound_impl.hpp | 26 +++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/mlpack/core/tree/hrectbound.hpp b/src/mlpack/core/tree/hrectbound.hpp index 948174a153e..7a0823ff332 100644 --- a/src/mlpack/core/tree/hrectbound.hpp +++ b/src/mlpack/core/tree/hrectbound.hpp @@ -190,7 +190,12 @@ class HRectBound /** * Returns the intersection of this bound and another. */ - HRectBound Intersect(const HRectBound& bound) const; + HRectBound operator&(const HRectBound& bound) const; + + /** + * Intersects this bound with another. + */ + HRectBound& operator&=(const HRectBound& bound); /** * Returns the volume of overlap of this bound and another. diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp index 9b3ee1ed835..ccfa0265b4f 100644 --- a/src/mlpack/core/tree/hrectbound_impl.hpp +++ b/src/mlpack/core/tree/hrectbound_impl.hpp @@ -435,6 +435,9 @@ inline bool HRectBound::Contains(const VecType& point) con return true; } +/** + * Determines if this bound partially contains a bound. + */ template inline bool HRectBound::Contains( const HRectBound& bound) const @@ -451,9 +454,12 @@ inline bool HRectBound::Contains( return true; } +/** + * Returns the intersection of this bound and another. + */ template inline HRectBound HRectBound:: -Intersect(const HRectBound& bound) const +operator&(const HRectBound& bound) const { HRectBound result(dim); @@ -465,6 +471,24 @@ Intersect(const HRectBound& bound) const return result; } +/** + * Intersects this bound with another. + */ +template +inline HRectBound& HRectBound:: +operator&=(const HRectBound& bound) +{ + for (size_t k = 0; k < dim; k++) + { + bounds[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].Lo()); + bounds[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].Hi()); + } + return *this; +} + +/** + * Returns the volume of overlap of this bound and another. + */ template inline ElemType HRectBound::Overlap( const HRectBound& bound) const From cdb6b5f36cf19efbb0eeb711e47c75313c30bdd7 Mon Sep 17 00:00:00 2001 From: Mikhail Lozhnikov Date: Thu, 7 Jul 2016 23:05:59 +0300 Subject: [PATCH 8/8] Add const modifiers. --- .../tree/rectangle_tree/minimal_coverage_sweep_impl.hpp | 8 ++++---- .../rectangle_tree/minimal_splits_number_sweep_impl.hpp | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp index 5c899097a3a..66ceec53adb 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp @@ -32,8 +32,8 @@ SweepNonLeafNode(const size_t axis, } // Sort high bounds of children. std::sort(sorted.begin(), sorted.end(), - [] (std::pair& s1, - std::pair& s2) + [] (const std::pair& s1, + const std::pair& s2) { return s1.first < s2.first; }); @@ -99,8 +99,8 @@ SweepLeafNode(const size_t axis, // Sort high bounds of children. std::sort(sorted.begin(), sorted.end(), - [] (std::pair& s1, - std::pair& s2) + [] (const std::pair& s1, + const std::pair& s2) { return s1.first < s2.first; }); diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp index b1b3f0f880d..f320ac43d50 100644 --- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp +++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp @@ -32,8 +32,8 @@ size_t MinimalSplitsNumberSweep::SweepNonLeafNode( // Sort candidates in order to check balancing. std::sort(sorted.begin(), sorted.end(), - [] (std::pair& s1, - std::pair& s2) + [] (const std::pair& s1, + const std::pair& s2) { return s1.first < s2.first; });