Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R+ and R++ trees implementation #699

Merged
merged 12 commits into from Jul 7, 2016
72 changes: 63 additions & 9 deletions src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp
Expand Up @@ -2,6 +2,8 @@
* @file minimal_coverage_sweep.hpp
* @author Mikhail Lozhnikov
*
* Definition of the MinimalCoverageSweep class, a class that finds a partition
* of a node along an axis.
*/
#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
Expand All @@ -11,17 +13,31 @@ namespace tree {

constexpr double fillFactor = 0.5;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fillFactor is not needed at all.


/**
* The MinimalCoverageSweep class finds a partition along which we
* can split a node according to the coverage of two resulting nodes.
* Moreover, the class evaluates the cost of each split.
*
* @tparam SplitPolicy The class that provides rules for inserting children of
* a node that is being split into two new subtrees.
*/
template<typename SplitPolicy>
class MinimalCoverageSweep
{
private:
/**
* Class to allow for faster sorting.
*/
template<typename ElemType>
struct SortStruct
{
ElemType d;
int n;
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that SortStruct is used a few other places, but does it give us any advantage over std::pair<ElemType, int>? It's possible that SortStruct is faster, but I am not sure of that. Have you played with that at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I haven't. I think these approaches are equal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I opened #712 about this and we can figure it out there instead of inside this PR, so we can get this one done more quickly. It is probably easier to do comparisons by making a separate branch from this one. Don't feel obligated to do that; I will eventually get around to it if you don't want to. But it seems like you never turn things down, so I guess you will probably do it anyway. :)


/**
* Comparator for sorting with SortStruct.
*/
template<typename ElemType>
static bool StructComp(const SortStruct<ElemType>& s1,
const SortStruct<ElemType>& s2)
Expand All @@ -30,28 +46,66 @@ class MinimalCoverageSweep
}

public:

//! A struct that provides the type of the sweep cost.
template<typename TreeType>
struct SweepCost
{
typedef typename TreeType::ElemType type;
};

/**
* Find a suitable partition of a non-leaf node along the provided axis.
* The method returns the cost of the split.
*
* @param axis The axis along which we are finding a partition.
* @param node The node that is being split.
* @param axisCut The coordinate at which the node may be split.
*/
template<typename TreeType>
static typename TreeType::ElemType SweepNonLeafNode(size_t axis,
const TreeType* node, typename TreeType::ElemType& axisCut);
static typename TreeType::ElemType SweepNonLeafNode(
const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut);

/**
* Find a suitable partition of a leaf node along the provided axis.
* The method returns the cost of the split.
*
* @param axis The axis along which we are finding a partition.
* @param node The node that is being split.
* @param axisCut The coordinate at which the node may be split.
*/
template<typename TreeType>
static typename TreeType::ElemType SweepLeafNode(size_t axis,
const TreeType* node, typename TreeType::ElemType& axisCut);
static typename TreeType::ElemType SweepLeafNode(
const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut);

/**
* Check if an intermediate node can be split along the axis at the provided
* coordinate.
*
* @param node The node that is being split.
* @param cutAxis The axis that we want to check.
* @param cut The coordinate that we want to check.
*/
template<typename TreeType, typename ElemType>
static bool CheckNonLeafSweep(const TreeType* node, size_t cutAxis,
ElemType cut);
static bool CheckNonLeafSweep(const TreeType* node,
const size_t cutAxis,
const ElemType cut);

/**
* Check if a leaf node can be split along the axis at the provided
* coordinate.
*
* @param node The node that is being split.
* @param cutAxis The axis that we want to check.
* @param cut The coordinate that we want to check.
*/
template<typename TreeType, typename ElemType>
static bool CheckLeafSweep(const TreeType* node, size_t cutAxis,
ElemType cut);
static bool CheckLeafSweep(const TreeType* node,
const size_t cutAxis,
const ElemType cut);
};

} // namespace tree
Expand Down
Expand Up @@ -2,6 +2,8 @@
* @file minimal_coverage_sweep_impl.hpp
* @author Mikhail Lozhnikov
*
* Implementation of the MinimalCoverageSweep class, a class that finds a
* partition of a node along an axis.
*/
#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
Expand All @@ -14,8 +16,9 @@ namespace tree {
template<typename SplitPolicy>
template<typename TreeType>
typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
SweepNonLeafNode(size_t axis, const TreeType* node,
typename TreeType::ElemType& axisCut)
SweepNonLeafNode(const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut)
{
typedef typename TreeType::ElemType ElemType;

Expand All @@ -26,12 +29,14 @@ SweepNonLeafNode(size_t axis, const TreeType* node,
sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
sorted[i].n = i;
}
// Sort high bounds of children.
std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);

size_t splitPointer = fillFactor * node->NumChildren();

axisCut = sorted[splitPointer - 1].d;

// Check if the partition is suitable.
if (!CheckNonLeafSweep(node, axis, axisCut))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you doing this check first before iterating over different values of splitPointer? I guess I don't understand the reasoning. It seems like this entire block of code is just to check that there does exist a valid split on this axis, so I guess the first if is just a shortcut to avoid checking every possible cut. If by default you want to go with a midpoint split, you can hardcode the 0.5 and remove fillFactor, I agree with that. But I guess if the default split does not work, you are trying all other possible splits until you find one that works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely so, I'll add a comment in order to clarify that.

{
for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++)
Expand All @@ -50,6 +55,7 @@ SweepNonLeafNode(size_t axis, const TreeType* node,
std::vector<ElemType> lowerBound2(node->Bound().Dim());
std::vector<ElemType> highBound2(node->Bound().Dim());

// Find lower and high bounds of two resulting nodes.
for (size_t k = 0; k < node->Bound().Dim(); k++)
{
lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
Expand All @@ -75,6 +81,9 @@ SweepNonLeafNode(size_t axis, const TreeType* node,
}
}

// Evaluate the cost of the split i.e. calculate the total coverage
// of two resulting nodes.

ElemType area1 = 1.0, area2 = 1.0;
ElemType overlappedArea = 1.0;

Expand Down Expand Up @@ -112,8 +121,9 @@ SweepNonLeafNode(size_t axis, const TreeType* node,
template<typename SplitPolicy>
template<typename TreeType>
typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
SweepLeafNode(size_t axis, const TreeType* node,
typename TreeType::ElemType& axisCut)
SweepLeafNode(const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut)
{
typedef typename TreeType::ElemType ElemType;

Expand All @@ -127,12 +137,14 @@ SweepLeafNode(size_t axis, const TreeType* node,
sorted[i].n = i;
}

// Sort high bounds of children.
std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);

size_t splitPointer = fillFactor * node->Count();

axisCut = sorted[splitPointer - 1].d;

// Check if the partition is suitable.
if (!CheckLeafSweep(node, axis, axisCut))
return std::numeric_limits<ElemType>::max();

Expand All @@ -141,6 +153,7 @@ SweepLeafNode(size_t axis, const TreeType* node,
std::vector<ElemType> lowerBound2(node->Bound().Dim());
std::vector<ElemType> highBound2(node->Bound().Dim());

// Find lower and high bounds of two resulting nodes.
for (size_t k = 0; k < node->Bound().Dim(); k++)
{
lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
Expand All @@ -154,7 +167,8 @@ SweepLeafNode(size_t axis, const TreeType* node,
highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
}

lowerBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
lowerBound2[k] = node->Dataset().col(
node->Point(sorted[splitPointer].n))[k];
highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];

for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
Expand All @@ -166,6 +180,9 @@ SweepLeafNode(size_t axis, const TreeType* node,
}
}

// Evaluate the cost of the split i.e. calculate the total coverage
// of two resulting nodes.

ElemType area1 = 1.0, area2 = 1.0;
ElemType overlappedArea = 1.0;

Expand All @@ -181,11 +198,14 @@ SweepLeafNode(size_t axis, const TreeType* node,
template<typename SplitPolicy>
template<typename TreeType, typename ElemType>
bool MinimalCoverageSweep<SplitPolicy>::
CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
CheckNonLeafSweep(const TreeType* node,
const size_t cutAxis,
const ElemType cut)
{
size_t numTreeOneChildren = 0;
size_t numTreeTwoChildren = 0;

// Calculate the number of children in the resulting nodes.
for (size_t i = 0; i < node->NumChildren(); i++)
{
TreeType* child = node->Children()[i];
Expand All @@ -196,6 +216,7 @@ CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
numTreeTwoChildren++;
else
{
// The split is required.
numTreeOneChildren++;
numTreeTwoChildren++;
}
Expand All @@ -210,11 +231,14 @@ CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
template<typename SplitPolicy>
template<typename TreeType, typename ElemType>
bool MinimalCoverageSweep<SplitPolicy>::
CheckLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
CheckLeafSweep(const TreeType* node,
const size_t cutAxis,
const ElemType cut)
{
size_t numTreeOnePoints = 0;
size_t numTreeTwoPoints = 0;

// Calculate the number of points in the resulting nodes.
for (size_t i = 0; i < node->NumPoints(); i++)
{
if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
Expand Down
Expand Up @@ -2,43 +2,80 @@
* @file minimal_splits_number_sweep.hpp
* @author Mikhail Lozhnikov
*
* Definition of the MinimalSplitsNumberSweep class, a class that finds a
* partition of a node along an axis.
*/
#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP

namespace mlpack {
namespace tree {

/**
* The MinimalSplitsNumberSweep class finds a partition along which we
* can split a node according to the number of required splits of the node.
* Moreover, the class evaluates the cost of each split.
*
* @tparam SplitPolicy The class that provides rules for inserting children of
* a node that is being split into two new subtrees.
*/
template<typename SplitPolicy>
class MinimalSplitsNumberSweep
{
private:
/**
* Class to allow for faster sorting.
*/
template<typename ElemType>
struct SortStruct
{
ElemType d;
int n;
};

/**
* Comparator for sorting with SortStruct.
*/
template<typename ElemType>
static bool StructComp(const SortStruct<ElemType>& s1,
const SortStruct<ElemType>& s2)
{
return s1.d < s2.d;
}
public:
//! A struct that provides the type of the sweep cost.
template<typename>
struct SweepCost
{
typedef size_t type;
};

/**
* Find a suitable partition of a non-leaf node along the provided axis.
* The method returns the cost of the split.
*
* @param axis The axis along which we are finding a partition.
* @param node The node that is being split.
* @param axisCut The coordinate at which the node may be split.
*/
template<typename TreeType>
static size_t SweepNonLeafNode(size_t axis, const TreeType* node,
static size_t SweepNonLeafNode(
const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut);

/**
* Find a suitable partition of a leaf node along the provided axis.
* The method returns the cost of the split.
*
* @param axis The axis along which we are finding a partition.
* @param node The node that is being split.
* @param axisCut The coordinate at which the node may be split.
*/
template<typename TreeType>
static size_t SweepLeafNode(size_t axis, const TreeType* node,
static size_t SweepLeafNode(
const size_t axis,
const TreeType* node,
typename TreeType::ElemType& axisCut);
};

Expand Down