diff --git a/doc/tutorials/approx_kfn/approx_kfn.txt b/doc/tutorials/approx_kfn/approx_kfn.txt new file mode 100644 index 00000000000..aa477a0d097 --- /dev/null +++ b/doc/tutorials/approx_kfn/approx_kfn.txt @@ -0,0 +1,1025 @@ +/*! + +@file approx_kfn.txt +@author Ryan Curtin +@brief Tutorial for how to use approximate furthest neighbor search in mlpack. + +@page akfntutorial Approximate furthest neighbor search (mlpack_approx_kfn) tutorial + +@section intro_akfntut Introduction + +\b mlpack implements multiple strategies for approximate furthest neighbor +search in its \c mlpack_approx_kfn and \c mlpack_kfn programs (each program +corresponds to different techniques). This tutorial discusses what problems +these algorithms solve and how to use each of the techniques that \b mlpack +implements. + +\b mlpack implements five approximate furthest neighbor search algorithms: + + - brute-force search (in \c mlpack_kfn) + - single-tree search (in \c mlpack_kfn) + - dual-tree search (in \c mlpack_kfn) + - query-dependent approximate furthest neighbor (QDAFN) (in \c mlpack_approx_kfn) + - DrusillaSelect (in \c mlpack_approx_kfn) + +These methods are described in the following papers: + +@code +@inproceedings{curtin2013tree, + title={Tree-Independent Dual-Tree Algorithms}, + author={Curtin, Ryan R. and March, William B. and Ram, Parikshit and Anderson, + David V. and Gray, Alexander G. and Isbell Jr., Charles L.}, + booktitle={Proceedings of The 30th International Conference on Machine + Learning (ICML '13)}, + pages={1435--1443}, + year={2013} +} +@endcode + +@code +@incollection{pagh2015approximate, + title={Approximate furthest neighbor in high dimensions}, + author={Pagh, Rasmus and Silvestri, Francesco and Sivertsen, Johan and Skala, + Matthew}, + booktitle={Similarity Search and Applications}, + pages={3--14}, + year={2015}, + publisher={Springer} +} +@endcode + +@code +@incollection{curtin2016fast, + title={Fast approximate furthest neighbors with data-dependent candidate + selection}, + author={Curtin, Ryan R., and Gardner, Andrew B.}, + booktitle={Similarity Search and Applications}, + pages={221--235}, + year={2016}, + publisher={Springer} +} +@endcode + +The problem of furthest neighbor search is simple, and is the opposite of the +much-more-studied nearest neighbor search problem. Given a set of reference +points \f$R\f$ (the set in which we are searching), and a set of query points +\f$Q\f$ (the set of points for which we want the furthest neighbor), our goal is +to return the \f$k\f$ furthest neighbors for each query point in \f$Q\f$: + +\f[ +\operatorname{k-argmax}_{p_r \in R} d(p_q, p_r). +\f] + +In order to solve this problem, \b mlpack provides a number of interfaces. + + - two \ref cli_akfntut "simple command-line executables" to calculate + approximate furthest neighbors + - a simple \ref cpp_qdafn_akfntut "C++ class for QDAFN" + - a simple \ref cpp_ds_akfntut "C++ class for DrusillaSelect" + - a simple \ref cpp_kfn_akfntut "C++ class for tree-based and brute-force" + search + +@section toc_akfntut Table of Contents + +A list of all the sections this tutorial contains. + + - \ref intro_akfntut + - \ref toc_akfntut + - \ref which_akfntut + - \ref cli_akfntut + - \ref cli_ex1_akfntut + - \ref cli_ex2_akfntut + - \ref cli_ex3_akfntut + - \ref cli_ex4_akfntut + - \ref cli_ex5_akfntut + - \ref cli_ex6_akfntut + - \ref cli_ex7_akfntut + - \ref cli_ex8_akfntut + - \ref cli_final_akfntut + - \ref cpp_ds_akfntut + - \ref cpp_ex1_ds_akfntut + - \ref cpp_ex2_ds_akfntut + - \ref cpp_ex3_ds_akfntut + - \ref cpp_ex4_ds_akfntut + - \ref cpp_ex5_ds_akfntut + - \ref cpp_qdafn_akfntut + - \ref cpp_ex1_qdafn_akfntut + - \ref cpp_ex2_qdafn_akfntut + - \ref cpp_ex3_qdafn_akfntut + - \ref cpp_ex4_qdafn_akfntut + - \ref cpp_ex5_qdafn_akfntut + - \ref cpp_ns_akfntut + - \ref cpp_ex1_ns_akfntut + - \ref cpp_ex2_ns_akfntut + - \ref cpp_ex3_ns_akfntut + - \ref cpp_ex4_ns_akfntut + - \ref further_doc_akfntut + +@section which_akfntut Which algorithm should be used? + +There are three algorithms for furthest neighbor search that \b mlpack +implements, and each is suited to a different setting. Below is some basic +guidance on what should be used. Note that the question of "which algorithm +should be used" is a very difficult question to answer, so the guidance below is +just that---guidance---and may not be right for a particular problem. + + - \c DrusillaSelect is very fast and will perform extremely well for datasets + with outliers or datasets with structure (like low-dimensional datasets + embedded in high dimensions) + - \c QDAFN is a random approach and therefore should be well-suited for + datasets with little to no structure + - The tree-based approaches (the \c KFN class and the \c mlpack_kfn program) is + best suited for low-dimensional datasets, and is most effective when very + small levels of approximation are desired, or when exact results are desired. + - Dual-tree search is most useful when the query set is large and structured + (like for all-furthest-neighbor search). + - Single-tree search is more useful when the query set is small. + +@section cli_akfntut Command-line 'mlpack_approx_kfn' and 'mlpack_kfn' + +\b mlpack provides two command-line programs to solve approximate furthest +neighbor search: + + - \c mlpack_approx_kfn, for the QDAFN and DrusillaSelect approaches + - \c mlpack_kfn, for exact and approximate tree-based approaches + +These two programs allow a large number of algorithms to be used to find +approximate furthest neighbors. Note that the \c mlpack_kfn program is also +documented by the \ref cli_nstut section of the \ref nstutorial page, as it +shares options with the \c mlpack_knn program. + +Below are several examples of how the \c mlpack_approx_kfn and \c mlpack_kfn +programs might be used. The first examples focus on the \c mlpack_approx_kfn +program, and the last few show how \c mlpack_kfn can be used to produce +approximate results. + +@subsection cli_ex1_akfntut Calculate 5 furthest neighbors with default options + +Here we have a query dataset \c queries.csv and a reference dataset \c refs.csv +and we wish to find the 5 furthest neighbors of every query point in the +reference dataset. We may do that with the \c mlpack_approx_kfn algorithm, +using the default of the \c DrusillaSelect algorithm with default parameters. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 5 -n n.csv -d d.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 5 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 5 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000342s +[INFO ] drusilla_select_search: 0.000780s +[INFO ] loading_data: 0.010689s +[INFO ] saving_data: 0.005585s +[INFO ] total_time: 0.018592s +@endcode + +Convenient timers for parts of the program operation are printed. The results, +saved in \c n.csv and \c d.csv, indicate the furthest neighbors and distances +for each query point. The row of the output file indicates the query point that +the results are for. The neighbors are listed from furthest to nearest; so, the +4th element in the 3rd row of \c d.csv indicates the distance between the 3rd +query point in \c queries.csv and its approximate 4th furthest neighbor. +Similarly, the same element in \c n.csv indicates the index of the approximate +4th furthest neighbor (with respect to \c refs.csv). + +@subsection cli_ex2_akfntut Specifying algorithm parameters for DrusillaSelect + +The \c -p (\c --num_projections) and \c -t (\c --num_tables) parameters affect +the running of the \c DrusillaSelect algorithm and the QDAFN algorithm. +Specifically, larger values for each of these parameters will search more +possible candidate furthest neighbors and produce better results (at the cost of +runtime). More details on how each of these parameters works is available in +the original papers, the \b mlpack source, or the documentation given by +\c --help. + +In the example below, we run \c DrusillaSelect to find 4 furthest neighbors +using 10 tables and 2 points in each table. In this case we have chosen to omit +the \c -n \c n.csv option, meaning that only the output candidate distances will +be written to \c d.csv. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 4 -n n.csv -d d.csv -t 10 -p 2 +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 4 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 4 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 2 +[INFO ] num_tables: 10 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000645s +[INFO ] drusilla_select_search: 0.000551s +[INFO ] loading_data: 0.008518s +[INFO ] saving_data: 0.003734s +[INFO ] total_time: 0.014019s +@endcode + +@subsection cli_ex3_akfntut Using QDAFN instead of DrusillaSelect + +The algorithm to be used for approximate furthest neighbor search can be +specified with the \c --algorithm (\c -a) option to the \c mlpack_approx_kfn +program. Below, we use the QDAFN algorithm instead of the default. We leave +the \c -p and \c -t options at their defaults---even though QDAFN often requires +more tables and points to get the same quality of results. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 3 -n n.csv -d d.csv -a qdafn +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building QDAFN model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 3 furthest neighbors with QDAFN... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: qdafn +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] loading_data: 0.008380s +[INFO ] qdafn_construct: 0.003399s +[INFO ] qdafn_search: 0.000886s +[INFO ] saving_data: 0.002253s +[INFO ] total_time: 0.015465s +@endcode + +@subsection cli_ex4_akfntut Printing results quality with exact distances + +The \c mlpack_approx_kfn program can calculate the quality of the results if the +\c --calculate_error (\c -e) flag is specified. Below we use the program with +its default parameters and calculate the error, which is displayed in the +output. The error is only calculated for the furthest neighbor, not all k; +therefore, in this example we have set \c -k to \c 1. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 1 -e -q -n n.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 1 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Calculating exact distances... +[INFO ] 28891 node combinations were scored. +[INFO ] 37735 base cases were calculated. +[INFO ] Calculation complete. +[INFO ] Average error: 1.08417. +[INFO ] Maximum error: 1.28712. +[INFO ] Minimum error: 1. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: true +[INFO ] distances_file: "" +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] neighbors_file: "" +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.001476s +[INFO ] drusilla_select_construct: 0.000309s +[INFO ] drusilla_select_search: 0.000495s +[INFO ] loading_data: 0.008462s +[INFO ] total_time: 0.011670s +[INFO ] tree_building: 0.000202s +@endcode + +Note that the output includes three lines indicating the error: + +@code +[INFO ] Average error: 1.08417. +[INFO ] Maximum error: 1.28712. +[INFO ] Minimum error: 1. +@endcode + +In this case, a minimum error of 1 indicates an exact result, and over the +entire query set the algorithm has returned a furthest neighbor candidate with +maximum error 1.28712. + +@subsection cli_ex5_akfntut Using cached exact distances for quality results + +However, for large datasets, calculating the error may take a long time, because +the exact furthest neighbors must be calculated. Therefore, if the exact +furthest neighbor distances are already known, they may be passed in with the +\c --exact_distances_file (\c -x) option in order to avoid the calculation. In +the example below, we assume \c exact.csv contains the exact furthest neighbor +distances. We run the \c qdafn algorithm in this example. + +Note that the \c -e option must be specified for the \c -x option have any +effect. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -k 1 -e -x exact.csv -n n.csv -v -a qdafn +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building QDAFN model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 1 furthest neighbors with QDAFN... +[INFO ] Search complete. +[INFO ] Loading 'exact.csv' as raw ASCII formatted data. Size is 1 x 1000. +[INFO ] Average error: 1.06914. +[INFO ] Maximum error: 1.67407. +[INFO ] Minimum error: 1. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: qdafn +[INFO ] calculate_error: true +[INFO ] distances_file: "" +[INFO ] exact_distances_file: exact.csv +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 1 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] loading_data: 0.010348s +[INFO ] qdafn_construct: 0.000318s +[INFO ] qdafn_search: 0.000793s +[INFO ] saving_data: 0.000259s +[INFO ] total_time: 0.012254s +@endcode + +@subsection cli_ex6_akfntut Using tree-based approximation with mlpack_kfn + +The \c mlpack_kfn algorithm allows specifying a desired approximation level with +the \c --epsilon (\c -e) option. The parameter must be greater than or equal +to 0 and less than 1. A setting of 0 indicates exact search. + +The example below runs dual-tree furthest neighbor search (the default +algorithm) with the approximation parameter set to 0.5. + +@code +$ mlpack_kfn -q queries.csv -r refs.csv -v -k 3 -e 0.5 -n n.csv -d d.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded reference data from 'refs.csv' (3x1000). +[INFO ] Building reference tree... +[INFO ] Tree built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded query data from 'queries.csv' (3x1000). +[INFO ] Searching for 3 neighbors with dual-tree kd-tree search... +[INFO ] 1611 node combinations were scored. +[INFO ] 13938 base cases were calculated. +[INFO ] 1611 node combinations were scored. +[INFO ] 13938 base cases were calculated. +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: dual_tree +[INFO ] distances_file: d.csv +[INFO ] epsilon: 0.5 +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] leaf_size: 20 +[INFO ] naive: false +[INFO ] neighbors_file: n.csv +[INFO ] output_model_file: "" +[INFO ] percentage: 1 +[INFO ] query_file: queries.csv +[INFO ] random_basis: false +[INFO ] reference_file: refs.csv +[INFO ] seed: 0 +[INFO ] single_mode: false +[INFO ] tree_type: kd +[INFO ] true_distances_file: "" +[INFO ] true_neighbors_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.000442s +[INFO ] loading_data: 0.008060s +[INFO ] saving_data: 0.002850s +[INFO ] total_time: 0.012667s +[INFO ] tree_building: 0.000251s +@endcode + +Note that the format of the output files \c d.csv and \c n.csv are the same as +for \c mlpack_approx_kfn. + +@subsection cli_ex7_akfntut Different algorithms with 'mlpack_kfn' + +The \c mlpack_kfn program offers a large number of different algorithms that can +be used. The \c --algorithm (\c -a) may be used to specify three main different +algorithm types: \c naive (brute-force search), \c single_tree (single-tree +search), \c dual_tree (dual-tree search, the default), and \c greedy +("defeatist" greedy search, which goes to one leaf node of the tree then +terminates). The example below uses single-tree search to find approximate +neighbors with epsilon set to 0.1. + +@code +mlpack_kfn -q queries.csv -r refs.csv -v -k 3 -e 0.1 -n n.csv -d d.csv -a single_tree +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded reference data from 'refs.csv' (3x1000). +[INFO ] Building reference tree... +[INFO ] Tree built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded query data from 'queries.csv' (3x1000). +[INFO ] Searching for 3 neighbors with single-tree kd-tree search... +[INFO ] 13240 node combinations were scored. +[INFO ] 15924 base cases were calculated. +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: single_tree +[INFO ] distances_file: d.csv +[INFO ] epsilon: 0.1 +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] leaf_size: 20 +[INFO ] naive: false +[INFO ] neighbors_file: n.csv +[INFO ] output_model_file: "" +[INFO ] percentage: 1 +[INFO ] query_file: queries.csv +[INFO ] random_basis: false +[INFO ] reference_file: refs.csv +[INFO ] seed: 0 +[INFO ] single_mode: false +[INFO ] tree_type: kd +[INFO ] true_distances_file: "" +[INFO ] true_neighbors_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.000850s +[INFO ] loading_data: 0.007858s +[INFO ] saving_data: 0.003445s +[INFO ] total_time: 0.013084s +[INFO ] tree_building: 0.000250s +@endcode + +@subsection cli_ex8_akfntut Saving a model for later use + +The \c mlpack_approx_kfn and \c mlpack_kfn programs both allow models to be +saved and loaded for future use. The \c --output_model_file (\c -M) option +allows specifying where to save a model, and the \c --input_model_file (\c -m) +option allows a model to be loaded instead of trained. So, if you specify +\c --input_model_file then you do not need to specify \c --reference_file +(\c -r), \c --num_projections (\c -p), or \c --num_tables (\c -t). + +The example below saves a model with 10 projections and 5 tables. Note that +neither \c --query_file (\c -q) nor \c -k are specified; this run only builds +the model and saves it to \c model.bin. + +@code +$ mlpack_approx_kfn -r refs.csv -t 5 -p 10 -v -M model.bin +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: "" +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 0 +[INFO ] neighbors_file: "" +[INFO ] num_projections: 10 +[INFO ] num_tables: 5 +[INFO ] output_model_file: model.bin +[INFO ] query_file: "" +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000321s +[INFO ] loading_data: 0.004700s +[INFO ] total_time: 0.007320s +@endcode + +Now, with the model saved, we can run approximate furthest neighbor search on a +query set using the saved model: + +@code +$ mlpack_approx_kfn -m model.bin -q queries.csv -k 3 -d d.csv -n n.csv -v +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 3 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: model.bin +[INFO ] k: 3 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_search: 0.000878s +[INFO ] loading_data: 0.004599s +[INFO ] saving_data: 0.003006s +[INFO ] total_time: 0.009234s +@endcode + +These options work in the same way for both the \c mlpack_approx_kfn and +\c mlpack_kfn programs. + +@subsection cli_final_akfntut Final command-line program notes + +Both the \c mlpack_kfn and \c mlpack_approx_kfn programs contain numerous +options not fully documented in these short examples. You can run each program +with the \c --help (\c -h) option for more information. + +@section cpp_ds_akfntut DrusillaSelect C++ class + +\b mlpack provides a simple \c DrusillaSelect C++ class that can be used inside +of C++ programs to perform approximate furthest neighbor search. The class has +only one template parameter---\c MatType---which specifies the type of matrix to +be use. That means the class can be used with either dense data (of type +\c arma::mat) or sparse data (of type \c arma::sp_mat). + +The following examples show simple usage of this class. + +@subsection cpp_ex1_ds_akfntut Approximate furthest neighbors with defaults + +The code below builds a \c DrusillaSelect model with default options on the +matrix \c dataset, then queries for the approximate furthest neighbor of every +point in the \c queries matrix. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with defaults. +DrusillaSelect<> ds(dataset); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +ds.Search(queries, 1, neighbors, distances); +@endcode + +At the end of this code, both the \c distances and \c neighbors matrices will +have number of columns equal to the number of columns in the \c queries matrix. +So, each column of the \c distances and \c neighbors matrices are the distances +or neighbors of the corresponding column in the \c queries matrix. + +@subsection cpp_ex2_ds_akfntut Custom numbers of tables and projections + +The following example constructs a DrusillaSelect model with 10 tables and 5 +projections. Once that is done it performs the same task as the previous +example. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with custom parameters. +DrusillaSelect<> ds(dataset, 10, 5); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +ds.Search(queries, 1, neighbors, distances); +@endcode + +@subsection cpp_ex3_ds_akfntut Accessing the candidate set + +The \c DrusillaSelect algorithm merely scans the reference set and extracts a +number of points that will be queried in a brute-force fashion when the +\c Search() method is called. We can access this set with the \c CandidateSet() +method. The code below prints the fifth point of the candidate set. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with custom parameters. +DrusillaSelect<> ds(dataset, 10, 5); + +// Print the fifth point of the candidate set. +std::cout << ds.CandidateSet().col(4).t(); +@endcode + +@subsection cpp_ex4_ds_akfntut Retraining on a new reference set + +It is possible to retrain a \c DrusillaSelect model with new parameters or with +a new reference set. This is functionally equivalent to creating a new model. +The example code below creates a first \c DrusillaSelect model using 3 tables +and 10 projections, and then retrains this with the same reference set using 10 +tables and 3 projections. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with initial parameters. +DrusillaSelect<> ds(dataset, 3, 10); + +// Now retrain with different parameters. +ds.Train(dataset, 10, 3); +@endcode + +@subsection cpp_ex5_ds_akfntut Running on sparse data + +We can set the template parameter for \c DrusillaSelect to \c arma::sp_mat in +order to perform furthest neighbor search on sparse data. This code below +creates a \c DrusillaSelect model using 4 tables and 6 projections with sparse +input data, then searches for 3 approximate furthest neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::sp_mat dataset; +// The query dataset. +extern arma::sp_mat querySet; + +// Construct the model on sparse data. +DrusillaSelect ds(dataset, 4, 6); + +// Search on query data. +arma::Mat neighbors; +arma::mat distances; +ds.Search(querySet, 3, neighbors, distances); +@endcode + +@section cpp_qdafn_akfntut QDAFN C++ class + +\b mlpack also provides a standalone simple \c QDAFN class for furthest neighbor +search. The API for this class is virtually identical to the \c DrusillaSelect +class, and also has one template parameter to specify the type of matrix to be +used (dense or sparse or other). + +The following subsections demonstrate usage of the \c QDAFN class in the same +way as the previous section's examples for \c DrusillaSelect. + +@subsection cpp_ex1_qdafn_akfntut Approximate furthest neighbors with defaults + +The code below builds a \c QDAFN model with default options on the +matrix \c dataset, then queries for the approximate furthest neighbor of every +point in the \c queries matrix. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with defaults. +QDAFN<> qd(dataset); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +qd.Search(queries, 1, neighbors, distances); +@endcode + +At the end of this code, both the \c distances and \c neighbors matrices will +have number of columns equal to the number of columns in the \c queries matrix. +So, each column of the \c distances and \c neighbors matrices are the distances +or neighbors of the corresponding column in the \c queries matrix. + +@subsection cpp_ex2_qdafn_akfntut Custom numbers of tables and projections + +The following example constructs a QDAFN model with 15 tables and 30 +projections. Once that is done it performs the same task as the previous +example. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with custom parameters. +QDAFN<> qdafn(dataset, 15, 30); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +qdafn.Search(queries, 1, neighbors, distances); +@endcode + +@subsection cpp_ex3_qdafn_akfntut Accessing the candidate set + +The \c QDAFN algorithm scans the reference set, extracting points that have been +projected onto random directions. Each random direction corresponds to a single +table. The \c QDAFN class stores these points as a vector of matrices, which +can be accessed with the \c CandidateSet() method. The code below prints the +fifth point of the candidate set of the third table. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with custom parameters. +QDAFN<> qdafn(dataset, 10, 5); + +// Print the fifth point of the candidate set. +std::cout << ds.CandidateSet(2).col(4).t(); +@endcode + +@subsection cpp_ex4_qdafn_akfntut Retraining on a new reference set + +It is possible to retrain a \c QDAFN model with new parameters or with +a new reference set. This is functionally equivalent to creating a new model. +The example code below creates a first \c QDAFN model using 10 tables +and 40 projections, and then retrains this with the same reference set using 15 +tables and 25 projections. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with initial parameters. +QDAFN<> qdafn(dataset, 3, 10); + +// Now retrain with different parameters. +qdafn.Train(dataset, 10, 3); +@endcode + +@subsection cpp_ex5_qdafn_akfntut Running on sparse data + +We can set the template parameter for \c QDAFN to \c arma::sp_mat in +order to perform furthest neighbor search on sparse data. This code below +creates a \c QDAFN model using 20 tables and 60 projections with sparse +input data, then searches for 3 approximate furthest neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::sp_mat dataset; +// The query dataset. +extern arma::sp_mat querySet; + +// Construct the model on sparse data. +QDAFN qdafn(dataset, 20, 60); + +// Search on query data. +arma::Mat neighbors; +arma::mat distances; +qdafn.Search(querySet, 3, neighbors, distances); +@endcode + +@section cpp_ns_akfntut KFN C++ class + +The extensive \c NeighborSearch class also provides a way to search for +approximate furthest neighbors using a different, tree-based technique. For +full documentation on this class, see the +\ref nstutorial "NeighborSearch tutorial". The \c KFN class is a convenient +typedef of the \c NeighborSearch class that can be used to perform the furthest +neighbors task with kd-trees. + +In the following subsections, the \c KFN class is used in short code examples. + +@subsection cpp_ex1_ns_akfntut Simple furthest neighbors example + +The \c KFN class has construction semantics similar to \c DrusillaSelect and +\c QDAFN. The example below constructs a \c KFN object (which will build the +tree on the reference set), but note that the third parameter to the constructor +allows us to specify our desired level of approximation. In this example we +choose epsilon = 0.05. Then, the code searches for 3 approximate furthest +neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object, performing the default dual-tree search with +// approximation level epsilon = 0.05. +KFN kfn(dataset, KFN::DUAL_TREE_MODE, 0.05); + +// Search for approximate furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 3, neighbors, distances); +@endcode + +@subsection cpp_ex2_ns_akfntut Retraining on a new reference set + +Like the \c QDAFN and \c DrusillaSelect classes, the \c KFN class is capable of +retraining on a new reference set. The code below demonstrates this. + +@code +#include + +using namespace mlpack::neighbor; + +// The original reference set we train on. +extern arma::mat dataset; +// The new reference set we retrain on. +extern arma::mat newDataset; + +// Construct the object with approximation level 0.1. +KFN kfn(dataset, DUAL_TREE_MODE, 0.1); + +// Retrain on the new reference set. +kfn.Train(newDataset); +@endcode + +@subsection cpp_ex3_ns_akfntut Searching in single-tree mode + +The particular mode to be used in search can be specified in the constructor. +In this example, we use single-tree search (as opposed to the default of +dual-tree search). + +@code +#include + +using namespace mlpack::neighbor; + +// The reference set. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object with approximation level 0.25 and in single tree search +// mode. +KFN kfn(dataset, SINGLE_TREE_MODE, 0.25); + +// Search for 5 approximate furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 5, neighbors, distances); +@endcode + +@subsection cpp_ex4_ns_akfntut Searching in brute-force mode + +If desired, brute-force search ("naive search") can be used to find the furthest +neighbors; however, the result will not be approximate---it will be exact (since +every possibility will be considered). The code below performs exact furthest +neighbor search by using the \c KFN class in brute-force mode. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference set. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object in brute-force mode. We can leave the approximation +// parameter to its default (0) since brute-force will provide exact results. +KFN kfn(dataset, NAIVE_MODE); + +// Perform the search for 2 furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 2, neighbors, distances); +@endcode + +@section further_doc_akfntut Further documentation + +For further documentation on the approximate furthest neighbor facilities +offered by \b mlpack, consult the following documentation: + + - \ref nstutorial + - \ref mlpack::neighbor::QDAFN "QDAFN class documentation" + - \ref mlpack::neighbor::DrusillaSelect "DrusillaSelect class documentation" + - \ref mlpack::neighbor::NeighborSearch "NeighborSearch class documentation" + +*/ diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index dbbd2318bee..f292e9756c9 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -18,6 +18,7 @@ endmacro () set(DIRS preprocess adaboost + approx_kfn amf ann cf diff --git a/src/mlpack/methods/approx_kfn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt new file mode 100644 index 00000000000..06b729ca557 --- /dev/null +++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt @@ -0,0 +1,22 @@ +# Define the files we need to compile. +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + # DrusillaSelect sources. + drusilla_select.hpp + drusilla_select_impl.hpp + # QDAFN sources. + qdafn.hpp + qdafn_impl.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) + +# This program computes approximate furthest neighbors. +add_cli_executable(approx_kfn) diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp new file mode 100644 index 00000000000..794385039f5 --- /dev/null +++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp @@ -0,0 +1,266 @@ +/** + * @file smarthash_main.cpp + * @author Ryan Curtin + * + * Command-line program for the SmartHash algorithm. + */ +#include +#include +#include "drusilla_select.hpp" +#include "qdafn.hpp" + +using namespace mlpack; +using namespace mlpack::neighbor; +using namespace std; + +PROGRAM_INFO("Approximate furthest neighbor search", + "This program implements two strategies for furthest neighbor search. " + "These strategies are:" + "\n\n" + " - The 'qdafn' algorithm from 'Approximate Furthest Neighbor in High " + "Dimensions' by R. Pagh, F. Silvestri, J. Sivertsen, and M. Skala, in " + "Similarity Search and Applications 2015 (SISAP)." + "\n" + " - The 'DrusillaSelect' algorithm from 'Fast approximate furthest " + "neighbors with data-dependent candidate selection, by R.R. Curtin and A.B." + " Gardner, in Similarity Search and Applications 2016 (SISAP)." + "\n\n" + "These two strategies give approximate results for the furthest neighbor " + "search problem and can be used as fast replacements for other furthest " + "neighbor techniques such as those found in the mlpack_kfn program. Note " + "that typically, the 'ds' algorithm requires far fewer tables and " + "projections than the 'qdafn' algorithm." + "\n\n" + "Specify a reference set (set to search in) with --reference_file, " + "specify a query set with --query_file, and specify algorithm parameters " + "with --num_tables (-t) and --num_projections (-p) (or don't and defaults " + "will be used). The algorithm to be used (either 'ds'---the default---or " + "'qdafn') may be specified with --algorithm. Also specify the number of " + "neighbors to search for with --k. Each of those options also has short " + "names; see the detailed parameter documentation below." + "\n\n" + "If no query file is specified, the reference set will be used as the " + "query set. A model may be saved with --output_model_file (-M), and an " + "input model may be loaded instead of specifying a reference set with " + "--input_model_file (-m)." + "\n\n" + "Results for each query point are stored in the files specified by " + "--neighbors_file and --distances_file. This is in the same format as the " + "mlpack_kfn and mlpack_knn programs: each row holds the k distances or " + "neighbor indices for each query point."); + +PARAM_STRING_IN("reference_file", "File containing reference points.", "r", ""); +PARAM_STRING_IN("query_file", "File containing query points.", "q", ""); + +// Model loading and saving. +PARAM_STRING_IN("input_model_file", "File containing input model.", "m", ""); +PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M"); + +PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k", 0); + +PARAM_INT_IN("num_tables", "Number of hash tables to use.", "t", 5); +PARAM_INT_IN("num_projections", "Number of projections to use in each hash " + "table.", "p", 5); +PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds"); + +PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.", + "n", ""); +PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.", + "d", ""); + +PARAM_FLAG("calculate_error", "If set, calculate the average distance error for" + " the first furthest neighbor only.", "e"); +PARAM_STRING_IN("exact_distances_file", "File containing exact distances to " + "furthest neighbors; this can be used to avoid explicit calculation when " + "--calculate_error is set.", "x", ""); + +// If we save a model we must also save what type it is. +class ApproxKFNModel +{ + public: + int type; + DrusillaSelect<> ds; + QDAFN<> qdafn; + + //! Constructor, which does nothing. + ApproxKFNModel() : type(0), ds(1, 1), qdafn(1, 1) { } + + //! Serialize the model. + template + void Serialize(Archive& ar, const unsigned int /* version */) + { + ar & data::CreateNVP(type, "type"); + if (type == 0) + { + ar & data::CreateNVP(ds, "model"); + } + else + { + ar & data::CreateNVP(qdafn, "model"); + } + } +}; + +int main(int argc, char** argv) +{ + CLI::ParseCommandLine(argc, argv); + + if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file")) + Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must" + << " be specified!" << endl; + if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file")) + Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)" + << " can be specified!" << endl; + if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k")) + Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;" + << " no task will be performed." << endl; + if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") && + !CLI::HasParam("output_model_file")) + Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or " + << "--distances_file (-d) are specified; no output will be saved!" + << endl; + if (CLI::GetParam("algorithm") != "ds" && + CLI::GetParam("algorithm") != "qdafn") + Log::Fatal << "Unknown algorithm '" << CLI::GetParam("algorithm") + << "'; must be 'ds' or 'qdafn'!" << endl; + if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") || + CLI::HasParam("query_file"))) + Log::Fatal << "If search is being performed, then either --query_file " + << "or --reference_file must be specified!" << endl; + + if (CLI::GetParam("num_tables") <= 0) + Log::Fatal << "Invalid --num_tables value (" + << CLI::GetParam("num_tables") << "); must be greater than 0!" + << endl; + if (CLI::GetParam("num_projections") <= 0) + Log::Fatal << "Invalid --num_projections value (" + << CLI::GetParam("num_projections") << "); must be greater than 0!" + << endl; + + if (CLI::HasParam("calculate_error") && !CLI::HasParam("k")) + Log::Warn << "--calculate_error ignored because --k is not specified." + << endl; + if (CLI::HasParam("exact_distances_file") && + !CLI::HasParam("calculate_error")) + Log::Warn << "--exact_distances_file ignored beceause --calculate_error is " + << "not specified." << endl; + if (CLI::HasParam("calculate_error") && + !CLI::HasParam("exact_distances_file") && + !CLI::HasParam("reference_file")) + Log::Fatal << "Cannot calculate error without either --exact_distances_file" + << " or --reference_file specified!" << endl; + + // Do the building of a model, if necessary. + ApproxKFNModel m; + arma::mat referenceSet; // This may be used at query time. + if (CLI::HasParam("reference_file")) + { + const string referenceFile = CLI::GetParam("reference_file"); + data::Load(referenceFile, referenceSet); + + const size_t numTables = (size_t) CLI::GetParam("num_tables"); + const size_t numProjections = (size_t) CLI::GetParam("num_projections"); + const string algorithm = CLI::GetParam("algorithm"); + + if (algorithm == "ds") + { + Timer::Start("drusilla_select_construct"); + Log::Info << "Building DrusillaSelect model..." << endl; + m.type = 0; + m.ds = DrusillaSelect<>(referenceSet, numTables, numProjections); + Timer::Stop("drusilla_select_construct"); + } + else + { + Timer::Start("qdafn_construct"); + Log::Info << "Building QDAFN model..." << endl; + m.type = 1; + m.qdafn = QDAFN<>(referenceSet, numTables, numProjections); + Timer::Stop("qdafn_construct"); + } + Log::Info << "Model built." << endl; + } + else + { + // We must load the model from file. + const string inputModelFile = CLI::GetParam("input_model_file"); + data::Load(inputModelFile, "approx_kfn", m); + } + + // Now, do we need to do any queries? + if (CLI::HasParam("k")) + { + arma::mat querySet; // This may or may not be used. + const size_t k = (size_t) CLI::GetParam("k"); + + arma::Mat neighbors; + arma::mat distances; + + arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet; + if (CLI::HasParam("query_file")) + { + const string queryFile = CLI::GetParam("query_file"); + data::Load(queryFile, querySet); + } + + if (m.type == 0) + { + Timer::Start("drusilla_select_search"); + Log::Info << "Searching for " << k << " furthest neighbors with " + << "DrusillaSelect..." << endl; + m.ds.Search(set, k, neighbors, distances); + Timer::Stop("drusilla_select_search"); + } + else + { + Timer::Start("qdafn_search"); + Log::Info << "Searching for " << k << " furthest neighbors with " + << "QDAFN..." << endl; + m.qdafn.Search(set, k, neighbors, distances); + Timer::Stop("qdafn_search"); + } + Log::Info << "Search complete." << endl; + + // Should we calculate error? + if (CLI::HasParam("calculate_error")) + { + arma::mat exactDistances; + if (CLI::HasParam("exact_distances_file")) + { + data::Load(CLI::GetParam("exact_distances_file"), + exactDistances); + } + else + { + // Calculate exact distances. We are guaranteed the reference set is + // available. + Log::Info << "Calculating exact distances..." << endl; + AllkFN kfn(referenceSet); + arma::Mat exactNeighbors; + kfn.Search(set, 1, exactNeighbors, exactDistances); + Log::Info << "Calculation complete." << endl; + } + + const double averageError = arma::sum(exactDistances.row(0) / + distances.row(0)) / distances.n_cols; + const double minError = arma::min(exactDistances.row(0) / + distances.row(0)); + const double maxError = arma::max(exactDistances.row(0) / + distances.row(0)); + + Log::Info << "Average error: " << averageError << "." << endl; + Log::Info << "Maximum error: " << maxError << "." << endl; + Log::Info << "Minimum error: " << minError << "." << endl; + } + + // Save results, if desired. + if (CLI::HasParam("neighbors_file")) + data::Save(CLI::GetParam("neighbors_file"), neighbors, false); + if (CLI::HasParam("distances_file")) + data::Save(CLI::GetParam("distances_file"), distances, false); + } + + // Should we save the model? + if (CLI::HasParam("output_model_file")) + data::Save(CLI::GetParam("output_model_file"), "approx_kfn", m); +} diff --git a/src/mlpack/methods/approx_kfn/drusilla_select.hpp b/src/mlpack/methods/approx_kfn/drusilla_select.hpp new file mode 100644 index 00000000000..38b90ab8a5c --- /dev/null +++ b/src/mlpack/methods/approx_kfn/drusilla_select.hpp @@ -0,0 +1,125 @@ +/** + * @file drusilla_select.hpp + * @author Ryan Curtin + * + * An implementation of the approximate furthest neighbor algorithm specified in + * the following paper: + * + * @code + * @incollection{curtin2016fast, + * title={Fast approximate furthest neighbors with data-dependent candidate + * selection}, + * author={Curtin, R.R., and Gardner, A.B.}, + * booktitle={Similarity Search and Applications}, + * pages={221--235}, + * year={2016}, + * publisher={Springer} + * } + * @endcode + * + * This algorithm, called DrusillaSelect, constructs a candidate set of points + * to query to find an approximate furthest neighbor. The strange name is a + * result of the algorithm being named after a cat. The cat in question may be + * viewed at http://www.ratml.org/misc_img/drusilla_fence.png. + */ +#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP +#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP + +#include + +namespace mlpack { +namespace neighbor { + +template +class DrusillaSelect +{ + public: + /** + * Construct the DrusillaSelect object with the given reference set (this is + * the set that will be searched). The resulting set of candidate points that + * will be searched at query time will have size l*m. + * + * @param referenceSet Set of reference data. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + DrusillaSelect(const MatType& referenceSet, + const size_t l, + const size_t m); + + /** + * Construct the DrusillaSelect object with no given reference set. Be sure + * to call Train() before calling Search()! + * + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + DrusillaSelect(const size_t l, const size_t m); + + /** + * Build the set of candidate points on the given reference set. If l and m + * are left unspecified, then the values set in the constructor will be used + * instead. + * + * @param referenceSet Set to extract candidate points from. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + void Train(const MatType& referenceSet, + const size_t l = 0, + const size_t m = 0); + + /** + * Search for the k furthest neighbors of the given query set. (The query set + * can contain just one point: that is okay.) The results will be stored in + * the given neighbors and distances matrices, in the same format as the + * NeighborSearch and LSHSearch classes. That is, each column in the + * neighbors and distances matrices will refer to a single query point, and + * the k'th row in that column will refer to the k'th candidate neighbor or + * distance for that query point. + * + * @param querySet Set of query points to search. + * @param k Number of furthest neighbors to search for. + * @param neighbors Matrix to store resulting neighbors in. + * @param distances Matrix to store resulting distances in. + */ + void Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances); + + /** + * Serialize the model. + */ + template + void Serialize(Archive& ar, const unsigned int /* version */); + + //! Access the candidate set. + const MatType& CandidateSet() const { return candidateSet; } + //! Modify the candidate set. Be careful! + MatType& CandidateSet() { return candidateSet; } + + //! Access the indices of points in the candidate set. + const arma::Col& CandidateIndices() const { return candidateIndices; } + //! Modify the indices of points in the candidate set. Be careful! + arma::Col& CandidateIndices() { return candidateIndices; } + + private: + //! The reference set. + MatType candidateSet; + //! Indices of each point in the reference set. + arma::Col candidateIndices; + + //! The number of projections. + size_t l; + //! The number of points in each projection. + size_t m; +}; + +} // namespace neighbor +} // namespace mlpack + +// Include implementation. +#include "drusilla_select_impl.hpp" + +#endif diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp new file mode 100644 index 00000000000..942063b6c08 --- /dev/null +++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp @@ -0,0 +1,213 @@ +/** + * @file drusilla_select_impl.hpp + * @author Ryan Curtin + * + * Implementation of DrusillaSelect class methods. + */ +#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP +#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP + +// In case it hasn't been included yet. +#include "drusilla_select.hpp" + +#include +#include +#include +#include +#include + +namespace mlpack { +namespace neighbor { + +// Constructor. +template +DrusillaSelect::DrusillaSelect(const MatType& referenceSet, + const size_t l, + const size_t m) : + candidateSet(referenceSet.n_cols, l * m), + candidateIndices(l * m), + l(l), + m(m) +{ + if (l == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of l; must be greater than 0!"); + else if (m == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of m; must be greater than 0!"); + + Train(referenceSet, l, m); +} + +// Constructor with no training. +template +DrusillaSelect::DrusillaSelect(const size_t l, const size_t m) : + candidateSet(0, l * m), + candidateIndices(l * m), + l(l), + m(m) +{ + if (l == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of l; must be greater than 0!"); + else if (m == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of m; must be greater than 0!"); +} + +// Train the model. +template +void DrusillaSelect::Train( + const MatType& referenceSet, + const size_t lIn, + const size_t mIn) +{ + // Did the user specify a new size? If so, use it. + if (lIn > 0) + l = lIn; + if (mIn > 0) + m = mIn; + + if ((l * m) > referenceSet.n_cols) + throw std::invalid_argument("DrusillaSelect::Train(): l and m are too " + "large! Choose smaller values. l*m must be smaller than the number " + "of points in the dataset."); + + candidateSet.set_size(referenceSet.n_rows, l * m); + candidateIndices.set_size(l * m); + + arma::vec dataMean(arma::mean(referenceSet, 1)); + arma::vec norms(referenceSet.n_cols); + + MatType refCopy(referenceSet.n_rows, referenceSet.n_cols); + for (size_t i = 0; i < refCopy.n_cols; ++i) + { + refCopy.col(i) = referenceSet.col(i) - dataMean; + norms[i] = arma::norm(refCopy.col(i)); + } + + // Find the top m points for each of the l projections... + for (size_t i = 0; i < l; ++i) + { + // Pick best index. + arma::uword maxIndex; + norms.max(maxIndex); + + arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex))); + const size_t n_nonzero = (size_t) arma::sum(norms > 0); + + // Calculate distortion and offset and make scores. + std::vector closeAngle(referenceSet.n_cols, false); + arma::vec sums(referenceSet.n_cols); + for (size_t j = 0; j < referenceSet.n_cols; ++j) + { + if (norms[j] > 0.0) + { + const double offset = arma::dot(refCopy.col(j), line); + const double distortion = arma::norm(refCopy.col(j) - offset * line); + sums[j] = std::abs(offset) - std::abs(distortion); + closeAngle[j] = + (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0)); + } + else + { + sums[j] = norms[j]; + } + } + + // Find the top m elements using a priority queue. + typedef std::pair Candidate; + struct CandidateCmp + { + bool operator()(const Candidate& c1, const Candidate& c2) + { + return c2.first < c1.first; + } + }; + + std::vector clist(m, std::make_pair(double(-1.0), size_t(-1))); + std::priority_queue, CandidateCmp> + pq(CandidateCmp(), std::move(clist)); + + for (size_t j = 0; j < sums.n_elem; ++j) + { + Candidate c = std::make_pair(sums[j], j); + if (CandidateCmp()(c, pq.top())) + { + pq.pop(); + pq.push(c); + } + } + + // Take the top m elements for this table. + for (size_t j = 0; j < m; ++j) + { + const size_t index = pq.top().second; + pq.pop(); + candidateSet.col(i * m + j) = referenceSet.col(index); + candidateIndices[i * m + j] = index; + + // Mark the norm as -1 so we don't see this point again. + norms[index] = -1.0; + } + + // Calculate angles from the current projection. Anything close enough, + // mark the norm as 0. + for (size_t j = 0; j < norms.n_elem; ++j) + if (norms[j] > 0.0 && closeAngle[j]) + norms[j] = 0.0; + } +} + +// Search. +template +void DrusillaSelect::Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances) +{ + if (candidateSet.n_cols == 0) + throw std::runtime_error("DrusillaSelect::Search(): candidate set not " + "initialized! Call Train() first."); + + if (k > (l * m)) + throw std::invalid_argument("DrusillaSelect::Search(): requested k is " + "greater than number of points in candidate set! Increase l or m."); + + // We'll use the NeighborSearchRules class to perform our brute-force search. + // Note that we aren't using trees for our search, so we can use 'int' as a + // TreeType. + metric::EuclideanDistance metric; + NeighborSearchRules> + rules(candidateSet, querySet, k, metric, 0, false); + + for (size_t q = 0; q < querySet.n_cols; ++q) + for (size_t r = 0; r < candidateSet.n_cols; ++r) + rules.BaseCase(q, r); + + rules.GetResults(neighbors, distances); + + // Map the neighbors back to their original indices in the reference set. + for (size_t i = 0; i < neighbors.n_elem; ++i) + neighbors[i] = candidateIndices[neighbors[i]]; +} + +//! Serialize the model. +template +template +void DrusillaSelect::Serialize(Archive& ar, + const unsigned int /* version */) +{ + using data::CreateNVP; + + ar & CreateNVP(candidateSet, "candidateSet"); + ar & CreateNVP(candidateIndices, "candidateIndices"); + ar & CreateNVP(l, "l"); + ar & CreateNVP(m, "m"); +} + +} // namespace neighbor +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp new file mode 100644 index 00000000000..6ba8b81b5db --- /dev/null +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -0,0 +1,113 @@ +/** + * @file qdafn.hpp + * @author Ryan Curtin + * + * An implementation of the query-dependent approximate furthest neighbor + * algorithm specified in the following paper: + * + * @code + * @incollection{pagh2015approximate, + * title={Approximate furthest neighbor in high dimensions}, + * author={Pagh, R. and Silvestri, F. and Sivertsen, J. and Skala, M.}, + * booktitle={Similarity Search and Applications}, + * pages={3--14}, + * year={2015}, + * publisher={Springer} + * } + * @endcode + */ +#ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_HPP +#define MLPACK_METHODS_APPROX_KFN_QDAFN_HPP + +#include + +namespace mlpack { +namespace neighbor { + +template +class QDAFN +{ + public: + /** + * Construct the QDAFN object but do not train it. Be sure to call Train() + * before calling Search(). + * + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + QDAFN(const size_t l, const size_t m); + + /** + * Construct the QDAFN object with the given reference set (this is the set + * that will be searched). + * + * @param referenceSet Set of reference data. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + QDAFN(const MatType& referenceSet, + const size_t l, + const size_t m); + + /** + * Train the QDAFN model on the given reference set, optionally setting new + * parameters for the number of projections/tables (l) and the number of + * elements stored for each projection/table (m). + * + * @param referenceSet Reference set to train on. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + void Train(const MatType& referenceSet, + const size_t l = 0, + const size_t m = 0); + + /** + * Search for the k furthest neighbors of the given query set. (The query set + * can contain just one point, that is okay.) The results will be stored in + * the given neighbors and distances matrices, in the same format as the + * mlpack NeighborSearch and LSHSearch classes. + */ + void Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances); + + //! Serialize the model. + template + void Serialize(Archive& ar, const unsigned int /* version */); + + //! Get the number of projections. + size_t NumProjections() const { return candidateSet.size(); } + + //! Get the candidate set for the given projection table. + const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; } + //! Modify the candidate set for the given projection table. Careful! + MatType& CandidateSet(const size_t t) { return candidateSet[t]; } + + private: + //! The number of projections. + size_t l; + //! The number of elements to store for each projection. + size_t m; + //! The random lines we are projecting onto. Has l columns. + arma::mat lines; + //! Projections of each point onto each random line. + arma::mat projections; + + //! Indices of the points for each S. + arma::Mat sIndices; + //! Values of a_i * x for each point in S. + arma::mat sValues; + + // Candidate sets; one element in the vector for each table. + std::vector candidateSet; +}; + +} // namespace neighbor +} // namespace mlpack + +// Include implementation. +#include "qdafn_impl.hpp" + +#endif diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp new file mode 100644 index 00000000000..8d64f9578bb --- /dev/null +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -0,0 +1,185 @@ +/** + * @file qdafn_impl.hpp + * @author Ryan Curtin + * + * Implementation of QDAFN class methods. + */ +#ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP +#define MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP + +// In case it hasn't been included yet. +#include "qdafn.hpp" + +#include +#include + +namespace mlpack { +namespace neighbor { + +// Non-training constructor. +template +QDAFN::QDAFN(const size_t l, const size_t m) : l(l), m(m) +{ + if (l > 0) + throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!"); + if (m > 0) + throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!"); +} + +// Constructor. +template +QDAFN::QDAFN(const MatType& referenceSet, + const size_t l, + const size_t m) : + l(l), + m(m) +{ + if (l > 0) + throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!"); + if (m > 0) + throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!"); + + Train(referenceSet); +} + +// Train the object. +template +void QDAFN::Train(const MatType& referenceSet, + const size_t lIn, + const size_t mIn) +{ + if (lIn != 0) + l = lIn; + if (mIn != 0) + m = mIn; + + // Build tables. This is done by drawing random points from a Gaussian + // distribution as the vectors we project onto. The Gaussian should have zero + // mean and unit variance. + mlpack::distribution::GaussianDistribution gd(referenceSet.n_rows); + lines.set_size(referenceSet.n_rows, l); + for (size_t i = 0; i < l; ++i) + lines.col(i) = gd.Random(); + + // Now, project each of the reference points onto each line, and collect the + // top m elements. + projections = referenceSet.t() * lines; + + // Loop over each projection and find the top m elements. + sIndices.set_size(m, l); + sValues.set_size(m, l); + candidateSet.resize(l); + for (size_t i = 0; i < l; ++i) + { + candidateSet[i].set_size(referenceSet.n_rows, m); + arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend"); + + // Grab the top m elements. + for (size_t j = 0; j < m; ++j) + { + sIndices(j, i) = sortedIndices[j]; + sValues(j, i) = projections(sortedIndices[j], i); + candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]); + } + } +} + +// Search. +template +void QDAFN::Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances) +{ + if (k > m) + throw std::invalid_argument("QDAFN::Search(): requested k is greater than " + "value of m!"); + + neighbors.set_size(k, querySet.n_cols); + neighbors.fill(size_t() - 1); + distances.zeros(k, querySet.n_cols); + + // Search for each point. + for (size_t q = 0; q < querySet.n_cols; ++q) + { + // Initialize a priority queue. + // The size_t represents the index of the table, and the double represents + // the value of l_i * S_i - l_i * query (see line 6 of Algorithm 1). + std::priority_queue> queue; + for (size_t i = 0; i < l; ++i) + { + const double val = sValues(0, i) - arma::dot(querySet.col(q), + lines.col(i)); + queue.push(std::make_pair(val, i)); + } + + // To track where we are in each S table, we keep the next index to look at + // in each table (they start at 0). + arma::Col tableLocations = arma::zeros>(l); + + // Now that the queue is initialized, iterate over m elements. + std::vector> v(k, std::make_pair(-1.0, + size_t(-1))); + std::priority_queue> + resultsQueue(std::less>(), std::move(v)); + for (size_t i = 0; i < m; ++i) + { + std::pair p = queue.top(); + queue.pop(); + + // Get index of reference point to look at. + const size_t tableIndex = tableLocations[p.second]; + + // Calculate distance from query point. + const double dist = mlpack::metric::EuclideanDistance::Evaluate( + querySet.col(q), candidateSet[p.second].col(tableIndex)); + + // Is this neighbor good enough to insert into the results? + if (dist > resultsQueue.top().first) + { + resultsQueue.pop(); + resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second))); + } + + // Now (line 14) get the next element and insert into the queue. Do this + // by adjusting the previous value. Don't insert anything if we are at + // the end of the search, though. + if (i < m - 1) + { + tableLocations[p.second]++; + const double val = p.first - sValues(tableIndex, p.second) + + sValues(tableIndex + 1, p.second); + + queue.push(std::make_pair(val, p.second)); + } + } + + // Extract the results. + for (size_t j = 1; j <= k; ++j) + { + neighbors(k - j, q) = resultsQueue.top().second; + distances(k - j, q) = resultsQueue.top().first; + resultsQueue.pop(); + } + } +} + +template +template +void QDAFN::Serialize(Archive& ar, const unsigned int /* version */) +{ + using data::CreateNVP; + + ar & CreateNVP(l, "l"); + ar & CreateNVP(m, "m"); + ar & CreateNVP(lines, "lines"); + ar & CreateNVP(projections, "projections"); + ar & CreateNVP(sIndices, "sIndices"); + ar & CreateNVP(sValues, "sValues"); + ar & CreateNVP(candidateSet, "candidateSet"); +} + +} // namespace neighbor +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/approx_kfn/qdafn_main.cpp b/src/mlpack/methods/approx_kfn/qdafn_main.cpp new file mode 100644 index 00000000000..0e866473e24 --- /dev/null +++ b/src/mlpack/methods/approx_kfn/qdafn_main.cpp @@ -0,0 +1,98 @@ +/** + * @file qdafn_main.cpp + * @author Ryan Curtin + * + * Command-line program for the QDAFN algorithm. + */ +#include +#include "qdafn.hpp" +#include + +using namespace qdafn; +using namespace mlpack; +using namespace std; + +PROGRAM_INFO("Query-dependent approximate furthest neighbor search", + "This program implements the algorithm from the SISAP 2015 paper titled " + "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. " + "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to " + "search in) with --reference_file, specify a query set (set to search for) " + "with --query_file, and specify algorithm parameters with --num_tables and " + "--num_projections (or don't, and defaults will be used). Also specify " + "the number of points to search for with --k. Each of those options has " + "short names too; see the detailed parameter documentation below." + "\n\n" + "Results for each query point are stored in the files specified by " + "--neighbors_file and --distances_file. This is in the same format as the " + "mlpack KFN and KNN programs: each row holds the k distances or neighbor " + "indices for each query point."); + +PARAM_STRING_REQ("reference_file", "File containing reference points.", "r"); +PARAM_STRING_REQ("query_file", "File containing query points.", "q"); + +PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k"); + +PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10); +PARAM_INT("num_projections", "Number of projections to use in each hash table.", + "p", 30); + +PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.", + "n", ""); +PARAM_STRING("distances_file", "File to save furthest neighbor distances to.", + "d", ""); + +PARAM_FLAG("calculate_error", "If set, calculate the average distance error.", + "e"); + +int main(int argc, char** argv) +{ + CLI::ParseCommandLine(argc, argv); + + const string referenceFile = CLI::GetParam("reference_file"); + const string queryFile = CLI::GetParam("query_file"); + const size_t k = (size_t) CLI::GetParam("k"); + const size_t numTables = (size_t) CLI::GetParam("num_tables"); + const size_t numProjections = (size_t) CLI::GetParam("num_projections"); + + // Load the data. + arma::mat referenceData, queryData; + data::Load(referenceFile, referenceData, true); + data::Load(queryFile, queryData, true); + + // Construct the object. + Timer::Start("qdafn_construct"); + QDAFN<> q(referenceData, numTables, numProjections); + Timer::Stop("qdafn_construct"); + + // Do the search. + arma::Mat neighbors; + arma::mat distances; + Timer::Start("qdafn_search"); + q.Search(queryData, k, neighbors, distances); + Timer::Stop("qdafn_search"); + + // Print the number of base cases. + Log::Info << "Total distance evaluations: " << + (queryData.n_cols * numProjections) << "." << endl; + + if (CLI::HasParam("calculate_error")) + { + neighbor::AllkFN kfn(referenceData); + + arma::Mat trueNeighbors; + arma::mat trueDistances; + + kfn.Search(queryData, 1, trueNeighbors, trueDistances); + + const double averageError = arma::sum(trueDistances / distances.row(0)) / + distances.n_cols; + + Log::Info << "Average error: " << averageError << "." << endl; + } + + // Save the results. + if (CLI::HasParam("neighbors_file")) + data::Save(CLI::GetParam("neighbors_file"), neighbors); + if (CLI::HasParam("distances_file")) + data::Save(CLI::GetParam("distances_file"), distances); +} diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 9ad40927c8f..9f6965b614d 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable(mlpack_test decision_stump_test.cpp det_test.cpp distribution_test.cpp + drusilla_select_test.cpp emst_test.cpp fastmks_test.cpp feedforward_network_test.cpp @@ -59,6 +60,7 @@ add_executable(mlpack_test nystroem_method_test.cpp pca_test.cpp perceptron_test.cpp + qdafn_test.cpp quic_svd_test.cpp radical_test.cpp randomized_svd_test.cpp diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp new file mode 100644 index 00000000000..cce2704a15d --- /dev/null +++ b/src/mlpack/tests/drusilla_select_test.cpp @@ -0,0 +1,165 @@ +/** + * @file drusilla_select_test.cpp + * @author Ryan Curtin + * + * Test for DrusillaSelect. + */ +#include +#include + +#include +#include "test_tools.hpp" +#include "serialization.hpp" + +using namespace mlpack; +using namespace mlpack::neighbor; + +BOOST_AUTO_TEST_SUITE(DrusillaSelectTest); + +// If we have a dataset with an extreme outlier, then every point (except that +// one) should end up with that point as the furthest neighbor candidate. +BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest) +{ + arma::mat dataset = arma::randu(5, 100); + dataset.col(99) += 100; // Make last column very large. + + // Construct with some reasonable parameters. + DrusillaSelect<> ds(dataset, 5, 5); + + // Query with every point except the extreme point. + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset.cols(0, 98), 1, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 99); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); + + for (size_t i = 0; i < 99; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], 99); + } +} + +// If we use only one projection with the number of points equal to what is in +// the dataset, we should end up with the exact result. +BOOST_AUTO_TEST_CASE(DrusillaSelectExhaustiveExactTest) +{ + arma::mat dataset = arma::randu(5, 100); + + // Construct with one projection and 100 points in that projection. + DrusillaSelect<> ds(dataset, 100, 1); + + arma::mat distances, distancesTrue; + arma::Mat neighbors, neighborsTrue; + + ds.Search(dataset, 5, neighbors, distances); + + AllkFN kfn(dataset); + kfn.Search(dataset, 5, neighborsTrue, distancesTrue); + + BOOST_REQUIRE_EQUAL(neighborsTrue.n_cols, neighbors.n_cols); + BOOST_REQUIRE_EQUAL(neighborsTrue.n_rows, neighbors.n_rows); + BOOST_REQUIRE_EQUAL(distancesTrue.n_cols, distances.n_cols); + BOOST_REQUIRE_EQUAL(distancesTrue.n_rows, distances.n_rows); + + for (size_t i = 0; i < distances.n_elem; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsTrue[i]); + BOOST_REQUIRE_CLOSE(distances[i], distancesTrue[i], 1e-5); + } +} + +// Test that we can call Train() after calling the constructor. +BOOST_AUTO_TEST_CASE(RetrainTest) +{ + arma::mat firstDataset = arma::randu(3, 10); + arma::mat dataset = arma::randu(3, 200); + + DrusillaSelect<> ds(firstDataset, 3, 3); + ds.Train(std::move(dataset), 2, 2); + + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset, 1, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 200); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); +} + +// Test serialization. +BOOST_AUTO_TEST_CASE(SerializationTest) +{ + // Create a random dataset. + arma::mat dataset = arma::randu(3, 100); + + DrusillaSelect<> ds(dataset, 3, 3); + + arma::mat fakeDataset1 = arma::randu(2, 15); + arma::mat fakeDataset2 = arma::randu(10, 18); + DrusillaSelect<> dsXml(fakeDataset1, 5, 3); + DrusillaSelect<> dsText(2, 2); + DrusillaSelect<> dsBinary(5, 2); + dsBinary.Train(fakeDataset2); + + // Now do the serialization. + SerializeObjectAll(ds, dsXml, dsText, dsBinary); + + // Now do a search and make sure all the results are the same. + arma::Mat neighbors, neighborsXml, neighborsText, neighborsBinary; + arma::mat distances, distancesXml, distancesText, distancesBinary; + + ds.Search(dataset, 3, neighbors, distances); + dsXml.Search(dataset, 3, neighborsXml, distancesXml); + dsText.Search(dataset, 3, neighborsText, distancesText); + dsBinary.Search(dataset, 3, neighborsBinary, distancesBinary); + + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsXml.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsXml.n_cols); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsText.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsText.n_cols); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsBinary.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsBinary.n_cols); + + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesXml.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesXml.n_cols); + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesText.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesText.n_cols); + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesBinary.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesBinary.n_cols); + + for (size_t i = 0; i < neighbors.n_elem; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsXml[i]); + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsText[i]); + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsBinary[i]); + + BOOST_REQUIRE_CLOSE(distances[i], distancesXml[i], 1e-5); + BOOST_REQUIRE_CLOSE(distances[i], distancesText[i], 1e-5); + BOOST_REQUIRE_CLOSE(distances[i], distancesBinary[i], 1e-5); + } +} + +// Make sure we can create the object with a sparse matrix. +BOOST_AUTO_TEST_CASE(SparseTest) +{ + arma::sp_mat dataset; + dataset.sprandu(50, 1000, 0.3); + + DrusillaSelect ds(dataset, 5, 10); + + // Run a search. + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset, 3, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3); + BOOST_REQUIRE_EQUAL(distances.n_cols, 1000); + BOOST_REQUIRE_EQUAL(distances.n_rows, 3); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp new file mode 100644 index 00000000000..332b7c7e81b --- /dev/null +++ b/src/mlpack/tests/qdafn_test.cpp @@ -0,0 +1,206 @@ +/** + * @file qdafn_test.cpp + * @author Ryan Curtin + * + * Test the QDAFN functionality. + */ +#include +#include "test_tools.hpp" +#include "serialization.hpp" + +#include +#include +#include + +using namespace std; +using namespace arma; +using namespace mlpack; +using namespace mlpack::neighbor; + +BOOST_AUTO_TEST_SUITE(QDAFNTest); + +/** + * With one reference point, make sure that is the one that is returned. + */ +BOOST_AUTO_TEST_CASE(QDAFNTrivialTest) +{ + arma::mat refSet(5, 1); + refSet.randu(); + + // 5 tables, 1 point. + QDAFN<> qdafn(refSet, 5, 1); + + arma::mat querySet(5, 5); + querySet.randu(); + + arma::Mat neighbors; + arma::mat distances; + qdafn.Search(querySet, 1, neighbors, distances); + + // Check sizes. + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 5); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 5); + + for (size_t i = 0; i < 5; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], 0); + const double dist = metric::EuclideanDistance::Evaluate(querySet.col(i), + refSet.col(0)); + BOOST_REQUIRE_CLOSE(distances[i], dist, 1e-5); + } +} + +/** + * Given a random uniform reference set, ensure that we get a neighbor and + * distance within 10% of the actual true furthest neighbor distance at least + * 70% of the time. + */ +BOOST_AUTO_TEST_CASE(QDAFNUniformSet) +{ + arma::mat uniformSet = arma::randu(25, 1000); + + QDAFN<> qdafn(uniformSet, 10, 30); + + // Get the actual neighbors. + AllkFN kfn(uniformSet); + arma::Mat trueNeighbors; + arma::mat trueDistances; + + kfn.Search(1000, trueNeighbors, trueDistances); + + arma::Mat qdafnNeighbors; + arma::mat qdafnDistances; + + qdafn.Search(uniformSet, 1, qdafnNeighbors, qdafnDistances); + + BOOST_REQUIRE_EQUAL(qdafnNeighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(qdafnNeighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(qdafnDistances.n_rows, 1); + BOOST_REQUIRE_EQUAL(qdafnDistances.n_cols, 1000); + + size_t successes = 0; + for (size_t i = 0; i < 1000; ++i) + { + // Find the true neighbor. + size_t trueIndex = 1000; + for (size_t j = 0; j < 1000; ++j) + { + if (trueNeighbors(j, i) == qdafnNeighbors(0, i)) + { + trueIndex = j; + break; + } + } + + BOOST_REQUIRE_NE(trueIndex, 1000); + if (0.9 * trueDistances(0, i) <= qdafnDistances(0, i)) + ++successes; + } + + BOOST_REQUIRE_GE(successes, 700); +} + +/** + * Test re-training method. + */ +BOOST_AUTO_TEST_CASE(RetrainTest) +{ + arma::mat dataset = arma::randu(25, 500); + arma::mat newDataset = arma::randu(15, 600); + + QDAFN<> qdafn(dataset, 20, 60); + + qdafn.Train(newDataset, 10, 50); + + BOOST_REQUIRE_EQUAL(qdafn.NumProjections(), 10); + for (size_t i = 0; i < 10; ++i) + { + BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_rows, 15); + BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_cols, 50); + } +} + +/** + * Test serialization of QDAFN. + */ +BOOST_AUTO_TEST_CASE(SerializationTest) +{ + // Use a random dataset. + arma::mat dataset = arma::randu(15, 300); + + QDAFN<> qdafn(dataset, 10, 50); + + arma::mat fakeDataset1 = arma::randu(10, 200); + arma::mat fakeDataset2 = arma::randu(50, 500); + QDAFN<> qdafnXml(fakeDataset1, 5, 10); + QDAFN<> qdafnText(6, 50); + QDAFN<> qdafnBinary(7, 15); + qdafnBinary.Train(fakeDataset2); + + // Serialize the objects. + SerializeObjectAll(qdafn, qdafnXml, qdafnText, qdafnBinary); + + // Check that the tables are all the same. + BOOST_REQUIRE_EQUAL(qdafnXml.NumProjections(), qdafn.NumProjections()); + BOOST_REQUIRE_EQUAL(qdafnText.NumProjections(), qdafn.NumProjections()); + BOOST_REQUIRE_EQUAL(qdafnBinary.NumProjections(), qdafn.NumProjections()); + + for (size_t i = 0; i < qdafn.NumProjections(); ++i) + { + BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + + BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + + for (size_t j = 0; j < qdafn.CandidateSet(i).n_elem; ++j) + { + if (std::abs(qdafn.CandidateSet(i)[j]) < 1e-5) + { + BOOST_REQUIRE_SMALL(qdafnXml.CandidateSet(i)[j], 1e-5); + BOOST_REQUIRE_SMALL(qdafnText.CandidateSet(i)[j], 1e-5); + BOOST_REQUIRE_SMALL(qdafnBinary.CandidateSet(i)[j], 1e-5); + } + else + { + const double value = qdafn.CandidateSet(i)[j]; + BOOST_REQUIRE_CLOSE(qdafnXml.CandidateSet(i)[j], value, 1e-5); + BOOST_REQUIRE_CLOSE(qdafnText.CandidateSet(i)[j], value, 1e-5); + BOOST_REQUIRE_CLOSE(qdafnBinary.CandidateSet(i)[j], value, 1e-5); + } + } + } +} + +// Make sure QDAFN works with sparse data. +BOOST_AUTO_TEST_CASE(SparseTest) +{ + arma::sp_mat dataset; + dataset.sprandu(200, 1000, 0.3); + + // Create a sparse version. + QDAFN sparse(dataset, 15, 50); + + // Make sure the results are of the right shape. It's hard to test anything + // more than that because we don't have easy-to-check performance guarantees. + arma::Mat neighbors; + arma::mat distances; + sparse.Search(dataset, 3, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(distances.n_rows, 3); + BOOST_REQUIRE_EQUAL(distances.n_cols, 1000); +} + +BOOST_AUTO_TEST_SUITE_END();