Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vantage point tree #708

Merged
merged 14 commits into from Aug 8, 2016
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
Expand Up @@ -456,6 +456,10 @@ class BinarySpaceTree
//! Store the center of the bounding region in the given vector.
void Center(arma::vec& center) { bound.Center(center); }

//! Returns false: The first point of this node is not the centroid
//! of its bound.
static constexpr bool IsFirstPointCentroid() { return false; }

private:
/**
* Splits the current node, assigning its left and right children recursively.
Expand Down
4 changes: 4 additions & 0 deletions src/mlpack/core/tree/cover_tree/cover_tree.hpp
Expand Up @@ -374,6 +374,10 @@ class CoverTree
//! Get the instantiated metric.
MetricType& Metric() const { return *metric; }

//! Returns true: The first point of this node is the centroid
//! of its bound.
static constexpr bool IsFirstPointCentroid() { return true; }

private:
//! Reference to the matrix which this tree is built on.
const MatType* dataset;
Expand Down
4 changes: 4 additions & 0 deletions src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
Expand Up @@ -492,6 +492,10 @@ class RectangleTree
//! Returns false: this tree type does not have self children.
static bool HasSelfChildren() { return false; }

//! Returns false: The first point of this node is not the centroid
//! of its bound.
static constexpr bool IsFirstPointCentroid() { return false; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea I had here for IsFirstPointCentroid() was to leave it a part of TreeTraits, but just to make it a function:

template<typename TreeType>
struct TreeTraits
{
  static constexpr bool FirstPointIsCentroid(TreeType*) { return false; }
  ...
}

and then for the VP tree we can specialize and make it non-constexpr:

template<>
struct TreeTraits<VPTree>
{
  static bool FirstPointIsCentroid(VPTree* v) { ... }
}

It changes the syntax of all of the traits, but it keeps them out of the definition of each tree itself, and we are guaranteed that there is a value for these even if the person who wrote the tree did not specify them. Does that seem reasonable to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. That should simplify the TreeType API.


private:
/**
* Splits the current node, recursing up the tree.
Expand Down
175 changes: 102 additions & 73 deletions src/mlpack/core/tree/vantage_point_tree/dual_tree_traverser_impl.hpp
Expand Up @@ -79,16 +79,23 @@ DualTreeTraverser<RuleType>::Traverse(
{
// We have to recurse down the query node. In this case the recursion order
// does not matter.
const double pointScore = rule.Score(queryNode.Point(0), referenceNode);
++numScores;

if (pointScore != DBL_MAX)
Traverse(queryNode.Point(0), referenceNode);
else
++numPrunes;

// Before recursing, we have to set the traversal information correctly.
rule.TraversalInfo() = traversalInfo;
// If the first point of the query node is the centroid, the query node
// contains a point. In this case we should run the single tree traverser.
if (queryNode.IsFirstPointCentroid())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the query node contains a point we have to use the single tree traverser.

{
const double pointScore = rule.Score(queryNode.Point(0), referenceNode);
++numScores;

if (pointScore != DBL_MAX)
Traverse(queryNode.Point(0), referenceNode);
else
++numPrunes;

// Before recursing, we have to set the traversal information correctly.
rule.TraversalInfo() = traversalInfo;
}

const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
++numScores;

Expand All @@ -109,10 +116,15 @@ DualTreeTraverser<RuleType>::Traverse(
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
const size_t queryEnd = queryNode.Begin() + queryNode.Count();
for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
rule.BaseCase(query, referenceNode.Point(0));
numBaseCases += queryNode.Count();
// If the reference node contains a point we should calculate all
// base cases with this point.
if (referenceNode.IsFirstPointCentroid())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the reference node contains a point we have to calculate all base cases.

{
const size_t queryEnd = queryNode.Begin() + queryNode.Count();
for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
rule.BaseCase(query, referenceNode.Point(0));
numBaseCases += queryNode.Count();
}
// We have to recurse down the reference node. In this case the recursion
// order does matter. Before recursing, though, we have to set the
// traversal information correctly.
Expand Down Expand Up @@ -189,69 +201,36 @@ DualTreeTraverser<RuleType>::Traverse(
}
else
{
for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0));
numBaseCases += queryNode.NumDescendants();
// If the reference node contains a point we should calculate all
// base cases with this point.
if (referenceNode.IsFirstPointCentroid())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is inefficient. Is it possible to implement the score algorithm for a reference point and a query node?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is inefficient. I think that the better solution is to do this: during recursion, for all query points in a node, perform a base case with all parent vantage points. Something like this:

for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
  TreeType* p = referenceNode.Parent();
  do
  {
    if (p->NumPoints() > 0) // We are holding a vantage point.
      BaseCase(queryNode.Point(i), p->Point(0)); // p->NumPoints() should never be greater than 1.
  } while ((p = p->Parent()) != NULL);
}

That would have to come after the base case section but before any recursions, so I guess up by line 50 or so would be the right place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah hang on, but we could end up with duplicated base cases like this. Let me think for a while if there might be a better way to avoid duplications but preserve this type of technique. Maybe you are right that traversing with a query node and a reference point could work, but that would require a lot of refactoring of all the dual-tree algorithms so I'd prefer to avoid that...

{
for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
rule.BaseCase(queryNode.Descendant(i), referenceNode.Point(0));
numBaseCases += queryNode.NumDescendants();
}
// We have to recurse down both query and reference nodes. Because the
// query descent order does not matter, we will go to the left query child
// first. Before recursing, we have to set the traversal information
// correctly.
double leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left());
typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
rule.TraversalInfo() = traversalInfo;
double rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right());
typename RuleType::TraversalInfoType rightInfo;
numScores += 2;

if (leftScore < rightScore)
{
// Recurse to the left. Restore the left traversal info. Store the right
// traversal info.
rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(queryNode.Point(0), *referenceNode.Left());

// Is it still valid to recurse to the right?
rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(),
rightScore);
double leftScore;
typename RuleType::TraversalInfoType leftInfo;
double rightScore;
typename RuleType::TraversalInfoType rightInfo;

if (rightScore != DBL_MAX)
{
// Restore the right traversal info.
rule.TraversalInfo() = rightInfo;
Traverse(queryNode.Point(0), *referenceNode.Right());
}
else
++numPrunes;
}
else if (rightScore < leftScore)
if (queryNode.IsFirstPointCentroid())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the single tree traverser if the query node contains a point.

{
// Recurse to the right.
Traverse(queryNode.Point(0), *referenceNode.Right());

// Is it still valid to recurse to the left?
leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(),
leftScore);
leftScore = rule.Score(queryNode.Point(0), *referenceNode.Left());
leftInfo = rule.TraversalInfo();
rule.TraversalInfo() = traversalInfo;
rightScore = rule.Score(queryNode.Point(0), *referenceNode.Right());
numScores += 2;

if (leftScore != DBL_MAX)
if (leftScore < rightScore)
{
// Restore the left traversal info.
rule.TraversalInfo() = leftInfo;
Traverse(queryNode.Point(0), *referenceNode.Left());
}
else
++numPrunes;
}
else
{
if (leftScore == DBL_MAX)
{
numPrunes += 2;
}
else
{
// Choose the left first. Restore the left traversal info and store the
// right traversal info.
// Recurse to the left. Restore the left traversal info. Store the right
// traversal info.
rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(queryNode.Point(0), *referenceNode.Left());
Expand All @@ -262,17 +241,63 @@ DualTreeTraverser<RuleType>::Traverse(

if (rightScore != DBL_MAX)
{
// Restore the right traversal information.
// Restore the right traversal info.
rule.TraversalInfo() = rightInfo;
Traverse(queryNode.Point(0), *referenceNode.Right());
}
else
++numPrunes;
}
}
else if (rightScore < leftScore)
{
// Recurse to the right.
Traverse(queryNode.Point(0), *referenceNode.Right());

// Restore the main traversal information.
rule.TraversalInfo() = traversalInfo;
// Is it still valid to recurse to the left?
leftScore = rule.Rescore(queryNode.Point(0), *referenceNode.Left(),
leftScore);

if (leftScore != DBL_MAX)
{
// Restore the left traversal info.
rule.TraversalInfo() = leftInfo;
Traverse(queryNode.Point(0), *referenceNode.Left());
}
else
++numPrunes;
}
else
{
if (leftScore == DBL_MAX)
{
numPrunes += 2;
}
else
{
// Choose the left first. Restore the left traversal info and store the
// right traversal info.
rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(queryNode.Point(0), *referenceNode.Left());

// Is it still valid to recurse to the right?
rightScore = rule.Rescore(queryNode.Point(0), *referenceNode.Right(),
rightScore);

if (rightScore != DBL_MAX)
{
// Restore the right traversal information.
rule.TraversalInfo() = rightInfo;
Traverse(queryNode.Point(0), *referenceNode.Right());
}
else
++numPrunes;
}
}

// Restore the main traversal information.
rule.TraversalInfo() = traversalInfo;
}

// Now recurse down the left node.
leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
Expand Down Expand Up @@ -452,8 +477,12 @@ DualTreeTraverser<RuleType>::Traverse(
return;
}

rule.BaseCase(queryIndex, referenceNode.Point(0));
numBaseCases++;
// If the reference node contains a point we should calculate the base case.
if (referenceNode.IsFirstPointCentroid())
{
rule.BaseCase(queryIndex, referenceNode.Point(0));
numBaseCases++;
}

// Store the current traversal info.
traversalInfo = rule.TraversalInfo();
Expand Down
Expand Up @@ -51,7 +51,9 @@ SingleTreeTraverser<RuleType>::Traverse(
return;
}

rule.BaseCase(queryIndex, referenceNode.Point(0));
// If the reference node contains a point we should calculate the base case.
if (referenceNode.IsFirstPointCentroid())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the reference node contains a point we have to calculate the base case.

rule.BaseCase(queryIndex, referenceNode.Point(0));

// If either score is DBL_MAX, we do not recurse into that node.
double leftScore = rule.Score(queryIndex, *referenceNode.Left());
Expand Down
Expand Up @@ -29,6 +29,7 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin,
if (mu == 0)
return false;

// The first point of the left child is centroid.
data.swap_cols(begin, vantagePointIndex);

arma::Col<ElemType> vantagePoint = data.col(begin);
Expand All @@ -54,6 +55,7 @@ SplitNode(const BoundType& bound, MatType& data, const size_t begin,
if (mu == 0)
return false;

// The first point of the left child is centroid.
data.swap_cols(begin, vantagePointIndex);
size_t t = oldFromNew[begin];
oldFromNew[begin] = oldFromNew[vantagePointIndex];
Expand Down