Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spill trees #747

Merged
merged 84 commits into from
Aug 18, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
8adae19
Add Spill Trees, based on binary space trees.
MarcosPividori Jul 4, 2016
634e271
Fix details.
MarcosPividori Jul 6, 2016
8b2dc10
Improve expansion of node's bound. Do not include the overlapping buf…
MarcosPividori Jul 7, 2016
7aaf969
Fix simple errors.
MarcosPividori Jul 8, 2016
8d5994c
Add support for Hybrid SP-Tree Search (Single tree traverser).
MarcosPividori Jul 12, 2016
2e67697
Add support for spill trees in knn search.
MarcosPividori Jul 12, 2016
c23f0b5
Syntax details.
MarcosPividori Jul 13, 2016
91334cb
Fix error in the order of parameters.
MarcosPividori Jul 13, 2016
061ee92
Add tests for spill tree.
MarcosPividori Jul 16, 2016
7bd4416
Add approximate knn search tests for hybrid spill trees.
MarcosPividori Jul 16, 2016
946eddf
Add exact knn search tests for hybrid spill trees.
MarcosPividori Jul 16, 2016
c9f6cd3
Include missing headers.
MarcosPividori Jul 16, 2016
e33a112
Use a simple tree traverser for spill tree (Remove defeatist seach in…
MarcosPividori Jul 18, 2016
e1c01ca
Implement defeatist search in the Rescore() method, with a specializa…
MarcosPividori Jul 18, 2016
723bec5
Remove unnecessary BoundType template parameter.
MarcosPividori Jul 27, 2016
2af816a
Record splitDimension and splitValue.
MarcosPividori Jul 27, 2016
5266096
Set Overlapping node when percentage greater than rho but not because of
MarcosPividori Jul 27, 2016
a599bf8
Add NeighborSearchRules specialization for Spill Trees.
MarcosPividori Jul 27, 2016
b4d5a1d
Add dual tree traverser for Spill Trees.
MarcosPividori Jul 27, 2016
9551cc6
Set non-overlapping for spill query tree.
MarcosPividori Jul 27, 2016
9c1ef2a
Avoid calculating bounds when oldScore is the best possible.
MarcosPividori Jul 27, 2016
498e74d
Remove unnecessary copying of dataset in SpillTrees.
MarcosPividori Jul 28, 2016
e95e3b4
Create a new class SpillSearch that encapsulates an instance of Neigh…
MarcosPividori Jul 28, 2016
0120357
Update NSModel to use SpillSearch class.
MarcosPividori Jul 28, 2016
5d68382
Log Num of BaseCases/Score.
MarcosPividori Jul 29, 2016
bff4eca
Add more tests for SpillSearch
MarcosPividori Jul 29, 2016
c86b32c
Update documentation.
MarcosPividori Aug 1, 2016
ce34be6
Remove B_2 bound for neighbor search with spill trees.
MarcosPividori Aug 1, 2016
ea37f0d
Change spill rules's filename.
MarcosPividori Aug 1, 2016
f778e8e
Fix SpillTree's move constructor
MarcosPividori Aug 1, 2016
a4ac0b2
Include overlapping points in each child's bounding box.
MarcosPividori Aug 1, 2016
9588a6d
Generalize Spill Trees, to consider general splitting hyperplanes, no…
MarcosPividori Aug 4, 2016
4a66098
Remove unnecessary code.
MarcosPividori Aug 4, 2016
b9e23d4
Update SpillSearch to the general definition of Spill Trees.
MarcosPividori Aug 4, 2016
a71b57c
Fix serialization.
MarcosPividori Aug 4, 2016
54614ba
Update documentation.
MarcosPividori Aug 5, 2016
388941a
Use arma::Col instead of std::vector. Also, avoid resizing in SplitPo…
MarcosPividori Aug 6, 2016
72274e7
Properly consider all points to the left when hyperplane was not init…
MarcosPividori Aug 6, 2016
fdc8e9c
Avoid calculating the score of both child nodes when not necessary, w…
MarcosPividori Aug 8, 2016
8f32156
Split implementation of different split methods in separated files, a…
MarcosPividori Aug 8, 2016
02d4a6d
Add a new test to check the splitting of points in Spill Trees.
MarcosPividori Aug 10, 2016
e81731d
Always normalise a projection vector.
MarcosPividori Aug 11, 2016
17d2c2f
Add Tests for Hyperplane class.
MarcosPividori Aug 11, 2016
d8b16a0
Fix file order.
MarcosPividori Aug 11, 2016
3b15fdf
Add new element to TreeTraits, to know if NumDescendants() includes d…
MarcosPividori Aug 11, 2016
fa81be4
Add rho and leafSize members on SpillSearch, and a command line param…
MarcosPividori Aug 11, 2016
9a8a9eb
Add command line parameter to calculate the effective error.
MarcosPividori Aug 11, 2016
f33a19d
Shouldn't modify parameters of trained model.
MarcosPividori Aug 11, 2016
d22e65a
Merge branch 'master' into spill-trees
MarcosPividori Aug 11, 2016
69804c6
Properly count the number of cases when calculating the effective err…
MarcosPividori Aug 11, 2016
180fc11
Merge branch 'master' into spill-trees
MarcosPividori Aug 16, 2016
c42e40b
This should make appVeyor succeed.
MarcosPividori Aug 16, 2016
7e5e054
Remove old include.
MarcosPividori Aug 16, 2016
d3ab34a
Add compiler option for MSVC.
MarcosPividori Aug 16, 2016
e06cfc8
Many syntax details.
MarcosPividori Aug 16, 2016
722d211
Fix comment.
MarcosPividori Aug 16, 2016
89725f4
Syntax details.
MarcosPividori Aug 16, 2016
0757955
Use linspace to fill a Col vector.
MarcosPividori Aug 16, 2016
dce67ec
Simplify the code.
MarcosPividori Aug 16, 2016
9215515
Add EffectiveError() to NeighborSearch class, and a proper command li…
MarcosPividori Aug 16, 2016
3e79d27
Add Recall() to NeighborSearch class, and a proper command line optio…
MarcosPividori Aug 16, 2016
99de521
Emphasize command line comments.
MarcosPividori Aug 16, 2016
ac5f836
Simple syntax detail.
MarcosPividori Aug 16, 2016
207c1d4
Split test in 2 tests.
MarcosPividori Aug 16, 2016
471bb24
Simplify knn test.
MarcosPividori Aug 16, 2016
f1253d3
Add tests for Move and Copy constructors of SpilTrees.
MarcosPividori Aug 17, 2016
f989164
Details in spilltree and knn tests.
MarcosPividori Aug 17, 2016
02fb618
Split line.
MarcosPividori Aug 17, 2016
0267acb
Properly define the Tree Traversers for Spill trees.
MarcosPividori Aug 17, 2016
70fbeab
Add a new template parameter to NeighborSearch class, to set a specif…
MarcosPividori Aug 17, 2016
b334674
Set defeatist traverser for spill search class.
MarcosPividori Aug 17, 2016
e4ce9be
Remove specialization of NeighborSearchRules for SpillTrees.
MarcosPividori Aug 17, 2016
2bad753
Avoid B_2 bound for Spill Trees.
MarcosPividori Aug 17, 2016
b55534d
Properly consider ElemType parameter for bounds.
MarcosPividori Aug 17, 2016
2698225
Improve reverse compatibility of NSModel.
MarcosPividori Aug 17, 2016
e4cf1fd
Avoid calculating the bounds when oldScore is the best possible.
MarcosPividori Aug 17, 2016
ecaf1bf
Properly update search model.
MarcosPividori Aug 17, 2016
c5fe485
Add Warning when true_neighbors_file and true_distances_file are prov…
MarcosPividori Aug 17, 2016
5795c74
Remove SpillSearch class. Use NeighborSearch class and typedefs for d…
MarcosPividori Aug 18, 2016
ae84513
Update template name.
MarcosPividori Aug 18, 2016
b1530ea
Update tests to consider SpillKNN.
MarcosPividori Aug 18, 2016
da35f94
Remove outdated information.
MarcosPividori Aug 18, 2016
8a5ec21
Doesn't copy the dataset in the copy constructor when it is not neces…
MarcosPividori Aug 18, 2016
815391b
Fix Copy constructor. Properly set pointer to dataset.
MarcosPividori Aug 18, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/mlpack/methods/neighbor_search/knn_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
"trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees, and Spill "
"trees).", "l", 20);
PARAM_DOUBLE_IN("tau", "Overlapping size (for spill trees).", "u", 0);
PARAM_DOUBLE_IN("rho", "Balance threshold (for spill trees).", "b", 0.7);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To emphasize where these are valid, I might consider writing only valid for spill trees instead of just for spill trees.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Done in: 99de521


PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
Expand Down Expand Up @@ -117,6 +118,9 @@ int main(int argc, char *argv[])
if (CLI::HasParam("tau"))
Log::Warn << "--tau (-u) will be ignored because --input_model_file"
<< " is specified." << endl;
if (CLI::HasParam("rho"))
Log::Warn << "--rho (-b) will be ignored because --input_model_file"
<< " is specified." << endl;
if (CLI::HasParam("random_basis"))
Log::Warn << "--random_basis (-R) will be ignored because "
<< "--input_model_file is specified." << endl;
Expand Down Expand Up @@ -157,6 +161,14 @@ int main(int argc, char *argv[])
if (CLI::HasParam("tau") && "spill" != CLI::GetParam<string>("tree_type"))
Log::Fatal << "Tau parameter is only valid for spill trees." << endl;

// Sanity check on rho.
const double rho = CLI::GetParam<double>("rho");
if (rho < 0 || rho > 1)
Log::Fatal << "Invalid rho: " << rho << ". Must be in the range [0,1]. "
<< endl;
if (CLI::HasParam("rho") && "spill" != CLI::GetParam<string>("tree_type"))
Log::Fatal << "Rho parameter is only valid for spill trees." << endl;

// Sanity check on epsilon.
const double epsilon = CLI::GetParam<double>("epsilon");
if (epsilon < 0)
Expand Down Expand Up @@ -204,6 +216,7 @@ int main(int argc, char *argv[])
knn.RandomBasis() = randomBasis;
knn.LeafSize() = size_t(lsInt);
knn.Tau() = tau;
knn.Rho() = rho;

arma::mat referenceSet;
data::Load(referenceFile, referenceSet, true);
Expand Down Expand Up @@ -231,6 +244,7 @@ int main(int argc, char *argv[])
knn.LeafSize() = size_t(lsInt);
knn.Epsilon() = epsilon;
knn.Tau() = tau;
knn.Rho() = rho;
}

// Perform search, if desired.
Expand Down
18 changes: 15 additions & 3 deletions src/mlpack/methods/neighbor_search/ns_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
const size_t leafSize;
//! Overlapping size (for spill trees).
const double tau;
//! Balance threshold (for spill trees).
const double rho;

//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename NSType>
Expand Down Expand Up @@ -143,7 +145,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
arma::Mat<size_t>& neighbors,
arma::mat& distances,
const size_t leafSize,
const double tau);
const double tau,
const double rho);
};

/**
Expand All @@ -162,6 +165,8 @@ class TrainVisitor : public boost::static_visitor<void>
size_t leafSize;
//! Overlapping size (for spill trees).
const double tau;
//! Balance threshold (for spill trees).
const double rho;

//! Train on the given NSType considering the leafSize.
template<typename NSType>
Expand Down Expand Up @@ -190,10 +195,11 @@ class TrainVisitor : public boost::static_visitor<void>
void operator()(NSSpillType* ns) const;

//! Construct the TrainVisitor object with the given reference set, leafSize
//! for BinarySpaceTrees, and tau for spill trees.
//! for BinarySpaceTrees, and tau and rho for spill trees.
TrainVisitor(arma::mat&& referenceSet,
const size_t leafSize,
const double tau);
const double tau,
const double rho);
};

/**
Expand Down Expand Up @@ -289,6 +295,8 @@ class NSModel

//! Overlapping size (for spill trees).
double tau;
//! Balance threshold (for spill trees).
double rho;

//! If true, random projections are used.
bool randomBasis;
Expand Down Expand Up @@ -348,6 +356,10 @@ class NSModel
double Tau() const { return tau; }
double& Tau() { return tau; }

//! Expose rho.
double Rho() const { return rho; }
double& Rho() { return rho; }

//! Expose treeType.
TreeTypes TreeType() const { return treeType; }
TreeTypes& TreeType() { return treeType; }
Expand Down
24 changes: 15 additions & 9 deletions src/mlpack/methods/neighbor_search/ns_model_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
const size_t leafSize,
const double tau) :
const double tau,
const double rho) :
querySet(querySet),
k(k),
neighbors(neighbors),
distances(distances),
leafSize(leafSize),
tau(tau)
tau(tau),
rho(rho)
{}

//! Default Bichromatic neighbor search on the given NSType instance.
Expand Down Expand Up @@ -84,7 +86,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
// For Dual Tree Search on SpillTrees, the queryTree must be built with
// non overlapping (tau = 0).
typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
leafSize);
leafSize, rho);
ns->Search(&queryTree, k, neighbors, distances);
}
else
Expand Down Expand Up @@ -126,10 +128,12 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
template<typename SortPolicy>
TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
const size_t leafSize,
const double tau) :
const double tau,
const double rho) :
referenceSet(std::move(referenceSet)),
leafSize(leafSize),
tau(tau)
tau(tau),
rho(rho)
{}

//! Default Train on the given NSType instance.
Expand Down Expand Up @@ -173,7 +177,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
else
{
typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
std::move(referenceSet), tau, leafSize);
std::move(referenceSet), tau, leafSize, rho);
ns->Train(tree);
// Give the model ownership of the tree.
ns->neighborSearch.treeOwner = true;
Expand Down Expand Up @@ -257,6 +261,7 @@ NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
treeType(treeType),
leafSize(20),
tau(0),
rho(0.7),
randomBasis(randomBasis)
{
// Nothing to do.
Expand Down Expand Up @@ -307,6 +312,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
ar & data::CreateNVP(treeType, "treeType");
ar & data::CreateNVP(leafSize, "leafSize");
ar & data::CreateNVP(tau, "tau");
ar & data::CreateNVP(rho, "rho");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate to always be talking about reverse compatibility, but this is reverse-incompatible, since tau, rho, and leafSize aren't part of the previous saved models. So the thing to do is to uncomment the version parameter, and use version == 0 to assume that leafSize, tau, and rho aren't available (and set them to defaults), and version == 1 to assume that leafSize, tau, and rho are available.

You'll need to then set the serialization version. For LSHSearch, that's done like this:


//! Set the serialization version of the LSHSearch class.
BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
    mlpack::neighbor::LSHSearch<SortPolicy>, 1);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Done in: 2698225

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I noticed that the previous revision did not serialize leafSize, so this fixes a bug too!

ar & data::CreateNVP(randomBasis, "randomBasis");
ar & data::CreateNVP(q, "q");

Expand Down Expand Up @@ -454,11 +460,11 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
epsilon);
break;
case SPILL_TREE:
nSearch = new NSSpillType(naive, singleMode, tau, epsilon);
nSearch = new NSSpillType(naive, singleMode, tau, leafSize, rho, epsilon);
break;
}

TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau);
TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
boost::apply_visitor(tn, nSearch);

if (!naive)
Expand Down Expand Up @@ -491,7 +497,7 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
<< std::endl;

BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
leafSize, tau);
leafSize, tau, rho);
boost::apply_visitor(search, nSearch);
}

Expand Down
28 changes: 28 additions & 0 deletions src/mlpack/methods/neighbor_search/spill_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,17 @@ class SpillSearch
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
* @param tau Overlapping size (non-negative).
* @param leafSize Max size of each leaf in the tree.
* @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
SpillSearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
const double tau = 0,
const double leafSize = 20,
const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());

Expand All @@ -91,13 +95,17 @@ class SpillSearch
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
* @param tau Overlapping size (non-negative).
* @param leafSize Max size of each leaf in the tree.
* @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
SpillSearch(MatType&& referenceSet,
const bool naive = false,
const bool singleMode = false,
const double tau = 0,
const double leafSize = 20,
const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());

Expand All @@ -113,12 +121,16 @@ class SpillSearch
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
* @param tau Overlapping size (non-negative).
* @param leafSize Max size of each leaf in the tree.
* @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated distance metric.
*/
SpillSearch(Tree* referenceTree,
const bool singleMode = false,
const double tau = 0,
const double leafSize = 20,
const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());

Expand All @@ -131,12 +143,16 @@ class SpillSearch
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
* @param tau Overlapping size (non-negative).
* @param leafSize Max size of each leaf in the tree.
* @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated metric.
*/
SpillSearch(const bool naive = false,
const bool singleMode = false,
const double tau = 0,
const double leafSize = 20,
const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());

Expand Down Expand Up @@ -262,6 +278,12 @@ class SpillSearch
//! Access the overlapping size.
double Tau() const { return tau; }

//! Access the balance threshold.
double Rho() const { return rho; }

//! Access the leaf size.
double LeafSize() const { return leafSize; }

//! Access the reference dataset.
const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }

Expand All @@ -277,6 +299,12 @@ class SpillSearch
//! Overlapping size.
double tau;

//! Balance threshold.
double rho;

//! Max leaf size.
double leafSize;

//! The NSModel class should have access to internal members.
template<typename SortPolicy>
friend class TrainVisitor;
Expand Down
Loading