Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify NSModel to use boost variant. #693

Merged
merged 6 commits into from Jun 20, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 62 additions & 2 deletions src/mlpack/methods/neighbor_search/ns_model.hpp
Expand Up @@ -19,6 +19,9 @@
namespace mlpack {
namespace neighbor {

/**
* Alias template for euclidean neighbor search.
*/
template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
Expand Down Expand Up @@ -49,6 +52,10 @@ struct NSModelName<FurthestNeighborSort>
static const std::string Name() { return "furthest_neighbor_search_model"; }
};

/**
* MonoSearchVisitor executes a monochromatic neighbor search on the given
* NSType. We don't make any difference for different instantiations of NSType.
*/
class MonoSearchVisitor : public boost::static_visitor<void>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I hate to be picky, so I hope I am not giving too much work, but do you think you could add documentation for these classes and their members and member functions? I am not sure everyone who looks through this file will be familiar with the visitor paradigm; there is no need to explain what that is in your comments, but it may be useful to have comments along the lines of "MonoSearchVisitor executes a monochromatic neighbor search on the given NSType", or something like this. Maybe a few more words are useful to explain that better. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I should add, if you don't want to or are busy doing other things, I can do the documentation here, it will only take a few minutes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rcurtin Sure!! I am here for this! I agree it can be confusing for someone with no knowledge of boost variants. I will add documentation.
Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in: 6c2c3ca

{
private:
Expand All @@ -65,6 +72,12 @@ class MonoSearchVisitor : public boost::static_visitor<void>
arma::mat& distances);
};

/**
* BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
* We use template specialization to differenciate those tree types that
* accept leafSize as a parameter. In these cases, before doing neighbor search,
* a query tree with proper leafSize is built from the querySet.
*/
template<typename SortPolicy>
class BiSearchVisitor : public boost::static_visitor<void>
{
Expand All @@ -75,22 +88,27 @@ class BiSearchVisitor : public boost::static_visitor<void>
arma::mat& distances;
const size_t leafSize;

//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename NSType>
void SearchLeaf(NSType* ns) const;

public:
//! Alias template necessary for visual c++ compiler.
template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
using NSTypeT = NSType<SortPolicy, TreeType>;

//! Default Bichromatic neighbor search on the given NSType instance.
template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
void operator()(NSTypeT<TreeType>* ns) const;

//! Bichromatic neighbor search on the given NSType specialized for KDTrees.
void operator()(NSTypeT<tree::KDTree>* ns) const;

//! Bichromatic neighbor search on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;

BiSearchVisitor(const arma::mat& querySet,
Expand All @@ -100,62 +118,88 @@ class BiSearchVisitor : public boost::static_visitor<void>
const size_t leafSize);
};

/**
* TrainVisitor sets the reference set to a new reference set on the given
* NSType. We use template specialization to differenciate those tree types that
* accept leafSize as a parameter. In these cases, a reference tree with proper
* leafSize is built from the referenceSet.
*/
template<typename SortPolicy>
class TrainVisitor : public boost::static_visitor<void>
{
private:
arma::mat&& referenceSet;
size_t leafSize;

//! Train on the given NSType considering the leafSize.
template<typename NSType>
void TrainLeaf(NSType* ns) const;

public:
//! Alias template necessary for visual c++ compiler.
template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
using NSTypeT = NSType<SortPolicy, TreeType>;

//! Default Train on the given NSType instance.
template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
void operator()(NSTypeT<TreeType>* ns) const;

//! Train on the given NSType specialized for KDTrees.
void operator()(NSTypeT<tree::KDTree>* ns) const;

//! Train on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;

TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
};

/**
* SingleModeVisitor exposes the SingleMode method of the given NSType.
*/
class SingleModeVisitor : public boost::static_visitor<bool&>
{
public:
template<typename NSType>
bool& operator()(NSType* ns) const;
};

/**
* NaiveVisitor exposes the Naive method of the given NSType.
*/
class NaiveVisitor : public boost::static_visitor<bool&>
{
public:
template<typename NSType>
bool& operator()(NSType *ns) const;
};

/**
* ReferenceSetVisitor exposes the referenceSet of the given NSType.
*/
class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
{
public:
template<typename NSType>
const arma::mat& operator()(NSType *ns) const;
};

/**
* DeleteVisitor deletes the given NSType instance.
*/
class DeleteVisitor : public boost::static_visitor<void>
{
public:
template<typename NSType>
void operator()(NSType *ns) const;
};

/**
* SerializeVisitor serializes the given NSType instance.
*/
template<typename Archive>
class SerializeVisitor : public boost::static_visitor<void>
{
Expand All @@ -170,10 +214,17 @@ class SerializeVisitor : public boost::static_visitor<void>
SerializeVisitor(Archive& ar, const std::string& name);
};

/**
* The NSModel class provides an easy way to serialize a model, abstracts away
* the different types of trees, and also reflects the NeighborSearch API.
*
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
*/
template<typename SortPolicy>
class NSModel
{
public:
//! Enum type to identify each accepted tree type.
enum TreeTypes
{
KD_TREE,
Expand All @@ -185,13 +236,21 @@ class NSModel
};

private:
//! Tree type considered for neighbor search.
TreeTypes treeType;

//! For tree types that accept the maxLeafSize parameter.
size_t leafSize;

// For random projections.
//! For random projections.
bool randomBasis;
arma::mat q;

/**
* nSearch holds an instance of the NeigborSearch class for the current
* treeType. It is initialized every time BuildModel is executed.
* We access to the contained value through the visitor classes defined above.
*/
boost::variant<NSType<SortPolicy, tree::KDTree>*,
NSType<SortPolicy, tree::StandardCoverTree>*,
NSType<SortPolicy, tree::RTree>*,
Expand Down Expand Up @@ -248,11 +307,12 @@ class NSModel
arma::Mat<size_t>& neighbors,
arma::mat& distances);

//! Perform neighbor search.
//! Perform monochromatic neighbor search.
void Search(const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);

//! Return a string representation of the current tree type.
std::string TreeName() const;
};

Expand Down
26 changes: 20 additions & 6 deletions src/mlpack/methods/neighbor_search/ns_model_impl.hpp
Expand Up @@ -16,6 +16,7 @@
namespace mlpack {
namespace neighbor {

//! Save parameters for monochromatic neighbor search.
MonoSearchVisitor::MonoSearchVisitor(const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances) :
Expand All @@ -24,6 +25,7 @@ MonoSearchVisitor::MonoSearchVisitor(const size_t k,
distances(distances)
{}

//! Monochromatic neighbor search on the given NSType instance.
template<typename NSType>
void MonoSearchVisitor::operator()(NSType *ns) const
{
Expand All @@ -32,6 +34,7 @@ void MonoSearchVisitor::operator()(NSType *ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Save parameters for bichromatic neighbor search.
template<typename SortPolicy>
BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
const size_t k,
Expand All @@ -45,6 +48,7 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
leafSize(leafSize)
{}

//! Default Bichromatic neighbor search on the given NSType instance.
template<typename SortPolicy>
template<template<typename TreeMetricType,
typename TreeStatType,
Expand All @@ -56,6 +60,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Bichromatic neighbor search on the given NSType specialized for KDTrees.
template<typename SortPolicy>
void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
{
Expand All @@ -64,6 +69,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Bichromatic neighbor search on the given NSType specialized for BallTrees.
template<typename SortPolicy>
void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
{
Expand All @@ -72,6 +78,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
Expand Down Expand Up @@ -99,14 +106,15 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
ns->Search(querySet, k, neighbors, distances);
}


//! Save parameters for Train.
template<typename SortPolicy>
TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
const size_t leafSize) :
referenceSet(std::move(referenceSet)),
leafSize(leafSize)
{}

//! Default Train on the given NSType instance.
template<typename SortPolicy>
template<template<typename TreeMetricType,
typename TreeStatType,
Expand All @@ -118,6 +126,7 @@ void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Train on the given NSType specialized for KDTrees.
template<typename SortPolicy>
void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
{
Expand All @@ -126,6 +135,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Train on the given NSType specialized for BallTrees.
template<typename SortPolicy>
void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
{
Expand All @@ -134,6 +144,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}

//! Train on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
Expand All @@ -154,7 +165,7 @@ void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
}
}


//! Expose the SingleMode method of the given NSType.
template<typename NSType>
bool& SingleModeVisitor::operator()(NSType* ns) const
{
Expand All @@ -163,7 +174,7 @@ bool& SingleModeVisitor::operator()(NSType* ns) const
throw std::runtime_error("no neighbor search model initialized");
}


//! Expose the Naive method of the given NSType.
template<typename NSType>
bool& NaiveVisitor::operator()(NSType* ns) const
{
Expand All @@ -172,7 +183,7 @@ bool& NaiveVisitor::operator()(NSType* ns) const
throw std::runtime_error("no neighbor search model initialized");
}


//! Expose the referenceSet of the given NSType.
template<typename NSType>
const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
{
Expand All @@ -181,22 +192,23 @@ const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
throw std::runtime_error("no neighbor search model initialized");
}


//! Clean memory, if necessary.
template<typename NSType>
void DeleteVisitor::operator()(NSType* ns) const
{
if (ns)
delete ns;
}


//! Save parameters for serialization.
template<typename Archive>
SerializeVisitor<Archive>::SerializeVisitor(Archive& ar,
const std::string& name) :
ar(ar),
name(name)
{}

//! Serialize the given NSType instance.
template<typename Archive>
template<typename NSType>
void SerializeVisitor<Archive>::operator()(NSType* ns) const
Expand Down Expand Up @@ -243,6 +255,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
boost::apply_visitor(s, nSearch);
}

//! Expose the dataset.
template<typename SortPolicy>
const arma::mat& NSModel<SortPolicy>::Dataset() const
{
Expand All @@ -262,6 +275,7 @@ bool& NSModel<SortPolicy>::SingleMode()
return boost::apply_visitor(SingleModeVisitor(), nSearch);
}

//! Expose Naive.
template<typename SortPolicy>
bool NSModel<SortPolicy>::Naive() const
{
Expand Down