From 86a9852f19daf58e4ae5a3d3ca745deee43d8d16 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Mon, 13 Jun 2016 10:31:58 -0300 Subject: [PATCH 1/6] Modify NSModel to use boost variant. --- .../neighbor_search/neighbor_search.hpp | 5 +- .../methods/neighbor_search/ns_model.hpp | 153 +++++- .../methods/neighbor_search/ns_model_impl.hpp | 506 ++++++++---------- 3 files changed, 345 insertions(+), 319 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp index e84b74cc9ec..bf62ae7ee65 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp @@ -27,8 +27,7 @@ namespace neighbor /** Neighbor-search routines. These include * searches. */ { // Forward declaration. -template -class NSModel; +class TrainVisitor; /** * The NeighborSearch class is a template class for performing distance-based @@ -308,7 +307,7 @@ class NeighborSearch bool treeNeedsReset; //! The NSModel class should have access to internal members. - friend class NSModel; + friend class TrainVisitor; }; // class NeighborSearch } // namespace neighbor diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 9c16199aabb..458bf97318a 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -13,12 +13,24 @@ #include #include #include - +#include #include "neighbor_search.hpp" namespace mlpack { namespace neighbor { +template class TreeType> +using NSType = NeighborSearch, + arma::mat>::template DualTreeTraverser>; + template struct NSModelName { @@ -37,6 +49,121 @@ struct NSModelName static const std::string Name() { return "furthest_neighbor_search_model"; } }; +class SearchKVisitor : public boost::static_visitor +{ + private: + const size_t k; + arma::Mat& neighbors; + arma::mat& distances; + + public: + template + void operator()(NSType *ns) const; + + SearchKVisitor(const size_t k, + arma::Mat& neighbors, + arma::mat& distances); +}; + +class SearchVisitor : public boost::static_visitor +{ + private: + const arma::mat& querySet; + const size_t k; + arma::Mat& neighbors; + arma::mat& distances; + const size_t leafSize; + + template + void SearchLeaf(NSType *ns) const; + + public: + template class TreeType> + void operator()(NSType *ns) const; + + template + void operator()(NSType *ns) const; + + template + void operator()(NSType *ns) const; + + SearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize); +}; + +class TrainVisitor : public boost::static_visitor +{ + private: + arma::mat&& referenceSet; + size_t leafSize; + + template + void TrainLeaf(NSType* ns) const; + + public: + template class TreeType> + void operator()(NSType *ns) const; + + template + void operator()(NSType *ns) const; + + template + void operator()(NSType *ns) const; + + TrainVisitor(arma::mat&& referenceSet, const size_t leafSize); +}; + +class SingleModeVisitor : public boost::static_visitor +{ + public: + template + bool& operator()(NSType *ns) const; +}; + +class NaiveVisitor : public boost::static_visitor +{ + public: + template + bool& operator()(NSType *ns) const; +}; + +class ReferenceSetVisitor : public boost::static_visitor +{ + public: + template + const arma::mat& operator()(NSType *ns) const; +}; + +class DeleteVisitor : public boost::static_visitor +{ + public: + template + void operator()(NSType *ns) const; +}; + +template +class SerializeVisitor : public boost::static_visitor +{ + private: + Archive& ar; + const std::string& name; + + public: + template + void operator()(NSType *ns) const; + + SerializeVisitor(Archive& ar, const std::string& name); +}; + template class NSModel { @@ -59,24 +186,12 @@ class NSModel bool randomBasis; arma::mat q; - template class TreeType> - using NSType = NeighborSearch, - arma::mat>::template DualTreeTraverser>; - - // Only one of these pointers will be non-NULL. - NSType* kdTreeNS; - NSType* coverTreeNS; - NSType* rTreeNS; - NSType* rStarTreeNS; - NSType* ballTreeNS; - NSType* xTreeNS; + boost::variant*, + NSType*, + NSType*, + NSType*, + NSType*, + NSType*> nSearch; public: /** diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index e4aa7c179fe..ea2206eb961 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -16,6 +16,190 @@ namespace mlpack { namespace neighbor { +SearchKVisitor::SearchKVisitor(const size_t k, + arma::Mat& neighbors, + arma::mat& distances) : + k(k), + neighbors(neighbors), + distances(distances) +{} + +template +void SearchKVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->Search(k, neighbors, distances); + throw std::runtime_error("no neighbor search model initialized"); +} + + +SearchVisitor::SearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize) : + querySet(querySet), + k(k), + neighbors(neighbors), + distances(distances), + leafSize(leafSize) +{} + +template class TreeType> +void SearchVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->Search(querySet, k, neighbors, distances); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void SearchVisitor::operator()(NSType *ns) const +{ + if (ns) + return SearchLeaf(ns); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void SearchVisitor::operator()(NSType *ns) const +{ + if (ns) + return SearchLeaf(ns); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void SearchVisitor::SearchLeaf(NSType *ns) const +{ + if (!ns->Naive() && !ns->SingleMode()) + { + std::vector oldFromNewQueries; + typename NSType::Tree queryTree(std::move(querySet), oldFromNewQueries, + leafSize); + + arma::Mat neighborsOut; + arma::mat distancesOut; + ns->Search(&queryTree, k, neighborsOut, distancesOut); + + // Unmap the query points. + distances.set_size(distancesOut.n_rows, distancesOut.n_cols); + neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols); + for (size_t i = 0; i < neighborsOut.n_cols; ++i) + { + neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i); + distances.col(oldFromNewQueries[i]) = distancesOut.col(i); + } + } + else + ns->Search(querySet, k, neighbors, distances); +} + + +TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) : + referenceSet(std::move(referenceSet)), + leafSize(leafSize) +{} + +template class TreeType> +void TrainVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->Train(std::move(referenceSet)); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void TrainVisitor::operator ()(NSType *ns) const +{ + if (ns) + return TrainLeaf(ns); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void TrainVisitor::operator ()(NSType *ns) const +{ + if (ns) + return TrainLeaf(ns); + throw std::runtime_error("no neighbor search model initialized"); +} + +template +void TrainVisitor::TrainLeaf(NSType* ns) const +{ + if (ns->Naive()) + ns->Train(std::move(referenceSet)); + else + { + std::vector oldFromNewReferences; + typename NSType::Tree* tree = + new typename NSType::Tree(std::move(referenceSet), + oldFromNewReferences, leafSize); + ns->Train(tree); + + // Give the model ownership of the tree and the mappings. + ns->treeOwner = true; + ns->oldFromNewReferences = std::move(oldFromNewReferences); + } +} + + +template +bool& SingleModeVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->SingleMode(); + throw std::runtime_error("no neighbor search model initialized"); +} + + +template +bool& NaiveVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->Naive(); + throw std::runtime_error("no neighbor search model initialized"); +} + + +template +const arma::mat& ReferenceSetVisitor::operator()(NSType *ns) const +{ + if (ns) + return ns->ReferenceSet(); + throw std::runtime_error("no neighbor search model initialized"); +} + + +template +void DeleteVisitor::operator()(NSType *ns) const +{ + if (ns) + delete ns; +} + + +template +SerializeVisitor::SerializeVisitor(Archive& ar, + const std::string& name) : + ar(ar), + name(name) +{} + +template +template +void SerializeVisitor::operator()(NSType *ns) const +{ + ar & data::CreateNVP(ns, name); +} + /** * Initialize the NSModel with the given type and whether or not a random * basis should be used. @@ -23,13 +207,7 @@ namespace neighbor { template NSModel::NSModel(TreeTypes treeType, bool randomBasis) : treeType(treeType), - randomBasis(randomBasis), - kdTreeNS(NULL), - coverTreeNS(NULL), - rTreeNS(NULL), - rStarTreeNS(NULL), - ballTreeNS(NULL), - xTreeNS(NULL) + randomBasis(randomBasis) { // Nothing to do. } @@ -38,18 +216,7 @@ NSModel::NSModel(TreeTypes treeType, bool randomBasis) : template NSModel::~NSModel() { - if (kdTreeNS) - delete kdTreeNS; - if (coverTreeNS) - delete coverTreeNS; - if (rTreeNS) - delete rTreeNS; - if (rStarTreeNS) - delete rStarTreeNS; - if (ballTreeNS) - delete ballTreeNS; - if (xTreeNS) - delete xTreeNS; + boost::apply_visitor(DeleteVisitor(), nSearch); } //! Serialize the kNN model. @@ -64,148 +231,43 @@ void NSModel::Serialize(Archive& ar, // This should never happen, but just in case, be clean with memory. if (Archive::is_loading::value) - { - if (kdTreeNS) - delete kdTreeNS; - if (coverTreeNS) - delete coverTreeNS; - if (rTreeNS) - delete rTreeNS; - if (rStarTreeNS) - delete rStarTreeNS; - if (ballTreeNS) - delete ballTreeNS; - if (xTreeNS) - delete xTreeNS; - - // Set all the pointers to NULL. - kdTreeNS = NULL; - coverTreeNS = NULL; - rTreeNS = NULL; - rStarTreeNS = NULL; - ballTreeNS = NULL; - xTreeNS = NULL; - } + boost::apply_visitor(DeleteVisitor(), nSearch); // We'll only need to serialize one of the kNN objects, based on the type. const std::string& name = NSModelName::Name(); - switch (treeType) - { - case KD_TREE: - ar & data::CreateNVP(kdTreeNS, name); - break; - case COVER_TREE: - ar & data::CreateNVP(coverTreeNS, name); - break; - case R_TREE: - ar & data::CreateNVP(rTreeNS, name); - break; - case R_STAR_TREE: - ar & data::CreateNVP(rStarTreeNS, name); - break; - case BALL_TREE: - ar & data::CreateNVP(ballTreeNS, name); - break; - case X_TREE: - ar & data::CreateNVP(xTreeNS, name); - break; - } + SerializeVisitor s(ar, name); + boost::apply_visitor(s, nSearch); } template const arma::mat& NSModel::Dataset() const { - if (kdTreeNS) - return kdTreeNS->ReferenceSet(); - else if (coverTreeNS) - return coverTreeNS->ReferenceSet(); - else if (rTreeNS) - return rTreeNS->ReferenceSet(); - else if (rStarTreeNS) - return rStarTreeNS->ReferenceSet(); - else if (ballTreeNS) - return ballTreeNS->ReferenceSet(); - else if (xTreeNS) - return xTreeNS->ReferenceSet(); - - throw std::runtime_error("no neighbor search model initialized"); + return boost::apply_visitor(ReferenceSetVisitor(), nSearch); } //! Expose singleMode. template bool NSModel::SingleMode() const { - if (kdTreeNS) - return kdTreeNS->SingleMode(); - else if (coverTreeNS) - return coverTreeNS->SingleMode(); - else if (rTreeNS) - return rTreeNS->SingleMode(); - else if (rStarTreeNS) - return rStarTreeNS->SingleMode(); - else if (ballTreeNS) - return ballTreeNS->SingleMode(); - else if (xTreeNS) - return xTreeNS->SingleMode(); - - throw std::runtime_error("no neighbor search model initialized"); + return boost::apply_visitor(SingleModeVisitor(), nSearch); } template bool& NSModel::SingleMode() { - if (kdTreeNS) - return kdTreeNS->SingleMode(); - else if (coverTreeNS) - return coverTreeNS->SingleMode(); - else if (rTreeNS) - return rTreeNS->SingleMode(); - else if (rStarTreeNS) - return rStarTreeNS->SingleMode(); - else if (ballTreeNS) - return ballTreeNS->SingleMode(); - else if (xTreeNS) - return xTreeNS->SingleMode(); - - throw std::runtime_error("no neighbor search model initialized"); + return boost::apply_visitor(SingleModeVisitor(), nSearch); } template bool NSModel::Naive() const { - if (kdTreeNS) - return kdTreeNS->Naive(); - else if (coverTreeNS) - return coverTreeNS->Naive(); - else if (rTreeNS) - return rTreeNS->Naive(); - else if (rStarTreeNS) - return rStarTreeNS->Naive(); - else if (ballTreeNS) - return ballTreeNS->Naive(); - else if (xTreeNS) - return xTreeNS->Naive(); - - throw std::runtime_error("no neighbor search model initialized"); + return boost::apply_visitor(NaiveVisitor(), nSearch); } template bool& NSModel::Naive() { - if (kdTreeNS) - return kdTreeNS->Naive(); - else if (coverTreeNS) - return coverTreeNS->Naive(); - else if (rTreeNS) - return rTreeNS->Naive(); - else if (rStarTreeNS) - return rStarTreeNS->Naive(); - else if (ballTreeNS) - return ballTreeNS->Naive(); - else if (xTreeNS) - return xTreeNS->Naive(); - - throw std::runtime_error("no neighbor search model initialized"); + return boost::apply_visitor(NaiveVisitor(), nSearch); } //! Build the reference tree. @@ -248,18 +310,7 @@ void NSModel::BuildModel(arma::mat&& referenceSet, } // Clean memory, if necessary. - if (kdTreeNS) - delete kdTreeNS; - if (coverTreeNS) - delete coverTreeNS; - if (rTreeNS) - delete rTreeNS; - if (rStarTreeNS) - delete rStarTreeNS; - if (ballTreeNS) - delete ballTreeNS; - if (xTreeNS) - delete xTreeNS; + boost::apply_visitor(DeleteVisitor(), nSearch); // Do we need to modify the reference set? if (randomBasis) @@ -274,69 +325,29 @@ void NSModel::BuildModel(arma::mat&& referenceSet, switch (treeType) { case KD_TREE: - // If necessary, build the kd-tree. - if (naive) - { - kdTreeNS = new NSType(std::move(referenceSet), naive, - singleMode); - } - else - { - std::vector oldFromNewReferences; - typename NSType::Tree* kdTree = - new typename NSType::Tree(std::move(referenceSet), - oldFromNewReferences, leafSize); - kdTreeNS = new NSType(kdTree, singleMode); - - // Give the model ownership of the tree and the mappings. - kdTreeNS->treeOwner = true; - kdTreeNS->oldFromNewReferences = std::move(oldFromNewReferences); - } - + nSearch = new NSType(naive, singleMode); break; case COVER_TREE: - // If necessary, build the cover tree. - coverTreeNS = new NSType(std::move(referenceSet), - naive, singleMode); + nSearch = new NSType(naive, + singleMode); break; case R_TREE: - // If necessary, build the R tree. - rTreeNS = new NSType(std::move(referenceSet), naive, - singleMode); + nSearch = new NSType(naive, singleMode); break; case R_STAR_TREE: - // If necessary, build the R* tree. - rStarTreeNS = new NSType(std::move(referenceSet), naive, - singleMode); + nSearch = new NSType(naive, singleMode); break; case BALL_TREE: - // If necessary, build the ball tree. - if (naive) - { - ballTreeNS = new NSType(std::move(referenceSet), naive, - singleMode); - } - else - { - std::vector oldFromNewReferences; - typename NSType::Tree* ballTree = - new typename NSType::Tree(std::move(referenceSet), - oldFromNewReferences, leafSize); - ballTreeNS = new NSType(ballTree, singleMode); - - // Give the model ownership of the tree and the mappings. - ballTreeNS->treeOwner = true; - ballTreeNS->oldFromNewReferences = std::move(oldFromNewReferences); - } - + nSearch = new NSType(naive, singleMode); break; case X_TREE: - // If necessary, build the X tree. - xTreeNS = new NSType(std::move(referenceSet), naive, - singleMode); + nSearch = new NSType(naive, singleMode); break; } + TrainVisitor tn(std::move(referenceSet),leafSize); + boost::apply_visitor(tn, nSearch); + if (!naive) { Timer::Stop("tree_building"); @@ -363,88 +374,8 @@ void NSModel::Search(arma::mat&& querySet, else Log::Info << "brute-force (naive) search..." << std::endl; - switch (treeType) - { - case KD_TREE: - if (!kdTreeNS->Naive() && !kdTreeNS->SingleMode()) - { - // Build a second tree and search. - Timer::Start("tree_building"); - Log::Info << "Building query tree..." << std::endl; - std::vector oldFromNewQueries; - typename NSType::Tree queryTree(std::move(querySet), - oldFromNewQueries, leafSize); - Log::Info << "Tree built." << std::endl; - Timer::Stop("tree_building"); - - arma::Mat neighborsOut; - arma::mat distancesOut; - kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut); - - // Unmap the query points. - distances.set_size(distancesOut.n_rows, distancesOut.n_cols); - neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols); - for (size_t i = 0; i < neighborsOut.n_cols; ++i) - { - neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i); - distances.col(oldFromNewQueries[i]) = distancesOut.col(i); - } - } - else - { - // Search without building a second tree. - kdTreeNS->Search(querySet, k, neighbors, distances); - } - break; - case COVER_TREE: - // No mapping necessary. - coverTreeNS->Search(querySet, k, neighbors, distances); - break; - case R_TREE: - // No mapping necessary. - rTreeNS->Search(querySet, k, neighbors, distances); - break; - case R_STAR_TREE: - // No mapping necessary. - rStarTreeNS->Search(querySet, k, neighbors, distances); - break; - case BALL_TREE: - if (!ballTreeNS->Naive() && !ballTreeNS->SingleMode()) - { - // Build a second tree and search. - Timer::Start("tree_building"); - Log::Info << "Building query tree..." << std::endl; - std::vector oldFromNewQueries; - typename NSType::Tree queryTree(std::move(querySet), - oldFromNewQueries, leafSize); - Log::Info << "Tree built." << std::endl; - Timer::Stop("tree_building"); - - arma::Mat neighborsOut; - arma::mat distancesOut; - ballTreeNS->Search(&queryTree, k, neighborsOut, distancesOut); - - // Unmap the query points. - distances.set_size(distancesOut.n_rows, distancesOut.n_cols); - neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols); - for (size_t i = 0; i < neighborsOut.n_cols; ++i) - { - neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i); - distances.col(oldFromNewQueries[i]) = distancesOut.col(i); - } - } - else - { - // Search without building a second tree. - ballTreeNS->Search(querySet, k, neighbors, distances); - } - - break; - case X_TREE: - // No mapping necessary. - xTreeNS->Search(querySet, k, neighbors, distances); - break; - } + SearchVisitor search(querySet, k, neighbors, distances, leafSize); + boost::apply_visitor(search, nSearch); } //! Perform neighbor search. @@ -461,27 +392,8 @@ void NSModel::Search(const size_t k, else Log::Info << "brute-force (naive) search..." << std::endl; - switch (treeType) - { - case KD_TREE: - kdTreeNS->Search(k, neighbors, distances); - break; - case COVER_TREE: - coverTreeNS->Search(k, neighbors, distances); - break; - case R_TREE: - rTreeNS->Search(k, neighbors, distances); - break; - case R_STAR_TREE: - rStarTreeNS->Search(k, neighbors, distances); - break; - case BALL_TREE: - ballTreeNS->Search(k, neighbors, distances); - break; - case X_TREE: - xTreeNS->Search(k, neighbors, distances); - break; - } + SearchKVisitor search(k, neighbors, distances); + boost::apply_visitor(search, nSearch); } //! Get the name of the tree type. From a0bcee91b243d7f0518032ddd51de34b3b92cea9 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Mon, 13 Jun 2016 17:55:29 -0300 Subject: [PATCH 2/6] Improve name of Search visitors. --- .../methods/neighbor_search/ns_model.hpp | 20 ++++++------- .../methods/neighbor_search/ns_model_impl.hpp | 30 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 458bf97318a..f269e1330ad 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -49,7 +49,7 @@ struct NSModelName static const std::string Name() { return "furthest_neighbor_search_model"; } }; -class SearchKVisitor : public boost::static_visitor +class MonoSearchVisitor : public boost::static_visitor { private: const size_t k; @@ -60,12 +60,12 @@ class SearchKVisitor : public boost::static_visitor template void operator()(NSType *ns) const; - SearchKVisitor(const size_t k, - arma::Mat& neighbors, - arma::mat& distances); + MonoSearchVisitor(const size_t k, + arma::Mat& neighbors, + arma::mat& distances); }; -class SearchVisitor : public boost::static_visitor +class BiSearchVisitor : public boost::static_visitor { private: const arma::mat& querySet; @@ -90,11 +90,11 @@ class SearchVisitor : public boost::static_visitor template void operator()(NSType *ns) const; - SearchVisitor(const arma::mat& querySet, - const size_t k, - arma::Mat& neighbors, - arma::mat& distances, - const size_t leafSize); + BiSearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize); }; class TrainVisitor : public boost::static_visitor diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index ea2206eb961..ef67d18da84 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -16,16 +16,16 @@ namespace mlpack { namespace neighbor { -SearchKVisitor::SearchKVisitor(const size_t k, - arma::Mat& neighbors, - arma::mat& distances) : +MonoSearchVisitor::MonoSearchVisitor(const size_t k, + arma::Mat& neighbors, + arma::mat& distances) : k(k), neighbors(neighbors), distances(distances) {} template -void SearchKVisitor::operator()(NSType *ns) const +void MonoSearchVisitor::operator()(NSType *ns) const { if (ns) return ns->Search(k, neighbors, distances); @@ -33,11 +33,11 @@ void SearchKVisitor::operator()(NSType *ns) const } -SearchVisitor::SearchVisitor(const arma::mat& querySet, - const size_t k, - arma::Mat& neighbors, - arma::mat& distances, - const size_t leafSize) : +BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize) : querySet(querySet), k(k), neighbors(neighbors), @@ -49,7 +49,7 @@ template class TreeType> -void SearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSType *ns) const { if (ns) return ns->Search(querySet, k, neighbors, distances); @@ -57,7 +57,7 @@ void SearchVisitor::operator()(NSType *ns) const } template -void SearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSType *ns) const { if (ns) return SearchLeaf(ns); @@ -65,7 +65,7 @@ void SearchVisitor::operator()(NSType *ns) const } template -void SearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSType *ns) const { if (ns) return SearchLeaf(ns); @@ -73,7 +73,7 @@ void SearchVisitor::operator()(NSType *ns) const } template -void SearchVisitor::SearchLeaf(NSType *ns) const +void BiSearchVisitor::SearchLeaf(NSType *ns) const { if (!ns->Naive() && !ns->SingleMode()) { @@ -374,7 +374,7 @@ void NSModel::Search(arma::mat&& querySet, else Log::Info << "brute-force (naive) search..." << std::endl; - SearchVisitor search(querySet, k, neighbors, distances, leafSize); + BiSearchVisitor search(querySet, k, neighbors, distances, leafSize); boost::apply_visitor(search, nSearch); } @@ -392,7 +392,7 @@ void NSModel::Search(const size_t k, else Log::Info << "brute-force (naive) search..." << std::endl; - SearchKVisitor search(k, neighbors, distances); + MonoSearchVisitor search(k, neighbors, distances); boost::apply_visitor(search, nSearch); } From ddb252ce03520b699e8335f5261ecd08b32d5a60 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Wed, 15 Jun 2016 11:44:16 -0300 Subject: [PATCH 3/6] Set SortPolicy as template parameter of classes BiSearchVisitor and TrainVisitor. Also, add a more specific definition of NSTypeT, to avoid vc compiler errors. --- .../neighbor_search/neighbor_search.hpp | 3 +- .../methods/neighbor_search/ns_model.hpp | 124 +++++++++--------- .../methods/neighbor_search/ns_model_impl.hpp | 47 ++++--- 3 files changed, 93 insertions(+), 81 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp index bf62ae7ee65..999f261c8f0 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp @@ -27,6 +27,7 @@ namespace neighbor /** Neighbor-search routines. These include * searches. */ { // Forward declaration. +template class TrainVisitor; /** @@ -307,7 +308,7 @@ class NeighborSearch bool treeNeedsReset; //! The NSModel class should have access to internal members. - friend class TrainVisitor; + friend class TrainVisitor; }; // class NeighborSearch } // namespace neighbor diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index f269e1330ad..613ef45efdc 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -52,116 +52,122 @@ struct NSModelName class MonoSearchVisitor : public boost::static_visitor { private: - const size_t k; - arma::Mat& neighbors; - arma::mat& distances; + const size_t k; + arma::Mat& neighbors; + arma::mat& distances; public: - template - void operator()(NSType *ns) const; + template + void operator()(NSType *ns) const; - MonoSearchVisitor(const size_t k, - arma::Mat& neighbors, - arma::mat& distances); + MonoSearchVisitor(const size_t k, + arma::Mat& neighbors, + arma::mat& distances); }; +template class BiSearchVisitor : public boost::static_visitor { private: - const arma::mat& querySet; - const size_t k; - arma::Mat& neighbors; - arma::mat& distances; - const size_t leafSize; + const arma::mat& querySet; + const size_t k; + arma::Mat& neighbors; + arma::mat& distances; + const size_t leafSize; - template - void SearchLeaf(NSType *ns) const; + template + void SearchLeaf(NSType* ns) const; public: - template class TreeType> - void operator()(NSType *ns) const; - - template - void operator()(NSType *ns) const; - - template - void operator()(NSType *ns) const; - - BiSearchVisitor(const arma::mat& querySet, - const size_t k, - arma::Mat& neighbors, - arma::mat& distances, - const size_t leafSize); + template class TreeType> + using NSTypeT = NSType; + + template class TreeType> + void operator()(NSTypeT* ns) const; + + void operator()(NSTypeT* ns) const; + + void operator()(NSTypeT* ns) const; + + BiSearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize); }; +template class TrainVisitor : public boost::static_visitor { private: - arma::mat&& referenceSet; - size_t leafSize; + arma::mat&& referenceSet; + size_t leafSize; - template - void TrainLeaf(NSType* ns) const; + template + void TrainLeaf(NSType* ns) const; public: - template class TreeType> - void operator()(NSType *ns) const; + template class TreeType> + using NSTypeT = NSType; + + template class TreeType> + void operator()(NSTypeT* ns) const; - template - void operator()(NSType *ns) const; + void operator()(NSTypeT* ns) const; - template - void operator()(NSType *ns) const; + void operator()(NSTypeT* ns) const; - TrainVisitor(arma::mat&& referenceSet, const size_t leafSize); + TrainVisitor(arma::mat&& referenceSet, const size_t leafSize); }; class SingleModeVisitor : public boost::static_visitor { public: - template - bool& operator()(NSType *ns) const; + template + bool& operator()(NSType *ns) const; }; class NaiveVisitor : public boost::static_visitor { public: - template - bool& operator()(NSType *ns) const; + template + bool& operator()(NSType *ns) const; }; class ReferenceSetVisitor : public boost::static_visitor { public: - template - const arma::mat& operator()(NSType *ns) const; + template + const arma::mat& operator()(NSType *ns) const; }; class DeleteVisitor : public boost::static_visitor { public: - template - void operator()(NSType *ns) const; + template + void operator()(NSType *ns) const; }; template class SerializeVisitor : public boost::static_visitor { private: - Archive& ar; - const std::string& name; + Archive& ar; + const std::string& name; public: - template - void operator()(NSType *ns) const; + template + void operator()(NSType *ns) const; - SerializeVisitor(Archive& ar, const std::string& name); + SerializeVisitor(Archive& ar, const std::string& name); }; template diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index ef67d18da84..bcded0ec2ed 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -32,12 +32,12 @@ void MonoSearchVisitor::operator()(NSType *ns) const throw std::runtime_error("no neighbor search model initialized"); } - -BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, - const size_t k, - arma::Mat& neighbors, - arma::mat& distances, - const size_t leafSize) : +template +BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances, + const size_t leafSize) : querySet(querySet), k(k), neighbors(neighbors), @@ -45,11 +45,11 @@ BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, leafSize(leafSize) {} -template +template class TreeType> -void BiSearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSTypeT* ns) const { if (ns) return ns->Search(querySet, k, neighbors, distances); @@ -57,7 +57,7 @@ void BiSearchVisitor::operator()(NSType *ns) const } template -void BiSearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSTypeT* ns) const { if (ns) return SearchLeaf(ns); @@ -65,15 +65,16 @@ void BiSearchVisitor::operator()(NSType *ns) const } template -void BiSearchVisitor::operator()(NSType *ns) const +void BiSearchVisitor::operator()(NSTypeT* ns) const { if (ns) return SearchLeaf(ns); throw std::runtime_error("no neighbor search model initialized"); } +template template -void BiSearchVisitor::SearchLeaf(NSType *ns) const +void BiSearchVisitor::SearchLeaf(NSType *ns) const { if (!ns->Naive() && !ns->SingleMode()) { @@ -99,16 +100,18 @@ void BiSearchVisitor::SearchLeaf(NSType *ns) const } -TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) : +template +TrainVisitor::TrainVisitor(arma::mat&& referenceSet, + const size_t leafSize) : referenceSet(std::move(referenceSet)), leafSize(leafSize) {} -template +template class TreeType> -void TrainVisitor::operator()(NSType *ns) const +void TrainVisitor::operator()(NSTypeT* ns) const { if (ns) return ns->Train(std::move(referenceSet)); @@ -116,7 +119,7 @@ void TrainVisitor::operator()(NSType *ns) const } template -void TrainVisitor::operator ()(NSType *ns) const +void TrainVisitor::operator ()(NSTypeT* ns) const { if (ns) return TrainLeaf(ns); @@ -124,15 +127,16 @@ void TrainVisitor::operator ()(NSType *ns) const } template -void TrainVisitor::operator ()(NSType *ns) const +void TrainVisitor::operator ()(NSTypeT* ns) const { if (ns) return TrainLeaf(ns); throw std::runtime_error("no neighbor search model initialized"); } +template template -void TrainVisitor::TrainLeaf(NSType* ns) const +void TrainVisitor::TrainLeaf(NSType* ns) const { if (ns->Naive()) ns->Train(std::move(referenceSet)); @@ -345,7 +349,7 @@ void NSModel::BuildModel(arma::mat&& referenceSet, break; } - TrainVisitor tn(std::move(referenceSet),leafSize); + TrainVisitor tn(std::move(referenceSet),leafSize); boost::apply_visitor(tn, nSearch); if (!naive) @@ -374,7 +378,8 @@ void NSModel::Search(arma::mat&& querySet, else Log::Info << "brute-force (naive) search..." << std::endl; - BiSearchVisitor search(querySet, k, neighbors, distances, leafSize); + BiSearchVisitor search(querySet, k, neighbors, distances, + leafSize); boost::apply_visitor(search, nSearch); } From 83f1925bb246471323c893b7ead2f7ebf39124b1 Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Wed, 15 Jun 2016 23:13:35 -0300 Subject: [PATCH 4/6] Details to improve code and documentation. --- src/mlpack/methods/neighbor_search/ns_model.hpp | 8 ++++++-- .../methods/neighbor_search/ns_model_impl.hpp | 14 +++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 613ef45efdc..f0db0d115fe 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -58,7 +58,7 @@ class MonoSearchVisitor : public boost::static_visitor public: template - void operator()(NSType *ns) const; + void operator()(NSType* ns) const; MonoSearchVisitor(const size_t k, arma::Mat& neighbors, @@ -132,7 +132,7 @@ class SingleModeVisitor : public boost::static_visitor { public: template - bool& operator()(NSType *ns) const; + bool& operator()(NSType* ns) const; }; class NaiveVisitor : public boost::static_visitor @@ -220,15 +220,19 @@ class NSModel bool SingleMode() const; bool& SingleMode(); + //! Expose naiveMode. bool Naive() const; bool& Naive(); + //! Expose leafSize. size_t LeafSize() const { return leafSize; } size_t& LeafSize() { return leafSize; } + //! Expose treeType. TreeTypes TreeType() const { return treeType; } TreeTypes& TreeType() { return treeType; } + //! Expose randomBasis. bool RandomBasis() const { return randomBasis; } bool& RandomBasis() { return randomBasis; } diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index bcded0ec2ed..8c7400e0ada 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -74,7 +74,7 @@ void BiSearchVisitor::operator()(NSTypeT* ns) const template template -void BiSearchVisitor::SearchLeaf(NSType *ns) const +void BiSearchVisitor::SearchLeaf(NSType* ns) const { if (!ns->Naive() && !ns->SingleMode()) { @@ -156,7 +156,7 @@ void TrainVisitor::TrainLeaf(NSType* ns) const template -bool& SingleModeVisitor::operator()(NSType *ns) const +bool& SingleModeVisitor::operator()(NSType* ns) const { if (ns) return ns->SingleMode(); @@ -165,7 +165,7 @@ bool& SingleModeVisitor::operator()(NSType *ns) const template -bool& NaiveVisitor::operator()(NSType *ns) const +bool& NaiveVisitor::operator()(NSType* ns) const { if (ns) return ns->Naive(); @@ -174,7 +174,7 @@ bool& NaiveVisitor::operator()(NSType *ns) const template -const arma::mat& ReferenceSetVisitor::operator()(NSType *ns) const +const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const { if (ns) return ns->ReferenceSet(); @@ -183,7 +183,7 @@ const arma::mat& ReferenceSetVisitor::operator()(NSType *ns) const template -void DeleteVisitor::operator()(NSType *ns) const +void DeleteVisitor::operator()(NSType* ns) const { if (ns) delete ns; @@ -199,7 +199,7 @@ SerializeVisitor::SerializeVisitor(Archive& ar, template template -void SerializeVisitor::operator()(NSType *ns) const +void SerializeVisitor::operator()(NSType* ns) const { ar & data::CreateNVP(ns, name); } @@ -349,7 +349,7 @@ void NSModel::BuildModel(arma::mat&& referenceSet, break; } - TrainVisitor tn(std::move(referenceSet),leafSize); + TrainVisitor tn(std::move(referenceSet), leafSize); boost::apply_visitor(tn, nSearch); if (!naive) From 6c2c3ca83d8d955b6b6f8634d9f88df06f6e3cca Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Thu, 16 Jun 2016 23:55:08 -0300 Subject: [PATCH 5/6] Improve documentation. --- .../methods/neighbor_search/ns_model.hpp | 64 ++++++++++++++++++- .../methods/neighbor_search/ns_model_impl.hpp | 26 ++++++-- 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index f0db0d115fe..0203e71ccba 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -19,6 +19,9 @@ namespace mlpack { namespace neighbor { +/** + * Alias template for euclidean neighbor search. + */ template static const std::string Name() { return "furthest_neighbor_search_model"; } }; +/** + * MonoSearchVisitor executes a monochromatic neighbor search on the given + * NSType. We don't make any difference for different instantiations of NSType. + */ class MonoSearchVisitor : public boost::static_visitor { private: @@ -65,6 +72,12 @@ class MonoSearchVisitor : public boost::static_visitor arma::mat& distances); }; +/** + * BiSearchVisitor executes a bichromatic neighbor search on the given NSType. + * We use template specialization to differenciate those tree types that + * accept leafSize as a parameter. In these cases, before doing neighbor search, + * a query tree with proper leafSize is built from the querySet. + */ template class BiSearchVisitor : public boost::static_visitor { @@ -75,22 +88,27 @@ class BiSearchVisitor : public boost::static_visitor arma::mat& distances; const size_t leafSize; + //! Bichromatic neighbor search on the given NSType considering the leafSize. template void SearchLeaf(NSType* ns) const; public: + //! Alias template necessary for visual c++ compiler. template class TreeType> using NSTypeT = NSType; + //! Default Bichromatic neighbor search on the given NSType instance. template class TreeType> void operator()(NSTypeT* ns) const; + //! Bichromatic neighbor search on the given NSType specialized for KDTrees. void operator()(NSTypeT* ns) const; + //! Bichromatic neighbor search on the given NSType specialized for BallTrees. void operator()(NSTypeT* ns) const; BiSearchVisitor(const arma::mat& querySet, @@ -100,6 +118,12 @@ class BiSearchVisitor : public boost::static_visitor const size_t leafSize); }; +/** + * TrainVisitor sets the reference set to a new reference set on the given + * NSType. We use template specialization to differenciate those tree types that + * accept leafSize as a parameter. In these cases, a reference tree with proper + * leafSize is built from the referenceSet. + */ template class TrainVisitor : public boost::static_visitor { @@ -107,27 +131,35 @@ class TrainVisitor : public boost::static_visitor arma::mat&& referenceSet; size_t leafSize; + //! Train on the given NSType considering the leafSize. template void TrainLeaf(NSType* ns) const; public: + //! Alias template necessary for visual c++ compiler. template class TreeType> using NSTypeT = NSType; + //! Default Train on the given NSType instance. template class TreeType> void operator()(NSTypeT* ns) const; + //! Train on the given NSType specialized for KDTrees. void operator()(NSTypeT* ns) const; + //! Train on the given NSType specialized for BallTrees. void operator()(NSTypeT* ns) const; TrainVisitor(arma::mat&& referenceSet, const size_t leafSize); }; +/** + * SingleModeVisitor exposes the SingleMode method of the given NSType. + */ class SingleModeVisitor : public boost::static_visitor { public: @@ -135,6 +167,9 @@ class SingleModeVisitor : public boost::static_visitor bool& operator()(NSType* ns) const; }; +/** + * NaiveVisitor exposes the Naive method of the given NSType. + */ class NaiveVisitor : public boost::static_visitor { public: @@ -142,6 +177,9 @@ class NaiveVisitor : public boost::static_visitor bool& operator()(NSType *ns) const; }; +/** + * ReferenceSetVisitor exposes the referenceSet of the given NSType. + */ class ReferenceSetVisitor : public boost::static_visitor { public: @@ -149,6 +187,9 @@ class ReferenceSetVisitor : public boost::static_visitor const arma::mat& operator()(NSType *ns) const; }; +/** + * DeleteVisitor deletes the given NSType instance. + */ class DeleteVisitor : public boost::static_visitor { public: @@ -156,6 +197,9 @@ class DeleteVisitor : public boost::static_visitor void operator()(NSType *ns) const; }; +/** + * SerializeVisitor serializes the given NSType instance. + */ template class SerializeVisitor : public boost::static_visitor { @@ -170,10 +214,17 @@ class SerializeVisitor : public boost::static_visitor SerializeVisitor(Archive& ar, const std::string& name); }; +/** + * The NSModel class provides an easy way to serialize a model, abstracts away + * the different types of trees, and also reflects the NeighborSearch API. + * + * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort. + */ template class NSModel { public: + //! Enum type to identify each accepted tree type. enum TreeTypes { KD_TREE, @@ -185,13 +236,21 @@ class NSModel }; private: + //! Tree type considered for neighbor search. TreeTypes treeType; + + //! For tree types that accept the maxLeafSize parameter. size_t leafSize; - // For random projections. + //! For random projections. bool randomBasis; arma::mat q; + /** + * nSearch holds an instance of the NeigborSearch class for the current + * treeType. It is initialized every time BuildModel is executed. + * We access to the contained value through the visitor classes defined above. + */ boost::variant*, NSType*, NSType*, @@ -248,11 +307,12 @@ class NSModel arma::Mat& neighbors, arma::mat& distances); - //! Perform neighbor search. + //! Perform monochromatic neighbor search. void Search(const size_t k, arma::Mat& neighbors, arma::mat& distances); + //! Return a string representation of the current tree type. std::string TreeName() const; }; diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index 8c7400e0ada..1001ff4b204 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -16,6 +16,7 @@ namespace mlpack { namespace neighbor { +//! Save parameters for monochromatic neighbor search. MonoSearchVisitor::MonoSearchVisitor(const size_t k, arma::Mat& neighbors, arma::mat& distances) : @@ -24,6 +25,7 @@ MonoSearchVisitor::MonoSearchVisitor(const size_t k, distances(distances) {} +//! Monochromatic neighbor search on the given NSType instance. template void MonoSearchVisitor::operator()(NSType *ns) const { @@ -32,6 +34,7 @@ void MonoSearchVisitor::operator()(NSType *ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Save parameters for bichromatic neighbor search. template BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, const size_t k, @@ -45,6 +48,7 @@ BiSearchVisitor::BiSearchVisitor(const arma::mat& querySet, leafSize(leafSize) {} +//! Default Bichromatic neighbor search on the given NSType instance. template template::operator()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Bichromatic neighbor search on the given NSType specialized for KDTrees. template void BiSearchVisitor::operator()(NSTypeT* ns) const { @@ -64,6 +69,7 @@ void BiSearchVisitor::operator()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Bichromatic neighbor search on the given NSType specialized for BallTrees. template void BiSearchVisitor::operator()(NSTypeT* ns) const { @@ -72,6 +78,7 @@ void BiSearchVisitor::operator()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Bichromatic neighbor search on the given NSType considering the leafSize. template template void BiSearchVisitor::SearchLeaf(NSType* ns) const @@ -99,7 +106,7 @@ void BiSearchVisitor::SearchLeaf(NSType* ns) const ns->Search(querySet, k, neighbors, distances); } - +//! Save parameters for Train. template TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) : @@ -107,6 +114,7 @@ TrainVisitor::TrainVisitor(arma::mat&& referenceSet, leafSize(leafSize) {} +//! Default Train on the given NSType instance. template template::operator()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Train on the given NSType specialized for KDTrees. template void TrainVisitor::operator ()(NSTypeT* ns) const { @@ -126,6 +135,7 @@ void TrainVisitor::operator ()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Train on the given NSType specialized for BallTrees. template void TrainVisitor::operator ()(NSTypeT* ns) const { @@ -134,6 +144,7 @@ void TrainVisitor::operator ()(NSTypeT* ns) const throw std::runtime_error("no neighbor search model initialized"); } +//! Train on the given NSType considering the leafSize. template template void TrainVisitor::TrainLeaf(NSType* ns) const @@ -154,7 +165,7 @@ void TrainVisitor::TrainLeaf(NSType* ns) const } } - +//! Expose the SingleMode method of the given NSType. template bool& SingleModeVisitor::operator()(NSType* ns) const { @@ -163,7 +174,7 @@ bool& SingleModeVisitor::operator()(NSType* ns) const throw std::runtime_error("no neighbor search model initialized"); } - +//! Expose the Naive method of the given NSType. template bool& NaiveVisitor::operator()(NSType* ns) const { @@ -172,7 +183,7 @@ bool& NaiveVisitor::operator()(NSType* ns) const throw std::runtime_error("no neighbor search model initialized"); } - +//! Expose the referenceSet of the given NSType. template const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const { @@ -181,7 +192,7 @@ const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const throw std::runtime_error("no neighbor search model initialized"); } - +//! Clean memory, if necessary. template void DeleteVisitor::operator()(NSType* ns) const { @@ -189,7 +200,7 @@ void DeleteVisitor::operator()(NSType* ns) const delete ns; } - +//! Save parameters for serialization. template SerializeVisitor::SerializeVisitor(Archive& ar, const std::string& name) : @@ -197,6 +208,7 @@ SerializeVisitor::SerializeVisitor(Archive& ar, name(name) {} +//! Serialize the given NSType instance. template template void SerializeVisitor::operator()(NSType* ns) const @@ -243,6 +255,7 @@ void NSModel::Serialize(Archive& ar, boost::apply_visitor(s, nSearch); } +//! Expose the dataset. template const arma::mat& NSModel::Dataset() const { @@ -262,6 +275,7 @@ bool& NSModel::SingleMode() return boost::apply_visitor(SingleModeVisitor(), nSearch); } +//! Expose Naive. template bool NSModel::Naive() const { From b34dac8e4bceab76ff860af035ef9d0bee1ac1fa Mon Sep 17 00:00:00 2001 From: MarcosPividori Date: Fri, 17 Jun 2016 03:03:59 -0300 Subject: [PATCH 6/6] Fix NSModel serialization. Use boost serialize function for variants. --- .../methods/neighbor_search/ns_model.hpp | 17 -------- .../methods/neighbor_search/ns_model_impl.hpp | 43 +++++++++++-------- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index 0203e71ccba..d87549e9208 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -197,23 +197,6 @@ class DeleteVisitor : public boost::static_visitor void operator()(NSType *ns) const; }; -/** - * SerializeVisitor serializes the given NSType instance. - */ -template -class SerializeVisitor : public boost::static_visitor -{ - private: - Archive& ar; - const std::string& name; - - public: - template - void operator()(NSType *ns) const; - - SerializeVisitor(Archive& ar, const std::string& name); -}; - /** * The NSModel class provides an easy way to serialize a model, abstracts away * the different types of trees, and also reflects the NeighborSearch API. diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index 1001ff4b204..5ed97721cd0 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -13,6 +13,8 @@ // In case it hasn't been included yet. #include "ns_model.hpp" +#include + namespace mlpack { namespace neighbor { @@ -200,22 +202,6 @@ void DeleteVisitor::operator()(NSType* ns) const delete ns; } -//! Save parameters for serialization. -template -SerializeVisitor::SerializeVisitor(Archive& ar, - const std::string& name) : - ar(ar), - name(name) -{} - -//! Serialize the given NSType instance. -template -template -void SerializeVisitor::operator()(NSType* ns) const -{ - ar & data::CreateNVP(ns, name); -} - /** * Initialize the NSModel with the given type and whether or not a random * basis should be used. @@ -235,6 +221,27 @@ NSModel::~NSModel() boost::apply_visitor(DeleteVisitor(), nSearch); } +/** + * Non-intrusive serialization for Neighbor Search class. We need this + * definition because we are going to use the serialize function for boost + * variant, which will look for a serialize function for its member types. + */ +template class TreeType, + template class TraversalType> +void serialize( + Archive& ar, + NeighborSearch& ns, + const unsigned int version) +{ + ns.Serialize(ar, version); +} + //! Serialize the kNN model. template template @@ -249,10 +256,8 @@ void NSModel::Serialize(Archive& ar, if (Archive::is_loading::value) boost::apply_visitor(DeleteVisitor(), nSearch); - // We'll only need to serialize one of the kNN objects, based on the type. const std::string& name = NSModelName::Name(); - SerializeVisitor s(ar, name); - boost::apply_visitor(s, nSearch); + ar & data::CreateNVP(nSearch, name); } //! Expose the dataset.