diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt index a2f8cfcd25b..b41979c8f4d 100644 --- a/COPYRIGHT.txt +++ b/COPYRIGHT.txt @@ -93,7 +93,7 @@ Copyright: Copyright 2018, B Kartheek Reddy Copyright 2018, Atharva Khandait Copyright 2018, Wenhao Huang - Copyright 2018, Roberto Hueso + Copyright 2018-2019, Roberto Hueso Copyright 2018, Prabhat Sharma Copyright 2018, Tan Jun An Copyright 2018, Moksh Jain diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index b1820894ed9..db569c3a37c 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -17,6 +17,7 @@ set(DIRS gmm hmm hoeffding_trees + kde kernel_pca kmeans lars diff --git a/src/mlpack/methods/kde/CMakeLists.txt b/src/mlpack/methods/kde/CMakeLists.txt new file mode 100644 index 00000000000..fa5977534b9 --- /dev/null +++ b/src/mlpack/methods/kde/CMakeLists.txt @@ -0,0 +1,23 @@ +# Define the files we need to compile. +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + kde.hpp + kde_impl.hpp + kde_rules.hpp + kde_rules_impl.hpp + kde_stat.hpp + kde_model.hpp + kde_model_impl.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) + +add_cli_executable(kde) +add_python_binding(kde) diff --git a/src/mlpack/methods/kde/kde.hpp b/src/mlpack/methods/kde/kde.hpp new file mode 100644 index 00000000000..2691671aeab --- /dev/null +++ b/src/mlpack/methods/kde/kde.hpp @@ -0,0 +1,263 @@ +/** + * @file kde.hpp + * @author Roberto Hueso + * + * Kernel Density Estimation. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_KDE_KDE_HPP +#define MLPACK_METHODS_KDE_KDE_HPP + +#include +#include + +#include "kde_stat.hpp" + +namespace mlpack { +namespace kde /** Kernel Density Estimation. */ { + +//! KDEMode represents the ways in which KDE algorithm can be executed. +enum KDEMode +{ + DUAL_TREE_MODE, + SINGLE_TREE_MODE +}; + +/** + * The KDE class is a template class for performing Kernel Density Estimations. + * In statistics, kernel density estimation is a way to estimate the + * probability density function of a variable in a non parametric way. + * This implementation performs this estimation using a tree-independent + * dual-tree algorithm. Details about this algorithm are available in KDERules. + * + * @tparam KernelType Kernel function to use for KDE calculations. + * @tparam MetricType Metric to use for KDE calculations. + * @tparam MatType Type of data to use. + * @tparam TreeType Type of tree to use; must satisfy the TreeType policy API. + * @tparam DualTreeTraversalType Type of dual-tree traversal to use. + * @tparam SingleTreeTraversalType Type of single-tree traversal to use. + */ +template class TreeType = tree::KDTree, + template class DualTreeTraversalType = + TreeType::template DualTreeTraverser, + template class SingleTreeTraversalType = + TreeType::template SingleTreeTraverser> +class KDE +{ + public: + //! Convenience typedef. + typedef TreeType Tree; + + /** + * Initialize KDE object using custom instantiated Metric and Kernel objects. + * + * @param relError Relative error tolerance of the model. + * @param absError Absolute error tolerance of the model. + * @param kernel Instantiated kernel object. + * @param mode Mode for the algorithm. + * @param metric Instantiated metric object. + */ + KDE(const double relError = 0.05, + const double absError = 0, + KernelType kernel = KernelType(), + const KDEMode mode = DUAL_TREE_MODE, + MetricType metric = MetricType()); + + /** + * Construct KDE object as a copy of the given model. This may be + * computationally intensive! + * + * @param other KDE object to copy. + */ + KDE(const KDE& other); + + /** + * Construct KDE object taking ownership of the given model. + * + * @param other KDE object to take ownership of. + */ + KDE(KDE&& other); + + /** + * Copy a KDE model. + * + * Use std::move if the object to copy is no longer needed. + * + * @param other KDE model to copy. + */ + KDE& operator=(KDE other); + + /** + * Destroy the KDE object. If this object created any trees, they will be + * deleted. If you created the trees then you have to delete them yourself. + */ + ~KDE(); + + /** + * Trains the KDE model. It builds a tree using a reference set. + * + * Use std::move if the reference set is no longer needed. + * + * @param referenceSet Set of reference data. + */ + void Train(MatType referenceSet); + + /** + * Trains the KDE model. Sets the reference tree to an already created tree. + * + * - If TreeTraits::RearrangesDataset is False then it is possible + * to use an empty oldFromNewReferences vector. + * + * @param referenceTree Built reference tree. + * @param oldFromNewReferences Permutations of reference points obtained + * during tree generation. + */ + void Train(Tree* referenceTree, std::vector* oldFromNewReferences); + + /** + * Estimate density of each point in the query set given the data of the + * reference set. The result is stored in an estimations vector. + * Estimations might not be normalized. + * + * - Dimension of each point in the query set must match the dimension of each + * point in the reference set. + * + * - Use std::move if the query set is no longer needed. + * + * @pre The model has to be previously trained. + * @param querySet Set of query points to get the density of. + * @param estimations Object which will hold the density of each query point. + */ + void Evaluate(MatType querySet, arma::vec& estimations); + + /** + * Estimate density of each point in the query set given the data of an + * already created query tree. The result is stored in an estimations vector. + * Estimations might not be normalized. + * + * - Dimension of each point in the queryTree dataset must match the dimension + * of each point in the reference set. + * + * - Use std::move if the query tree is no longer needed. + * + * @pre The model has to be previously trained and mode has to be dual-tree. + * @param queryTree Tree of query points to get the density of. + * @param oldFromNewQueries Mappings of query points to the tree dataset. + * @param estimations Object which will hold the density of each query point. + */ + void Evaluate(Tree* queryTree, + const std::vector& oldFromNewQueries, + arma::vec& estimations); + + /** + * Estimate density of each point in the reference set given the data of the + * reference set. It does not compute the estimation of a point with itself. + * The result is stored in an estimations vector. Estimations might not be + * normalized. + * + * @pre The model has to be previously trained. + * @param estimations Object which will hold the density of each reference + * point. + */ + void Evaluate(arma::vec& estimations); + + //! Get the kernel. + const KernelType& Kernel() const { return kernel; } + + //! Modify the kernel. + KernelType& Kernel() { return kernel; } + + //! Get the metric. + const MetricType& Metric() const { return metric; } + + //! Modify the metric. + MetricType& Metric() { return metric; } + + //! Get the reference tree. + Tree* ReferenceTree() { return referenceTree; } + + //! Get relative error tolerance. + double RelativeError() const { return relError; } + + //! Modify relative error tolerance (0 <= newError <= 1). + void RelativeError(const double newError); + + //! Get absolute error tolerance. + double AbsoluteError() const { return absError; } + + //! Modify absolute error tolerance (0 <= newError). + void AbsoluteError(const double newError); + + //! Check whether reference tree is owned by the KDE model. + bool OwnsReferenceTree() const { return ownsReferenceTree; } + + //! Check whether KDE model is trained or not. + bool IsTrained() const { return trained; } + + //! Get the mode of KDE. + KDEMode Mode() const { return mode; } + + //! Modify the mode of KDE. + KDEMode& Mode() { return mode; } + + //! Serialize the model. + template + void serialize(Archive& ar, const unsigned int /* version */); + + private: + //! Kernel. + KernelType kernel; + + //! Metric. + MetricType metric; + + //! Reference tree. + Tree* referenceTree; + + //! Permutations of reference points. + std::vector* oldFromNewReferences; + + //! Relative error tolerance. + double relError; + + //! Absolute error tolerance. + double absError; + + //! If true, the KDE object is responsible for deleting the reference tree. + bool ownsReferenceTree; + + //! If true, the KDE object is trained. + bool trained; + + //! Mode of the KDE algorithm. + KDEMode mode; + + //! Check whether absolute and relative error values are compatible. + static void CheckErrorValues(const double relError, const double absError); + + //! Rearrange estimations vector if required. + static void RearrangeEstimations(const std::vector& oldFromNew, + arma::vec& estimations); +}; + +} // namespace kde +} // namespace mlpack + +// Include implementation. +#include "kde_impl.hpp" + +#endif // MLPACK_METHODS_KDE_KDE_HPP diff --git a/src/mlpack/methods/kde/kde_impl.hpp b/src/mlpack/methods/kde/kde_impl.hpp new file mode 100644 index 00000000000..5ffd8d2c500 --- /dev/null +++ b/src/mlpack/methods/kde/kde_impl.hpp @@ -0,0 +1,645 @@ +/** + * @file kde_impl.hpp + * @author Roberto Hueso + * + * Implementation of Kernel Density Estimation. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include "kde.hpp" +#include "kde_rules.hpp" + +namespace mlpack { +namespace kde { + +//! Construct tree that rearranges the dataset. +template +TreeType* BuildTree( + MatType&& dataset, + std::vector& oldFromNew, + const typename std::enable_if< + tree::TreeTraits::RearrangesDataset>::type* = 0) +{ + return new TreeType(std::forward(dataset), oldFromNew); +} + +//! Construct tree that doesn't rearrange the dataset. +template +TreeType* BuildTree( + MatType&& dataset, + const std::vector& /* oldFromNew */, + const typename std::enable_if< + !tree::TreeTraits::RearrangesDataset>::type* = 0) +{ + return new TreeType(std::forward(dataset)); +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +KDE:: +KDE(const double relError, + const double absError, + KernelType kernel, + const KDEMode mode, + MetricType metric) : + kernel(kernel), + metric(metric), + referenceTree(nullptr), + oldFromNewReferences(nullptr), + relError(relError), + absError(absError), + ownsReferenceTree(false), + trained(false), + mode(mode) +{ + CheckErrorValues(relError, absError); +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +KDE:: +KDE(const KDE& other) : + kernel(KernelType(other.kernel)), + metric(MetricType(other.metric)), + relError(other.relError), + absError(other.absError), + ownsReferenceTree(other.ownsReferenceTree), + trained(other.trained), + mode(other.mode) +{ + if (trained) + { + if (ownsReferenceTree) + { + oldFromNewReferences = + new std::vector(*other.oldFromNewReferences); + referenceTree = new Tree(*other.referenceTree); + } + else + { + oldFromNewReferences = other.oldFromNewReferences; + referenceTree = other.referenceTree; + } + } +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +KDE:: +KDE(KDE&& other) : + kernel(std::move(other.kernel)), + metric(std::move(other.metric)), + referenceTree(other.referenceTree), + oldFromNewReferences(other.oldFromNewReferences), + relError(other.relError), + absError(other.absError), + ownsReferenceTree(other.ownsReferenceTree), + trained(other.trained), + mode(other.mode) +{ + other.kernel = std::move(KernelType()); + other.metric = std::move(MetricType()); + other.referenceTree = nullptr; + other.oldFromNewReferences = nullptr; + other.ownsReferenceTree = false; + other.trained = false; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +KDE& +KDE:: +operator=(KDE other) +{ + // Clean memory + if (ownsReferenceTree) + { + delete referenceTree; + delete oldFromNewReferences; + } + + // Move + this->kernel = std::move(other.kernel); + this->metric = std::move(other.metric); + this->referenceTree = std::move(other.referenceTree); + this->oldFromNewReferences = std::move(other.oldFromNewReferences); + this->relError = other.relError; + this->absError = other.absError; + this->ownsReferenceTree = other.ownsReferenceTree; + this->trained = other.trained; + this->mode = other.mode; + + return *this; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +KDE:: +~KDE() +{ + if (ownsReferenceTree) + { + delete referenceTree; + delete oldFromNewReferences; + } +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +Train(MatType referenceSet) +{ + // Check if referenceSet is not an empty set. + if (referenceSet.n_cols == 0) + throw std::invalid_argument("cannot train KDE model with an empty " + "reference set"); + + if (ownsReferenceTree) + { + delete referenceTree; + delete oldFromNewReferences; + } + + this->ownsReferenceTree = true; + Timer::Start("building_reference_tree"); + this->oldFromNewReferences = new std::vector; + this->referenceTree = BuildTree(std::move(referenceSet), + *oldFromNewReferences); + Timer::Stop("building_reference_tree"); + this->trained = true; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +Train(Tree* referenceTree, std::vector* oldFromNewReferences) +{ + // Check if referenceTree dataset is not an empty set. + if (referenceTree->Dataset().n_cols == 0) + throw std::invalid_argument("cannot train KDE model with an empty " + "reference set"); + + if (ownsReferenceTree == true) + { + delete this->referenceTree; + delete this->oldFromNewReferences; + } + + this->ownsReferenceTree = false; + this->referenceTree = referenceTree; + this->oldFromNewReferences = oldFromNewReferences; + this->trained = true; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +Evaluate(MatType querySet, arma::vec& estimations) +{ + if (mode == DUAL_TREE_MODE) + { + Timer::Start("building_query_tree"); + std::vector oldFromNewQueries; + Tree* queryTree = BuildTree(std::move(querySet), oldFromNewQueries); + Timer::Stop("building_query_tree"); + this->Evaluate(queryTree, oldFromNewQueries, estimations); + delete queryTree; + } + else if (mode == SINGLE_TREE_MODE) + { + // Get estimations vector ready. + estimations.clear(); + estimations.set_size(querySet.n_cols); + estimations.fill(arma::fill::zeros); + + // Check whether has already been trained. + if (!trained) + { + throw std::runtime_error("cannot evaluate KDE model: model needs to be " + "trained before evaluation"); + } + + // Check querySet has at least 1 element to evaluate. + if (querySet.n_cols == 0) + { + Log::Warn << "KDE::Evaluate(): querySet is empty, no predictions will " + << "be returned" << std::endl; + return; + } + + // Check whether dimensions match. + if (querySet.n_rows != referenceTree->Dataset().n_rows) + { + throw std::invalid_argument("cannot evaluate KDE model: querySet and " + "referenceSet dimensions don't match"); + } + + Timer::Start("computing_kde"); + // Evaluate + typedef KDERules RuleType; + RuleType rules = RuleType(referenceTree->Dataset(), + querySet, + estimations, + relError, + absError, + metric, + kernel, + false); + + // Create traverser. + SingleTreeTraversalType traverser(rules); + + // Traverse for each point. + for (size_t i = 0; i < querySet.n_cols; ++i) + traverser.Traverse(i, *referenceTree); + + estimations /= referenceTree->Dataset().n_cols; + Timer::Stop("computing_kde"); + + Log::Info << rules.Scores() << " node combinations were scored." + << std::endl; + Log::Info << rules.BaseCases() << " base cases were calculated." + << std::endl; + } +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +Evaluate(Tree* queryTree, + const std::vector& oldFromNewQueries, + arma::vec& estimations) +{ + // Get estimations vector ready. + estimations.clear(); + estimations.set_size(queryTree->Dataset().n_cols); + estimations.fill(arma::fill::zeros); + + // Check whether has already been trained. + if (!trained) + { + throw std::runtime_error("cannot evaluate KDE model: model needs to be " + "trained before evaluation"); + } + + // Check querySet has at least 1 element to evaluate. + if (queryTree->Dataset().n_cols == 0) + { + Log::Warn << "KDE::Evaluate(): querySet is empty, no predictions will " + << "be returned" << std::endl; + return; + } + + // Check whether dimensions match. + if (queryTree->Dataset().n_rows != referenceTree->Dataset().n_rows) + { + throw std::invalid_argument("cannot evaluate KDE model: querySet and " + "referenceSet dimensions don't match"); + } + + // Check the mode is correct. + if (mode != DUAL_TREE_MODE) + { + throw std::invalid_argument("cannot evaluate KDE model: cannot use " + "a query tree when mode is different from " + "dual-tree"); + } + + Timer::Start("computing_kde"); + + // Evaluate. + typedef KDERules RuleType; + RuleType rules = RuleType(referenceTree->Dataset(), + queryTree->Dataset(), + estimations, + relError, + absError, + metric, + kernel, + false); + + // Create traverser. + DualTreeTraversalType traverser(rules); + traverser.Traverse(*queryTree, *referenceTree); + estimations /= referenceTree->Dataset().n_cols; + Timer::Stop("computing_kde"); + + // Rearrange if necessary. + RearrangeEstimations(oldFromNewQueries, estimations); + + Log::Info << rules.Scores() << " node combinations were scored." << std::endl; + Log::Info << rules.BaseCases() << " base cases were calculated." << std::endl; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +Evaluate(arma::vec& estimations) +{ + // Check whether has already been trained. + if (!trained) + { + throw std::runtime_error("cannot evaluate KDE model: model needs to be " + "trained before evaluation"); + } + + // Get estimations vector ready. + estimations.clear(); + estimations.set_size(referenceTree->Dataset().n_cols); + estimations.fill(arma::fill::zeros); + + Timer::Start("computing_kde"); + // Evaluate + typedef KDERules RuleType; + RuleType rules = RuleType(referenceTree->Dataset(), + referenceTree->Dataset(), + estimations, + relError, + absError, + metric, + kernel, + true); + + if (mode == DUAL_TREE_MODE) + { + // Create traverser. + DualTreeTraversalType traverser(rules); + traverser.Traverse(*referenceTree, *referenceTree); + } + else if (mode == SINGLE_TREE_MODE) + { + SingleTreeTraversalType traverser(rules); + for (size_t i = 0; i < referenceTree->Dataset().n_cols; ++i) + traverser.Traverse(i, *referenceTree); + } + + estimations /= referenceTree->Dataset().n_cols; + // Rearrange if necessary. + RearrangeEstimations(*oldFromNewReferences, estimations); + Timer::Stop("computing_kde"); + + Log::Info << rules.Scores() << " node combinations were scored." << std::endl; + Log::Info << rules.BaseCases() << " base cases were calculated." << std::endl; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +RelativeError(const double newError) +{ + CheckErrorValues(newError, absError); + relError = newError; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +AbsoluteError(const double newError) +{ + CheckErrorValues(relError, newError); + absError = newError; +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +template +void KDE:: +serialize(Archive& ar, const unsigned int /* version */) +{ + // Serialize preferences. + ar & BOOST_SERIALIZATION_NVP(relError); + ar & BOOST_SERIALIZATION_NVP(absError); + ar & BOOST_SERIALIZATION_NVP(trained); + ar & BOOST_SERIALIZATION_NVP(mode); + + // If we are loading, clean up memory if necessary. + if (Archive::is_loading::value) + { + if (ownsReferenceTree && referenceTree) + { + delete referenceTree; + delete oldFromNewReferences; + } + // After loading tree, we own it. + ownsReferenceTree = true; + } + + // Serialize the rest of values. + ar & BOOST_SERIALIZATION_NVP(kernel); + ar & BOOST_SERIALIZATION_NVP(metric); + ar & BOOST_SERIALIZATION_NVP(referenceTree); + ar & BOOST_SERIALIZATION_NVP(oldFromNewReferences); +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +CheckErrorValues(const double relError, const double absError) +{ + if (relError < 0 || relError > 1) + { + throw std::invalid_argument("Relative error tolerance must be a value " + "between 0 and 1"); + } + if (absError < 0) + { + throw std::invalid_argument("Absolute error tolerance must be a value " + "greater or equal to 0"); + } +} + +template class TreeType, + template class DualTreeTraversalType, + template class SingleTreeTraversalType> +void KDE:: +RearrangeEstimations(const std::vector& oldFromNew, + arma::vec& estimations) +{ + if (tree::TreeTraits::RearrangesDataset) + { + const size_t nQueries = oldFromNew.size(); + arma::vec rearrangedEstimations(nQueries); + + // Remap vector. + for (size_t i = 0; i < nQueries; ++i) + rearrangedEstimations(oldFromNew.at(i)) = estimations(i); + + estimations = std::move(rearrangedEstimations); + } +} + +} // namespace kde +} // namespace mlpack diff --git a/src/mlpack/methods/kde/kde_main.cpp b/src/mlpack/methods/kde/kde_main.cpp new file mode 100644 index 00000000000..6f9b09bf0ec --- /dev/null +++ b/src/mlpack/methods/kde/kde_main.cpp @@ -0,0 +1,208 @@ +/** + * @file kde_main.cpp + * @author Roberto Hueso + * + * Executable for running Kernel Density Estimation. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include + +#include "kde.hpp" +#include "kde_model.hpp" + +using namespace mlpack; +using namespace mlpack::kde; +using namespace mlpack::util; +using namespace std; + +// Define parameters for the executable. +PROGRAM_INFO("Kernel Density Estimation", + "This program performs a Kernel Density Estimation. KDE is a " + "non-parametric way of estimating probability density function. " + "For each query point the program will estimate its probability density " + "by applying a kernel function to each reference point. The computational " + "complexity of this is O(N^2) where there are N query points and N " + "reference points, but this implementation will typically see better " + "performance as it uses an approximate dual or single tree algorithm for " + "acceleration." + "\n\n" + "Dual or single tree optimization allows to avoid lots of barely relevant " + "calculations (as kernel function values decrease with distance), so it is " + "an approximate computation. You can specify the maximum relative error " + "tolerance for each query value with " + PRINT_PARAM_STRING("rel_error") + + " as well as the maximum absolute error tolerance with the parameter " + + PRINT_PARAM_STRING("abs_error") + ". This program runs using an Euclidean " + "metric. Kernel function can be selected using the " + + PRINT_PARAM_STRING("kernel") + " option. You can also choose what which " + "type of tree to use for the dual-tree algorithm with " + + PRINT_PARAM_STRING("tree") + ". It is also possible to select whether to " + "use dual-tree algorithm or single-tree algorithm using the " + + PRINT_PARAM_STRING("algorithm") + " option." + "\n\n" + "For example, the following will run KDE using the data in " + + PRINT_DATASET("ref_data") + " for training and the data in " + + PRINT_DATASET("qu_data") + " as query data. It will apply an Epanechnikov " + "kernel with a 0.2 bandwidth to each reference point and use a KD-Tree for " + "the dual-tree optimization. The returned predictions will be within 5% of " + "the real KDE value for each query point." + "\n\n" + + PRINT_CALL("kde", "reference", "ref_data", "query", "qu_data", "bandwidth", + 0.2, "kernel", "epanechnikov", "tree", "kd-tree", "rel_error", + 0.05, "predictions", "out_data") + + "\n\n" + "the predicted density estimations will be stored in " + + PRINT_DATASET("out_data") + "." + "\n" + "If no " + PRINT_PARAM_STRING("query") + " is provided, then KDE will be " + "computed on the " + PRINT_PARAM_STRING("reference") + " dataset." + "\n" + "It is possible to select either a reference dataset or an input model " + "but not both at the same time."); + +// Required options. +PARAM_MATRIX_IN("reference", "Input reference dataset use for KDE.", "r"); +PARAM_MATRIX_IN("query", "Query dataset to KDE on.", "q"); +PARAM_DOUBLE_IN("bandwidth", "Bandwidth of the kernel.", "b", 1.0); + +// Load or save models. +PARAM_MODEL_IN(KDEModel, + "input_model", + "Contains pre-trained KDE model.", + "m"); +PARAM_MODEL_OUT(KDEModel, + "output_model", + "If specified, the KDE model will be saved here.", + "M"); + +// Configuration options. +PARAM_STRING_IN("kernel", "Kernel to use for the prediction." + "('gaussian', 'epanechnikov', 'laplacian', 'spherical', 'triangular').", + "k", "gaussian"); +PARAM_STRING_IN("tree", "Tree to use for the prediction." + "('kd-tree', 'ball-tree', 'cover-tree', 'octree', 'r-tree').", + "t", "kd-tree"); +PARAM_STRING_IN("algorithm", "Algorithm to use for the prediction." + "('dual-tree', 'single-tree').", + "a", "dual-tree"); +PARAM_DOUBLE_IN("rel_error", + "Relative error tolerance for the prediction.", + "e", + 0.05); +PARAM_DOUBLE_IN("abs_error", + "Relative error tolerance for the prediction.", + "E", + 0.0); + +// Output predictions options. +PARAM_COL_OUT("predictions", "Vector to store density predictions.", + "p"); + +// Maybe, in the future, it could be interesting to implement different metrics. + +static void mlpackMain() +{ + // Get some parameters. + const double bandwidth = CLI::GetParam("bandwidth"); + const std::string kernelStr = CLI::GetParam("kernel"); + const std::string treeStr = CLI::GetParam("tree"); + const std::string modeStr = CLI::GetParam("algorithm"); + const double relError = CLI::GetParam("rel_error"); + const double absError = CLI::GetParam("abs_error"); + + // Initialize results vector. + arma::vec estimations; + + // You can only specify reference data or a pre-trained model. + RequireOnlyOnePassed({ "reference", "input_model" }, true); + ReportIgnoredParam({{ "input_model", true }}, "tree"); + ReportIgnoredParam({{ "input_model", true }}, "kernel"); + ReportIgnoredParam({{ "input_model", true }}, "rel_error"); + ReportIgnoredParam({{ "input_model", true }}, "abs_error"); + + // Requirements for parameter values. + RequireParamInSet("kernel", { "gaussian", "epanechnikov", + "laplacian", "spherical", "triangular" }, true, "unknown kernel type"); + RequireParamInSet("tree", { "kd-tree", "ball-tree", "cover-tree", + "octree", "r-tree"}, true, "unknown tree type"); + RequireParamInSet("algorithm", { "dual-tree", "single-tree"}, + true, "unknown algorithm"); + RequireParamValue("rel_error", [](double x){return x >= 0 && x <= 1;}, + true, "relative error must be between 0 and 1"); + RequireParamValue("abs_error", [](double x){return x >= 0;}, + true, "absolute error must be equal or greater than 0"); + + KDEModel* kde; + + if (CLI::HasParam("reference")) + { + arma::mat reference = std::move(CLI::GetParam("reference")); + + kde = new KDEModel(); + // Set parameters. + kde->Bandwidth() = bandwidth; + kde->RelativeError() = relError; + kde->AbsoluteError() = absError; + + // Set KernelType. + if (kernelStr == "gaussian") + kde->KernelType() = KDEModel::GAUSSIAN_KERNEL; + else if (kernelStr == "epanechnikov") + kde->KernelType() = KDEModel::EPANECHNIKOV_KERNEL; + else if (kernelStr == "laplacian") + kde->KernelType() = KDEModel::LAPLACIAN_KERNEL; + else if (kernelStr == "spherical") + kde->KernelType() = KDEModel::SPHERICAL_KERNEL; + else if (kernelStr == "triangular") + kde->KernelType() = KDEModel::TRIANGULAR_KERNEL; + + // Set TreeType. + if (treeStr == "kd-tree") + kde->TreeType() = KDEModel::KD_TREE; + else if (treeStr == "ball-tree") + kde->TreeType() = KDEModel::BALL_TREE; + else if (treeStr == "cover-tree") + kde->TreeType() = KDEModel::COVER_TREE; + else if (treeStr == "octree") + kde->TreeType() = KDEModel::OCTREE; + else if (treeStr == "r-tree") + kde->TreeType() = KDEModel::R_TREE; + + // Build model. + kde->BuildModel(std::move(reference)); + + // Set Mode. + if (modeStr == "dual-tree") + kde->Mode() = KDEMode::DUAL_TREE_MODE; + else if (modeStr == "single-tree") + kde->Mode() = KDEMode::SINGLE_TREE_MODE; + } + else + { + // Load model. + kde = CLI::GetParam("input_model"); + } + + // Evaluation. + if (CLI::HasParam("query")) + { + arma::mat query = std::move(CLI::GetParam("query")); + kde->Evaluate(std::move(query), estimations); + } + else + { + kde->Evaluate(estimations); + } + + // Output predictions if needed. + if (CLI::HasParam("predictions")) + CLI::GetParam("predictions") = std::move(estimations); + + // Save model. + if (CLI::HasParam("output_model")) + CLI::GetParam("output_model") = kde; +} diff --git a/src/mlpack/methods/kde/kde_model.hpp b/src/mlpack/methods/kde/kde_model.hpp new file mode 100644 index 00000000000..89d49e25782 --- /dev/null +++ b/src/mlpack/methods/kde/kde_model.hpp @@ -0,0 +1,381 @@ +/** + * @file kde_model.hpp + * @author Roberto Hueso + * + * Model for KDE. It abstracts different types of tree, kernels, etc. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_KDE_MODEL_HPP +#define MLPACK_METHODS_KDE_MODEL_HPP + +// Include trees. +#include +#include +#include +#include + +// Include core. +#include + +// Remaining includes. +#include +#include "kde.hpp" + +namespace mlpack { +namespace kde { + +//! Alias template. +template class TreeType> +using KDEType = KDE::template DualTreeTraverser, + TreeType::template SingleTreeTraverser>; +/** + * KernelNormalizer holds a set of methods to normalize estimations applying + * in each case the appropiate kernel normalizer function. + */ +class KernelNormalizer +{ + private: + // SFINAE check if Normalizer function is present. + HAS_MEM_FUNC(Normalizer, HasNormalizer); + + public: + //! Normalization not needed. + template + static void ApplyNormalizer( + KernelType& /* kernel */, + const size_t /* dimension */, + arma::vec& /* estimations */, + const typename std::enable_if< + !HasNormalizer::value>:: + type* = 0) + { return; } + + //! Normalize kernels that have normalizer. + template + static void ApplyNormalizer( + KernelType& kernel, + const size_t dimension, + arma::vec& estimations, + const typename std::enable_if< + HasNormalizer::value>:: + type* = 0) + { + estimations /= kernel.Normalizer(dimension); + } +}; + +/** + * DualMonoKDE computes a Kernel Density Estimation on the given KDEType. + * It performs a monochromatic KDE. + */ +class DualMonoKDE : public boost::static_visitor +{ + private: + //! Vector to store the KDE results. + arma::vec& estimations; + + public: + //! Alias template necessary for Visual C++ compiler. + template class TreeType> + using KDETypeT = KDEType; + + //! Default DualMonoKDE on some KDEType. + template class TreeType> + void operator()(KDETypeT* kde) const; + + // TODO Implement specific cases where a leaf size can be selected. + + //! DualMonoKDE constructor. + DualMonoKDE(arma::vec& estimations); +}; + +/** + * DualBiKDE computes a Kernel Density Estimation on the given KDEType. + * It performs a bichromatic KDE. + */ +class DualBiKDE : public boost::static_visitor +{ + private: + //! Query set dimensionality. + const size_t dimension; + + //! The query set for the KDE. + const arma::mat& querySet; + + //! Vector to store the KDE results. + arma::vec& estimations; + + public: + //! Alias template necessary for Visual C++ compiler. + template class TreeType> + using KDETypeT = KDEType; + + //! Default DualBiKDE on some KDEType. + template class TreeType> + void operator()(KDETypeT* kde) const; + + // TODO Implement specific cases where a leaf size can be selected. + + //! DualBiKDE constructor. Takes ownership of the given querySet. + DualBiKDE(arma::mat&& querySet, arma::vec& estimations); +}; + +/** + * TrainVisitor trains a given KDEType using a reference set. + */ +class TrainVisitor : public boost::static_visitor +{ + private: + //! The reference set used for training. + arma::mat&& referenceSet; + + public: + //! Default TrainVisitor on some KDEType. + template class TreeType> + void operator()(KDEType* kde) const; + + // TODO Implement specific cases where a leaf size can be selected. + + //! TrainVisitor constructor. Takes ownership of the given referenceSet. + TrainVisitor(arma::mat&& referenceSet); +}; + +/** + * ModeVisitor exposes the Mode() method of the KDEType. + */ +class ModeVisitor : public boost::static_visitor +{ + public: + //! Return mode of KDEType instance. + template + KDEMode& operator()(KDEType* kde) const; +}; + +class DeleteVisitor : public boost::static_visitor +{ + public: + //! Delete KDEType instance. + template + void operator()(KDEType* kde) const; +}; + +class KDEModel +{ + public: + enum TreeTypes + { + KD_TREE, + BALL_TREE, + COVER_TREE, + OCTREE, + R_TREE + }; + + enum KernelTypes + { + GAUSSIAN_KERNEL, + EPANECHNIKOV_KERNEL, + LAPLACIAN_KERNEL, + SPHERICAL_KERNEL, + TRIANGULAR_KERNEL + }; + + private: + //! Bandwidth of the kernel. + double bandwidth; + + //! Relative error tolerance. + double relError; + + //! Absolute error tolerance. + double absError; + + //! Type of kernel. + KernelTypes kernelType; + + //! Type of tree. + TreeTypes treeType; + + /** + * kdeModel holds an instance of each possible combination of KernelType and + * TreeType. It is initialized using BuildModel. + */ + boost::variant*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*, + KDEType*> kdeModel; + + public: + /** + * Initialize KDEModel. + * + * @param bandwidth Bandwidth to use for the kernel. + * @param relError Maximum relative error tolerance for each point in the + * model. For example, 0.05 means that each value must be + * within 5% of the true KDE value. + * @param absError Maximum absolute error tolerance for each point in the + * model. For example, 0.1 means that for each point the + * value can have a maximum error of 0.1 units. + * @param kernelType Type of kernel to use. + * @param treeType Type of tree to use. + */ + KDEModel(const double bandwidth = 1.0, + const double relError = 0.05, + const double absError = 0, + const KernelTypes kernelType = KernelTypes::GAUSSIAN_KERNEL, + const TreeTypes treeType = TreeTypes::KD_TREE); + + //! Copy constructor of the given model. + KDEModel(const KDEModel& other); + + //! Move constructor of the given model. Takes ownership of the model. + KDEModel(KDEModel&& other); + + /** + * Copy the given model. + * + * Use std::move if the object to copy is no longer needed. + * + * @param other KDEModel to copy. + */ + KDEModel& operator=(KDEModel other); + + //! Destroy the KDEModel object. + ~KDEModel(); + + //! Serialize the KDE model. + template + void serialize(Archive& ar, const unsigned int /* version */); + + //! Get the bandwidth of the kernel. + double Bandwidth() const { return bandwidth; } + + //! Modify the bandwidth of the kernel. + double& Bandwidth() { return bandwidth; } + + //! Get the relative error tolerance. + double RelativeError() const { return relError; } + + //! Modify the relative error tolerance. + double& RelativeError() { return relError; } + + //! Get the absolute error tolerance. + double AbsoluteError() const { return absError; } + + //! Modify the absolute error tolerance. + double& AbsoluteError() { return absError; } + + //! Get the tree type of the model. + TreeTypes TreeType() const { return treeType; } + + //! Modify the tree type of the model. + TreeTypes& TreeType() { return treeType; } + + //! Get the kernel type of the model. + KernelTypes KernelType() const { return kernelType; } + + //! Modify the kernel type of the model. + KernelTypes& KernelType() { return kernelType; } + + //! Get the mode of the model. + KDEMode Mode() const; + + //! Modify the mode of the model. + KDEMode& Mode(); + + /** + * Build the KDE model with the given parameters and then trains it with the + * given reference data. + * Takes possession of the reference set to avoid a copy, so the reference set + * will not be usable after this. + * + * @param referenceSet Set of reference points. + */ + void BuildModel(arma::mat&& referenceSet); + + /** + * Perform kernel density estimation on the given query set. + * Takes possession of the query set to avoid a copy, so the query set + * will not be usable after this. If possible, it returns normalized + * estimations. + * + * @pre The model has to be previously created with BuildModel. + * @param querySet Set of query points. + * @param estimations Vector where the results will be stored in the same + * order as the query points. + */ + void Evaluate(arma::mat&& querySet, arma::vec& estimations); + + /** + * Perform kernel density estimation on the reference set. + * If possible, it returns normalized estimations. + * + * @pre The model has to be previously created with BuildModel. + * @param estimations Vector where the results will be stored in the same + * order as the query points. + */ + void Evaluate(arma::vec& estimations); + + + private: + //! Clean memory. + void CleanMemory(); +}; + +} // namespace kde +} // namespace mlpack + +#include "kde_model_impl.hpp" + +#endif diff --git a/src/mlpack/methods/kde/kde_model_impl.hpp b/src/mlpack/methods/kde/kde_model_impl.hpp new file mode 100644 index 00000000000..a4ab7236284 --- /dev/null +++ b/src/mlpack/methods/kde/kde_model_impl.hpp @@ -0,0 +1,365 @@ +/** + * @file kde_model_impl.hpp + * @author Roberto Hueso + * + * Implementation of KDE Model. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_KDE_MODEL_IMPL_HPP +#define MLPACK_METHODS_KDE_MODEL_IMPL_HPP + +// In case it hasn't been included yet. +#include "kde_model.hpp" + +#include + +namespace mlpack { +namespace kde { + +//! Initialize the KDEModel with the given parameters. +inline KDEModel::KDEModel(const double bandwidth, + const double relError, + const double absError, + const KernelTypes kernelType, + const TreeTypes treeType) : + bandwidth(bandwidth), + relError(relError), + absError(absError), + kernelType(kernelType), + treeType(treeType) +{ + // Nothing to do. +} + +// Copy constructor. +inline KDEModel::KDEModel(const KDEModel& other) : + bandwidth(other.bandwidth), + relError(other.relError), + absError(other.absError), + kernelType(other.kernelType), + treeType(other.treeType) +{ + // Nothing to do. +} + +// Move constructor. +inline KDEModel::KDEModel(KDEModel&& other) : + bandwidth(other.bandwidth), + relError(other.relError), + absError(other.absError), + kernelType(other.kernelType), + treeType(other.treeType), + kdeModel(std::move(other.kdeModel)) +{ + // Reset other model. + other.bandwidth = 1.0; + other.relError = 0.05; + other.absError = 0; + other.kernelType = KernelTypes::GAUSSIAN_KERNEL; + other.treeType = TreeTypes::KD_TREE; + other.kdeModel = decltype(other.kdeModel)(); +} + +inline KDEModel& KDEModel::operator=(KDEModel other) +{ + boost::apply_visitor(DeleteVisitor(), kdeModel); + bandwidth = other.bandwidth; + relError = other.relError; + absError = other.absError; + kernelType = other.kernelType; + treeType = other.treeType; + kdeModel = std::move(other.kdeModel); + return *this; +} + +// Clean memory. +inline KDEModel::~KDEModel() +{ + boost::apply_visitor(DeleteVisitor(), kdeModel); +} + +inline void KDEModel::BuildModel(arma::mat&& referenceSet) +{ + // Clean memory, if necessary. + boost::apply_visitor(DeleteVisitor(), kdeModel); + + // Build the actual model. + if (kernelType == GAUSSIAN_KERNEL && treeType == KD_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::GaussianKernel(bandwidth)); + } + else if (kernelType == GAUSSIAN_KERNEL && treeType == BALL_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::GaussianKernel(bandwidth)); + } + else if (kernelType == GAUSSIAN_KERNEL && treeType == COVER_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::GaussianKernel(bandwidth)); + } + else if (kernelType == GAUSSIAN_KERNEL && treeType == OCTREE) + { + kdeModel = new KDEType + (relError, absError, kernel::GaussianKernel(bandwidth)); + } + else if (kernelType == GAUSSIAN_KERNEL && treeType == R_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::GaussianKernel(bandwidth)); + } + else if (kernelType == EPANECHNIKOV_KERNEL && treeType == KD_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::EpanechnikovKernel(bandwidth)); + } + else if (kernelType == EPANECHNIKOV_KERNEL && treeType == BALL_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::EpanechnikovKernel(bandwidth)); + } + else if (kernelType == EPANECHNIKOV_KERNEL && treeType == COVER_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::EpanechnikovKernel(bandwidth)); + } + else if (kernelType == EPANECHNIKOV_KERNEL && treeType == OCTREE) + { + kdeModel = new KDEType + (relError, absError, kernel::EpanechnikovKernel(bandwidth)); + } + else if (kernelType == EPANECHNIKOV_KERNEL && treeType == R_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::EpanechnikovKernel(bandwidth)); + } + else if (kernelType == LAPLACIAN_KERNEL && treeType == KD_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::LaplacianKernel(bandwidth)); + } + else if (kernelType == LAPLACIAN_KERNEL && treeType == BALL_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::LaplacianKernel(bandwidth)); + } + else if (kernelType == LAPLACIAN_KERNEL && treeType == COVER_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::LaplacianKernel(bandwidth)); + } + else if (kernelType == LAPLACIAN_KERNEL && treeType == OCTREE) + { + kdeModel = new KDEType + (relError, absError, kernel::LaplacianKernel(bandwidth)); + } + else if (kernelType == LAPLACIAN_KERNEL && treeType == R_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::LaplacianKernel(bandwidth)); + } + else if (kernelType == SPHERICAL_KERNEL && treeType == KD_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::SphericalKernel(bandwidth)); + } + else if (kernelType == SPHERICAL_KERNEL && treeType == BALL_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::SphericalKernel(bandwidth)); + } + else if (kernelType == SPHERICAL_KERNEL && treeType == COVER_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::SphericalKernel(bandwidth)); + } + else if (kernelType == SPHERICAL_KERNEL && treeType == OCTREE) + { + kdeModel = new KDEType + (relError, absError, kernel::SphericalKernel(bandwidth)); + } + else if (kernelType == SPHERICAL_KERNEL && treeType == R_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::SphericalKernel(bandwidth)); + } + else if (kernelType == TRIANGULAR_KERNEL && treeType == KD_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::TriangularKernel(bandwidth)); + } + else if (kernelType == TRIANGULAR_KERNEL && treeType == BALL_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::TriangularKernel(bandwidth)); + } + else if (kernelType == TRIANGULAR_KERNEL && treeType == COVER_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::TriangularKernel(bandwidth)); + } + else if (kernelType == TRIANGULAR_KERNEL && treeType == OCTREE) + { + kdeModel = new KDEType + (relError, absError, kernel::TriangularKernel(bandwidth)); + } + else if (kernelType == TRIANGULAR_KERNEL && treeType == R_TREE) + { + kdeModel = new KDEType + (relError, absError, kernel::TriangularKernel(bandwidth)); + } + + // Train the model. + TrainVisitor train(std::move(referenceSet)); + boost::apply_visitor(train, kdeModel); +} + +// Perform bichromatic evaluation. +inline void KDEModel::Evaluate(arma::mat&& querySet, arma::vec& estimations) +{ + Log::Info << "Evaluating KDE..." << std::endl; + DualBiKDE eval(std::move(querySet), estimations); + boost::apply_visitor(eval, kdeModel); +} + +// Perform monochromatic evaluation. +inline void KDEModel::Evaluate(arma::vec& estimations) +{ + Log::Info << "Evaluating KDE..." << std::endl; + DualMonoKDE eval(estimations); + boost::apply_visitor(eval, kdeModel); +} + +// Clean memory. +inline void KDEModel::CleanMemory() +{ + boost::apply_visitor(DeleteVisitor(), kdeModel); +} + +// Parameters for KDE evaluation. +DualMonoKDE::DualMonoKDE(arma::vec& estimations): + estimations(estimations) +{} + +// Default KDE evaluation. +template class TreeType> +void DualMonoKDE::operator()(KDETypeT* kde) const +{ + if (kde) + { + kde->Evaluate(estimations); + const size_t dimension = (kde->ReferenceTree())->Dataset().n_rows; + KernelNormalizer::ApplyNormalizer(kde->Kernel(), + dimension, + estimations); + } + else + { + throw std::runtime_error("no KDE model initialized"); + } +} + +// Parameters for KDE evaluation. +DualBiKDE::DualBiKDE(arma::mat&& querySet, arma::vec& estimations): + dimension(querySet.n_rows), + querySet(std::move(querySet)), + estimations(estimations) +{} + +// Default KDE evaluation. +template class TreeType> +void DualBiKDE::operator()(KDETypeT* kde) const +{ + if (kde) + { + kde->Evaluate(std::move(querySet), estimations); + KernelNormalizer::ApplyNormalizer(kde->Kernel(), + dimension, + estimations); + } + else + { + throw std::runtime_error("no KDE model initialized"); + } +} + +// Parameters for Train. +TrainVisitor::TrainVisitor(arma::mat&& referenceSet) : + referenceSet(std::move(referenceSet)) +{} + +// Default Train. +template class TreeType> +void TrainVisitor::operator()(KDEType* kde) const +{ + Log::Info << "Training KDE model..." << std::endl; + if (kde) + kde->Train(std::move(referenceSet)); + else + throw std::runtime_error("no KDE model initialized"); +} + +// Delete model. +template +void DeleteVisitor::operator()(KDEType* kde) const +{ + if (kde) + delete kde; +} + +// Mode of model. +template +KDEMode& ModeVisitor::operator()(KDEType* kde) const +{ + if (kde) + return kde->Mode(); + else + throw std::runtime_error("no KDE model initialized"); +} + +// Get mode of model. +KDEMode KDEModel::Mode() const +{ + return boost::apply_visitor(ModeVisitor(), kdeModel); +} + +// Modify mode of model. +KDEMode& KDEModel::Mode() +{ + return boost::apply_visitor(ModeVisitor(), kdeModel); +} + +// Serialize the model. +template +void KDEModel::serialize(Archive& ar, const unsigned int /* version */) +{ + ar & BOOST_SERIALIZATION_NVP(bandwidth); + ar & BOOST_SERIALIZATION_NVP(relError); + ar & BOOST_SERIALIZATION_NVP(absError); + ar & BOOST_SERIALIZATION_NVP(kernelType); + ar & BOOST_SERIALIZATION_NVP(treeType); + + if (Archive::is_loading::value) + boost::apply_visitor(DeleteVisitor(), kdeModel); + + ar & BOOST_SERIALIZATION_NVP(kdeModel); +} + +} // namespace kde +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/kde/kde_rules.hpp b/src/mlpack/methods/kde/kde_rules.hpp new file mode 100644 index 00000000000..f5eb608f9c9 --- /dev/null +++ b/src/mlpack/methods/kde/kde_rules.hpp @@ -0,0 +1,136 @@ +/** + * @file kde_rules.hpp + * @author Roberto Hueso + * + * Rules Kernel Density estimation, so that it can be done with arbitrary tree + * types. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_KDE_RULES_HPP +#define MLPACK_METHODS_KDE_RULES_HPP + +#include + +namespace mlpack { +namespace kde { + +template +class KDERules +{ + public: + /** + * Construct KDERules. + * + * @param referenceSet Reference set data. + * @param querySet Query set data. + * @param densities Vector where estimations will be written. + * @param relError Relative error tolerance. + * @param absError Absolute error tolerance. + * @param metric Instantiated metric. + * @param kernel Instantiated kernel. + * @param sameSet True if query and reference sets are the same + * (monochromatic evaluation). + */ + KDERules(const arma::mat& referenceSet, + const arma::mat& querySet, + arma::vec& densities, + const double relError, + const double absError, + MetricType& metric, + KernelType& kernel, + const bool sameSet); + + //! Base Case. + double BaseCase(const size_t queryIndex, const size_t referenceIndex); + + //! SingleTree Rescore. + double Score(const size_t queryIndex, TreeType& referenceNode); + + //! SingleTree Score. + double Rescore(const size_t queryIndex, + TreeType& referenceNode, + const double oldScore) const; + + //! DoubleTree Score. + double Score(TreeType& queryNode, TreeType& referenceNode); + + //! DoubleTree Rescore. + double Rescore(TreeType& queryNode, + TreeType& referenceNode, + const double oldScore) const; + + typedef typename tree::TraversalInfo TraversalInfoType; + + //! Get traversal information. + const TraversalInfoType& TraversalInfo() const { return traversalInfo; } + + //! Modify traversal information. + TraversalInfoType& TraversalInfo() { return traversalInfo; } + + //! Get the number of base cases. + size_t BaseCases() const { return baseCases; } + + //! Get the number of scores. + size_t Scores() const { return scores; } + + private: + //! Evaluate kernel value of 2 points given their indexes. + double EvaluateKernel(const size_t queryIndex, + const size_t referenceIndex) const; + + //! Evaluate kernel value of 2 points. + double EvaluateKernel(const arma::vec& query, + const arma::vec& reference) const; + + //! The reference set. + const arma::mat& referenceSet; + + //! The query set. + const arma::mat& querySet; + + //! Density values. + arma::vec& densities; + + //! Absolute error tolerance. + const double absError; + + //! Relatve error tolerance. + const double relError; + + //! Instantiated metric. + MetricType& metric; + + //! Instantiated kernel. + KernelType& kernel; + + //! Whether reference and query sets are the same. + const bool sameSet; + + //! The last query index. + size_t lastQueryIndex; + + //! The last reference index. + size_t lastReferenceIndex; + + //! Traversal information. + TraversalInfoType traversalInfo; + + //! The number of base cases. + size_t baseCases; + + //! The number of scores. + size_t scores; +}; + +} // namespace kde +} // namespace mlpack + +// Include implementation. +#include "kde_rules_impl.hpp" + +#endif diff --git a/src/mlpack/methods/kde/kde_rules_impl.hpp b/src/mlpack/methods/kde/kde_rules_impl.hpp new file mode 100644 index 00000000000..87273ebfc96 --- /dev/null +++ b/src/mlpack/methods/kde/kde_rules_impl.hpp @@ -0,0 +1,246 @@ +/** + * @file kde_rules_impl.hpp + * @author Roberto Hueso + * + * Implementation of rules for Kernel Density Estimation with generic trees. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_KDE_RULES_IMPL_HPP +#define MLPACK_METHODS_KDE_RULES_IMPL_HPP + +// In case it hasn't been included yet. +#include "kde_rules.hpp" + +namespace mlpack { +namespace kde { + +template +KDERules::KDERules( + const arma::mat& referenceSet, + const arma::mat& querySet, + arma::vec& densities, + const double relError, + const double absError, + MetricType& metric, + KernelType& kernel, + const bool sameSet) : + referenceSet(referenceSet), + querySet(querySet), + densities(densities), + absError(absError), + relError(relError), + metric(metric), + kernel(kernel), + sameSet(sameSet), + lastQueryIndex(querySet.n_cols), + lastReferenceIndex(referenceSet.n_cols), + baseCases(0), + scores(0) +{ + // Nothing to do. +} + +//! The base case. +template +inline force_inline +double KDERules::BaseCase( + const size_t queryIndex, + const size_t referenceIndex) +{ + // If reference and query sets are the same we don't want to compute the + // estimation of a point with itself. + if (sameSet && (queryIndex == referenceIndex)) + return 0.0; + + // Avoid duplicated calculations. + if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex)) + return 0.0; + + // Calculations. + const double distance = metric.Evaluate(querySet.col(queryIndex), + referenceSet.col(referenceIndex)); + densities(queryIndex) += kernel.Evaluate(distance); + + ++baseCases; + lastQueryIndex = queryIndex; + lastReferenceIndex = referenceIndex; + return distance; +} + +//! Single-tree scoring function. +template +inline double KDERules:: +Score(const size_t queryIndex, TreeType& referenceNode) +{ + double score, maxKernel, minKernel, bound; + const arma::vec& queryPoint = querySet.unsafe_col(queryIndex); + const double minDistance = referenceNode.MinDistance(queryPoint); + bool newCalculations = true; + + if (tree::TreeTraits::FirstPointIsCentroid && + lastQueryIndex == queryIndex && + traversalInfo.LastReferenceNode() != NULL && + traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)) + { + // Don't duplicate calculations. + newCalculations = false; + lastQueryIndex = queryIndex; + lastReferenceIndex = referenceNode.Point(0); + } + else + { + // Calculations are new. + maxKernel = kernel.Evaluate(minDistance); + minKernel = kernel.Evaluate(referenceNode.MaxDistance(queryPoint)); + bound = maxKernel - minKernel; + } + + if (newCalculations && + bound <= (absError + relError * minKernel) / referenceSet.n_cols) + { + // Estimate values. + double kernelValue; + + // Calculate kernel value based on reference node centroid. + if (tree::TreeTraits::FirstPointIsCentroid) + { + kernelValue = EvaluateKernel(queryIndex, referenceNode.Point(0)); + } + else + { + kde::KDEStat& referenceStat = referenceNode.Stat(); + kernelValue = EvaluateKernel(queryPoint, referenceStat.Centroid()); + } + + densities(queryIndex) += referenceNode.NumDescendants() * kernelValue; + + // Don't explore this tree branch. + score = DBL_MAX; + } + else + { + score = minDistance; + } + + ++scores; + traversalInfo.LastReferenceNode() = &referenceNode; + traversalInfo.LastScore() = score; + return score; +} + +template +inline double KDERules::Rescore( + const size_t /* queryIndex */, + TreeType& /* referenceNode */, + const double oldScore) const +{ + // If it's pruned it continues to be pruned. + return oldScore; +} + +//! Double-tree scoring function. +template +inline double KDERules:: +Score(TreeType& queryNode, TreeType& referenceNode) +{ + double score, maxKernel, minKernel, bound; + const double minDistance = queryNode.MinDistance(referenceNode); + // Calculations are not duplicated. + bool newCalculations = true; + + if (tree::TreeTraits::FirstPointIsCentroid && + (traversalInfo.LastQueryNode() != NULL) && + (traversalInfo.LastReferenceNode() != NULL) && + (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && + (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) + { + // Don't duplicate calculations. + newCalculations = false; + lastQueryIndex = queryNode.Point(0); + lastReferenceIndex = referenceNode.Point(0); + } + else + { + // Calculations are new. + maxKernel = kernel.Evaluate(minDistance); + minKernel = kernel.Evaluate(queryNode.MaxDistance(referenceNode)); + bound = maxKernel - minKernel; + } + + // If possible, avoid some calculations because of the error tolerance. + if (newCalculations && + bound <= (absError + relError * minKernel) / referenceSet.n_cols) + { + // Auxiliary variables. + double kernelValue; + kde::KDEStat& referenceStat = referenceNode.Stat(); + kde::KDEStat& queryStat = queryNode.Stat(); + + // If calculating a center is not required. + if (tree::TreeTraits::FirstPointIsCentroid) + { + kernelValue = EvaluateKernel(queryNode.Point(0), referenceNode.Point(0)); + } + // Sadly, we have no choice but to calculate the center. + else + { + kernelValue = EvaluateKernel(queryStat.Centroid(), + referenceStat.Centroid()); + } + + // Sum up estimations. + for (size_t i = 0; i < queryNode.NumDescendants(); ++i) + { + densities(queryNode.Descendant(i)) += + referenceNode.NumDescendants() * kernelValue; + } + score = DBL_MAX; + } + else + { + score = minDistance; + } + + ++scores; + traversalInfo.LastQueryNode() = &queryNode; + traversalInfo.LastReferenceNode() = &referenceNode; + traversalInfo.LastScore() = score; + return score; +} + +//! Double-tree rescore. +template +inline double KDERules:: +Rescore(TreeType& /*queryNode*/, + TreeType& /*referenceNode*/, + const double oldScore) const +{ + // If a branch is pruned then it continues to be pruned. + return oldScore; +} + +template +inline force_inline double KDERules:: +EvaluateKernel(const size_t queryIndex, + const size_t referenceIndex) const +{ + return EvaluateKernel(querySet.unsafe_col(queryIndex), + referenceSet.unsafe_col(referenceIndex)); +} + +template +inline force_inline double KDERules:: +EvaluateKernel(const arma::vec& query, const arma::vec& reference) const +{ + return kernel.Evaluate(metric.Evaluate(query, reference)); +} + +} // namespace kde +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/kde/kde_stat.hpp b/src/mlpack/methods/kde/kde_stat.hpp new file mode 100644 index 00000000000..92d6a118156 --- /dev/null +++ b/src/mlpack/methods/kde/kde_stat.hpp @@ -0,0 +1,83 @@ +/** + * @file kde_stat.hpp + * @author Roberto Hueso + * + * Defines TreeStatType for KDE. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_KDE_STAT_HPP +#define MLPACK_METHODS_KDE_STAT_HPP + +#include + +namespace mlpack { +namespace kde { + +/** + * Extra data for each node in the tree. + */ +class KDEStat +{ + public: + //! Initialize the statistic. + KDEStat() : validCentroid(false) { } + + //! Initialization for a fully initialized node. + template + KDEStat(TreeType& node) + { + // Calculate centroid if necessary. + if (!tree::TreeTraits::FirstPointIsCentroid) + { + node.Center(centroid); + validCentroid = true; + } + else + { + validCentroid = false; + } + } + + //! Get the centroid of the node. + inline const arma::vec& Centroid() const + { + if (validCentroid) + return centroid; + throw std::logic_error("Centroid must be assigned before requesting its " + "value"); + } + + //! Modify the centroid of the node. + void SetCentroid(arma::vec newCentroid) + { + validCentroid = true; + centroid = std::move(newCentroid); + } + + //! Get whether the centroid is valid. + inline bool ValidCentroid() const { return validCentroid; } + + //! Serialize the statistic to/from an archive. + template + void serialize(Archive& ar, const unsigned int /* version */) + { + ar & BOOST_SERIALIZATION_NVP(centroid); + ar & BOOST_SERIALIZATION_NVP(validCentroid); + } + + private: + //! Node centroid. + arma::vec centroid; + + //! Whether the centroid is updated or is junk. + bool validCentroid; +}; + +} // namespace kde +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 4bc7cbc5241..3178180b0fa 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -39,6 +39,7 @@ add_executable(mlpack_test hyperplane_test.cpp imputation_test.cpp init_rules_test.cpp + kde_test.cpp kernel_pca_test.cpp kernel_test.cpp kernel_traits_test.cpp @@ -119,6 +120,7 @@ add_executable(mlpack_test main_tests/det_test.cpp main_tests/decision_tree_test.cpp main_tests/decision_stump_test.cpp + main_tests/kde_test.cpp main_tests/linear_regression_test.cpp main_tests/logistic_regression_test.cpp main_tests/lmnn_test.cpp diff --git a/src/mlpack/tests/kde_test.cpp b/src/mlpack/tests/kde_test.cpp new file mode 100644 index 00000000000..3d1cecb7d9a --- /dev/null +++ b/src/mlpack/tests/kde_test.cpp @@ -0,0 +1,821 @@ +/** + * @file kde_test.cpp + * @author Roberto Hueso + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#include + +#include +#include +#include +#include +#include + +#include +#include "test_tools.hpp" +#include "serialization.hpp" + +using namespace mlpack; +using namespace mlpack::kde; +using namespace mlpack::metric; +using namespace mlpack::tree; +using namespace mlpack::kernel; + +using namespace boost::serialization; + +BOOST_AUTO_TEST_SUITE(KDETest); + +// Brute force gaussian KDE. +template +void BruteForceKDE(const arma::mat& reference, + const arma::mat& query, + arma::vec& densities, + KernelType& kernel) +{ + metric::EuclideanDistance metric; + for (size_t i = 0; i < query.n_cols; ++i) + { + for (size_t j = 0; j < reference.n_cols; ++j) + { + double distance = metric.Evaluate(query.col(i),reference.col(j)); + densities(i) += kernel.Evaluate(distance); + } + } + densities /= reference.n_cols; +} + +/** + * Test if simple case is correct according to manually calculated results. + */ +BOOST_AUTO_TEST_CASE(KDESimpleTest) +{ + // Transposed reference and query sets because it's easier to read. + arma::mat reference = { {-1.0, -1.0}, + {-2.0, -1.0}, + {-3.0, -2.0}, + { 1.0, 1.0}, + { 2.0, 1.0}, + { 3.0, 2.0} }; + arma::mat query = { { 0.0, 0.5}, + { 0.4, -3.0}, + { 0.0, 0.0}, + {-2.1, 1.0} }; + arma::inplace_trans(reference); + arma::inplace_trans(query); + arma::vec estimations; + // Manually calculated results. + arma::vec estimationsResult = {0.08323668699564207296148765, + 0.00167470061366603324010116, + 0.07658867126520703394465527, + 0.01028120384800740999553525}; + KDE + kde(0.0, 0.01, GaussianKernel(0.8)); + kde.Train(reference); + kde.Evaluate(query, estimations); + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(estimations[i], estimationsResult[i], 0.01); +} + +/** + * Test Train(Tree...) and Evaluate(Tree...). + */ +BOOST_AUTO_TEST_CASE(KDETreeAsArguments) +{ + // Transposed reference and query sets because it's easier to read. + arma::mat reference = { {-1.0, -1.0}, + {-2.0, -1.0}, + {-3.0, -2.0}, + { 1.0, 1.0}, + { 2.0, 1.0}, + { 3.0, 2.0} }; + arma::mat query = { { 0.0, 0.5}, + { 0.4, -3.0}, + { 0.0, 0.0}, + {-2.1, 1.0} }; + arma::inplace_trans(reference); + arma::inplace_trans(query); + arma::vec estimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec estimationsResult = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.8; + + // Get brute force results. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + estimationsResult, + kernel); + + // Get dual-tree results. + typedef KDTree Tree; + std::vector oldFromNewQueries, oldFromNewReferences; + Tree* queryTree = new Tree(query, oldFromNewQueries, 2); + Tree* referenceTree = new Tree(reference, oldFromNewReferences, 2); + KDE + kde(0.0, 1e-6, GaussianKernel(kernelBandwidth)); + kde.Train(referenceTree, &oldFromNewReferences); + kde.Evaluate(queryTree, std::move(oldFromNewQueries), estimations); + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(estimations[i], estimationsResult[i], 0.01); + delete queryTree; + delete referenceTree; +} + +/** + * Test dual-tree implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(GaussianKDEBruteForceTest) +{ + arma::mat reference = arma::randu(2, 200); + arma::mat query = arma::randu(2, 60); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.3; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test single-tree implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(GaussianSingleKDEBruteForceTest) +{ + arma::mat reference = arma::randu(2, 300); + arma::mat query = arma::randu(2, 100); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.3; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::SINGLE_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test single-tree implementation results against brute force results using + * a cover-tree and Epanechnikov kernel. + */ +BOOST_AUTO_TEST_CASE(EpanechnikovCoverSingleKDETest) +{ + arma::mat reference = arma::randu(2, 300); + arma::mat query = arma::randu(2, 100); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 1.1; + const double relError = 0.08; + + // Brute force KDE. + EpanechnikovKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::SINGLE_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test single-tree implementation results against brute force results using + * an octree and Epanechnikov kernel. + */ +BOOST_AUTO_TEST_CASE(EpanechnikovOctreeSingleKDETest) +{ + arma::mat reference = arma::randu(2, 300); + arma::mat query = arma::randu(2, 100); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 1.0; + const double relError = 0.05; + + // Brute force KDE. + EpanechnikovKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::SINGLE_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test BallTree dual-tree implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(BallTreeGaussianKDETest) +{ + arma::mat reference = arma::randu(2, 200); + arma::mat query = arma::randu(2, 60); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.4; + const double relError = 0.05; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // BallTree KDE. + typedef BallTree Tree; + std::vector oldFromNewQueries, oldFromNewReferences; + Tree* queryTree = new Tree(query, oldFromNewQueries, 2); + Tree* referenceTree = new Tree(reference, oldFromNewReferences, 2); + KDE + kde(relError, 0.0, GaussianKernel(kernelBandwidth)); + kde.Train(referenceTree, &oldFromNewReferences); + kde.Evaluate(queryTree, std::move(oldFromNewQueries), treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); + + delete queryTree; + delete referenceTree; +} + +/** + * Test Octree dual-tree implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(OctreeGaussianKDETest) +{ + arma::mat reference = arma::randu(2, 500); + arma::mat query = arma::randu(2, 200); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.3; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test RTree dual-tree implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(RTreeGaussianKDETest) +{ + arma::mat reference = arma::randu(2, 500); + arma::mat query = arma::randu(2, 200); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.3; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test Standard Cover Tree dual-tree implementation results against brute + * force results. + */ +BOOST_AUTO_TEST_CASE(StandardCoverTreeGaussianKDETest) +{ + arma::mat reference = arma::randu(2, 500); + arma::mat query = arma::randu(2, 200); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.3; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test duplicated value in reference matrix. + */ +BOOST_AUTO_TEST_CASE(DuplicatedReferenceSampleKDETest) +{ + arma::mat reference = arma::randu(2, 30); + arma::mat query = arma::randu(2, 10); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.4; + const double relError = 0.05; + + // Duplicate value. + reference.col(2) = reference.col(3); + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Dual-tree KDE. + typedef KDTree Tree; + std::vector oldFromNewQueries, oldFromNewReferences; + Tree* queryTree = new Tree(query, oldFromNewQueries, 2); + Tree* referenceTree = new Tree(reference, oldFromNewReferences, 2); + KDE + kde(relError, 0.0, GaussianKernel(kernelBandwidth)); + kde.Train(referenceTree, &oldFromNewReferences); + kde.Evaluate(queryTree, oldFromNewQueries, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); + + delete queryTree; + delete referenceTree; +} + +/** + * Test duplicated value in query matrix. + */ +BOOST_AUTO_TEST_CASE(DuplicatedQuerySampleKDETest) +{ + arma::mat reference = arma::randu(2, 30); + arma::mat query = arma::randu(2, 10); + arma::vec estimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.4; + const double relError = 0.05; + + // Duplicate value. + query.col(2) = query.col(3); + + // Dual-tree KDE. + typedef KDTree Tree; + std::vector oldFromNewQueries, oldFromNewReferences; + Tree* queryTree = new Tree(query, oldFromNewQueries, 2); + Tree* referenceTree = new Tree(reference, oldFromNewReferences, 2); + KDE + kde(relError, 0.0, GaussianKernel(kernelBandwidth)); + kde.Train(referenceTree, &oldFromNewReferences); + kde.Evaluate(queryTree, oldFromNewQueries, estimations); + + // Check whether results are equal. + BOOST_REQUIRE_CLOSE(estimations[2], estimations[3], relError*100); + + delete queryTree; + delete referenceTree; +} + +/** + * Test dual-tree breadth-first implementation results against brute force + * results. + */ +BOOST_AUTO_TEST_CASE(BreadthFirstKDETest) +{ + arma::mat reference = arma::randu(2, 200); + arma::mat query = arma::randu(2, 60); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.8; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Breadth-First KDE. + metric::EuclideanDistance metric; + KDE::template BreadthFirstDualTreeTraverser> + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test 1-dimensional implementation results against brute force results. + */ +BOOST_AUTO_TEST_CASE(OneDimensionalTest) +{ + arma::mat reference = arma::randu(1, 200); + arma::mat query = arma::randu(1, 60); + arma::vec bfEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec treeEstimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.7; + const double relError = 0.01; + + // Brute force KDE. + GaussianKernel kernel(kernelBandwidth); + BruteForceKDE(reference, + query, + bfEstimations, + kernel); + + // Optimized KDE. + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, treeEstimations); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(bfEstimations[i], treeEstimations[i], relError*100); +} + +/** + * Test a case where an empty reference set is given to train the model. + */ +BOOST_AUTO_TEST_CASE(EmptyReferenceTest) +{ + arma::mat reference; + arma::mat query = arma::randu(1, 10); + arma::vec estimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.7; + const double relError = 0.01; + + // KDE. + metric::EuclideanDistance metric; + GaussianKernel kernel(kernelBandwidth); + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + + // When training using the dataset matrix. + BOOST_REQUIRE_THROW(kde.Train(reference), std::invalid_argument); + + // When training using a tree. + std::vector oldFromNewReferences; + typedef KDTree Tree; + Tree* referenceTree = new Tree(reference, oldFromNewReferences, 2); + BOOST_REQUIRE_THROW( + kde.Train(referenceTree, &oldFromNewReferences), std::invalid_argument); + + delete referenceTree; +} + +/** + * Tests when reference set values and query set values dimensions don't match. + */ +BOOST_AUTO_TEST_CASE(EvaluationMatchDimensionsTest) +{ + arma::mat reference = arma::randu(3, 10); + arma::mat query = arma::randu(1, 10); + arma::vec estimations = arma::vec(query.n_cols, arma::fill::zeros); + const double kernelBandwidth = 0.7; + const double relError = 0.01; + + // KDE. + metric::EuclideanDistance metric; + GaussianKernel kernel(kernelBandwidth); + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + + // When evaluating using the query dataset matrix. + BOOST_REQUIRE_THROW(kde.Evaluate(query, estimations), + std::invalid_argument); + + // When evaluating using a query tree. + typedef KDTree Tree; + std::vector oldFromNewQueries; + Tree* queryTree = new Tree(query, oldFromNewQueries, 3); + BOOST_REQUIRE_THROW(kde.Evaluate(queryTree, oldFromNewQueries, estimations), + std::invalid_argument); + delete queryTree; +} + +/** + * Tests when an empty query set is given to be evaluated. + */ +BOOST_AUTO_TEST_CASE(EmptyQuerySetTest) +{ + arma::mat reference = arma::randu(1, 10); + arma::mat query; + // Set estimations to the wrong size. + arma::vec estimations(33, arma::fill::zeros); + const double kernelBandwidth = 0.7; + const double relError = 0.01; + + // KDE. + metric::EuclideanDistance metric; + GaussianKernel kernel(kernelBandwidth); + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + + // The query set must be empty. + BOOST_REQUIRE_EQUAL(query.n_cols, 0); + // When evaluating using the query dataset matrix. + BOOST_REQUIRE_NO_THROW(kde.Evaluate(query, estimations)); + + // When evaluating using a query tree. + typedef KDTree Tree; + std::vector oldFromNewQueries; + Tree* queryTree = new Tree(query, oldFromNewQueries, 3); + BOOST_REQUIRE_NO_THROW( + kde.Evaluate(queryTree, oldFromNewQueries, estimations)); + delete queryTree; + + // Estimations must be empty. + BOOST_REQUIRE_EQUAL(estimations.size(), 0); +} + +/** + * Tests serialiation of KDE models. + */ +BOOST_AUTO_TEST_CASE(SerializationTest) +{ + // Initial KDE model to me serialized. + const double relError = 0.25; + const double absError = 0.0; + arma::mat reference = arma::randu(4, 800); + KDE + kde(relError, absError, GaussianKernel(0.25)); + kde.Train(reference); + + // Get estimations to compare. + arma::mat query = arma::randu(4, 100);; + arma::vec estimations = arma::vec(query.n_cols, arma::fill::zeros); + kde.Evaluate(query, estimations); + + // Initialize serialized objects. + KDE kdeXml, kdeText, kdeBinary; + SerializeObjectAll(kde, kdeXml, kdeText, kdeBinary); + + // Check everything is correct. + BOOST_REQUIRE_CLOSE(kde.RelativeError(), relError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeXml.RelativeError(), relError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeText.RelativeError(), relError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeBinary.RelativeError(), relError, 1e-8); + + BOOST_REQUIRE_CLOSE(kde.AbsoluteError(), absError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeXml.AbsoluteError(), absError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeText.AbsoluteError(), absError, 1e-8); + BOOST_REQUIRE_CLOSE(kdeBinary.AbsoluteError(), absError, 1e-8); + + BOOST_REQUIRE_EQUAL(kde.IsTrained(), true); + BOOST_REQUIRE_EQUAL(kdeXml.IsTrained(), true); + BOOST_REQUIRE_EQUAL(kdeText.IsTrained(), true); + BOOST_REQUIRE_EQUAL(kdeBinary.IsTrained(), true); + + const KDEMode mode = KDEMode::DUAL_TREE_MODE; + BOOST_REQUIRE_EQUAL(kde.Mode(), mode); + BOOST_REQUIRE_EQUAL(kdeXml.Mode(), mode); + BOOST_REQUIRE_EQUAL(kdeText.Mode(), mode); + BOOST_REQUIRE_EQUAL(kdeBinary.Mode(), mode); + + // Test if execution gives the same result. + arma::vec xmlEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec textEstimations = arma::vec(query.n_cols, arma::fill::zeros); + arma::vec binEstimations = arma::vec(query.n_cols, arma::fill::zeros); + + kdeXml.Evaluate(query, xmlEstimations); + kdeText.Evaluate(query, textEstimations); + kdeBinary.Evaluate(query, binEstimations); + + for (size_t i = 0; i < query.n_cols; ++i) + { + BOOST_REQUIRE_CLOSE(estimations[i], xmlEstimations[i], relError*100); + BOOST_REQUIRE_CLOSE(estimations[i], textEstimations[i], relError*100); + BOOST_REQUIRE_CLOSE(estimations[i], binEstimations[i], relError*100); + } +} + +/** + * Test if the copy constructor and copy operator works properly. + */ +BOOST_AUTO_TEST_CASE(CopyConstructor) +{ + arma::mat reference = arma::randu(2, 300); + arma::mat query = arma::randu(2, 100); + arma::vec estimations1, estimations2, estimations3; + const double kernelBandwidth = 1.5; + const double relError = 0.05; + + typedef KDE + KDEType; + + // KDE. + KDEType kde(relError, 0, kernel::GaussianKernel(kernelBandwidth)); + kde.Train(std::move(reference)); + + // Copy constructor KDE. + KDEType constructor(kde); + + // Copy operator KDE. + KDEType oper = kde; + + // Evaluations. + kde.Evaluate(query, estimations1); + constructor.Evaluate(query, estimations2); + oper.Evaluate(query, estimations3); + + // Check results. + for (size_t i = 0; i < query.n_cols; ++i) + { + BOOST_REQUIRE_CLOSE(estimations1[i], estimations2[i], 1e-10); + BOOST_REQUIRE_CLOSE(estimations2[i], estimations3[i], 1e-10); + } +} + +/** + * Test if the move constructor works properly. + */ +BOOST_AUTO_TEST_CASE(MoveConstructor) +{ + arma::mat reference = arma::randu(2, 300); + arma::mat query = arma::randu(2, 100); + arma::vec estimations1, estimations2, estimations3; + const double kernelBandwidth = 1.2; + const double relError = 0.05; + + typedef KDE + KDEType; + + // KDE. + KDEType kde(relError, 0, kernel::EpanechnikovKernel(kernelBandwidth)); + kde.Train(std::move(reference)); + kde.Evaluate(query, estimations1); + + // Move constructor KDE. + KDEType constructor(std::move(kde)); + constructor.Evaluate(query, estimations2); + + // Check results. + BOOST_REQUIRE_THROW(kde.Evaluate(query, estimations3), std::runtime_error); + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(estimations1[i], estimations2[i], 1e-10); +} + +/** + * Test if an untrained KDE works properly. + */ +BOOST_AUTO_TEST_CASE(NotTrained) +{ + arma::mat query = arma::randu(1, 10); + std::vector oldFromNew; + arma::vec estimations; + + KDE<> kde; + KDE<>::Tree queryTree(query, oldFromNew); + + // Check results. + BOOST_REQUIRE_THROW(kde.Evaluate(query, estimations), std::runtime_error); + BOOST_REQUIRE_THROW(kde.Evaluate(&queryTree, oldFromNew, estimations), + std::runtime_error); + BOOST_REQUIRE_THROW(kde.Evaluate(estimations), std::runtime_error); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/main_tests/kde_test.cpp b/src/mlpack/tests/main_tests/kde_test.cpp new file mode 100644 index 00000000000..c61ebc054e2 --- /dev/null +++ b/src/mlpack/tests/main_tests/kde_test.cpp @@ -0,0 +1,408 @@ +/** + * @file kde_test.cpp + * @author Roberto Hueso + * + * Test mlpackMain() of kde_main.cpp + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#include + +#define BINDING_TYPE BINDING_TYPE_TEST + +static const std::string testName = "KDE"; + +#include +#include +#include "test_helper.hpp" +#include + +#include +#include "../test_tools.hpp" + +using namespace mlpack; + +struct KDETestFixture +{ + public: + KDETestFixture() + { + // Cache in the options for this program. + CLI::RestoreSettings(testName); + } + + ~KDETestFixture() + { + // Clear the settings. + CLI::ClearSettings(); + } +}; + +void ResetKDESettings() +{ + CLI::ClearSettings(); + CLI::RestoreSettings(testName); +} + +BOOST_FIXTURE_TEST_SUITE(KDEMainTest, KDETestFixture); + +/** + * Ensure that the estimations we get for KDEMain, are the same as the ones we + * get from the KDE class without any wrappers. Requires normalization. + **/ +BOOST_AUTO_TEST_CASE(KDEGaussianRTreeResultsMain) +{ + // Datasets. + arma::mat reference = arma::randu(3, 500); + arma::mat query = arma::randu(3, 100); + arma::vec kdeEstimations, mainEstimations; + double kernelBandwidth = 1.5; + double relError = 0.05; + + kernel::GaussianKernel kernel(kernelBandwidth); + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, kdeEstimations); + // Normalize estimations. + kdeEstimations /= kernel.Normalizer(reference.n_rows); + + // Main estimations. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("kernel", std::string("gaussian")); + SetInputParam("tree", std::string("r-tree")); + SetInputParam("rel_error", relError); + SetInputParam("bandwidth", kernelBandwidth); + + mlpackMain(); + + mainEstimations = std::move(CLI::GetParam("predictions")); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(kdeEstimations[i], mainEstimations[i], relError); +} + +/** + * Ensure that the estimations we get for KDEMain, are the same as the ones we + * get from the KDE class without any wrappers. Doesn't require normalization. + **/ +BOOST_AUTO_TEST_CASE(KDETriangularBallTreeResultsMain) +{ + // Datasets. + arma::mat reference = arma::randu(3, 300); + arma::mat query = arma::randu(3, 100); + arma::vec kdeEstimations, mainEstimations; + double kernelBandwidth = 3.0; + double relError = 0.06; + + kernel::TriangularKernel kernel(kernelBandwidth); + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, kdeEstimations); + + // Main estimations. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("kernel", std::string("triangular")); + SetInputParam("tree", std::string("ball-tree")); + SetInputParam("rel_error", relError); + SetInputParam("bandwidth", kernelBandwidth); + + mlpackMain(); + + mainEstimations = std::move(CLI::GetParam("predictions")); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(kdeEstimations[i], mainEstimations[i], relError); +} + +/** + * Ensure that the estimations we get for KDEMain, are the same as the ones we + * get from the KDE class without any wrappers in the monochromatic case. + **/ +BOOST_AUTO_TEST_CASE(KDEMonoResultsMain) +{ + // Datasets. + arma::mat reference = arma::randu(2, 300); + arma::vec kdeEstimations, mainEstimations; + double kernelBandwidth = 2.3; + double relError = 0.05; + + kernel::EpanechnikovKernel kernel(kernelBandwidth); + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::DUAL_TREE_MODE, metric); + kde.Train(reference); + // Perform monochromatic KDE. + kde.Evaluate(kdeEstimations); + // Normalize. + kdeEstimations /= kernel.Normalizer(reference.n_rows); + + // Main estimations. + SetInputParam("reference", reference); + SetInputParam("kernel", std::string("epanechnikov")); + SetInputParam("tree", std::string("cover-tree")); + SetInputParam("rel_error", relError); + SetInputParam("bandwidth", kernelBandwidth); + + mlpackMain(); + + mainEstimations = std::move(CLI::GetParam("predictions")); + + // Check whether results are equal. + for (size_t i = 0; i < reference.n_cols; ++i) + BOOST_REQUIRE_CLOSE(kdeEstimations[i], mainEstimations[i], relError); +} + +/** + * Ensuring that absence of input data is checked. + **/ +BOOST_AUTO_TEST_CASE(KDENoInputData) +{ + // No input data is not provided. Should throw a runtime error. + Log::Fatal.ignoreInput = true; + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + Log::Fatal.ignoreInput = false; +} + +/** + * Check result has as many densities as query points. + **/ +BOOST_AUTO_TEST_CASE(KDEOutputSize) +{ + const size_t dim = 3; + const size_t samples = 110; + arma::mat reference = arma::randu(dim, 325); + arma::mat query = arma::randu(dim, samples); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + + mlpackMain(); + // Check number of output elements. + BOOST_REQUIRE_EQUAL(CLI::GetParam("predictions").size(), samples); +} + +/** + * Check that saved model can be reused. + **/ +BOOST_AUTO_TEST_CASE(KDEModelReuse) +{ + const size_t dim = 3; + const size_t samples = 100; + const double relError = 0.05; + arma::mat reference = arma::randu(dim, 300); + arma::mat query = arma::randu(dim, samples); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("bandwidth", 2.4); + SetInputParam("rel_error", 0.05); + + mlpackMain(); + + arma::vec oldEstimations = std::move(CLI::GetParam("predictions")); + + // Change parameters and load model. + CLI::GetSingleton().Parameters()["reference"].wasPassed = false; + SetInputParam("bandwidth", 0.5); + SetInputParam("query", query); + SetInputParam("input_model", + std::move(CLI::GetParam("output_model"))); + + mlpackMain(); + + arma::vec newEstimations = std::move(CLI::GetParam("predictions")); + + // Check estimations are the same. + for (size_t i = 0; i < samples; ++i) + BOOST_REQUIRE_CLOSE(oldEstimations[i], newEstimations[i], relError); +} + +/** + * Ensure that the estimations we get for KDEMain, are the same as the ones we + * get from the KDE class without any wrappers using single-tree mode. + **/ +BOOST_AUTO_TEST_CASE(KDEGaussianSingleKDTreeResultsMain) +{ + // Datasets. + arma::mat reference = arma::randu(3, 400); + arma::mat query = arma::randu(3, 400); + arma::vec kdeEstimations, mainEstimations; + double kernelBandwidth = 3.0; + double relError = 0.06; + + kernel::GaussianKernel kernel(kernelBandwidth); + metric::EuclideanDistance metric; + KDE + kde(relError, 0.0, kernel, KDEMode::SINGLE_TREE_MODE, metric); + kde.Train(reference); + kde.Evaluate(query, kdeEstimations); + kdeEstimations /= kernel.Normalizer(reference.n_rows); + + // Main estimations. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("kernel", std::string("gaussian")); + SetInputParam("tree", std::string("kd-tree")); + SetInputParam("algorithm", std::string("single-tree")); + SetInputParam("rel_error", relError); + SetInputParam("bandwidth", kernelBandwidth); + + mlpackMain(); + + mainEstimations = std::move(CLI::GetParam("predictions")); + + // Check whether results are equal. + for (size_t i = 0; i < query.n_cols; ++i) + BOOST_REQUIRE_CLOSE(kdeEstimations[i], mainEstimations[i], relError); +} + +/** + * Ensure we get an exception when an invalid kernel is specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainInvalidKernel) +{ + arma::mat reference = arma::randu(2, 10); + arma::mat query = arma::randu(2, 5); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("kernel", std::string("linux")); + + Log::Fatal.ignoreInput = true; + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + Log::Fatal.ignoreInput = false; +} + +/** + * Ensure we get an exception when an invalid tree is specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainInvalidTree) +{ + arma::mat reference = arma::randu(2, 10); + arma::mat query = arma::randu(2, 5); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("tree", std::string("olive")); + + Log::Fatal.ignoreInput = true; + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + Log::Fatal.ignoreInput = false; +} + +/** + * Ensure we get an exception when an invalid algorithm is specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainInvalidAlgorithm) +{ + arma::mat reference = arma::randu(2, 10); + arma::mat query = arma::randu(2, 5); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("algorithm", std::string("bogosort")); + + Log::Fatal.ignoreInput = true; + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + Log::Fatal.ignoreInput = false; +} + +/** + * Ensure we get an exception when both reference and input_model are + * specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainReferenceAndModel) +{ + arma::mat reference = arma::randu(2, 10); + arma::mat query = arma::randu(2, 5); + KDEModel* model = new KDEModel(); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + SetInputParam("input_model", model); + + Log::Fatal.ignoreInput = true; + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + Log::Fatal.ignoreInput = false; +} + +/** + * Ensure we get an exception when an invalid absolute error is specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainInvalidAbsoluteError) +{ + arma::mat reference = arma::randu(1, 10); + arma::mat query = arma::randu(1, 5); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + + Log::Fatal.ignoreInput = true; + // Invalid value. + SetInputParam("abs_error", -0.1); + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + + // Valid value. + SetInputParam("abs_error", 5.8); + BOOST_REQUIRE_NO_THROW(mlpackMain()); + Log::Fatal.ignoreInput = false; +} + +/** + * Ensure we get an exception when an invalid relative error is specified. + **/ +BOOST_AUTO_TEST_CASE(KDEMainInvalidRelativeError) +{ + arma::mat reference = arma::randu(1, 10); + arma::mat query = arma::randu(1, 5); + + // Main params. + SetInputParam("reference", reference); + SetInputParam("query", query); + + Log::Fatal.ignoreInput = true; + // Invalid under 0. + SetInputParam("rel_error", -0.1); + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + + // Invalid over 1. + SetInputParam("rel_error", 1.1); + BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error); + + // Valid value. + SetInputParam("rel_error", 0.3); + BOOST_REQUIRE_NO_THROW(mlpackMain()); + Log::Fatal.ignoreInput = false; +} + +BOOST_AUTO_TEST_SUITE_END();