diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp index f02b96c2bca..9dd77609f01 100644 --- a/src/mlpack/methods/neighbor_search/knn_main.cpp +++ b/src/mlpack/methods/neighbor_search/knn_main.cpp @@ -153,7 +153,7 @@ int main(int argc, char *argv[]) const string treeType = CLI::GetParam("tree_type"); const bool randomBasis = CLI::HasParam("random_basis"); - int tree = 0; + KNNModel::TreeTypes tree = KNNModel::KD_TREE; if (treeType == "kd") tree = KNNModel::KD_TREE; else if (treeType == "cover") diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp index df90f97968f..9c16199aabb 100644 --- a/src/mlpack/methods/neighbor_search/ns_model.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model.hpp @@ -52,7 +52,7 @@ class NSModel }; private: - int treeType; + TreeTypes treeType; size_t leafSize; // For random projections. @@ -83,7 +83,7 @@ class NSModel * Initialize the NSModel with the given type and whether or not a random * basis should be used. */ - NSModel(int treeType = TreeTypes::KD_TREE, bool randomBasis = false); + NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false); //! Clean memory, if necessary. ~NSModel(); @@ -105,8 +105,8 @@ class NSModel size_t LeafSize() const { return leafSize; } size_t& LeafSize() { return leafSize; } - int TreeType() const { return treeType; } - int& TreeType() { return treeType; } + TreeTypes TreeType() const { return treeType; } + TreeTypes& TreeType() { return treeType; } bool RandomBasis() const { return randomBasis; } bool& RandomBasis() { return randomBasis; } diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index b25cf0eac0e..28c5a0bf8f0 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -21,7 +21,7 @@ namespace neighbor { * basis should be used. */ template -NSModel::NSModel(int treeType, bool randomBasis) : +NSModel::NSModel(TreeTypes treeType, bool randomBasis) : treeType(treeType), randomBasis(randomBasis), kdTreeNS(NULL),