From e3b7b4e31a16ec87bb87ffc8d1dc0b948ea35599 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 08:08:29 -0700 Subject: [PATCH 01/30] Add README. --- README.txt | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 README.txt diff --git a/README.txt b/README.txt new file mode 100644 index 00000000000..eb1bbf8ab2b --- /dev/null +++ b/README.txt @@ -0,0 +1,16 @@ +This repository contains an implementation of the hashing algorithm for +approximate furthest neighbor search detailed in the paper + +"Approximate Furthest Neighbor in High Dimensions" +by Rasmus Pagh, Francesco Silverstri, Johan Siversten, and Matthew Skala +presented at SISAP 2015. + +There is another implementation available here: +https://github.com/johanvts/FN-Implementations + +but I wanted to re-implement this to ensure that I understood it correctly, and +so that I could get a better comparison. + +This code is built using mlpack and Armadillo, so when you configure with CMake +you may have to specify the installation directory of mlpack and Armadillo, if +they are not already installed on the system. From b18b24ea628dc0193d291992601d3466cfd9dec9 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 08:16:22 -0700 Subject: [PATCH 02/30] Add basic CMake configuration. --- CMakeLists.txt | 252 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000000..289630c1974 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,252 @@ +# Much of this is borrowed from mlpack's CMakeLists.txt. +cmake_minimum_required(VERSION 2.8.5) +project(qdafn C CXX) + +# Ensure that we have a C++11 compiler. +include(CMake/CXX11.cmake) +check_for_cxx11_compiler(HAS_CXX11) +if(NOT HAS_CXX11) + message(FATAL_ERROR "No C++11 compiler available!") +endif() +enable_cxx11() + +# Define compilation options. +option(DEBUG "Compile with debugging information" ON) +option(PROFILE "Compile with profiling information" ON) + +# Set the CFLAGS and CXXFLAGS depending on the options the user specified. +# Only GCC-like compilers support -Wextra, and other compilers give tons of +# output for -Wall, so only -Wall and -Wextra on GCC. +if(CMAKE_COMPILER_IS_GNUCC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -ftemplate-depth=1000") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra") +endif() + +# If using clang, we have to link against libc++ depending on the +# OS (at least on some systems). Further, gcc sometimes optimizes calls to +# math.h functions, making -lm unnecessary with gcc, but it may still be +# necessary with clang. +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + if (APPLE) + # detect OS X version. Use '/usr/bin/sw_vers -productVersion' to + # extract V from '10.V.x'.) + exec_program(/usr/bin/sw_vers ARGS + -productVersion OUTPUT_VARIABLE MACOSX_VERSION_RAW) + string(REGEX REPLACE + "10\\.([0-9]+).*" "\\1" + MACOSX_VERSION + "${MACOSX_VERSION_RAW}") + + # OSX Lion (10.7) and OS X Mountain Lion (10.8) doesn't automatically + # select the right stdlib. + if(${MACOSX_VERSION} LESS 9) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} +-stdlib=libc++") + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} +-stdlib=libc++") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") + endif() + endif() + + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -lm") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -lm") + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -lm") +endif() + +# Debugging CFLAGS. Turn optimizations off; turn debugging symbols on. +if(DEBUG) + add_definitions(-DDEBUG) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -ftemplate-backtrace-limit=0") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99 -g -O0") + # mlpack uses it's own mlpack::backtrace class based on Binary File Descriptor + # and linux Dynamic Loader and more portable version in + # future + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + find_package(Bfd) + find_package(LibDL) + if(LIBBFD_FOUND AND LIBDL_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic") + include_directories(${LIBBFD_INCLUDE_DIRS}) + include_directories(${LIBDL_INCLUDE_DIRS}) + add_definitions(-DHAS_BFD_DL) + else() + message(WARNING "No libBFD and/or libDL has been found!") + endif() + endif() +else() + add_definitions(-DARMA_NO_DEBUG) + add_definitions(-DNDEBUG) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99 -O3") +endif() + +# Profiling CFLAGS. Turn profiling information on. +if(CMAKE_COMPILER_IS_GNUCC AND PROFILE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pg") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pg") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pg") +endif() + +# Find dependencies. This isn't very robust. +find_package(Armadillo 4.100.0 REQUIRED) + +# If Armadillo was compiled without ARMA_64BIT_WORD and we are on a 64-bit +# system (where size_t will be 64 bits), suggest to the user that they should +# compile Armadillo with 64-bit words. Note that with Armadillo 5.000.0 and +# newer, ARMA_64BIT_WORD is enabled by default. +if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # Check the version, to see if ARMA_64BIT_WORD is enabled by default. + set(ARMA_HAS_64BIT_WORD 0) + if(NOT (${ARMADILLO_VERSION_MAJOR} LESS 5)) + set(ARMA_HAS_64BIT_WORD 1) + else() + # Can we open the configuration file? If not, issue a warning. + if(NOT EXISTS "${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp") + message(WARNING "Armadillo configuration file " + "(${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp) does not +exist!") + else() + # We are on a 64-bit system. Does Armadillo have ARMA_64BIT_WORD enabled? + file(READ "${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp" +ARMA_CONFIG) + string(REGEX MATCH + "[\r\n][ ]*#define ARMA_64BIT_WORD" + ARMA_HAS_64BIT_WORD_PRE + "${ARMA_CONFIG}") + + string(LENGTH "${ARMA_HAS_64BIT_WORD_PRE}" ARMA_HAS_64BIT_WORD) + endif() + endif() + + if(ARMA_HAS_64BIT_WORD EQUAL 0) + message(WARNING "This is a 64-bit system, but Armadillo was compiled " + "without 64-bit index support. Consider recompiling Armadillo with " + "ARMA_64BIT_WORD to enable 64-bit indices (large matrix support). " + "mlpack will still work without ARMA_64BIT_WORD defined, but will not " + "scale to matrices with more than 4 billion elements.") + endif() +else() + # If we are on a 32-bit system, we must manually specify the size of the word + # to be 32 bits, since otherwise Armadillo will produce a warning that it is + # disabling 64-bit support. + if (CMAKE_SIZEOF_VOID_P EQUAL 4) + add_definitions(-DARMA_32BIT_WORD) + endif () +endif() + + +# On Windows, Armadillo should be using LAPACK and BLAS but we still need to +# link against it. We don't want to use the FindLAPACK or FindBLAS modules +# because then we are required to have a FORTRAN compiler (argh!) so we will try +# and find LAPACK and BLAS ourselves, using a slightly modified variant of the +# script Armadillo uses to find these. +if (WIN32) + find_library(LAPACK_LIBRARY + NAMES lapack liblapack lapack_win32_MT lapack_win32 + PATHS "C:/Program Files/Armadillo" + PATH_SUFFIXES "examples/lib_win32/") + + if (NOT LAPACK_LIBRARY) + message(FATAL_ERROR "Cannot find LAPACK library (.lib)!") + endif () + + find_library(BLAS_LIBRARY + NAMES blas libblas blas_win32_MT blas_win32 + PATHS "C:/Program Files/Armadillo" + PATH_SUFFIXES "examples/lib_win32/") + + if (NOT BLAS_LIBRARY) + message(FATAL_ERROR "Cannot find BLAS library (.lib)!") + endif () + + # Piggyback LAPACK and BLAS linking into Armadillo link. + set(ARMADILLO_LIBRARIES + ${ARMADILLO_LIBRARIES} ${BLAS_LIBRARY} ${LAPACK_LIBRARY}) +endif () + +# Include directories for the previous dependencies. +include_directories(${ARMADILLO_INCLUDE_DIRS}) + +# Unfortunately this configuration variable is necessary and will need to be +# updated as time goes on and new versions are released. +set(Boost_ADDITIONAL_VERSIONS + "1.49.0" "1.50.0" "1.51.0" "1.52.0" "1.53.0" "1.54.0" "1.55.0") +find_package(Boost 1.49 + COMPONENTS + program_options + unit_test_framework + serialization + REQUIRED +) +include_directories(${Boost_INCLUDE_DIRS}) + +link_directories(${Boost_LIBRARY_DIRS}) + +# In Visual Studio, automatic linking is performed, so we don't need to worry +# about it. Clear the list of libraries to link against and let Visual Studio +# handle it. +if (MSVC) + link_directories(${Boost_LIBRARY_DIRS}) + set(Boost_LIBRARIES "") +endif () + +# For Boost testing framework (will have no effect on non-testing executables). +# This specifies to Boost that we are dynamically linking to the Boost test +# library. +add_definitions(-DBOOST_TEST_DYN_LINK) + +# On Windows, things end up under Debug/ or Release/. +if (WIN32) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +else () + # If not on Windows, put them under more standard UNIX-like places. This is + # necessary, otherwise they would all end up in + # ${CMAKE_BINARY_DIR}/src/mlpack/methods/... or somewhere else random like + # that. + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/) +endif () + +# Find the mlpack library and include directory. +find_library(MLPACK_LIBRARY + NAMES mlpack + PATHS /usr/lib64 /usr/lib /usr/local/lib64 /usr/local/lib +) + +find_path(MLPACK_INCLUDE_DIR mlpack/core.hpp + /usr/include/ + /usr/local/include/ +) + +if (MLPACK_LIBRARY and MLPACK_INCLUDE_DIR) + mark_as_advanced(MLPACK_LIBRARY MLPACK_INCLUDE_DIR) + include_directories(${MLPACK_INCLUDE_DIR}) +else () + message(FATAL_ERROR "Could not find mlpack; try specifying MLPACK_LIBRARY and" + " MLPACK_INCLUDE_DIR") +endif () + +# Finally! Definitions of the files we are building. +add_executable(qdafn + qdafn_main.cpp + qdafn.hpp + qdafn_impl.hpp +) +target_link_libraries(qdafn + ${MLPACK_LIBRARY} + ${Boost_LIBRARIES} + ${ARMADILLO_LIBRARIES} +) + +add_executable(qdafn_test + qdafn_test.cpp +) +target_link_libraries(qdafn_test + ${MLPACK_LIBRARY} + ${Boost_LIBRARIES} + ${ARMADILLO_LIBRARIES} +) From 06b46e87e22b2fc94ec8e1c1d64d742e4f35236a Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 12:47:14 -0700 Subject: [PATCH 03/30] Add implementation, not yet tested. --- qdafn.hpp | 83 ++++++++++++++++++++++++++++ qdafn_impl.hpp | 147 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 qdafn.hpp create mode 100644 qdafn_impl.hpp diff --git a/qdafn.hpp b/qdafn.hpp new file mode 100644 index 00000000000..557b421eeb4 --- /dev/null +++ b/qdafn.hpp @@ -0,0 +1,83 @@ +/** + * @file qdafn.hpp + * @author Ryan Curtin + * + * An implementation of the query-dependent approximate furthest neighbor + * algorithm specified in the following paper: + * + * @code + * @incollection{pagh2015approximate, + * title={Approximate furthest neighbor in high dimensions}, + * author={Pagh, R. and Silvestri, F. and Sivertsen, J. and Skala, M.}, + * booktitle={Similarity Search and Applications}, + * pages={3--14}, + * year={2015}, + * publisher={Springer} + * } + * @endcode + */ +#ifndef QDAFN_HPP +#define QDAFN_HPP + +#include + +namespace qdafn { + +template +class QDAFN +{ + public: + /** + * Construct the QDAFN object with the given reference set (this is the set + * that will be searched). + * + * @param referenceSet Set of reference data. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + QDAFN(const MatType& referenceSet, + const size_t l, + const size_t m); + + /** + * Search for the k furthest neighbors of the given query set. (The query set + * can contain just one point, that is okay.) The results will be stored in + * the given neighbors and distances matrices, in the same format as the + * mlpack NeighborSearch and LSHSearch classes. + */ + void Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances); + + private: + //! The reference set. + const MatType& referenceSet; + + //! The number of projections. + const size_t l; + //! The number of elements to store for each projection. + const size_t m; + //! The random lines we are projecting onto. Has l columns. + arma::mat lines; + + //! Indices of the points for each S. + arma::Mat sIndices; + //! Values of a_i * x for each point in S. + arma::mat sValues; + + //! Insert a neighbor into a set of results for a given query point. + void InsertNeighbor(arma::mat& distances, + arma::Mat& neighbors, + const size_t queryIndex, + const size_t pos, + const size_t neighbor, + const double distance) const; +}; + +} // namespace qdafn + +// Include implementation. +#include "qdafn_impl.hpp" + +#endif diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp new file mode 100644 index 00000000000..1b8cfaa41f7 --- /dev/null +++ b/qdafn_impl.hpp @@ -0,0 +1,147 @@ +/** + * @file qdafn_impl.hpp + * @author Ryan Curtin + * + * Implementation of QDAFN class methods. + */ +#ifndef QDAFN_IMPL_HPP +#define QDAFN_IMPL_HPP + +// In case it hasn't been included yet. +#include "qdafn.hpp" + +namespace qdafn { + +// Constructor. +template +QDAFN::QDAFN(const MatType& referenceSet, + const size_t l, + const size_t m) : + referenceSet(referenceSet), + l(l), + m(m) +{ + // Build tables. This is done by drawing random points from a Gaussian + // distribution as the vectors we project onto. The Gaussian should have zero + // mean and unit variance. + mlpack::distribution::GaussianDistribution gd(referenceSet.n_rows); + lines.set_size(referenceSet.n_rows, l); + for (size_t i = 0; i < l; ++i) + lines.col(i) = gd.Random(); + + // Now, project each of the reference points onto each line, and collect the + // top m elements. + arma::mat projections = lines.t() * referenceSet; + + // Loop over each projection and find the top m elements. + sIndices.set_size(m, l); + sValues.set_size(m, l); + for (size_t i = 0; i < l; ++i) + { + arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend"); + + // Grab the top m elements. + for (size_t j = 0; j < m; ++j) + { + sIndices[j] = sortedIndices[j]; + sValues[j] = projections(sortedIndices[j], i); + } + } +} + +// Search. +template +void QDAFN::Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances) +{ + if (k > m) + throw std::invalid_argument("QDAFN::Search(): requested k is greater than " + "value of m!"); + + neighbors.set_size(k, querySet.n_cols); + distances.zeros(k, querySet.n_cols); + + // Search for each point. + for (size_t q = 0; q < querySet.n_cols; ++q) + { + // Initialize a priority queue. + // The size_t represents the index of the table, and the double represents + // the value of l_i * S_i - l_i * query (see line 6 of Algorithm 1). + std::priority_queue> queue; + for (size_t i = 0; i < l; ++i) + { + const double val = projections(0, i) - arma::dot(querySet.col(q), + lines.col(i)); + queue.push(std::make_pair(val, i)); + } + + // To track where we are in each S table, we keep the next index to look at + // in each table (they start at 0). + arma::Col tableLocations = arma::zeros>(l); + + // Now that the queue is initialized, iterate over m elements. + for (size_t i = 0; i < m; ++i) + { + std::pair p = queue.top(); + queue.pop(); + + // Get index of reference point to look at. + size_t referenceIndex = sIndices(tableLocations[p.second], p.second); + + // Calculate distance from query point. + const double dist = mlpack::metric::EuclideanDistance::Evaluate( + querySet.col(q), referenceSet.col(referenceIndex)); + + // Is this neighbor good enough to insert into the results? + arma::vec queryDist = distances.unsafe_col(q); + arma::Col queryIndices = neighbors.unsafe_col(q); + const size_t insertPosition = + mlpack::neighbor::FurthestNeighborSort::SortDistance(queryDist, + queryIndices, dist); + + // SortDistance() returns (size_t() - 1) if we shouldn't add it. + if (insertPosition != (size_t() - 1)) + InsertNeighbor(distances, neighbors, q, referenceIndex, dist); + + // Now (line 14) get the next element and insert into the queue. Do this + // by adjusting the previous value. + tableLocations[p.second]++; + const double val = p.first - + projections(tableLocations[p.second] - 1, p.second) + + projections(tableLocations[p.second], p.second); + + queue.push(std::make_pair(val, p.second)); + } + } +} + +template +void QDAFN::InsertNeighbor(arma::mat& distances, + arma::Mat& neighbors, + const size_t queryIndex, + const size_t pos, + const size_t neighbor, + const double distance) const +{ + // We only memmove() if there is actually a need to shift something. + if (pos < (distances.n_rows - 1)) + { + const size_t len = (distances.n_rows - 1) - pos; + memmove(distances.colptr(queryIndex) + (pos + 1), + distances.colptr(queryIndex) + pos, + sizeof(double) * len); + memmove(neighbors.colptr(queryIndex) + (pos + 1), + neighbors.colptr(queryIndex) + pos, + sizeof(size_t) * len); + } + + // Now put the new information in the right index. + distances(pos, queryIndex) = distance; + neighbors(pos, queryIndex) = neighbor; +} + +} // namespace qdafn + +#endif From d5c51fce4d33d00cbfa68798a6f13630ae02c9a5 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 12:51:54 -0700 Subject: [PATCH 04/30] Add utility script to check C++11 support. --- CXX11.cmake | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 CXX11.cmake diff --git a/CXX11.cmake b/CXX11.cmake new file mode 100644 index 00000000000..2dbfcc4b7bb --- /dev/null +++ b/CXX11.cmake @@ -0,0 +1,48 @@ +# This is cloned from +# https://github.com/nitroshare/CXX11-CMake-Macros +# until C++11 support finally hits CMake stable (should be 3.1, I think). + +# Copyright (c) 2013 Nathan Osman + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +# Determines whether or not the compiler supports C++11 +macro(check_for_cxx11_compiler _VAR) + message(STATUS "Checking for C++11 compiler") + set(${_VAR}) + if((MSVC AND (MSVC10 OR MSVC11 OR MSVC12 OR MSVC14)) OR + (CMAKE_COMPILER_IS_GNUCXX AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.6) OR + (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.1) OR + (CMAKE_CXX_COMPILER_ID STREQUAL "Intel" AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 12.0)) + set(${_VAR} 1) + message(STATUS "Checking for C++11 compiler - available") + else() + message(STATUS "Checking for C++11 compiler - unavailable") + endif() +endmacro() + +# Sets the appropriate flag to enable C++11 support +macro(enable_cxx11) + if(CMAKE_COMPILER_IS_GNUCXX OR + CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR + CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") + endif() +endmacro() + From 452943f059afbcaabe0c143676f35d91f54a7733 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 13:04:33 -0700 Subject: [PATCH 05/30] Fix some bugs with the trivial test. --- qdafn.hpp | 4 +++- qdafn_impl.hpp | 31 ++++++++++++++++++++----------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/qdafn.hpp b/qdafn.hpp index 557b421eeb4..860acb83562 100644 --- a/qdafn.hpp +++ b/qdafn.hpp @@ -23,7 +23,7 @@ namespace qdafn { -template +template class QDAFN { public: @@ -60,6 +60,8 @@ class QDAFN const size_t m; //! The random lines we are projecting onto. Has l columns. arma::mat lines; + //! Projections of each point onto each random line. + arma::mat projections; //! Indices of the points for each S. arma::Mat sIndices; diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp index 1b8cfaa41f7..368b84c3dc1 100644 --- a/qdafn_impl.hpp +++ b/qdafn_impl.hpp @@ -10,6 +10,9 @@ // In case it hasn't been included yet. #include "qdafn.hpp" +#include +#include + namespace qdafn { // Constructor. @@ -31,7 +34,7 @@ QDAFN::QDAFN(const MatType& referenceSet, // Now, project each of the reference points onto each line, and collect the // top m elements. - arma::mat projections = lines.t() * referenceSet; + projections = referenceSet.t() * lines; // Loop over each projection and find the top m elements. sIndices.set_size(m, l); @@ -43,8 +46,8 @@ QDAFN::QDAFN(const MatType& referenceSet, // Grab the top m elements. for (size_t j = 0; j < m; ++j) { - sIndices[j] = sortedIndices[j]; - sValues[j] = projections(sortedIndices[j], i); + sIndices(j, i) = sortedIndices[j]; + sValues(j, i) = projections(sortedIndices[j], i); } } } @@ -61,6 +64,7 @@ void QDAFN::Search(const MatType& querySet, "value of m!"); neighbors.set_size(k, querySet.n_cols); + neighbors.fill(size_t() - 1); distances.zeros(k, querySet.n_cols); // Search for each point. @@ -103,16 +107,21 @@ void QDAFN::Search(const MatType& querySet, // SortDistance() returns (size_t() - 1) if we shouldn't add it. if (insertPosition != (size_t() - 1)) - InsertNeighbor(distances, neighbors, q, referenceIndex, dist); + InsertNeighbor(distances, neighbors, q, insertPosition, referenceIndex, + dist); // Now (line 14) get the next element and insert into the queue. Do this - // by adjusting the previous value. - tableLocations[p.second]++; - const double val = p.first - - projections(tableLocations[p.second] - 1, p.second) + - projections(tableLocations[p.second], p.second); - - queue.push(std::make_pair(val, p.second)); + // by adjusting the previous value. Don't insert anything if we are at + // the end of the search, though. + if (i < m - 1) + { + tableLocations[p.second]++; + const double val = p.first - + projections(tableLocations[p.second] - 1, p.second) + + projections(tableLocations[p.second], p.second); + + queue.push(std::make_pair(val, p.second)); + } } } } From bfdfdb9b22c53fe2473d5851d3abbaab0bc46248 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 13:04:43 -0700 Subject: [PATCH 06/30] Minor changes for better configuration. --- CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 289630c1974..d01fca78d34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8.5) project(qdafn C CXX) # Ensure that we have a C++11 compiler. -include(CMake/CXX11.cmake) +include(CXX11.cmake) check_for_cxx11_compiler(HAS_CXX11) if(NOT HAS_CXX11) message(FATAL_ERROR "No C++11 compiler available!") @@ -212,6 +212,8 @@ else () endif () # Find the mlpack library and include directory. +set(MLPACK_LIBRARY "") +set(MLPACK_INCLUDE_DIR "") find_library(MLPACK_LIBRARY NAMES mlpack PATHS /usr/lib64 /usr/lib /usr/local/lib64 /usr/local/lib @@ -222,7 +224,7 @@ find_path(MLPACK_INCLUDE_DIR mlpack/core.hpp /usr/local/include/ ) -if (MLPACK_LIBRARY and MLPACK_INCLUDE_DIR) +if (MLPACK_LIBRARY AND MLPACK_INCLUDE_DIR) mark_as_advanced(MLPACK_LIBRARY MLPACK_INCLUDE_DIR) include_directories(${MLPACK_INCLUDE_DIR}) else () From 443028c17710e77caea5eeddbcdc2419791c8c45 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 13:04:53 -0700 Subject: [PATCH 07/30] First test case. --- qdafn_test.cpp | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 qdafn_test.cpp diff --git a/qdafn_test.cpp b/qdafn_test.cpp new file mode 100644 index 00000000000..83c1da43cb6 --- /dev/null +++ b/qdafn_test.cpp @@ -0,0 +1,50 @@ +/** + * @file qdafn_test.cpp + * @author Ryan Curtin + * + * Test the QDAFN functionality. + */ +#define BOOST_TEST_MODULE QDAFNTest + +#include + +#include +#include "qdafn.hpp" + +using namespace std; +using namespace arma; +using namespace mlpack; +using namespace qdafn; + +/** + * With one reference point, make sure that is the one that is returned. + */ +BOOST_AUTO_TEST_CASE(QDAFNTrivialTest) +{ + arma::mat refSet(5, 1); + refSet.randu(); + + // 5 tables, 1 point. + QDAFN<> qdafn(refSet, 5, 1); + + arma::mat querySet(5, 5); + querySet.randu(); + + arma::Mat neighbors; + arma::mat distances; + qdafn.Search(querySet, 1, neighbors, distances); + + // Check sizes. + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 5); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 5); + + for (size_t i = 0; i < 5; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], 0); + const double dist = metric::EuclideanDistance::Evaluate(querySet.col(i), + refSet.col(0)); + BOOST_REQUIRE_CLOSE(distances[i], dist, 1e-5); + } +} From 3f682d51724cf4eb01a06616c10feab79729898a Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Wed, 13 Apr 2016 13:05:00 -0700 Subject: [PATCH 08/30] Nothing here yet, but required for CMake to configure correctly... --- qdafn_main.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 qdafn_main.cpp diff --git a/qdafn_main.cpp b/qdafn_main.cpp new file mode 100644 index 00000000000..e69de29bb2d From 19c6965d55f76793702b22257b2db17b9c955e93 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 19 Apr 2016 13:18:28 -0700 Subject: [PATCH 09/30] A better test. I'm reasonably convinced this works right now. --- qdafn_test.cpp | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/qdafn_test.cpp b/qdafn_test.cpp index 83c1da43cb6..ee106102f22 100644 --- a/qdafn_test.cpp +++ b/qdafn_test.cpp @@ -10,11 +10,13 @@ #include #include "qdafn.hpp" +#include using namespace std; using namespace arma; using namespace mlpack; using namespace qdafn; +using namespace mlpack::neighbor; /** * With one reference point, make sure that is the one that is returned. @@ -48,3 +50,53 @@ BOOST_AUTO_TEST_CASE(QDAFNTrivialTest) BOOST_REQUIRE_CLOSE(distances[i], dist, 1e-5); } } + +/** + * Given a random uniform reference set, ensure that we get a neighbor and + * distance within 10% of the actual true furthest neighbor distance at least + * 70% of the time. + */ +BOOST_AUTO_TEST_CASE(QDAFNUniformSet) +{ + arma::mat uniformSet = arma::randu(25, 1000); + + QDAFN<> qdafn(uniformSet, 10, 30); + + // Get the actual neighbors. + AllkFN kfn(uniformSet); + arma::Mat trueNeighbors; + arma::mat trueDistances; + + kfn.Search(1000, trueNeighbors, trueDistances); + + arma::Mat qdafnNeighbors; + arma::mat qdafnDistances; + + qdafn.Search(uniformSet, 1, qdafnNeighbors, qdafnDistances); + + BOOST_REQUIRE_EQUAL(qdafnNeighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(qdafnNeighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(qdafnDistances.n_rows, 1); + BOOST_REQUIRE_EQUAL(qdafnDistances.n_cols, 1000); + + size_t successes = 0; + for (size_t i = 0; i < 1000; ++i) + { + // Find the true neighbor. + size_t trueIndex = 1000; + for (size_t j = 0; j < 1000; ++j) + { + if (trueNeighbors(j, i) == qdafnNeighbors(0, i)) + { + trueIndex = j; + break; + } + } + + BOOST_REQUIRE_NE(trueIndex, 1000); + if (0.9 * trueDistances(0, i) <= qdafnDistances(0, i)) + ++successes; + } + + BOOST_REQUIRE_GE(successes, 700); +} From 680f59bab89929098c8890b4499cd7303727fc5e Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 19 Apr 2016 13:43:17 -0700 Subject: [PATCH 10/30] Don't return duplicate points. --- qdafn_impl.hpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp index 368b84c3dc1..47a698cde0e 100644 --- a/qdafn_impl.hpp +++ b/qdafn_impl.hpp @@ -104,9 +104,18 @@ void QDAFN::Search(const MatType& querySet, const size_t insertPosition = mlpack::neighbor::FurthestNeighborSort::SortDistance(queryDist, queryIndices, dist); + bool found = false; + for (size_t j = 0; j < neighbors.n_rows; ++j) + { + if (neighbors(j, q) == referenceIndex) + { + found = true; + break; + } + } // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1)) + if (insertPosition != (size_t() - 1) && !found) InsertNeighbor(distances, neighbors, q, insertPosition, referenceIndex, dist); From 921c45646bbf4ea9cd51c14376a2fc03f6aaa125 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 19 Apr 2016 13:43:30 -0700 Subject: [PATCH 11/30] Add main program. --- qdafn_main.cpp | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/qdafn_main.cpp b/qdafn_main.cpp index e69de29bb2d..47969c81a77 100644 --- a/qdafn_main.cpp +++ b/qdafn_main.cpp @@ -0,0 +1,75 @@ +/** + * @file qdafn_main.cpp + * @author Ryan Curtin + * + * Command-line program for the QDAFN algorithm. + */ +#include +#include "qdafn.hpp" + +using namespace qdafn; +using namespace mlpack; +using namespace std; + +PROGRAM_INFO("Query-dependent approximate furthest neighbor search", + "This program implements the algorithm from the SISAP 2015 paper titled " + "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. " + "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to " + "search in) with --reference_file, specify a query set (set to search for) " + "with --query_file, and specify algorithm parameters with --num_tables and " + "--num_projections (or don't, and defaults will be used). Also specify " + "the number of points to search for with --k. Each of those options has " + "short names too; see the detailed parameter documentation below." + "\n\n" + "Results for each query point are stored in the files specified by " + "--neighbors_file and --distances_file. This is in the same format as the " + "mlpack KFN and KNN programs: each row holds the k distances or neighbor " + "indices for each query point."); + +PARAM_STRING_REQ("reference_file", "File containing reference points.", "r"); +PARAM_STRING_REQ("query_file", "File containing query points.", "q"); + +PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k"); + +PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10); +PARAM_INT("num_projections", "Number of projections to use in each hash table.", + "p", 30); + +PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.", + "n", ""); +PARAM_STRING("distances_file", "File to save furthest neighbor distances to.", + "d", ""); + +int main(int argc, char** argv) +{ + CLI::ParseCommandLine(argc, argv); + + const string referenceFile = CLI::GetParam("reference_file"); + const string queryFile = CLI::GetParam("query_file"); + const size_t k = (size_t) CLI::GetParam("k"); + const size_t numTables = (size_t) CLI::GetParam("num_tables"); + const size_t numProjections = (size_t) CLI::GetParam("num_projections"); + + // Load the data. + arma::mat referenceData, queryData; + data::Load(referenceFile, referenceData, true); + data::Load(queryFile, queryData, true); + + // Construct the object. + Timer::Start("qdafn_construct"); + QDAFN<> q(referenceData, numTables, numProjections); + Timer::Stop("qdafn_construct"); + + // Do the search. + arma::Mat neighbors; + arma::mat distances; + Timer::Start("qdafn_search"); + q.Search(queryData, k, neighbors, distances); + Timer::Stop("qdafn_search"); + + // Save the results. + if (CLI::HasParam("neighbors_file")) + data::Save(CLI::GetParam("neighbors_file"), neighbors); + if (CLI::HasParam("distances_file")) + data::Save(CLI::GetParam("distances_file"), distances); +} From 5249e985a26a7a5cb00ba261db89c1c0e9c5f24e Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 19 Apr 2016 20:17:47 -0400 Subject: [PATCH 12/30] Add flag to print test error. --- qdafn_main.cpp | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/qdafn_main.cpp b/qdafn_main.cpp index 47969c81a77..0e866473e24 100644 --- a/qdafn_main.cpp +++ b/qdafn_main.cpp @@ -6,6 +6,7 @@ */ #include #include "qdafn.hpp" +#include using namespace qdafn; using namespace mlpack; @@ -40,6 +41,9 @@ PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.", PARAM_STRING("distances_file", "File to save furthest neighbor distances to.", "d", ""); +PARAM_FLAG("calculate_error", "If set, calculate the average distance error.", + "e"); + int main(int argc, char** argv) { CLI::ParseCommandLine(argc, argv); @@ -67,6 +71,25 @@ int main(int argc, char** argv) q.Search(queryData, k, neighbors, distances); Timer::Stop("qdafn_search"); + // Print the number of base cases. + Log::Info << "Total distance evaluations: " << + (queryData.n_cols * numProjections) << "." << endl; + + if (CLI::HasParam("calculate_error")) + { + neighbor::AllkFN kfn(referenceData); + + arma::Mat trueNeighbors; + arma::mat trueDistances; + + kfn.Search(queryData, 1, trueNeighbors, trueDistances); + + const double averageError = arma::sum(trueDistances / distances.row(0)) / + distances.n_cols; + + Log::Info << "Average error: " << averageError << "." << endl; + } + // Save the results. if (CLI::HasParam("neighbors_file")) data::Save(CLI::GetParam("neighbors_file"), neighbors); From 5e0db4c90f96fe58ca31a2723a1a4e686043950d Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 24 Oct 2016 16:27:10 +0900 Subject: [PATCH 13/30] Add DrusillaSelect implementation. --- src/mlpack/methods/CMakeLists.txt | 1 + .../approx_kfn/.drusilla_select.hpp.swp | Bin 0 -> 16384 bytes .../approx_kfn/.drusilla_select_impl.hpp.swo | Bin 0 -> 45056 bytes .../approx_kfn/.drusilla_select_impl.hpp.swp | Bin 0 -> 20480 bytes src/mlpack/methods/approx_kfn/CMakeLists.txt | 20 ++ .../methods/approx_kfn/drusilla_select.hpp | 125 +++++++++++ .../approx_kfn/drusilla_select_impl.hpp | 210 ++++++++++++++++++ .../approx_kfn/drusilla_select_main.cpp | 100 +++++++++ src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/drusilla_select_test.cpp | 145 ++++++++++++ 10 files changed, 602 insertions(+) create mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp create mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo create mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp create mode 100644 src/mlpack/methods/approx_kfn/CMakeLists.txt create mode 100644 src/mlpack/methods/approx_kfn/drusilla_select.hpp create mode 100644 src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp create mode 100644 src/mlpack/methods/approx_kfn/drusilla_select_main.cpp create mode 100644 src/mlpack/tests/drusilla_select_test.cpp diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index dbbd2318bee..f292e9756c9 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -18,6 +18,7 @@ endmacro () set(DIRS preprocess adaboost + approx_kfn amf ann cf diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp new file mode 100644 index 0000000000000000000000000000000000000000..ae44b2878c4201185817841f1f41ab0b2f854f98 GIT binary patch literal 16384 zcmeI2`-@#y6~|9o(^hMuwHBcw*-0>+8D{RygEVFGAd|_{CX+EUNvo7z&fL52+><$v z>pADnWHQG70s4yw75u6ClchzhApYPFzEBie!Iu6Y6tuL6DAGppQ3>d0?S0Oz^n#*4$10@3+GSEqzV~=ii_V3%f zN91>N(M>>N(M>>N(M>>N(TPF7)aZ9I#1K3JF{+`%KqOm z@c*vt_rdJ{fr01Q&+=c%K*>PKK*>PKK*>PKK*>PKK*>PKK*>PKK*_-WAOmj0aelSM zakkupah(6>`~RCCa-83QpM#%*6g&w&1ExV0ys^!3egVD)G?)U1!A|fAa6i}vez29h z1Yjq4^MmLG&w&KQ;BjynOo9p+19yWP_c_kn;CJA+;LG3`xEI_7Uj2aMd<#Ut1Bbvv z;2-aIoWFw`;2R(X5%?6i4{QOi-s?E8fS15^Pyq^j9J~$}-vd`c49wc%9)y{!S z6*Vtt`*O2lr&Fz}vze3>>OhC7n|e`Lfh)z+&<%9L)eWue4UWOmPvwb_ zEx0vLji{@p1?|>bRxxI%nM$Fer^70NAg`_qRoPF`mA164ZVfi-&{C*b)OFQxQ}?`f zL$OaRk?(n3d4`qUb|DQ4nv+p`IhU93xjc__tnnP7tEU~Q=6hk+##s!D-LMe_wp=Y~ z$idh+y26bjFPv8qyjPSu<0i_-5B*LMCSr*7)Xb5B+zL?`c7iz_i!EGBl1Y*%hE4`R z_pq6@AhTycnZLHO+$P4*bjcdP2+S_G^=_fV_IXC9y$XtQ%LQHY7Tr{j3;V_^XLRhj zKC@?NBCfaaq}@<3n)R;}CfJ5_2}I?)Oyh z^BwwB2=&^)IHlsL`f3`x^;FfF7#>rvJ9pzKkbzSp<_|uRq}pxRA~SD%-Aole)7Evb z>FLI(jCg2$x8c9;!iDvi zVK^bg!xd!`jb^lQ+~FNodQx7OLZM-zzke#4lNbiIX-kFCIy@aGMs{M|B}w&v`kXo# zba|2Olkmmm`|Ws7xIO;W^&olO^!j=3oQWZyV90C4zHC#EkY^sZ+(noqiYk93n2We3pJFsyhM2k>SRu+HtQ^}KJKhIb>B^r zfsu1ohxhDtLvSlX^-;&^F8uQ4ms;s5w5vgiy35SNwyavmxn*^Ul3fkm)Ku-n#Y;11&z+m7&73@Ua&oqIW_lWpP1a@2+pJ|? zO}x!Ky<~^8m8R|SYISL8sS>+s;8&t}z8ZK*o%Qd0_o8uA78R9tIFEabWJ8xPI$VD~ zikWKxb5@2^NWoQ_DQf|j37lpF&NqpGI8Y)BG8MSX)DV5l{miW*1r<~dmA&rlv2tPY zz@6(BwTeYyrjZcfhVv{uD+>(XV|VFwjKF{{QX zT(Q^VrTa~k8wQSL+=;Nv&nD8*l6XqnZR>bk$|pkAzqKly%iMZj#|^r}tco`;E`~y5 zm?cYH}#K6-z)q7{Cxjk?C;+N-vqMf-v)lRjXg0K z0e6Etz{~9EzXC3SkAeHaAK1IU1fB&e;2KE5XTdBu2>!F0Lv7yA4HILY@D zU=nNtTfsjm|GsbnLhvLw1jfOG;7{!N%TLKb$w0|K$-sMvfz2jCHzl|0mQt$)R+kLC zm)Fn2QuRNm-J8R`9W}Qjo2QL$v(7EWjliwhUZ1+&&JU)lH0wr=vzuAOD6 zZ_}rJ3BNq+uuj!N)~5RGUTZS)^XzhW>A8N=wLfVqE6(l?m*u@|aP!?<(FrI(gb= z$KQb#n}OJzK~{DXW=+(wgv>3iZ_+GDW@e1`Z~>)j@*q^N#ldE}x?U^mb*`_io{sgm zCH9vYVL5*)_+29-T%=C()2T$nWAAD=$=^-sbY(n~=VyzTWfJp5$wsxiGJ~}jdUCs~ zvI6$F!Q`ssyuoI&@^a=IwP^i3>d{jva%!#E7Nx)c+?m;# z&AyL7$`Nz-oZs8GbMM@__kU;Z-2ean@4dsl7q42Q-duQNhR?~F%!lqg>5|#+S(@p6 zJd^Rpty;0v^cqf0LMQwTjaS!Lr46oI8FL!l1HB6`y!gEHmz{s^#=-O{Z%-uH<9nr^ zYNc+K#=7~h^{c09H%8sE-(%J5p1Z3!Hd5;;drjY|RIH+JSL{-w=v3>K!f3soPTu@7 z5-<{&NCM5qNcXuXW!|`OL9axa(=}V2ddfu;#Wu%{1dIfX1dIfX1dIfX1dIfX1dIfJ z3nkE~pOAS6<$8How&TO^m$iK#*A{+4TlnvUzh{Ntm=c0dlui_XW_bLpw(wVlzrFCg zyDj{c;qNFPJKDmdd`IDpw(wVkzh4u6H`~Hzguh=KeqYfR9{x06Mgm3xMgm3xMgm3x zMgm3xMgm3xMgm3xMglKM30UP!W)@2ET2YKq+9N&sS$_X3JOGE_YPbp{-o}s~zE6cy z;AQaXS7$PR0e=qff_K7pxERiWr%uXb4#JJ_LD&oLht;qEX2CPRo5}nF_QPHnhLtcE zetBXhb3fb%x4<7k5jMd97(1gkNINc^DpqJD~&@ zz#oEw$l#+?&33kq4!fRT7{SLP>pNH2#YQ%#y48Z*qD(T}w;M$X&W`2P_NMKP57~_Z z9wKhVCniz#r)!5*Y1+yiQPtU9K@v(eMm=|@@*Cy8KBq>wQ!ZNG)@IeNHL_hpRm*Fv zwEWS&z9HN4N~76apYm)j;WzBE8dHw1wtBYJusss8YO0HCIU_1NrutPim)XtFi9eQC zwfg$1R%5OzJAT8em25vZjWgC*jm26R{c~GF%1?S`$*uVfrO#4h`OFU2DXX}6i%D|B zc-`)g%cg=<( z>ARgmZ?MLsSm4%GRpCq}4c{<(kHQMfPkO(URSIQ77)NR@HJ)nk{fk!BRjpk# zwf3$sW+3MR2eWMfT4K$eu2<%zrq5ZFkt(LVgHCBIc!X&DNK0OEQhCXi*^xVPk#cC6 zD^d}T)S=F5+_!401G-H-Rh1xTTqNBf6bhjf^_}6!9Y4X{$vdiBmFWYlv8`H=+5{Pk z5L)sjUWlbnyQpOyy{kI~ zaj!_fqC*h*TJ7>~b$8HJ6;UO+FTuS(!oAVTRs04ngnig*mhPjN{;RK zjD#J4Vk;W&E%6{vCW*e$wEGiR?us*19H~j{F6n%)s;QDi=hJB@@+zIhVcV|p(39Oc z&Ir%`h+4CHaOuDW#Wl+}tX#KjsJL`+aQ(VViWjU{TU@sOqM=o*S1&CNEnmHSU_)`$ zn!(k@m4kyBHNOYcD*F9FugsE;6K$nh7COHV1-|d}d@bKsp_4p~46m8_>aA82r+H7U z$75d&Xpen4j7CR~b|UZjp84}LqW?dR-rfIdt^Y^k-@itm{}ucY9u?gl#zA!cZg?$- z?*9YMyA`g5ozR2^Y=#xE93<^uagA@lpTRY-0$vUepx5t#h43`G`&Z#B@FjQ@MAyF) zu7TBX4m^o2{}cEQJO-bIkHNKY8LWrX;aT+hpTUE06KsS9@H9I7UC;m@-UvTIU;hjo zfUDplSOZz;g6GiJzX?ykC*T8MLm9eZA()SmfRTWaz;CSt=AhU)BQ)63-)+6O^}D$J zX!Ne-g*wgiCW|&%*x2Ehmgv(tWg2mGlxak)Sy80u`$*(;q~}amsJd%KWgO`R0;8}t zF-BGCXX~z?f1wE-an3vChyR=pMG}CBBi~iH!@^`X=MzPzfZ>knNasPWqAqm}%OeDr56IizX0yFdiPY;I9 zjz${#)|z7$qYIfG1MkLv4-^SCH9(nR)>$-cF9q*?Qw ztu^dXGKSPMAN!N%i&j6sC#Q!7)4F|Xe9&_kC1tF2Ddr4khdmc7S|__o7~DVkh~aF8 z48sHi>r>S&+m(VeJ)-|Vjeb54T|PPf{}8_)gpa}=xDv)d;{A~Lhu~iLBwP<4hWEl4 z%!Qw#+wX_>!h2vPbi*m|pXmBOgWKRr*Z@QDdUz83{&v_8H-ZoAVJ(R6eb=EHGt0KNaia2;%g)8I950=yEw!C3wsa2s3&Z-d#O;QNf_-wOxeE*OJBm<3;A z4F68}NBDcV4gLo9!ZJ7>?q(c+9CpG+I1#?hgo|&%_hjNlbcc31HOAqB`>8xziDkZ` zHh2taV#OU1!)ZZ3C#i9IuUTt2Rl8_=p6g}1(zj|nf}QKb8Wui=jG8tW<88#_^Kl4I zR_G|8mS6&(YREpycl@m_$c#}0Lj0XFlt!ccY561`<#a&1 z-8Ax}==5xd#?^i^1}+ej(`#kZ=Xy?bN!@CLi%{#WQ37$EQa6MRXIh+Mz6L=2&|Htq}*bjSQGbnfi zi2wiR;073h%V7odLk8|a@4pFdfNS9WP=i5O3Fkou?nCcC1p8ndHp3=Z1h0ly!Vl2> z#b4ldi2MZ3;r}0@_kRbz4o|=*;E&-Q@CR@%{15v76Yynt4Bic+P=Ft!_ls@d^YAD< z0^rRa;H%gHu7y7Z7v{i!VGsDZ*aYMgJ?O?R5WT7ron~VfFjE1hJryAN($=oS z@{U$#E4+czj*vlU!cD&A3V&>=^GD*0pVVJFO&~~haEy<>@zKYz)cEKdAN{_*qwJ%P ze?;2zUe62DGp7PSZya^S*}Rf)`1W5!_x#^TT|MY$R(TIrSJlLF?J$Rhe}fquL0E;b zN-Kw%s%A)GTkj82IiL8yE}dRY3xW9f@n04=XzS!Ti$OHHyg+x02iIVJLv z4V{Q3`u3b;K$aE>R`e>!vLSxwmB|uSaCFu<3yPeLGL=u2D9^fIGCkjOO2JAO=|X}3|BO)27yX|h=HPQ1$A)1j3Olx*TT@yi0V4q;0V4q; z0V4q;0V4q;0V4q;0V9E9TLLUE9H{mkbm=feY%wOYT&NJ}|8vpt#lJtP|3Ag=zXH+!gRo<}-c9L^ z1dIfX1dIfX1dIfX1dIfX1dIfX1dIfX1llE_!cSLf&$3#xI?O~C<_QH`N(D|8u_uH$ zRLqv`l}ap#h+1F42SZgmt;%-z?rG^sCdq}_+O|+j?wZC#4_Qkh7wG?=!(V?MKK!Ep z&kRfR9D4uv;d}61xCd^6x4}AC0n6YV_&3fy03V09K@KE6SqJbUSPAEYtPk)nTLg41^hDH1+pK&RyZ5x!^`1u=J$Ubw!;`)2&>>! z__ECHhdba_ki7t|f-xw;7B~aG#oYdH!Xfww?13v`IlLC0WnTX;;12jS?1gv1Zny&S zAoKoZ&i{URAG{MTg!#|~C&RPM@&7se6rO^6;0Cx39FR2v&xAAJ+c;!A00-f2aA5 z(A@vj#Ih;Z0@SWLW(~mB-cUz*+dn+k>|Jdee+3?I?K==of49H5krsEgV0W-dr}zgs zi2c3l+Y5EKo{fK}ZkHe`PHseEL%wW|4TT6Viv?Nl9hLTn`+r;HPcqr`N3*3kn@aMO z%J}XeQ+f{PK(yfJK$To)aZHO!zzaq4E=!C?Yl4<2V6g9JmCXl>b<6RBo#L3L&8~pq z1$E*{-@imz58oUw6I_Bn_3hdCLP1_IoXZxYRrX)X;#220qK({w!*$P*jfBRF!TLLk z<5R=4^zKlt!~OYIF-Top?vzgdu$b7}xtJ_s++I|@n`-xhT%MgJn`~Bg%?d-76=wiC?7+wzL|WQKLT!DJjIy0-%Xzi9gu9{zxb zij1_|W(|Z1_9_e-Q+Xb{QktKP+scxM;z%E@y&c4=ur0QYDmXCe$_~fuKv$`tHU=vK z$UcXa`PFzyfG9B76I)UYbxcw8QcHPHtu%(Kf%l<#^5$j*BgUwi5a5POdGBfMlxh}B zh?h5?bg_iAT0Yg1iQ$&v>~)apAqI;!PQDlTKx;EawUb>ES%96{1FY@Nn%M=txrnaZ zc=iVUSa!K&R)Dl4GXMX6)WMI7PKf^RCG`I%(d{3DTVXFW;4;_@;xlj>ybAsoz5b{0 z6g&b~LN^=_uK?L!@D{ikTv!Sx!mHq)(d|D4H^a5C6D|W;_wOVSo&PcT1ndIY1MoDM z3D2P0KMe<9CzL_f`-^n{9{!&RC%_E24;}wq@J^8R|K197pbMTvzyA#EgO9>y*Z{qd zh1sCsaddpyA5hl&yA--%9z26y{}?<1cfdaQ2z&@W2(nM$Qdk7CKfrV7_V>em@NN*j zzXZeZI*>I055YKCumv{3rLYd(1iwJPzYpFH1F#fk!`G;_Y8VInP++{RTCb$m zDS^hO_j4#{t77VPUn|z!n#_)-4bN&lBNxsjrqR+?nqb*IZO1cKig2F37Q|sz7$>K? zk79fZs@cZWIak?H28u&YA&sJ_Il6e=4IYjZm%wcT*lfSi6fev=oZ)DJEL fk}`E*ZmJOZDRvChO8(^A0P0QOf=L6>4uJm$Z&XZV literal 0 HcmV?d00001 diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp new file mode 100644 index 0000000000000000000000000000000000000000..9d5090f40630218c445a14d7018811412de030e6 GIT binary patch literal 20480 zcmeI3dyE}b8Nde+QC=b-CJOO%7t6iv-MibO(rmX2+bwMucH5OLg{0J**?Z^i&hEW4 zw=;9wF7-k2AH+W;@=zW@#DoYj3WBjU@~{S@{}3@T(F6zGZ(Pu?clW^>TyjJj0l&pL63Qt4547UN8)A!E|c%t{2#j zq%M0+Ewm@h`VQA^&D%lW=)l@f@7%a$)0TBNOy-MxKJ&tU->dhxTOG4L-^ahwU;S+> zm~|U|f5W$0Rz2u%cwOIawal9SU9;PrR&}=1$-le!RcN5lz;QIt4Vr!H&bj{bwQB|> zS8vY>b;0@9AE!W(RA`{kK%s#`1BC_(4HOzEG*D=u(7^wi27=CM#uup6sZqtwi0-E> zy1ys-?M3(ZFG@cx`aK@qZ(fuh{VZ;U1_})n8YnbSXrRzQp@BjJg$4=@6dEWrP-vjg zz<)pkX2USneZVj}Ap8HZ|NjquzXZ?25x5KP0LgblL_hZn;Cwg*9y`}C?uNVIRyYWI zU?+SO-Z;lF4#U0h6}TO4gK<~`7r~on8^+(^akw3(VJobJx6U$*XW(gg1a5&E?1E7! z!+CJ@Ov88zJ_`q63VPv9f)>xfA$Sn(ft%nX@Xz;=4zI(j@H6-}JOuZ_O>hZZ4DX_( z7vU**67GaKXuy?l1$+>Wyw@HRbh2Hr=;^nk1LzD{4>I@)oA7pvnmmx8;+QtorKNYqq+Ua+|8X zq9=TXQo*d}?o)oy7#y@6((Ojg^k%wk%Lz(7Q*F}=wwnIz;NX;HdiB{-c~E&4pYQ{# zq2`tCs~OKS1Iyz@rlWfJmfcjPc{QxsW#a%pm%mJ}Z4M5$&0wW!*nVI-b;~cGM2QJA z7}ClZUYRPXa@>-2*YN|TOR4#avDdX5DyiNOuiUZFv4)d+>ip5;m5EX-MqqKug$fS^ za;kn6 zB~1Gzx7nn1{c>3iZ&0O)+T<9?>#76gsQq-#y;ePazuFL$rxJT>fjgAsp?*!zZI?m; z5SQkir{SznNM-tJpc*HQnyq?QDymFpre*m`3Wj+(cjVDKRL67YI0l5oRi#Fpg<#fp zW>mf9`j)aBw>vXi(OpmLA=1RIY5LBX)39mmU30s4BkVwRAXW)7O&lYW@?EvhQVrL+ zj3%@!q@dq*M3>5(F>MEZYkm5gQZ1I|D#|YN`iy25e?0!Izg=rNDHCaZ+5AxjlKlBf z>DL|zg=)HSXOtXb|sqKRWyj(W{?OTr-rrEqq-|Y-$!FsV+vuDu`iEm2x;vv zNhK<(r(yYZ%W1?De=;eA?Vpr0Gjk4F(z)beW*9{kiThuCSGl1!w`gc9)}3%FqpQg{ zNC$n~SXRroqV`K)Dec*^9m}t>nw9K$+LOl_iS#?QurlG6JgYhCb{4`ybFyS+KEBG3 z(sE@U#&c%F4HgxyEdBjhY8flx94|5qOTA8_Um>2wBwDx9g*wx}(x!<76p0q+&-Eth z1Q~T4CqI)-I<==onB^OhI-)+D8)i<;ah;nj&!y#}X`meS4SJWZzf?XFD%)VHKUQtn zFxjM~a&N%CNSm&0?q{mC_D4BGJy&^HT9lwQ*4^iN4Hh+BX1yUzOnoQq`g}NqnEZH` zEOF9!*_qjv9=TrG%*-um2wU1vi<;P0i`oFQiBZ)GCF3Lc387kzw5WTAEj@mj-pLx( zryAG*6PD?OZ(ByNSTEYHOQeh+TV{P$EIu(1S@fb8iV~t6h>)#FGiE#^j9uzW0o&9;Z-#ro}7n2dp|I?^1;TGuCER`g%2A3s!Db`h^8 z+Gk$Wx{lZCrIP3?VplAYgk0k;Ypc5#Q&m)zF<-*I-=uG(QY9;JHL_vbCCpjH)he~h z+SJS{GpCCvU2-pd<%U+ftOkvx^Tkhd`CAoZNxL;R)k`{eWR^1lN5%gik75wN#-IN_JOOvZ95g|E{wkaaubpccFT){t9BzS0*bX0qU*XR` z0QbW#xD;N&k3R$**bXb;!|*nKyZH4_z^xF#jUaygCqaDs68siF{-^LoxCyRH_rtxg6V^f%-o{sd7Q~O=4;$ekcoUuc75)N;;SQJw8)jiU zYy;8VW+<*g1BC_(4g4Q!pcj|VZW53ZOL&PDU5BdT<<~EDT&&tKhKMnR;f=``CyK== z9zz(YTPkCmC>iwH#3J{woqT21gRa>k*-T7!f7WgidUM6BA;=dS#kwhu|5w*UYcc6PWgWypnv=V7 z>WG^%ZK@N!Bv(A?8Ye0x>A8|=7d!b+YY!_E?8!x}myJR%3g_v!Fmy7J^5fd_nfP-& zq~d#aki{xUT$ympSHgCRT<$o_B8igDSV^V_nKzD!FB)-q-8ok@-LBKno|I0ZJLckj zWgM2E>k$Rdha07f;zathZnxDC+um$6u=K0hC}dhid3XL75lS{H=4GU0uTUOF7haE* zPm}SKUSX6c-j9*6?;d8xj6iC}G76_rfy;wvxN*#B0fHXI(CSRc~SksCgXWZaM_OZVj7+I4)hX@y~qJ(&)g#xW8hD3%znw%L^W!QTyVRP=OpC}HYn@rvv+IL{S9j}d#}R1} zrOKkqBiX@YQ%3LLNSjiJd^}Gw>dl%^7kZy0GfbjAu)Jt|-KPtSzV)6)+R^yH7Asn# z#nnv;MmsPlrscZx%495xKzb#T`!d}=8%fu_IGb_FNL*OWjb1mf`p8TD!Y~etR2*i~ z&9SH(gbhN?LOn>JRia~z+cexJwGp#p};*XqDjFtT{l9}>!1v%D=)gvJFPsh%7kCs7!Z>V(3S0_D@%3@wgl09iq zoJ5eBEsmToXV>XGOaINSiN#3-Q4&i`aS}oImUwp~Iav^IQD*k>*=XDyuh02an*}DH zPiK=QF2BfStkCRoOhXhW5yCT9;*%675%f8;c(%j;EKVX6ClQL12u!BMNdykpq;64o xULy8YwoPUzW}JT8SMZ7 literal 0 HcmV?d00001 diff --git a/src/mlpack/methods/approx_kfn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt new file mode 100644 index 00000000000..0e907d6659f --- /dev/null +++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt @@ -0,0 +1,20 @@ +# Define the files we need to compile. +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + # DrusillaSelect sources. + drusilla_select.hpp + drusilla_select_impl.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) + +# The code to compute the approximate neighbor for the given query and reference +# sets with p-stable LSH. +add_cli_executable(drusilla_select) diff --git a/src/mlpack/methods/approx_kfn/drusilla_select.hpp b/src/mlpack/methods/approx_kfn/drusilla_select.hpp new file mode 100644 index 00000000000..38b90ab8a5c --- /dev/null +++ b/src/mlpack/methods/approx_kfn/drusilla_select.hpp @@ -0,0 +1,125 @@ +/** + * @file drusilla_select.hpp + * @author Ryan Curtin + * + * An implementation of the approximate furthest neighbor algorithm specified in + * the following paper: + * + * @code + * @incollection{curtin2016fast, + * title={Fast approximate furthest neighbors with data-dependent candidate + * selection}, + * author={Curtin, R.R., and Gardner, A.B.}, + * booktitle={Similarity Search and Applications}, + * pages={221--235}, + * year={2016}, + * publisher={Springer} + * } + * @endcode + * + * This algorithm, called DrusillaSelect, constructs a candidate set of points + * to query to find an approximate furthest neighbor. The strange name is a + * result of the algorithm being named after a cat. The cat in question may be + * viewed at http://www.ratml.org/misc_img/drusilla_fence.png. + */ +#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP +#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP + +#include + +namespace mlpack { +namespace neighbor { + +template +class DrusillaSelect +{ + public: + /** + * Construct the DrusillaSelect object with the given reference set (this is + * the set that will be searched). The resulting set of candidate points that + * will be searched at query time will have size l*m. + * + * @param referenceSet Set of reference data. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + DrusillaSelect(const MatType& referenceSet, + const size_t l, + const size_t m); + + /** + * Construct the DrusillaSelect object with no given reference set. Be sure + * to call Train() before calling Search()! + * + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + DrusillaSelect(const size_t l, const size_t m); + + /** + * Build the set of candidate points on the given reference set. If l and m + * are left unspecified, then the values set in the constructor will be used + * instead. + * + * @param referenceSet Set to extract candidate points from. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + void Train(const MatType& referenceSet, + const size_t l = 0, + const size_t m = 0); + + /** + * Search for the k furthest neighbors of the given query set. (The query set + * can contain just one point: that is okay.) The results will be stored in + * the given neighbors and distances matrices, in the same format as the + * NeighborSearch and LSHSearch classes. That is, each column in the + * neighbors and distances matrices will refer to a single query point, and + * the k'th row in that column will refer to the k'th candidate neighbor or + * distance for that query point. + * + * @param querySet Set of query points to search. + * @param k Number of furthest neighbors to search for. + * @param neighbors Matrix to store resulting neighbors in. + * @param distances Matrix to store resulting distances in. + */ + void Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances); + + /** + * Serialize the model. + */ + template + void Serialize(Archive& ar, const unsigned int /* version */); + + //! Access the candidate set. + const MatType& CandidateSet() const { return candidateSet; } + //! Modify the candidate set. Be careful! + MatType& CandidateSet() { return candidateSet; } + + //! Access the indices of points in the candidate set. + const arma::Col& CandidateIndices() const { return candidateIndices; } + //! Modify the indices of points in the candidate set. Be careful! + arma::Col& CandidateIndices() { return candidateIndices; } + + private: + //! The reference set. + MatType candidateSet; + //! Indices of each point in the reference set. + arma::Col candidateIndices; + + //! The number of projections. + size_t l; + //! The number of points in each projection. + size_t m; +}; + +} // namespace neighbor +} // namespace mlpack + +// Include implementation. +#include "drusilla_select_impl.hpp" + +#endif diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp new file mode 100644 index 00000000000..a84b30467bf --- /dev/null +++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp @@ -0,0 +1,210 @@ +/** + * @file drusilla_select_impl.hpp + * @author Ryan Curtin + * + * Implementation of DrusillaSelect class methods. + */ +#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP +#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP + +// In case it hasn't been included yet. +#include "drusilla_select.hpp" + +#include +#include +#include +#include +#include + +namespace mlpack { +namespace neighbor { + +// Constructor. +template +DrusillaSelect::DrusillaSelect(const MatType& referenceSet, + const size_t l, + const size_t m) : + candidateSet(referenceSet.n_rows, l * m), + l(l), + m(m) +{ + if (l == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of l; must be greater than 0!"); + else if (m == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of m; must be greater than 0!"); + + Train(referenceSet, l, m); +} + +// Constructor with no training. +template +DrusillaSelect::DrusillaSelect(const size_t l, const size_t m) : + l(l), + m(m) +{ + if (l == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of l; must be greater than 0!"); + else if (m == 0) + throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid " + "value of m; must be greater than 0!"); +} + +// Train the model. +template +void DrusillaSelect::Train( + const MatType& referenceSet, + const size_t lIn, + const size_t mIn) +{ + // Did the user specify a new size? If so, use it. + if (lIn > 0) + l = lIn; + if (mIn > 0) + m = mIn; + + if ((l * m) > referenceSet.n_cols) + throw std::invalid_argument("DrusillaSelect::Train(): l and m are too " + "large! Choose smaller values. l*m must be smaller than the number " + "of points in the dataset."); + + arma::vec dataMean = arma::mean(referenceSet, 1); + arma::vec norms(referenceSet.n_cols); + + arma::mat refCopy = referenceSet.each_col() - dataMean; + for (size_t i = 0; i < refCopy.n_cols; ++i) + norms[i] = arma::norm(refCopy.col(i) - dataMean); + + // Find the top m points for each of the l projections... + for (size_t i = 0; i < l; ++i) + { + // Pick best index. + arma::uword maxIndex; + norms.max(maxIndex); + + arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)); + const size_t n_nonzero = (size_t) arma::sum(norms > 0); + + // Calculate distortion and offset. + arma::vec distortions(referenceSet.n_cols); + arma::vec offsets(referenceSet.n_cols); + for (size_t j = 0; j < referenceSet.n_cols; ++j) + { + if (norms[j] > 0.0) + { + offsets[j] = arma::dot(refCopy.col(j), line); + distortions[j] = arma::norm(refCopy.col(j) - offsets[j] * + line); + } + else + { + offsets[j] = 0.0; + distortions[j] = 0.0; + } + } + arma::vec sums = arma::abs(offsets) - arma::abs(distortions); + arma::uvec sortedSums = arma::sort_index(sums, "descend"); + + arma::vec bestSums(m); + arma::Col bestIndices(m); + bestSums.fill(-DBL_MAX); + + // Find the top m elements using a priority queue. + typedef std::pair Candidate; + struct CandidateCmp + { + bool operator()(const Candidate& c1, const Candidate& c2) + { + return c2.first > c1.first; + } + }; + + std::vector clist(m, std::make_pair(size_t(-1), double(0.0))); + std::priority_queue, CandidateCmp> + pq(CandidateCmp(), std::move(clist)); + + for (size_t j = 0; j < sums.n_elem; ++j) + { + Candidate c = std::make_pair(sums[j], j); + if (CandidateCmp()(c, pq.top())) + { + pq.pop(); + pq.push(c); + } + } + + // Take the top m elements for this table. + for (size_t j = 0; j < m; ++j) + { + const size_t index = pq.top().second; + pq.pop(); + candidateSet.col(i * m + j) = referenceSet.col(index); + + // Mark the norm as 0 so we don't see this point again. + norms[index] = 0.0; + } + + // Calculate angles from the current projection. Anything close enough, + // mark the norm as 0. + arma::vec farPoints = arma::conv_to::from( + arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0)); + norms %= farPoints; + } +} + +// Search. +template +void DrusillaSelect::Search(const MatType& querySet, + const size_t k, + arma::Mat& neighbors, + arma::mat& distances) +{ + if (candidateSet.n_cols == 0) + throw std::runtime_error("DrusillaSelect::Search(): candidate set not " + "initialized! Call Train() first."); + + if (k > (l * m)) + throw std::invalid_argument("DrusillaSelect::Search(): requested k is " + "greater than number of points in candidate set! Increase l or m."); + + // We'll use the NeighborSearchRules class to perform our brute-force search. + // Note that we aren't using trees for our search, so we can use 'int' as a + // TreeType. + metric::EuclideanDistance metric; + NeighborSearchRules> + rules(querySet, candidateSet, k, metric, 0, false); + + neighbors.set_size(k, querySet.n_cols); + neighbors.fill(size_t() - 1); + distances.zeros(k, querySet.n_cols); + + for (size_t q = 0; q < querySet.n_cols; ++q) + for (size_t r = 0; r < candidateSet.n_cols; ++r) + rules.BaseCase(q, r); + + // Map the neighbors back to their original indices in the reference set. + for (size_t i = 0; i < neighbors.n_elem; ++i) + neighbors[i] = candidateIndices[neighbors[i]]; +} + +//! Serialize the model. +template +template +void DrusillaSelect::Serialize(Archive& ar, + const unsigned int /* version */) +{ + using data::CreateNVP; + + ar & CreateNVP(candidateSet, "candidateSet"); + ar & CreateNVP(candidateIndices, "candidateIndices"); + ar & CreateNVP(l, "l"); + ar & CreateNVP(m, "m"); +} + +} // namespace neighbor +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp new file mode 100644 index 00000000000..9e55ec721d7 --- /dev/null +++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp @@ -0,0 +1,100 @@ +/** + * @file smarthash_main.cpp + * @author Ryan Curtin + * + * Command-line program for the SmartHash algorithm. + */ +#include +#include "smarthash_fn.hpp" +#include + +using namespace smarthash; +using namespace mlpack; +using namespace std; + +PROGRAM_INFO("Query-dependent approximate furthest neighbor search", + "This program implements the algorithm from the SISAP 2015 paper titled " + "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. " + "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to " + "search in) with --reference_file, specify a query set (set to search for) " + "with --query_file, and specify algorithm parameters with --num_tables and " + "--num_projections (or don't, and defaults will be used). Also specify " + "the number of points to search for with --k. Each of those options has " + "short names too; see the detailed parameter documentation below." + "\n\n" + "Results for each query point are stored in the files specified by " + "--neighbors_file and --distances_file. This is in the same format as the " + "mlpack KFN and KNN programs: each row holds the k distances or neighbor " + "indices for each query point."); + +PARAM_STRING_REQ("reference_file", "File containing reference points.", "r"); +PARAM_STRING_REQ("query_file", "File containing query points.", "q"); + +PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k"); + +PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10); +PARAM_INT("num_projections", "Number of projections to use in each hash table.", + "p", 30); + +PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.", + "n", ""); +PARAM_STRING("distances_file", "File to save furthest neighbor distances to.", + "d", ""); + +PARAM_FLAG("calculate_error", "If set, calculate the average distance error.", + "e"); +PARAM_STRING("exact_distances_file", "File containing exact distances", "x", ""); + +int main(int argc, char** argv) +{ + CLI::ParseCommandLine(argc, argv); + + const string referenceFile = CLI::GetParam("reference_file"); + const string queryFile = CLI::GetParam("query_file"); + const size_t k = (size_t) CLI::GetParam("k"); + const size_t numTables = (size_t) CLI::GetParam("num_tables"); + const size_t numProjections = (size_t) CLI::GetParam("num_projections"); + + // Load the data. + arma::mat referenceData, queryData; + data::Load(referenceFile, referenceData, true); + data::Load(queryFile, queryData, true); + + // Construct the object. + Timer::Start("smarthash_construct"); + SmartHash<> q(referenceData, numTables, numProjections); + Timer::Stop("smarthash_construct"); + + // Do the search. + arma::Mat neighbors; + arma::mat distances; + Timer::Start("smarthash_search"); + q.Search(queryData, k, neighbors, distances); + Timer::Stop("smarthash_search"); + + if (CLI::HasParam("calculate_error")) + { +// neighbor::AllkFN kfn(referenceData); + +// arma::Mat trueNeighbors; + arma::mat trueDistances; + data::Load(CLI::GetParam("exact_distances_file"), trueDistances); + +// kfn.Search(queryData, 1, trueNeighbors, trueDistances); + + const double averageError = arma::sum(trueDistances / distances.row(0)) / + distances.n_cols; + const double minError = arma::min(trueDistances / distances.row(0)); + const double maxError = arma::max(trueDistances / distances.row(0)); + + Log::Info << "Average error: " << averageError << "." << endl; + Log::Info << "Maximum error: " << maxError << "." << endl; + Log::Info << "Minimum error: " << minError << "." << endl; + } + + // Save the results. + if (CLI::HasParam("neighbors_file")) + data::Save(CLI::GetParam("neighbors_file"), neighbors); + if (CLI::HasParam("distances_file")) + data::Save(CLI::GetParam("distances_file"), distances); +} diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 9ad40927c8f..a93f7bf9e8b 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable(mlpack_test decision_stump_test.cpp det_test.cpp distribution_test.cpp + drusilla_select_test.cpp emst_test.cpp fastmks_test.cpp feedforward_network_test.cpp diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp new file mode 100644 index 00000000000..504fd6247ed --- /dev/null +++ b/src/mlpack/tests/drusilla_select_test.cpp @@ -0,0 +1,145 @@ +/** + * @file drusilla_select_test.cpp + * @author Ryan Curtin + * + * Test for DrusillaSelect. + */ +#include +#include + +#include +#include "test_tools.hpp" +#include "serialization.hpp" + +using namespace mlpack; +using namespace mlpack::neighbor; + +BOOST_AUTO_TEST_SUITE(DrusillaSelectTest); + +// If we have a dataset with an extreme outlier, then every point (except that +// one) should end up with that point as the furthest neighbor candidate. +BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest) +{ + arma::mat dataset = arma::randu(5, 100); + dataset.col(100) += 100; // Make last column very large. + + // Construct with some reasonable parameters. + DrusillaSelect<> ds(dataset, 5, 5); + + // Query with every point except the extreme point. + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset.cols(0, 99), 1, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 99); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); + + for (size_t i = 0; i < 99; ++i) + BOOST_REQUIRE_EQUAL(neighbors[i], 100); +} + +// If we use only one projection with the number of points equal to what is in +// the dataset, we should end up with the exact result. +BOOST_AUTO_TEST_CASE(DrusillaSelectExhaustiveExactTest) +{ + arma::mat dataset = arma::randu(5, 100); + + // Construct with one projection and 100 points in that projection. + DrusillaSelect<> ds(dataset, 100, 1); + + arma::mat distances, distancesTrue; + arma::Mat neighbors, neighborsTrue; + + ds.Search(dataset, 5, neighbors, distances); + + AllkFN kfn(dataset); + kfn.Search(dataset, 5, neighborsTrue, distancesTrue); + + BOOST_REQUIRE_EQUAL(neighborsTrue.n_cols, neighbors.n_cols); + BOOST_REQUIRE_EQUAL(neighborsTrue.n_rows, neighbors.n_rows); + BOOST_REQUIRE_EQUAL(distancesTrue.n_cols, distances.n_cols); + BOOST_REQUIRE_EQUAL(distancesTrue.n_rows, distances.n_rows); + + for (size_t i = 0; i < distances.n_elem; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsTrue[i]); + BOOST_REQUIRE_CLOSE(distances[i], distancesTrue[i], 1e-5); + } +} + +// Test that we can call Train() after calling the constructor. +BOOST_AUTO_TEST_CASE(RetrainTest) +{ + arma::mat firstDataset = arma::randu(3, 10); + arma::mat dataset = arma::randu(3, 200); + + DrusillaSelect<> ds(firstDataset, 3, 3); + ds.Train(std::move(dataset), 2, 2); + + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset, 1, neighbors, distances); + + BOOST_REQUIRE_EQUAL(dataset.n_elem, 0); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); + BOOST_REQUIRE_EQUAL(distances.n_cols, 200); + BOOST_REQUIRE_EQUAL(distances.n_rows, 1); +} + +// Test serialization. +BOOST_AUTO_TEST_CASE(SerializationTest) +{ + // Create a random dataset. + arma::mat dataset = arma::randu(3, 100); + + DrusillaSelect<> ds(dataset, 3, 3); + + arma::mat fakeDataset1 = arma::randu(2, 5); + arma::mat fakeDataset2 = arma::randu(10, 8); + DrusillaSelect<> dsXml(fakeDataset1, 10, 10); + DrusillaSelect<> dsText(2, 2); + DrusillaSelect<> dsBinary(5, 6); + dsBinary.Train(fakeDataset2); + + // Now do the serialization. + SerializeObjectAll(ds, dsXml, dsText, dsBinary); + + // Now do a search and make sure all the results are the same. + arma::Mat neighbors, neighborsXml, neighborsText, neighborsBinary; + arma::mat distances, distancesXml, distancesText, distancesBinary; + + ds.Search(dataset, 3, neighbors, distances); + dsXml.Search(dataset, 3, neighborsXml, distancesXml); + dsText.Search(dataset, 3, neighborsText, distancesText); + dsBinary.Search(dataset, 3, neighborsBinary, distancesBinary); + + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsXml.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsXml.n_cols); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsText.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsText.n_cols); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsBinary.n_rows); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsBinary.n_cols); + + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesXml.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesXml.n_cols); + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesText.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesText.n_cols); + BOOST_REQUIRE_EQUAL(distances.n_rows, distancesBinary.n_rows); + BOOST_REQUIRE_EQUAL(distances.n_cols, distancesBinary.n_cols); + + for (size_t i = 0; i < neighbors.n_elem; ++i) + { + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsXml[i]); + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsText[i]); + BOOST_REQUIRE_EQUAL(neighbors[i], neighborsBinary[i]); + + BOOST_REQUIRE_CLOSE(distances[i], distancesXml[i], 1e-5); + BOOST_REQUIRE_CLOSE(distances[i], distancesText[i], 1e-5); + BOOST_REQUIRE_CLOSE(distances[i], distancesBinary[i], 1e-5); + } +} + +BOOST_AUTO_TEST_SUITE_END(); From c671a4ada5c1f0854f267b4d59f555314b0ed74f Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 24 Oct 2016 03:45:00 -0400 Subject: [PATCH 14/30] Move things into the right place after subtree merge. --- src/mlpack/methods/approx_kfn/CMakeLists.txt | 3 + .../methods/approx_kfn/{qdafn => }/qdafn.hpp | 0 .../methods/approx_kfn/qdafn/CMakeLists.txt | 254 ------------------ .../methods/approx_kfn/qdafn/CXX11.cmake | 48 ---- .../methods/approx_kfn/qdafn/README.txt | 16 -- .../approx_kfn/{qdafn => }/qdafn_impl.hpp | 0 .../approx_kfn/{qdafn => }/qdafn_main.cpp | 0 .../approx_kfn/qdafn => tests}/qdafn_test.cpp | 0 8 files changed, 3 insertions(+), 318 deletions(-) rename src/mlpack/methods/approx_kfn/{qdafn => }/qdafn.hpp (100%) delete mode 100644 src/mlpack/methods/approx_kfn/qdafn/CMakeLists.txt delete mode 100644 src/mlpack/methods/approx_kfn/qdafn/CXX11.cmake delete mode 100644 src/mlpack/methods/approx_kfn/qdafn/README.txt rename src/mlpack/methods/approx_kfn/{qdafn => }/qdafn_impl.hpp (100%) rename src/mlpack/methods/approx_kfn/{qdafn => }/qdafn_main.cpp (100%) rename src/mlpack/{methods/approx_kfn/qdafn => tests}/qdafn_test.cpp (100%) diff --git a/src/mlpack/methods/approx_kfn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt index 0e907d6659f..fa7846211f1 100644 --- a/src/mlpack/methods/approx_kfn/CMakeLists.txt +++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt @@ -4,6 +4,9 @@ set(SOURCES # DrusillaSelect sources. drusilla_select.hpp drusilla_select_impl.hpp + # QDAFN sources. + qdafn.hpp + qdafn_impl.hpp ) # Add directory name to sources. diff --git a/src/mlpack/methods/approx_kfn/qdafn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp similarity index 100% rename from src/mlpack/methods/approx_kfn/qdafn/qdafn.hpp rename to src/mlpack/methods/approx_kfn/qdafn.hpp diff --git a/src/mlpack/methods/approx_kfn/qdafn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/qdafn/CMakeLists.txt deleted file mode 100644 index d01fca78d34..00000000000 --- a/src/mlpack/methods/approx_kfn/qdafn/CMakeLists.txt +++ /dev/null @@ -1,254 +0,0 @@ -# Much of this is borrowed from mlpack's CMakeLists.txt. -cmake_minimum_required(VERSION 2.8.5) -project(qdafn C CXX) - -# Ensure that we have a C++11 compiler. -include(CXX11.cmake) -check_for_cxx11_compiler(HAS_CXX11) -if(NOT HAS_CXX11) - message(FATAL_ERROR "No C++11 compiler available!") -endif() -enable_cxx11() - -# Define compilation options. -option(DEBUG "Compile with debugging information" ON) -option(PROFILE "Compile with profiling information" ON) - -# Set the CFLAGS and CXXFLAGS depending on the options the user specified. -# Only GCC-like compilers support -Wextra, and other compilers give tons of -# output for -Wall, so only -Wall and -Wextra on GCC. -if(CMAKE_COMPILER_IS_GNUCC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -ftemplate-depth=1000") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra") -endif() - -# If using clang, we have to link against libc++ depending on the -# OS (at least on some systems). Further, gcc sometimes optimizes calls to -# math.h functions, making -lm unnecessary with gcc, but it may still be -# necessary with clang. -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") - if (APPLE) - # detect OS X version. Use '/usr/bin/sw_vers -productVersion' to - # extract V from '10.V.x'.) - exec_program(/usr/bin/sw_vers ARGS - -productVersion OUTPUT_VARIABLE MACOSX_VERSION_RAW) - string(REGEX REPLACE - "10\\.([0-9]+).*" "\\1" - MACOSX_VERSION - "${MACOSX_VERSION_RAW}") - - # OSX Lion (10.7) and OS X Mountain Lion (10.8) doesn't automatically - # select the right stdlib. - if(${MACOSX_VERSION} LESS 9) - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} --stdlib=libc++") - set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} --stdlib=libc++") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") - endif() - endif() - - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -lm") - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -lm") - set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -lm") -endif() - -# Debugging CFLAGS. Turn optimizations off; turn debugging symbols on. -if(DEBUG) - add_definitions(-DDEBUG) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -ftemplate-backtrace-limit=0") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99 -g -O0") - # mlpack uses it's own mlpack::backtrace class based on Binary File Descriptor - # and linux Dynamic Loader and more portable version in - # future - if(CMAKE_SYSTEM_NAME STREQUAL "Linux") - find_package(Bfd) - find_package(LibDL) - if(LIBBFD_FOUND AND LIBDL_FOUND) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic") - include_directories(${LIBBFD_INCLUDE_DIRS}) - include_directories(${LIBDL_INCLUDE_DIRS}) - add_definitions(-DHAS_BFD_DL) - else() - message(WARNING "No libBFD and/or libDL has been found!") - endif() - endif() -else() - add_definitions(-DARMA_NO_DEBUG) - add_definitions(-DNDEBUG) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99 -O3") -endif() - -# Profiling CFLAGS. Turn profiling information on. -if(CMAKE_COMPILER_IS_GNUCC AND PROFILE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pg") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pg") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -pg") -endif() - -# Find dependencies. This isn't very robust. -find_package(Armadillo 4.100.0 REQUIRED) - -# If Armadillo was compiled without ARMA_64BIT_WORD and we are on a 64-bit -# system (where size_t will be 64 bits), suggest to the user that they should -# compile Armadillo with 64-bit words. Note that with Armadillo 5.000.0 and -# newer, ARMA_64BIT_WORD is enabled by default. -if(CMAKE_SIZEOF_VOID_P EQUAL 8) - # Check the version, to see if ARMA_64BIT_WORD is enabled by default. - set(ARMA_HAS_64BIT_WORD 0) - if(NOT (${ARMADILLO_VERSION_MAJOR} LESS 5)) - set(ARMA_HAS_64BIT_WORD 1) - else() - # Can we open the configuration file? If not, issue a warning. - if(NOT EXISTS "${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp") - message(WARNING "Armadillo configuration file " - "(${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp) does not -exist!") - else() - # We are on a 64-bit system. Does Armadillo have ARMA_64BIT_WORD enabled? - file(READ "${ARMADILLO_INCLUDE_DIR}/armadillo_bits/config.hpp" -ARMA_CONFIG) - string(REGEX MATCH - "[\r\n][ ]*#define ARMA_64BIT_WORD" - ARMA_HAS_64BIT_WORD_PRE - "${ARMA_CONFIG}") - - string(LENGTH "${ARMA_HAS_64BIT_WORD_PRE}" ARMA_HAS_64BIT_WORD) - endif() - endif() - - if(ARMA_HAS_64BIT_WORD EQUAL 0) - message(WARNING "This is a 64-bit system, but Armadillo was compiled " - "without 64-bit index support. Consider recompiling Armadillo with " - "ARMA_64BIT_WORD to enable 64-bit indices (large matrix support). " - "mlpack will still work without ARMA_64BIT_WORD defined, but will not " - "scale to matrices with more than 4 billion elements.") - endif() -else() - # If we are on a 32-bit system, we must manually specify the size of the word - # to be 32 bits, since otherwise Armadillo will produce a warning that it is - # disabling 64-bit support. - if (CMAKE_SIZEOF_VOID_P EQUAL 4) - add_definitions(-DARMA_32BIT_WORD) - endif () -endif() - - -# On Windows, Armadillo should be using LAPACK and BLAS but we still need to -# link against it. We don't want to use the FindLAPACK or FindBLAS modules -# because then we are required to have a FORTRAN compiler (argh!) so we will try -# and find LAPACK and BLAS ourselves, using a slightly modified variant of the -# script Armadillo uses to find these. -if (WIN32) - find_library(LAPACK_LIBRARY - NAMES lapack liblapack lapack_win32_MT lapack_win32 - PATHS "C:/Program Files/Armadillo" - PATH_SUFFIXES "examples/lib_win32/") - - if (NOT LAPACK_LIBRARY) - message(FATAL_ERROR "Cannot find LAPACK library (.lib)!") - endif () - - find_library(BLAS_LIBRARY - NAMES blas libblas blas_win32_MT blas_win32 - PATHS "C:/Program Files/Armadillo" - PATH_SUFFIXES "examples/lib_win32/") - - if (NOT BLAS_LIBRARY) - message(FATAL_ERROR "Cannot find BLAS library (.lib)!") - endif () - - # Piggyback LAPACK and BLAS linking into Armadillo link. - set(ARMADILLO_LIBRARIES - ${ARMADILLO_LIBRARIES} ${BLAS_LIBRARY} ${LAPACK_LIBRARY}) -endif () - -# Include directories for the previous dependencies. -include_directories(${ARMADILLO_INCLUDE_DIRS}) - -# Unfortunately this configuration variable is necessary and will need to be -# updated as time goes on and new versions are released. -set(Boost_ADDITIONAL_VERSIONS - "1.49.0" "1.50.0" "1.51.0" "1.52.0" "1.53.0" "1.54.0" "1.55.0") -find_package(Boost 1.49 - COMPONENTS - program_options - unit_test_framework - serialization - REQUIRED -) -include_directories(${Boost_INCLUDE_DIRS}) - -link_directories(${Boost_LIBRARY_DIRS}) - -# In Visual Studio, automatic linking is performed, so we don't need to worry -# about it. Clear the list of libraries to link against and let Visual Studio -# handle it. -if (MSVC) - link_directories(${Boost_LIBRARY_DIRS}) - set(Boost_LIBRARIES "") -endif () - -# For Boost testing framework (will have no effect on non-testing executables). -# This specifies to Boost that we are dynamically linking to the Boost test -# library. -add_definitions(-DBOOST_TEST_DYN_LINK) - -# On Windows, things end up under Debug/ or Release/. -if (WIN32) - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) - set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) -else () - # If not on Windows, put them under more standard UNIX-like places. This is - # necessary, otherwise they would all end up in - # ${CMAKE_BINARY_DIR}/src/mlpack/methods/... or somewhere else random like - # that. - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/) - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/) - set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib/) -endif () - -# Find the mlpack library and include directory. -set(MLPACK_LIBRARY "") -set(MLPACK_INCLUDE_DIR "") -find_library(MLPACK_LIBRARY - NAMES mlpack - PATHS /usr/lib64 /usr/lib /usr/local/lib64 /usr/local/lib -) - -find_path(MLPACK_INCLUDE_DIR mlpack/core.hpp - /usr/include/ - /usr/local/include/ -) - -if (MLPACK_LIBRARY AND MLPACK_INCLUDE_DIR) - mark_as_advanced(MLPACK_LIBRARY MLPACK_INCLUDE_DIR) - include_directories(${MLPACK_INCLUDE_DIR}) -else () - message(FATAL_ERROR "Could not find mlpack; try specifying MLPACK_LIBRARY and" - " MLPACK_INCLUDE_DIR") -endif () - -# Finally! Definitions of the files we are building. -add_executable(qdafn - qdafn_main.cpp - qdafn.hpp - qdafn_impl.hpp -) -target_link_libraries(qdafn - ${MLPACK_LIBRARY} - ${Boost_LIBRARIES} - ${ARMADILLO_LIBRARIES} -) - -add_executable(qdafn_test - qdafn_test.cpp -) -target_link_libraries(qdafn_test - ${MLPACK_LIBRARY} - ${Boost_LIBRARIES} - ${ARMADILLO_LIBRARIES} -) diff --git a/src/mlpack/methods/approx_kfn/qdafn/CXX11.cmake b/src/mlpack/methods/approx_kfn/qdafn/CXX11.cmake deleted file mode 100644 index 2dbfcc4b7bb..00000000000 --- a/src/mlpack/methods/approx_kfn/qdafn/CXX11.cmake +++ /dev/null @@ -1,48 +0,0 @@ -# This is cloned from -# https://github.com/nitroshare/CXX11-CMake-Macros -# until C++11 support finally hits CMake stable (should be 3.1, I think). - -# Copyright (c) 2013 Nathan Osman - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# Determines whether or not the compiler supports C++11 -macro(check_for_cxx11_compiler _VAR) - message(STATUS "Checking for C++11 compiler") - set(${_VAR}) - if((MSVC AND (MSVC10 OR MSVC11 OR MSVC12 OR MSVC14)) OR - (CMAKE_COMPILER_IS_GNUCXX AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.6) OR - (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.1) OR - (CMAKE_CXX_COMPILER_ID STREQUAL "Intel" AND NOT ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 12.0)) - set(${_VAR} 1) - message(STATUS "Checking for C++11 compiler - available") - else() - message(STATUS "Checking for C++11 compiler - unavailable") - endif() -endmacro() - -# Sets the appropriate flag to enable C++11 support -macro(enable_cxx11) - if(CMAKE_COMPILER_IS_GNUCXX OR - CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR - CMAKE_CXX_COMPILER_ID STREQUAL "Intel") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") - endif() -endmacro() - diff --git a/src/mlpack/methods/approx_kfn/qdafn/README.txt b/src/mlpack/methods/approx_kfn/qdafn/README.txt deleted file mode 100644 index eb1bbf8ab2b..00000000000 --- a/src/mlpack/methods/approx_kfn/qdafn/README.txt +++ /dev/null @@ -1,16 +0,0 @@ -This repository contains an implementation of the hashing algorithm for -approximate furthest neighbor search detailed in the paper - -"Approximate Furthest Neighbor in High Dimensions" -by Rasmus Pagh, Francesco Silverstri, Johan Siversten, and Matthew Skala -presented at SISAP 2015. - -There is another implementation available here: -https://github.com/johanvts/FN-Implementations - -but I wanted to re-implement this to ensure that I understood it correctly, and -so that I could get a better comparison. - -This code is built using mlpack and Armadillo, so when you configure with CMake -you may have to specify the installation directory of mlpack and Armadillo, if -they are not already installed on the system. diff --git a/src/mlpack/methods/approx_kfn/qdafn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp similarity index 100% rename from src/mlpack/methods/approx_kfn/qdafn/qdafn_impl.hpp rename to src/mlpack/methods/approx_kfn/qdafn_impl.hpp diff --git a/src/mlpack/methods/approx_kfn/qdafn/qdafn_main.cpp b/src/mlpack/methods/approx_kfn/qdafn_main.cpp similarity index 100% rename from src/mlpack/methods/approx_kfn/qdafn/qdafn_main.cpp rename to src/mlpack/methods/approx_kfn/qdafn_main.cpp diff --git a/src/mlpack/methods/approx_kfn/qdafn/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp similarity index 100% rename from src/mlpack/methods/approx_kfn/qdafn/qdafn_test.cpp rename to src/mlpack/tests/qdafn_test.cpp From 84bed62e1a30dbc7b2cea5db7333e78c980d4529 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 24 Oct 2016 05:08:49 -0400 Subject: [PATCH 15/30] Fix failing tests and bugs. --- .../approx_kfn/drusilla_select_impl.hpp | 54 +++++++++---------- src/mlpack/tests/drusilla_select_test.cpp | 17 +++--- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp index a84b30467bf..f264e64a779 100644 --- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp @@ -24,7 +24,8 @@ template DrusillaSelect::DrusillaSelect(const MatType& referenceSet, const size_t l, const size_t m) : - candidateSet(referenceSet.n_rows, l * m), + candidateSet(referenceSet.n_cols, l * m), + candidateIndices(l * m), l(l), m(m) { @@ -41,6 +42,8 @@ DrusillaSelect::DrusillaSelect(const MatType& referenceSet, // Constructor with no training. template DrusillaSelect::DrusillaSelect(const size_t l, const size_t m) : + candidateSet(0, l * m), + candidateIndices(l * m), l(l), m(m) { @@ -70,6 +73,9 @@ void DrusillaSelect::Train( "large! Choose smaller values. l*m must be smaller than the number " "of points in the dataset."); + candidateSet.set_size(referenceSet.n_rows, l * m); + candidateIndices.set_size(l * m); + arma::vec dataMean = arma::mean(referenceSet, 1); arma::vec norms(referenceSet.n_cols); @@ -87,29 +93,24 @@ void DrusillaSelect::Train( arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)); const size_t n_nonzero = (size_t) arma::sum(norms > 0); - // Calculate distortion and offset. - arma::vec distortions(referenceSet.n_cols); - arma::vec offsets(referenceSet.n_cols); + // Calculate distortion and offset and make scores. + std::vector closeAngle(referenceSet.n_cols, false); + arma::vec sums(referenceSet.n_cols); for (size_t j = 0; j < referenceSet.n_cols; ++j) { if (norms[j] > 0.0) { - offsets[j] = arma::dot(refCopy.col(j), line); - distortions[j] = arma::norm(refCopy.col(j) - offsets[j] * - line); + const double offset = arma::dot(refCopy.col(j), line); + const double distortion = arma::norm(refCopy.col(j) - offset * line); + sums[j] = std::abs(offset) - std::abs(distortion); + closeAngle[j] = + (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0)); } else { - offsets[j] = 0.0; - distortions[j] = 0.0; + sums[j] = norms[j]; } } - arma::vec sums = arma::abs(offsets) - arma::abs(distortions); - arma::uvec sortedSums = arma::sort_index(sums, "descend"); - - arma::vec bestSums(m); - arma::Col bestIndices(m); - bestSums.fill(-DBL_MAX); // Find the top m elements using a priority queue. typedef std::pair Candidate; @@ -117,11 +118,11 @@ void DrusillaSelect::Train( { bool operator()(const Candidate& c1, const Candidate& c2) { - return c2.first > c1.first; + return c2.first < c1.first; } }; - std::vector clist(m, std::make_pair(size_t(-1), double(0.0))); + std::vector clist(m, std::make_pair(double(-1.0), size_t(-1))); std::priority_queue, CandidateCmp> pq(CandidateCmp(), std::move(clist)); @@ -141,16 +142,17 @@ void DrusillaSelect::Train( const size_t index = pq.top().second; pq.pop(); candidateSet.col(i * m + j) = referenceSet.col(index); + candidateIndices[i * m + j] = index; - // Mark the norm as 0 so we don't see this point again. - norms[index] = 0.0; + // Mark the norm as -1 so we don't see this point again. + norms[index] = -1.0; } // Calculate angles from the current projection. Anything close enough, // mark the norm as 0. - arma::vec farPoints = arma::conv_to::from( - arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0)); - norms %= farPoints; + for (size_t j = 0; j < norms.n_elem; ++j) + if (norms[j] > 0.0 && closeAngle[j]) + norms[j] = 0.0; } } @@ -175,16 +177,14 @@ void DrusillaSelect::Search(const MatType& querySet, metric::EuclideanDistance metric; NeighborSearchRules> - rules(querySet, candidateSet, k, metric, 0, false); - - neighbors.set_size(k, querySet.n_cols); - neighbors.fill(size_t() - 1); - distances.zeros(k, querySet.n_cols); + rules(candidateSet, querySet, k, metric, 0, false); for (size_t q = 0; q < querySet.n_cols; ++q) for (size_t r = 0; r < candidateSet.n_cols; ++r) rules.BaseCase(q, r); + rules.GetResults(neighbors, distances); + // Map the neighbors back to their original indices in the reference set. for (size_t i = 0; i < neighbors.n_elem; ++i) neighbors[i] = candidateIndices[neighbors[i]]; diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp index 504fd6247ed..b60a1ad1281 100644 --- a/src/mlpack/tests/drusilla_select_test.cpp +++ b/src/mlpack/tests/drusilla_select_test.cpp @@ -21,7 +21,7 @@ BOOST_AUTO_TEST_SUITE(DrusillaSelectTest); BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest) { arma::mat dataset = arma::randu(5, 100); - dataset.col(100) += 100; // Make last column very large. + dataset.col(99) += 100; // Make last column very large. // Construct with some reasonable parameters. DrusillaSelect<> ds(dataset, 5, 5); @@ -29,7 +29,7 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest) // Query with every point except the extreme point. arma::mat distances; arma::Mat neighbors; - ds.Search(dataset.cols(0, 99), 1, neighbors, distances); + ds.Search(dataset.cols(0, 98), 1, neighbors, distances); BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99); BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); @@ -37,7 +37,9 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest) BOOST_REQUIRE_EQUAL(distances.n_rows, 1); for (size_t i = 0; i < 99; ++i) - BOOST_REQUIRE_EQUAL(neighbors[i], 100); + { + BOOST_REQUIRE_EQUAL(neighbors[i], 99); + } } // If we use only one projection with the number of points equal to what is in @@ -82,7 +84,6 @@ BOOST_AUTO_TEST_CASE(RetrainTest) arma::Mat neighbors; ds.Search(dataset, 1, neighbors, distances); - BOOST_REQUIRE_EQUAL(dataset.n_elem, 0); BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200); BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1); BOOST_REQUIRE_EQUAL(distances.n_cols, 200); @@ -97,11 +98,11 @@ BOOST_AUTO_TEST_CASE(SerializationTest) DrusillaSelect<> ds(dataset, 3, 3); - arma::mat fakeDataset1 = arma::randu(2, 5); - arma::mat fakeDataset2 = arma::randu(10, 8); - DrusillaSelect<> dsXml(fakeDataset1, 10, 10); + arma::mat fakeDataset1 = arma::randu(2, 15); + arma::mat fakeDataset2 = arma::randu(10, 18); + DrusillaSelect<> dsXml(fakeDataset1, 5, 3); DrusillaSelect<> dsText(2, 2); - DrusillaSelect<> dsBinary(5, 6); + DrusillaSelect<> dsBinary(5, 2); dsBinary.Train(fakeDataset2); // Now do the serialization. From ab2c213bf22e64ecd01f7a643cbd22a2e1b7e01b Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Mon, 24 Oct 2016 18:13:02 +0900 Subject: [PATCH 16/30] Remove accidental swap files. --- .../methods/approx_kfn/.drusilla_select.hpp.swp | Bin 16384 -> 0 bytes .../approx_kfn/.drusilla_select_impl.hpp.swo | Bin 45056 -> 0 bytes .../approx_kfn/.drusilla_select_impl.hpp.swp | Bin 20480 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp delete mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo delete mode 100644 src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp deleted file mode 100644 index ae44b2878c4201185817841f1f41ab0b2f854f98..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeI2`-@#y6~|9o(^hMuwHBcw*-0>+8D{RygEVFGAd|_{CX+EUNvo7z&fL52+><$v z>pADnWHQG70s4yw75u6ClchzhApYPFzEBie!Iu6Y6tuL6DAGppQ3>d0?S0Oz^n#*4$10@3+GSEqzV~=ii_V3%f zN91>N(M>>N(M>>N(M>>N(TPF7)aZ9I#1K3JF{+`%KqOm z@c*vt_rdJ{fr01Q&+=c%K*>PKK*>PKK*>PKK*>PKK*>PKK*>PKK*_-WAOmj0aelSM zakkupah(6>`~RCCa-83QpM#%*6g&w&1ExV0ys^!3egVD)G?)U1!A|fAa6i}vez29h z1Yjq4^MmLG&w&KQ;BjynOo9p+19yWP_c_kn;CJA+;LG3`xEI_7Uj2aMd<#Ut1Bbvv z;2-aIoWFw`;2R(X5%?6i4{QOi-s?E8fS15^Pyq^j9J~$}-vd`c49wc%9)y{!S z6*Vtt`*O2lr&Fz}vze3>>OhC7n|e`Lfh)z+&<%9L)eWue4UWOmPvwb_ zEx0vLji{@p1?|>bRxxI%nM$Fer^70NAg`_qRoPF`mA164ZVfi-&{C*b)OFQxQ}?`f zL$OaRk?(n3d4`qUb|DQ4nv+p`IhU93xjc__tnnP7tEU~Q=6hk+##s!D-LMe_wp=Y~ z$idh+y26bjFPv8qyjPSu<0i_-5B*LMCSr*7)Xb5B+zL?`c7iz_i!EGBl1Y*%hE4`R z_pq6@AhTycnZLHO+$P4*bjcdP2+S_G^=_fV_IXC9y$XtQ%LQHY7Tr{j3;V_^XLRhj zKC@?NBCfaaq}@<3n)R;}CfJ5_2}I?)Oyh z^BwwB2=&^)IHlsL`f3`x^;FfF7#>rvJ9pzKkbzSp<_|uRq}pxRA~SD%-Aole)7Evb z>FLI(jCg2$x8c9;!iDvi zVK^bg!xd!`jb^lQ+~FNodQx7OLZM-zzke#4lNbiIX-kFCIy@aGMs{M|B}w&v`kXo# zba|2Olkmmm`|Ws7xIO;W^&olO^!j=3oQWZyV90C4zHC#EkY^sZ+(noqiYk93n2We3pJFsyhM2k>SRu+HtQ^}KJKhIb>B^r zfsu1ohxhDtLvSlX^-;&^F8uQ4ms;s5w5vgiy35SNwyavmxn*^Ul3fkm)Ku-n#Y;11&z+m7&73@Ua&oqIW_lWpP1a@2+pJ|? zO}x!Ky<~^8m8R|SYISL8sS>+s;8&t}z8ZK*o%Qd0_o8uA78R9tIFEabWJ8xPI$VD~ zikWKxb5@2^NWoQ_DQf|j37lpF&NqpGI8Y)BG8MSX)DV5l{miW*1r<~dmA&rlv2tPY zz@6(BwTeYyrjZcfhVv{uD+>(XV|VFwjKF{{QX zT(Q^VrTa~k8wQSL+=;Nv&nD8*l6XqnZR>bk$|pkAzqKly%iMZj#|^r}tco`;E`~y5 zm?cYH}#K6-z)q7{Cxjk?C;+N-vqMf-v)lRjXg0K z0e6Etz{~9EzXC3SkAeHaAK1IU1fB&e;2KE5XTdBu2>!F0Lv7yA4HILY@D zU=nNtTfsjm|GsbnLhvLw1jfOG;7{!N%TLKb$w0|K$-sMvfz2jCHzl|0mQt$)R+kLC zm)Fn2QuRNm-J8R`9W}Qjo2QL$v(7EWjliwhUZ1+&&JU)lH0wr=vzuAOD6 zZ_}rJ3BNq+uuj!N)~5RGUTZS)^XzhW>A8N=wLfVqE6(l?m*u@|aP!?<(FrI(gb= z$KQb#n}OJzK~{DXW=+(wgv>3iZ_+GDW@e1`Z~>)j@*q^N#ldE}x?U^mb*`_io{sgm zCH9vYVL5*)_+29-T%=C()2T$nWAAD=$=^-sbY(n~=VyzTWfJp5$wsxiGJ~}jdUCs~ zvI6$F!Q`ssyuoI&@^a=IwP^i3>d{jva%!#E7Nx)c+?m;# z&AyL7$`Nz-oZs8GbMM@__kU;Z-2ean@4dsl7q42Q-duQNhR?~F%!lqg>5|#+S(@p6 zJd^Rpty;0v^cqf0LMQwTjaS!Lr46oI8FL!l1HB6`y!gEHmz{s^#=-O{Z%-uH<9nr^ zYNc+K#=7~h^{c09H%8sE-(%J5p1Z3!Hd5;;drjY|RIH+JSL{-w=v3>K!f3soPTu@7 z5-<{&NCM5qNcXuXW!|`OL9axa(=}V2ddfu;#Wu%{1dIfX1dIfX1dIfX1dIfX1dIfJ z3nkE~pOAS6<$8How&TO^m$iK#*A{+4TlnvUzh{Ntm=c0dlui_XW_bLpw(wVlzrFCg zyDj{c;qNFPJKDmdd`IDpw(wVkzh4u6H`~Hzguh=KeqYfR9{x06Mgm3xMgm3xMgm3x zMgm3xMgm3xMgm3xMglKM30UP!W)@2ET2YKq+9N&sS$_X3JOGE_YPbp{-o}s~zE6cy z;AQaXS7$PR0e=qff_K7pxERiWr%uXb4#JJ_LD&oLht;qEX2CPRo5}nF_QPHnhLtcE zetBXhb3fb%x4<7k5jMd97(1gkNINc^DpqJD~&@ zz#oEw$l#+?&33kq4!fRT7{SLP>pNH2#YQ%#y48Z*qD(T}w;M$X&W`2P_NMKP57~_Z z9wKhVCniz#r)!5*Y1+yiQPtU9K@v(eMm=|@@*Cy8KBq>wQ!ZNG)@IeNHL_hpRm*Fv zwEWS&z9HN4N~76apYm)j;WzBE8dHw1wtBYJusss8YO0HCIU_1NrutPim)XtFi9eQC zwfg$1R%5OzJAT8em25vZjWgC*jm26R{c~GF%1?S`$*uVfrO#4h`OFU2DXX}6i%D|B zc-`)g%cg=<( z>ARgmZ?MLsSm4%GRpCq}4c{<(kHQMfPkO(URSIQ77)NR@HJ)nk{fk!BRjpk# zwf3$sW+3MR2eWMfT4K$eu2<%zrq5ZFkt(LVgHCBIc!X&DNK0OEQhCXi*^xVPk#cC6 zD^d}T)S=F5+_!401G-H-Rh1xTTqNBf6bhjf^_}6!9Y4X{$vdiBmFWYlv8`H=+5{Pk z5L)sjUWlbnyQpOyy{kI~ zaj!_fqC*h*TJ7>~b$8HJ6;UO+FTuS(!oAVTRs04ngnig*mhPjN{;RK zjD#J4Vk;W&E%6{vCW*e$wEGiR?us*19H~j{F6n%)s;QDi=hJB@@+zIhVcV|p(39Oc z&Ir%`h+4CHaOuDW#Wl+}tX#KjsJL`+aQ(VViWjU{TU@sOqM=o*S1&CNEnmHSU_)`$ zn!(k@m4kyBHNOYcD*F9FugsE;6K$nh7COHV1-|d}d@bKsp_4p~46m8_>aA82r+H7U z$75d&Xpen4j7CR~b|UZjp84}LqW?dR-rfIdt^Y^k-@itm{}ucY9u?gl#zA!cZg?$- z?*9YMyA`g5ozR2^Y=#xE93<^uagA@lpTRY-0$vUepx5t#h43`G`&Z#B@FjQ@MAyF) zu7TBX4m^o2{}cEQJO-bIkHNKY8LWrX;aT+hpTUE06KsS9@H9I7UC;m@-UvTIU;hjo zfUDplSOZz;g6GiJzX?ykC*T8MLm9eZA()SmfRTWaz;CSt=AhU)BQ)63-)+6O^}D$J zX!Ne-g*wgiCW|&%*x2Ehmgv(tWg2mGlxak)Sy80u`$*(;q~}amsJd%KWgO`R0;8}t zF-BGCXX~z?f1wE-an3vChyR=pMG}CBBi~iH!@^`X=MzPzfZ>knNasPWqAqm}%OeDr56IizX0yFdiPY;I9 zjz${#)|z7$qYIfG1MkLv4-^SCH9(nR)>$-cF9q*?Qw ztu^dXGKSPMAN!N%i&j6sC#Q!7)4F|Xe9&_kC1tF2Ddr4khdmc7S|__o7~DVkh~aF8 z48sHi>r>S&+m(VeJ)-|Vjeb54T|PPf{}8_)gpa}=xDv)d;{A~Lhu~iLBwP<4hWEl4 z%!Qw#+wX_>!h2vPbi*m|pXmBOgWKRr*Z@QDdUz83{&v_8H-ZoAVJ(R6eb=EHGt0KNaia2;%g)8I950=yEw!C3wsa2s3&Z-d#O;QNf_-wOxeE*OJBm<3;A z4F68}NBDcV4gLo9!ZJ7>?q(c+9CpG+I1#?hgo|&%_hjNlbcc31HOAqB`>8xziDkZ` zHh2taV#OU1!)ZZ3C#i9IuUTt2Rl8_=p6g}1(zj|nf}QKb8Wui=jG8tW<88#_^Kl4I zR_G|8mS6&(YREpycl@m_$c#}0Lj0XFlt!ccY561`<#a&1 z-8Ax}==5xd#?^i^1}+ej(`#kZ=Xy?bN!@CLi%{#WQ37$EQa6MRXIh+Mz6L=2&|Htq}*bjSQGbnfi zi2wiR;073h%V7odLk8|a@4pFdfNS9WP=i5O3Fkou?nCcC1p8ndHp3=Z1h0ly!Vl2> z#b4ldi2MZ3;r}0@_kRbz4o|=*;E&-Q@CR@%{15v76Yynt4Bic+P=Ft!_ls@d^YAD< z0^rRa;H%gHu7y7Z7v{i!VGsDZ*aYMgJ?O?R5WT7ron~VfFjE1hJryAN($=oS z@{U$#E4+czj*vlU!cD&A3V&>=^GD*0pVVJFO&~~haEy<>@zKYz)cEKdAN{_*qwJ%P ze?;2zUe62DGp7PSZya^S*}Rf)`1W5!_x#^TT|MY$R(TIrSJlLF?J$Rhe}fquL0E;b zN-Kw%s%A)GTkj82IiL8yE}dRY3xW9f@n04=XzS!Ti$OHHyg+x02iIVJLv z4V{Q3`u3b;K$aE>R`e>!vLSxwmB|uSaCFu<3yPeLGL=u2D9^fIGCkjOO2JAO=|X}3|BO)27yX|h=HPQ1$A)1j3Olx*TT@yi0V4q;0V4q; z0V4q;0V4q;0V4q;0V9E9TLLUE9H{mkbm=feY%wOYT&NJ}|8vpt#lJtP|3Ag=zXH+!gRo<}-c9L^ z1dIfX1dIfX1dIfX1dIfX1dIfX1dIfX1llE_!cSLf&$3#xI?O~C<_QH`N(D|8u_uH$ zRLqv`l}ap#h+1F42SZgmt;%-z?rG^sCdq}_+O|+j?wZC#4_Qkh7wG?=!(V?MKK!Ep z&kRfR9D4uv;d}61xCd^6x4}AC0n6YV_&3fy03V09K@KE6SqJbUSPAEYtPk)nTLg41^hDH1+pK&RyZ5x!^`1u=J$Ubw!;`)2&>>! z__ECHhdba_ki7t|f-xw;7B~aG#oYdH!Xfww?13v`IlLC0WnTX;;12jS?1gv1Zny&S zAoKoZ&i{URAG{MTg!#|~C&RPM@&7se6rO^6;0Cx39FR2v&xAAJ+c;!A00-f2aA5 z(A@vj#Ih;Z0@SWLW(~mB-cUz*+dn+k>|Jdee+3?I?K==of49H5krsEgV0W-dr}zgs zi2c3l+Y5EKo{fK}ZkHe`PHseEL%wW|4TT6Viv?Nl9hLTn`+r;HPcqr`N3*3kn@aMO z%J}XeQ+f{PK(yfJK$To)aZHO!zzaq4E=!C?Yl4<2V6g9JmCXl>b<6RBo#L3L&8~pq z1$E*{-@imz58oUw6I_Bn_3hdCLP1_IoXZxYRrX)X;#220qK({w!*$P*jfBRF!TLLk z<5R=4^zKlt!~OYIF-Top?vzgdu$b7}xtJ_s++I|@n`-xhT%MgJn`~Bg%?d-76=wiC?7+wzL|WQKLT!DJjIy0-%Xzi9gu9{zxb zij1_|W(|Z1_9_e-Q+Xb{QktKP+scxM;z%E@y&c4=ur0QYDmXCe$_~fuKv$`tHU=vK z$UcXa`PFzyfG9B76I)UYbxcw8QcHPHtu%(Kf%l<#^5$j*BgUwi5a5POdGBfMlxh}B zh?h5?bg_iAT0Yg1iQ$&v>~)apAqI;!PQDlTKx;EawUb>ES%96{1FY@Nn%M=txrnaZ zc=iVUSa!K&R)Dl4GXMX6)WMI7PKf^RCG`I%(d{3DTVXFW;4;_@;xlj>ybAsoz5b{0 z6g&b~LN^=_uK?L!@D{ikTv!Sx!mHq)(d|D4H^a5C6D|W;_wOVSo&PcT1ndIY1MoDM z3D2P0KMe<9CzL_f`-^n{9{!&RC%_E24;}wq@J^8R|K197pbMTvzyA#EgO9>y*Z{qd zh1sCsaddpyA5hl&yA--%9z26y{}?<1cfdaQ2z&@W2(nM$Qdk7CKfrV7_V>em@NN*j zzXZeZI*>I055YKCumv{3rLYd(1iwJPzYpFH1F#fk!`G;_Y8VInP++{RTCb$m zDS^hO_j4#{t77VPUn|z!n#_)-4bN&lBNxsjrqR+?nqb*IZO1cKig2F37Q|sz7$>K? zk79fZs@cZWIak?H28u&YA&sJ_Il6e=4IYjZm%wcT*lfSi6fev=oZ)DJEL fk}`E*ZmJOZDRvChO8(^A0P0QOf=L6>4uJm$Z&XZV diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp deleted file mode 100644 index 9d5090f40630218c445a14d7018811412de030e6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeI3dyE}b8Nde+QC=b-CJOO%7t6iv-MibO(rmX2+bwMucH5OLg{0J**?Z^i&hEW4 zw=;9wF7-k2AH+W;@=zW@#DoYj3WBjU@~{S@{}3@T(F6zGZ(Pu?clW^>TyjJj0l&pL63Qt4547UN8)A!E|c%t{2#j zq%M0+Ewm@h`VQA^&D%lW=)l@f@7%a$)0TBNOy-MxKJ&tU->dhxTOG4L-^ahwU;S+> zm~|U|f5W$0Rz2u%cwOIawal9SU9;PrR&}=1$-le!RcN5lz;QIt4Vr!H&bj{bwQB|> zS8vY>b;0@9AE!W(RA`{kK%s#`1BC_(4HOzEG*D=u(7^wi27=CM#uup6sZqtwi0-E> zy1ys-?M3(ZFG@cx`aK@qZ(fuh{VZ;U1_})n8YnbSXrRzQp@BjJg$4=@6dEWrP-vjg zz<)pkX2USneZVj}Ap8HZ|NjquzXZ?25x5KP0LgblL_hZn;Cwg*9y`}C?uNVIRyYWI zU?+SO-Z;lF4#U0h6}TO4gK<~`7r~on8^+(^akw3(VJobJx6U$*XW(gg1a5&E?1E7! z!+CJ@Ov88zJ_`q63VPv9f)>xfA$Sn(ft%nX@Xz;=4zI(j@H6-}JOuZ_O>hZZ4DX_( z7vU**67GaKXuy?l1$+>Wyw@HRbh2Hr=;^nk1LzD{4>I@)oA7pvnmmx8;+QtorKNYqq+Ua+|8X zq9=TXQo*d}?o)oy7#y@6((Ojg^k%wk%Lz(7Q*F}=wwnIz;NX;HdiB{-c~E&4pYQ{# zq2`tCs~OKS1Iyz@rlWfJmfcjPc{QxsW#a%pm%mJ}Z4M5$&0wW!*nVI-b;~cGM2QJA z7}ClZUYRPXa@>-2*YN|TOR4#avDdX5DyiNOuiUZFv4)d+>ip5;m5EX-MqqKug$fS^ za;kn6 zB~1Gzx7nn1{c>3iZ&0O)+T<9?>#76gsQq-#y;ePazuFL$rxJT>fjgAsp?*!zZI?m; z5SQkir{SznNM-tJpc*HQnyq?QDymFpre*m`3Wj+(cjVDKRL67YI0l5oRi#Fpg<#fp zW>mf9`j)aBw>vXi(OpmLA=1RIY5LBX)39mmU30s4BkVwRAXW)7O&lYW@?EvhQVrL+ zj3%@!q@dq*M3>5(F>MEZYkm5gQZ1I|D#|YN`iy25e?0!Izg=rNDHCaZ+5AxjlKlBf z>DL|zg=)HSXOtXb|sqKRWyj(W{?OTr-rrEqq-|Y-$!FsV+vuDu`iEm2x;vv zNhK<(r(yYZ%W1?De=;eA?Vpr0Gjk4F(z)beW*9{kiThuCSGl1!w`gc9)}3%FqpQg{ zNC$n~SXRroqV`K)Dec*^9m}t>nw9K$+LOl_iS#?QurlG6JgYhCb{4`ybFyS+KEBG3 z(sE@U#&c%F4HgxyEdBjhY8flx94|5qOTA8_Um>2wBwDx9g*wx}(x!<76p0q+&-Eth z1Q~T4CqI)-I<==onB^OhI-)+D8)i<;ah;nj&!y#}X`meS4SJWZzf?XFD%)VHKUQtn zFxjM~a&N%CNSm&0?q{mC_D4BGJy&^HT9lwQ*4^iN4Hh+BX1yUzOnoQq`g}NqnEZH` zEOF9!*_qjv9=TrG%*-um2wU1vi<;P0i`oFQiBZ)GCF3Lc387kzw5WTAEj@mj-pLx( zryAG*6PD?OZ(ByNSTEYHOQeh+TV{P$EIu(1S@fb8iV~t6h>)#FGiE#^j9uzW0o&9;Z-#ro}7n2dp|I?^1;TGuCER`g%2A3s!Db`h^8 z+Gk$Wx{lZCrIP3?VplAYgk0k;Ypc5#Q&m)zF<-*I-=uG(QY9;JHL_vbCCpjH)he~h z+SJS{GpCCvU2-pd<%U+ftOkvx^Tkhd`CAoZNxL;R)k`{eWR^1lN5%gik75wN#-IN_JOOvZ95g|E{wkaaubpccFT){t9BzS0*bX0qU*XR` z0QbW#xD;N&k3R$**bXb;!|*nKyZH4_z^xF#jUaygCqaDs68siF{-^LoxCyRH_rtxg6V^f%-o{sd7Q~O=4;$ekcoUuc75)N;;SQJw8)jiU zYy;8VW+<*g1BC_(4g4Q!pcj|VZW53ZOL&PDU5BdT<<~EDT&&tKhKMnR;f=``CyK== z9zz(YTPkCmC>iwH#3J{woqT21gRa>k*-T7!f7WgidUM6BA;=dS#kwhu|5w*UYcc6PWgWypnv=V7 z>WG^%ZK@N!Bv(A?8Ye0x>A8|=7d!b+YY!_E?8!x}myJR%3g_v!Fmy7J^5fd_nfP-& zq~d#aki{xUT$ympSHgCRT<$o_B8igDSV^V_nKzD!FB)-q-8ok@-LBKno|I0ZJLckj zWgM2E>k$Rdha07f;zathZnxDC+um$6u=K0hC}dhid3XL75lS{H=4GU0uTUOF7haE* zPm}SKUSX6c-j9*6?;d8xj6iC}G76_rfy;wvxN*#B0fHXI(CSRc~SksCgXWZaM_OZVj7+I4)hX@y~qJ(&)g#xW8hD3%znw%L^W!QTyVRP=OpC}HYn@rvv+IL{S9j}d#}R1} zrOKkqBiX@YQ%3LLNSjiJd^}Gw>dl%^7kZy0GfbjAu)Jt|-KPtSzV)6)+R^yH7Asn# z#nnv;MmsPlrscZx%495xKzb#T`!d}=8%fu_IGb_FNL*OWjb1mf`p8TD!Y~etR2*i~ z&9SH(gbhN?LOn>JRia~z+cexJwGp#p};*XqDjFtT{l9}>!1v%D=)gvJFPsh%7kCs7!Z>V(3S0_D@%3@wgl09iq zoJ5eBEsmToXV>XGOaINSiN#3-Q4&i`aS}oImUwp~Iav^IQD*k>*=XDyuh02an*}DH zPiK=QF2BfStkCRoOhXhW5yCT9;*%675%f8;c(%j;EKVX6ClQL12u!BMNdykpq;64o xULy8YwoPUzW}JT8SMZ7 From a9452fc85fb9ffa58408c6e400b68a83d9ac0c9e Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 13:33:55 +0900 Subject: [PATCH 17/30] Move into mlpack namespaces. --- src/mlpack/methods/approx_kfn/qdafn.hpp | 10 +++-- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 42 ++++++++++---------- src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/qdafn_test.cpp | 11 +++-- 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp index 860acb83562..694fbf9b4b2 100644 --- a/src/mlpack/methods/approx_kfn/qdafn.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -16,12 +16,13 @@ * } * @endcode */ -#ifndef QDAFN_HPP -#define QDAFN_HPP +#ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_HPP +#define MLPACK_METHODS_APPROX_KFN_QDAFN_HPP #include -namespace qdafn { +namespace mlpack { +namespace neighbor { template class QDAFN @@ -77,7 +78,8 @@ class QDAFN const double distance) const; }; -} // namespace qdafn +} // namespace neighbor +} // namespace mlpack // Include implementation. #include "qdafn_impl.hpp" diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index 47a698cde0e..bf462da9d3e 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -4,8 +4,8 @@ * * Implementation of QDAFN class methods. */ -#ifndef QDAFN_IMPL_HPP -#define QDAFN_IMPL_HPP +#ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP +#define MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP // In case it hasn't been included yet. #include "qdafn.hpp" @@ -13,7 +13,8 @@ #include #include -namespace qdafn { +namespace mlpack { +namespace neighbor { // Constructor. template @@ -86,6 +87,10 @@ void QDAFN::Search(const MatType& querySet, arma::Col tableLocations = arma::zeros>(l); // Now that the queue is initialized, iterate over m elements. + std::vector> v(k, std::make_pair(-1.0, + size_t(-1))); + std::priority_queue> + resultsQueue(std::less>(), std::move(v)); for (size_t i = 0; i < m; ++i) { std::pair p = queue.top(); @@ -99,26 +104,12 @@ void QDAFN::Search(const MatType& querySet, querySet.col(q), referenceSet.col(referenceIndex)); // Is this neighbor good enough to insert into the results? - arma::vec queryDist = distances.unsafe_col(q); - arma::Col queryIndices = neighbors.unsafe_col(q); - const size_t insertPosition = - mlpack::neighbor::FurthestNeighborSort::SortDistance(queryDist, - queryIndices, dist); - bool found = false; - for (size_t j = 0; j < neighbors.n_rows; ++j) + if (dist > resultsQueue.top().first) { - if (neighbors(j, q) == referenceIndex) - { - found = true; - break; - } + resultsQueue.pop(); + resultsQueue.push(std::make_pair(dist, referenceIndex)); } - // SortDistance() returns (size_t() - 1) if we shouldn't add it. - if (insertPosition != (size_t() - 1) && !found) - InsertNeighbor(distances, neighbors, q, insertPosition, referenceIndex, - dist); - // Now (line 14) get the next element and insert into the queue. Do this // by adjusting the previous value. Don't insert anything if we are at // the end of the search, though. @@ -132,6 +123,14 @@ void QDAFN::Search(const MatType& querySet, queue.push(std::make_pair(val, p.second)); } } + + // Extract the results. + for (size_t j = 1; j <= k; ++j) + { + neighbors(k - j, q) = resultsQueue.top().second; + distances(k - j, q) = resultsQueue.top().first; + resultsQueue.pop(); + } } } @@ -160,6 +159,7 @@ void QDAFN::InsertNeighbor(arma::mat& distances, neighbors(pos, queryIndex) = neighbor; } -} // namespace qdafn +} // namespace neighbor +} // namespace mlpack #endif diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index a93f7bf9e8b..9f6965b614d 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -60,6 +60,7 @@ add_executable(mlpack_test nystroem_method_test.cpp pca_test.cpp perceptron_test.cpp + qdafn_test.cpp quic_svd_test.cpp radical_test.cpp randomized_svd_test.cpp diff --git a/src/mlpack/tests/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp index ee106102f22..ea64b526852 100644 --- a/src/mlpack/tests/qdafn_test.cpp +++ b/src/mlpack/tests/qdafn_test.cpp @@ -4,20 +4,21 @@ * * Test the QDAFN functionality. */ -#define BOOST_TEST_MODULE QDAFNTest - #include +#include "test_tools.hpp" +#include "serialization.hpp" #include -#include "qdafn.hpp" +#include #include using namespace std; using namespace arma; using namespace mlpack; -using namespace qdafn; using namespace mlpack::neighbor; +BOOST_AUTO_TEST_SUITE(QDAFNTest); + /** * With one reference point, make sure that is the one that is returned. */ @@ -100,3 +101,5 @@ BOOST_AUTO_TEST_CASE(QDAFNUniformSet) BOOST_REQUIRE_GE(successes, 700); } + +BOOST_AUTO_TEST_SUITE_END(); From 6c317b8b9a4a2a9fea729222878a549bc72bc93b Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 16:15:17 +0900 Subject: [PATCH 18/30] Refactor main program to include QDAFN. --- .../approx_kfn/drusilla_select_main.cpp | 289 ++++++++++++++---- 1 file changed, 226 insertions(+), 63 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp index 9e55ec721d7..4d6ef67c586 100644 --- a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp +++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp @@ -5,96 +5,259 @@ * Command-line program for the SmartHash algorithm. */ #include -#include "smarthash_fn.hpp" #include +#include "drusilla_select.hpp" +#include "qdafn.hpp" -using namespace smarthash; using namespace mlpack; +using namespace mlpack::neighbor; using namespace std; -PROGRAM_INFO("Query-dependent approximate furthest neighbor search", - "This program implements the algorithm from the SISAP 2015 paper titled " - "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. " - "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to " - "search in) with --reference_file, specify a query set (set to search for) " - "with --query_file, and specify algorithm parameters with --num_tables and " - "--num_projections (or don't, and defaults will be used). Also specify " - "the number of points to search for with --k. Each of those options has " - "short names too; see the detailed parameter documentation below." +PROGRAM_INFO("Approximate furthest neighbor search", + "This program implements two strategies for furthest neighbor search. " + "These strategies are:" + "\n\n" + " - The 'qdafn' algorithm from 'Approximate Furthest Neighbor in High " + "Dimensions' by R. Pagh, F. Silvestri, J. Sivertsen, and M. Skala, in " + "Similarity Search and Applications 2015 (SISAP)." + "\n" + " - The 'DrusillaSelect' algorithm from 'Fast approximate furthest " + "neighbors with data-dependent candidate selection, by R.R. Curtin and A.B." + " Gardner, in Similarity Search and Applications 2016 (SISAP)." + "\n\n" + "These two strategies give approximate results for the furthest neighbor " + "search problem and can be used as fast replacements for other furthest " + "neighbor techniques such as those found in the mlpack_kfn program. Note " + "that typically, the 'ds' algorithm requires far fewer tables and " + "projections than the 'qdafn' algorithm." + "\n\n" + "Specify a reference set (set to search in) with --reference_file, " + "specify a query set with --query_file, and specify algorithm parameters " + "with --num_tables (-l) and --num_projections (-m) (or don't and defaults " + "will be used). The algorithm to be used (either 'ds'---the default---or " + "'qdafn') may be specified with --algorithm. Also specify the number of " + "neighbors to search for with --k. Each of those options also has short " + "names; see the detailed parameter documentation below." + "\n\n" + "If no query file is specified, the reference set will be used as the " + "query set. A model may be saved with --output_model_file (-M), and an " + "input model may be loaded instead of specifying a reference set with " + "--input_model_file (-m)." "\n\n" "Results for each query point are stored in the files specified by " "--neighbors_file and --distances_file. This is in the same format as the " - "mlpack KFN and KNN programs: each row holds the k distances or neighbor " - "indices for each query point."); + "mlpack_kfn and mlpack_knn programs: each row holds the k distances or " + "neighbor indices for each query point."); + +PARAM_STRING_IN("reference_file", "File containing reference points.", "r", ""); +PARAM_STRING_IN("query_file", "File containing query points.", "q", ""); -PARAM_STRING_REQ("reference_file", "File containing reference points.", "r"); -PARAM_STRING_REQ("query_file", "File containing query points.", "q"); +// Model loading and saving. +PARAM_STRING_IN("input_model_file", "File containing input model.", "m", ""); +PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", ""); -PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k"); +PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k"); -PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10); -PARAM_INT("num_projections", "Number of projections to use in each hash table.", - "p", 30); +PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5); +PARAM_INT_IN("num_projections", "Number of projections to use in each hash " + "table.", "m", 5); +PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds"); -PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.", +PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.", "n", ""); -PARAM_STRING("distances_file", "File to save furthest neighbor distances to.", +PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.", "d", ""); PARAM_FLAG("calculate_error", "If set, calculate the average distance error.", "e"); -PARAM_STRING("exact_distances_file", "File containing exact distances", "x", ""); +PARAM_STRING_IN("exact_distances_file", "File containing exact distances to " + "furthest neighbors; this can be used to avoid explicit calculation when " + "--calculate_error is set.", "x", ""); + +// If we save a model we must also save what type it is. +class ApproxKFNModel +{ + public: + int type; + boost::any model; + + //! Constructor, which does nothing. + ApproxKFNModel() : type(0) { /* Nothing to do. */ } + + //! Serialize the model. + template + void Serialize(Archive& ar, const unsigned int /* version */) + { + ar & data::CreateNVP(type, "type"); + if (type == 0) + ar & data::CreateNVP(boost::any_cast>(model), "model"); + else + ar & data::CreateNVP(boost::any_cast>(model), "model"); + } +}; int main(int argc, char** argv) { CLI::ParseCommandLine(argc, argv); - const string referenceFile = CLI::GetParam("reference_file"); - const string queryFile = CLI::GetParam("query_file"); - const size_t k = (size_t) CLI::GetParam("k"); - const size_t numTables = (size_t) CLI::GetParam("num_tables"); - const size_t numProjections = (size_t) CLI::GetParam("num_projections"); - - // Load the data. - arma::mat referenceData, queryData; - data::Load(referenceFile, referenceData, true); - data::Load(queryFile, queryData, true); - - // Construct the object. - Timer::Start("smarthash_construct"); - SmartHash<> q(referenceData, numTables, numProjections); - Timer::Stop("smarthash_construct"); - - // Do the search. - arma::Mat neighbors; - arma::mat distances; - Timer::Start("smarthash_search"); - q.Search(queryData, k, neighbors, distances); - Timer::Stop("smarthash_search"); - - if (CLI::HasParam("calculate_error")) + if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file")) + Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must" + << " be specified!" << endl; + if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file")) + Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)" + << " can be specified!" << endl; + if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k")) + Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;" + << " no task will be performed." << endl; + if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") && + !CLI::HasParam("output_model_file")) + Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or " + << "--distances_file (-d) are specified; no output will be saved!" + << endl; + if (CLI::GetParam("algorithm") != "ds" && + CLI::GetParam("algorithm") != "qdafn") + Log::Fatal << "Unknown algorithm '" << CLI::GetParam("algorithm") + << "'; must be 'ds' or 'qdafn'!" << endl; + if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") || + CLI::HasParam("query_file"))) + Log::Fatal << "If search is being performed, then either --query_file " + << "or --reference_file must be specified!" << endl; + + if (CLI::GetParam("num_tables") <= 0) + Log::Fatal << "Invalid --num_tables value (" + << CLI::GetParam("num_tables") << "); must be greater than 0!" + << endl; + if (CLI::GetParam("num_projections") <= 0) + Log::Fatal << "Invalid --num_projections value (" + << CLI::GetParam("num_projections") << "); must be greater than 0!" + << endl; + + if (CLI::HasParam("calculate_error") && !CLI::HasParam("k")) + Log::Warn << "--calculate_error ignored because --k is not specified." + << endl; + if (CLI::HasParam("exact_distances_file") && + !CLI::HasParam("calculate_error")) + Log::Warn << "--exact_distances_file ignored beceause --calculate_error is " + << "not specified." << endl; + if (CLI::HasParam("calculate_error") && + !CLI::HasParam("exact_distances_file") && + !CLI::HasParam("reference_file")) + Log::Fatal << "Cannot calculate error without either --exact_distances_file" + << " or --reference_file specified!" << endl; + + // Do the building of a model, if necessary. + ApproxKFNModel m; + arma::mat referenceSet; // This may be used at query time. + if (CLI::HasParam("reference_file")) + { + const string referenceFile = CLI::GetParam("reference_file"); + data::Load(referenceFile, referenceSet); + + const size_t numTables = (size_t) CLI::GetParam("num_tables"); + const size_t numProjections = (size_t) CLI::GetParam("num_projections"); + const string algorithm = CLI::GetParam("algorithm"); + + if (algorithm == "ds") + { + Timer::Start("drusilla_select_construct"); + Log::Info << "Building DrusillaSelect model..." << endl; + m.type = 0; + m.model = boost::any(DrusillaSelect<>(referenceSet, numTables, + numProjections)); + Timer::Stop("drusilla_select_construct"); + } + else + { + Timer::Start("qdafn_construct"); + Log::Info << "Building QDAFN model..." << endl; + m.type = 1; + m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections)); + Timer::Stop("qdafn_construct"); + } + } + else { -// neighbor::AllkFN kfn(referenceData); + // We must load the model from file. + const string inputModelFile = CLI::GetParam("input_model_file"); + data::Load(inputModelFile, m); + } + + // Now, do we need to do any queries? + if (CLI::HasParam("k")) + { + const size_t k = (size_t) CLI::GetParam("k"); + + arma::Mat neighbors; + arma::mat distances; + + if (CLI::HasParam("query_file")) + { + const string queryFile = CLI::GetParam("query_file"); + arma::mat querySet; + data::Load(querySet, queryFile); + + if (m.type == 0) + { + Timer::Start("drusilla_select_search"); + boost::any_cast>(m.model).Search(querySet, k, + neighbors, distances); + Timer::Stop("drusilla_select_search"); + } + else + { + Timer::Start("qdafn_search"); + boost::any_cast>(m.model).Search(querySet, k, neighbors, + distances); + Timer::Stop("qdafn_search"); + } + } + else + { + // We will do search with the reference set. + if (m.type == 0) + boost::any_cast>(m.model).Search(k, neighbors, + distances); + else + boost::any_cast>(m.model).Search(k, neighbors, distances); + } -// arma::Mat trueNeighbors; - arma::mat trueDistances; - data::Load(CLI::GetParam("exact_distances_file"), trueDistances); + // Should we calculate error? + if (CLI::HasParam("calculate_error")) + { + arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet; + arma::mat exactDistances; + if (CLI::HasParam("exact_distances_file")) + { + data::Load(CLI::GetParam("exact_distances_file"), + exactDistances); + } + else + { + // Calculate exact distances. We are guaranteed the reference set is + // available. + AllkFN kfn(referenceSet); + arma::Mat exactNeighbors; + kfn.Search(set, k, exactNeighbors, exactDistances); -// kfn.Search(queryData, 1, trueNeighbors, trueDistances); + const double averageError = arma::sum(trueDistances / distances.row(0)) + / distances.n_cols; + const double minError = arma::min(trueDistances / distances.row(0)); + const double maxError = arma::max(trueDistances / distances.row(0)); - const double averageError = arma::sum(trueDistances / distances.row(0)) / - distances.n_cols; - const double minError = arma::min(trueDistances / distances.row(0)); - const double maxError = arma::max(trueDistances / distances.row(0)); + Log::Info << "Average error: " << averageError << "." << endl; + Log::Info << "Maximum error: " << maxError << "." << endl; + Log::Info << "Minimum error: " << minError << "." << endl; + } + } - Log::Info << "Average error: " << averageError << "." << endl; - Log::Info << "Maximum error: " << maxError << "." << endl; - Log::Info << "Minimum error: " << minError << "." << endl; + // Save results, if desired. + if (CLI::HasParam("neighbors_file")) + data::Save(CLI::GetParam("neighbors_file"), neighbors, false); + if (CLI::HasParam("distances_file")) + data::Save(CLI::GetParam("distances_file"), distances, false); } - // Save the results. - if (CLI::HasParam("neighbors_file")) - data::Save(CLI::GetParam("neighbors_file"), neighbors); - if (CLI::HasParam("distances_file")) - data::Save(CLI::GetParam("distances_file"), distances); + // Should we save the model? + if (CLI::HasParam("output_model_file")) + data::Save(CLI::GetParam("output_model_file"), m); } From 8fb09a11a6c4ec80983b56043a95bfa0865dd6b7 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 16:15:52 +0900 Subject: [PATCH 19/30] Move name of main program. --- src/mlpack/methods/approx_kfn/CMakeLists.txt | 5 ++--- .../{drusilla_select_main.cpp => approx_kfn_main.cpp} | 0 2 files changed, 2 insertions(+), 3 deletions(-) rename src/mlpack/methods/approx_kfn/{drusilla_select_main.cpp => approx_kfn_main.cpp} (100%) diff --git a/src/mlpack/methods/approx_kfn/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt index fa7846211f1..06b729ca557 100644 --- a/src/mlpack/methods/approx_kfn/CMakeLists.txt +++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt @@ -18,6 +18,5 @@ endforeach() # the parent scope). set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) -# The code to compute the approximate neighbor for the given query and reference -# sets with p-stable LSH. -add_cli_executable(drusilla_select) +# This program computes approximate furthest neighbors. +add_cli_executable(approx_kfn) diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp similarity index 100% rename from src/mlpack/methods/approx_kfn/drusilla_select_main.cpp rename to src/mlpack/methods/approx_kfn/approx_kfn_main.cpp From 6d7e0ee10359adff4a3dd15fa1cdeaf9f2f58921 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 16:57:05 +0900 Subject: [PATCH 20/30] Fix potential bug and simplify memory requirements. --- src/mlpack/methods/approx_kfn/qdafn.hpp | 15 ++---- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 50 ++++++++------------ 2 files changed, 25 insertions(+), 40 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp index 694fbf9b4b2..7617fc2b653 100644 --- a/src/mlpack/methods/approx_kfn/qdafn.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -51,10 +51,11 @@ class QDAFN arma::Mat& neighbors, arma::mat& distances); - private: - //! The reference set. - const MatType& referenceSet; + //! Serialize the model. + template + void Serialize(Archive& ar, const unsigned int /* version */); + private: //! The number of projections. const size_t l; //! The number of elements to store for each projection. @@ -69,13 +70,7 @@ class QDAFN //! Values of a_i * x for each point in S. arma::mat sValues; - //! Insert a neighbor into a set of results for a given query point. - void InsertNeighbor(arma::mat& distances, - arma::Mat& neighbors, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) const; + arma::cube candidateSet; }; } // namespace neighbor diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index bf462da9d3e..f1d04faf37e 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -21,7 +21,6 @@ template QDAFN::QDAFN(const MatType& referenceSet, const size_t l, const size_t m) : - referenceSet(referenceSet), l(l), m(m) { @@ -40,6 +39,7 @@ QDAFN::QDAFN(const MatType& referenceSet, // Loop over each projection and find the top m elements. sIndices.set_size(m, l); sValues.set_size(m, l); + candidateSet.set_size(referenceSet.n_rows, m, l); for (size_t i = 0; i < l; ++i) { arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend"); @@ -49,6 +49,7 @@ QDAFN::QDAFN(const MatType& referenceSet, { sIndices(j, i) = sortedIndices[j]; sValues(j, i) = projections(sortedIndices[j], i); + candidateSet.slice(l).col(j) = referenceSet.col(sortedIndices[j]); } } } @@ -77,8 +78,8 @@ void QDAFN::Search(const MatType& querySet, std::priority_queue> queue; for (size_t i = 0; i < l; ++i) { - const double val = projections(0, i) - arma::dot(querySet.col(q), - lines.col(i)); + const double val = sValues(0, i) - arma::dot(querySet.col(q), + lines.col(i)); queue.push(std::make_pair(val, i)); } @@ -97,17 +98,17 @@ void QDAFN::Search(const MatType& querySet, queue.pop(); // Get index of reference point to look at. - size_t referenceIndex = sIndices(tableLocations[p.second], p.second); + const size_t tableIndex = tableLocations[p.second]; // Calculate distance from query point. const double dist = mlpack::metric::EuclideanDistance::Evaluate( - querySet.col(q), referenceSet.col(referenceIndex)); + querySet.col(q), candidateSet.slice(p.second).col(tableIndex)); // Is this neighbor good enough to insert into the results? if (dist > resultsQueue.top().first) { resultsQueue.pop(); - resultsQueue.push(std::make_pair(dist, referenceIndex)); + resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second))); } // Now (line 14) get the next element and insert into the queue. Do this @@ -116,9 +117,8 @@ void QDAFN::Search(const MatType& querySet, if (i < m - 1) { tableLocations[p.second]++; - const double val = p.first - - projections(tableLocations[p.second] - 1, p.second) + - projections(tableLocations[p.second], p.second); + const double val = p.first - sValues(tableIndex, p.second) + + sValues(tableIndex + 1, p.second); queue.push(std::make_pair(val, p.second)); } @@ -135,28 +135,18 @@ void QDAFN::Search(const MatType& querySet, } template -void QDAFN::InsertNeighbor(arma::mat& distances, - arma::Mat& neighbors, - const size_t queryIndex, - const size_t pos, - const size_t neighbor, - const double distance) const +template +void QDAFN::Serialize(Archive& ar, const unsigned int /* version */) { - // We only memmove() if there is actually a need to shift something. - if (pos < (distances.n_rows - 1)) - { - const size_t len = (distances.n_rows - 1) - pos; - memmove(distances.colptr(queryIndex) + (pos + 1), - distances.colptr(queryIndex) + pos, - sizeof(double) * len); - memmove(neighbors.colptr(queryIndex) + (pos + 1), - neighbors.colptr(queryIndex) + pos, - sizeof(size_t) * len); - } - - // Now put the new information in the right index. - distances(pos, queryIndex) = distance; - neighbors(pos, queryIndex) = neighbor; + using data::CreateNVP; + + ar & CreateNVP(l, "l"); + ar & CreateNVP(m, "m"); + ar & CreateNVP(lines, "lines"); + ar & CreateNVP(projections, "projections"); + ar & CreateNVP(sIndices, "sIndices"); + ar & CreateNVP(sValues, "sValues"); + ar & CreateNVP(candidateSet, "candidateSet"); } } // namespace neighbor From 1795d726eb9eb83d357dbbdc5fb013c929d263f1 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 04:04:19 -0400 Subject: [PATCH 21/30] Fix invalid access. --- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index f1d04faf37e..9220989d82b 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -49,7 +49,7 @@ QDAFN::QDAFN(const MatType& referenceSet, { sIndices(j, i) = sortedIndices[j]; sValues(j, i) = projections(sortedIndices[j], i); - candidateSet.slice(l).col(j) = referenceSet.col(sortedIndices[j]); + candidateSet.slice(i).col(j) = referenceSet.col(sortedIndices[j]); } } } From fc4c3b87846e1b001fb1fcc64ccce8c11adad886 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 04:35:34 -0400 Subject: [PATCH 22/30] Fix incorrect inequality. --- .../methods/approx_kfn/approx_kfn_main.cpp | 73 +++++++++---------- .../approx_kfn/drusilla_select_impl.hpp | 2 +- src/mlpack/methods/approx_kfn/qdafn.hpp | 13 +++- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 4 + 4 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp index 4d6ef67c586..b5f2ac2406d 100644 --- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp +++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp @@ -33,7 +33,7 @@ PROGRAM_INFO("Approximate furthest neighbor search", "\n\n" "Specify a reference set (set to search in) with --reference_file, " "specify a query set with --query_file, and specify algorithm parameters " - "with --num_tables (-l) and --num_projections (-m) (or don't and defaults " + "with --num_tables (-t) and --num_projections (-p) (or don't and defaults " "will be used). The algorithm to be used (either 'ds'---the default---or " "'qdafn') may be specified with --algorithm. Also specify the number of " "neighbors to search for with --k. Each of those options also has short " @@ -54,13 +54,13 @@ PARAM_STRING_IN("query_file", "File containing query points.", "q", ""); // Model loading and saving. PARAM_STRING_IN("input_model_file", "File containing input model.", "m", ""); -PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", ""); +PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M"); -PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k"); +PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k", 0); -PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5); +PARAM_INT_IN("num_tables", "Number of hash tables to use.", "t", 5); PARAM_INT_IN("num_projections", "Number of projections to use in each hash " - "table.", "m", 5); + "table.", "p", 5); PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds"); PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.", @@ -79,10 +79,11 @@ class ApproxKFNModel { public: int type; - boost::any model; + DrusillaSelect<> ds; + QDAFN<> qdafn; //! Constructor, which does nothing. - ApproxKFNModel() : type(0) { /* Nothing to do. */ } + ApproxKFNModel() : type(0), ds(1, 1), qdafn(1, 1) { } //! Serialize the model. template @@ -90,9 +91,13 @@ class ApproxKFNModel { ar & data::CreateNVP(type, "type"); if (type == 0) - ar & data::CreateNVP(boost::any_cast>(model), "model"); + { + ar & data::CreateNVP(ds, "model"); + } else - ar & data::CreateNVP(boost::any_cast>(model), "model"); + { + ar & data::CreateNVP(qdafn, "model"); + } } }; @@ -162,8 +167,7 @@ int main(int argc, char** argv) Timer::Start("drusilla_select_construct"); Log::Info << "Building DrusillaSelect model..." << endl; m.type = 0; - m.model = boost::any(DrusillaSelect<>(referenceSet, numTables, - numProjections)); + m.ds = DrusillaSelect<>(referenceSet, numTables, numProjections); Timer::Stop("drusilla_select_construct"); } else @@ -171,7 +175,7 @@ int main(int argc, char** argv) Timer::Start("qdafn_construct"); Log::Info << "Building QDAFN model..." << endl; m.type = 1; - m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections)); + m.qdafn = QDAFN<>(referenceSet, numTables, numProjections); Timer::Stop("qdafn_construct"); } } @@ -179,52 +183,41 @@ int main(int argc, char** argv) { // We must load the model from file. const string inputModelFile = CLI::GetParam("input_model_file"); - data::Load(inputModelFile, m); + data::Load(inputModelFile, "approx_kfn", m); } // Now, do we need to do any queries? if (CLI::HasParam("k")) { + arma::mat querySet; // This may or may not be used. const size_t k = (size_t) CLI::GetParam("k"); arma::Mat neighbors; arma::mat distances; + arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet; if (CLI::HasParam("query_file")) { const string queryFile = CLI::GetParam("query_file"); - arma::mat querySet; - data::Load(querySet, queryFile); + data::Load(queryFile, querySet); + } - if (m.type == 0) - { - Timer::Start("drusilla_select_search"); - boost::any_cast>(m.model).Search(querySet, k, - neighbors, distances); - Timer::Stop("drusilla_select_search"); - } - else - { - Timer::Start("qdafn_search"); - boost::any_cast>(m.model).Search(querySet, k, neighbors, - distances); - Timer::Stop("qdafn_search"); - } + if (m.type == 0) + { + Timer::Start("drusilla_select_search"); + m.ds.Search(set, k, neighbors, distances); + Timer::Stop("drusilla_select_search"); } else { - // We will do search with the reference set. - if (m.type == 0) - boost::any_cast>(m.model).Search(k, neighbors, - distances); - else - boost::any_cast>(m.model).Search(k, neighbors, distances); + Timer::Start("qdafn_search"); + m.qdafn.Search(set, k, neighbors, distances); + Timer::Stop("qdafn_search"); } // Should we calculate error? if (CLI::HasParam("calculate_error")) { - arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet; arma::mat exactDistances; if (CLI::HasParam("exact_distances_file")) { @@ -239,10 +232,10 @@ int main(int argc, char** argv) arma::Mat exactNeighbors; kfn.Search(set, k, exactNeighbors, exactDistances); - const double averageError = arma::sum(trueDistances / distances.row(0)) + const double averageError = arma::sum(exactDistances / distances.row(0)) / distances.n_cols; - const double minError = arma::min(trueDistances / distances.row(0)); - const double maxError = arma::max(trueDistances / distances.row(0)); + const double minError = arma::min(exactDistances / distances.row(0)); + const double maxError = arma::max(exactDistances / distances.row(0)); Log::Info << "Average error: " << averageError << "." << endl; Log::Info << "Maximum error: " << maxError << "." << endl; @@ -259,5 +252,5 @@ int main(int argc, char** argv) // Should we save the model? if (CLI::HasParam("output_model_file")) - data::Save(CLI::GetParam("output_model_file"), m); + data::Save(CLI::GetParam("output_model_file"), "approx_kfn", m); } diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp index f264e64a779..95953745ea5 100644 --- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp @@ -104,7 +104,7 @@ void DrusillaSelect::Train( const double distortion = arma::norm(refCopy.col(j) - offset * line); sums[j] = std::abs(offset) - std::abs(distortion); closeAngle[j] = - (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0)); + (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0)); } else { diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp index 7617fc2b653..ad9e2062559 100644 --- a/src/mlpack/methods/approx_kfn/qdafn.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -28,6 +28,15 @@ template class QDAFN { public: + /** + * Construct the QDAFN object but do not train it. Be sure to call Train() + * before calling Search(). + * + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + QDAFN(const size_t l, const size_t m); + /** * Construct the QDAFN object with the given reference set (this is the set * that will be searched). @@ -57,9 +66,9 @@ class QDAFN private: //! The number of projections. - const size_t l; + size_t l; //! The number of elements to store for each projection. - const size_t m; + size_t m; //! The random lines we are projecting onto. Has l columns. arma::mat lines; //! Projections of each point onto each random line. diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index 9220989d82b..85ec99a6a16 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -16,6 +16,10 @@ namespace mlpack { namespace neighbor { +// Non-training constructor. +template +QDAFN::QDAFN(const size_t l, const size_t m) : l(l), m(m) { } + // Constructor. template QDAFN::QDAFN(const MatType& referenceSet, From 11326ef212e0982f611a40f6a1261651f654db8b Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Tue, 25 Oct 2016 04:44:53 -0400 Subject: [PATCH 23/30] Fix output and documentation. Also ensure that error is computed when a file is supplied. --- .../methods/approx_kfn/approx_kfn_main.cpp | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp index b5f2ac2406d..97de7c5a84d 100644 --- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp +++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp @@ -68,8 +68,8 @@ PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.", PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.", "d", ""); -PARAM_FLAG("calculate_error", "If set, calculate the average distance error.", - "e"); +PARAM_FLAG("calculate_error", "If set, calculate the average distance error for" + " the first furthest neighbor only.", "e"); PARAM_STRING_IN("exact_distances_file", "File containing exact distances to " "furthest neighbors; this can be used to avoid explicit calculation when " "--calculate_error is set.", "x", ""); @@ -178,6 +178,7 @@ int main(int argc, char** argv) m.qdafn = QDAFN<>(referenceSet, numTables, numProjections); Timer::Stop("qdafn_construct"); } + Log::Info << "Model built." << endl; } else { @@ -205,15 +206,20 @@ int main(int argc, char** argv) if (m.type == 0) { Timer::Start("drusilla_select_search"); + Log::Info << "Searching for " << k << " furthest neighbors with " + << "DrusillaSelect..." << endl; m.ds.Search(set, k, neighbors, distances); Timer::Stop("drusilla_select_search"); } else { Timer::Start("qdafn_search"); + Log::Info << "Searching for " << k << " furthest neighbors with " + << "QDAFN..." << endl; m.qdafn.Search(set, k, neighbors, distances); Timer::Stop("qdafn_search"); } + Log::Info << "Search complete." << endl; // Should we calculate error? if (CLI::HasParam("calculate_error")) @@ -228,19 +234,21 @@ int main(int argc, char** argv) { // Calculate exact distances. We are guaranteed the reference set is // available. + Log::Info << "Calculating exact distances..." << endl; AllkFN kfn(referenceSet); arma::Mat exactNeighbors; - kfn.Search(set, k, exactNeighbors, exactDistances); + kfn.Search(set, 1, exactNeighbors, exactDistances); + Log::Info << "Calculation complete." << endl; + } - const double averageError = arma::sum(exactDistances / distances.row(0)) - / distances.n_cols; - const double minError = arma::min(exactDistances / distances.row(0)); - const double maxError = arma::max(exactDistances / distances.row(0)); + const double averageError = arma::sum(exactDistances / distances.row(0)) / + distances.n_cols; + const double minError = arma::min(exactDistances / distances.row(0)); + const double maxError = arma::max(exactDistances / distances.row(0)); - Log::Info << "Average error: " << averageError << "." << endl; - Log::Info << "Maximum error: " << maxError << "." << endl; - Log::Info << "Minimum error: " << minError << "." << endl; - } + Log::Info << "Average error: " << averageError << "." << endl; + Log::Info << "Maximum error: " << maxError << "." << endl; + Log::Info << "Minimum error: " << minError << "." << endl; } // Save results, if desired. From c3d38252d944d0902f91c98cd589addc352abac7 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 20:50:24 +0900 Subject: [PATCH 24/30] Handle situations where the user passes in a distances matrix not a distance column. --- src/mlpack/methods/approx_kfn/approx_kfn_main.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp index 97de7c5a84d..794385039f5 100644 --- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp +++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp @@ -241,10 +241,12 @@ int main(int argc, char** argv) Log::Info << "Calculation complete." << endl; } - const double averageError = arma::sum(exactDistances / distances.row(0)) / - distances.n_cols; - const double minError = arma::min(exactDistances / distances.row(0)); - const double maxError = arma::max(exactDistances / distances.row(0)); + const double averageError = arma::sum(exactDistances.row(0) / + distances.row(0)) / distances.n_cols; + const double minError = arma::min(exactDistances.row(0) / + distances.row(0)); + const double maxError = arma::max(exactDistances.row(0) / + distances.row(0)); Log::Info << "Average error: " << averageError << "." << endl; Log::Info << "Maximum error: " << maxError << "." << endl; From 4b35ecc9e8490fcb2fa499c685494b6604789ce4 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 20:50:40 +0900 Subject: [PATCH 25/30] Refactor QDAFN to better handle sparse data matrices. --- src/mlpack/methods/approx_kfn/qdafn.hpp | 8 +++++++- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 7 ++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp index ad9e2062559..f7949db6e7d 100644 --- a/src/mlpack/methods/approx_kfn/qdafn.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -64,6 +64,11 @@ class QDAFN template void Serialize(Archive& ar, const unsigned int /* version */); + //! Get the candidate set for the given projection table. + const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; } + //! Modify the candidate set for the given projection table. Careful! + MatType& CandidateSet(const size_t t) { return candidateSet[t]; } + private: //! The number of projections. size_t l; @@ -79,7 +84,8 @@ class QDAFN //! Values of a_i * x for each point in S. arma::mat sValues; - arma::cube candidateSet; + // Candidate sets; one element in the vector for each table. + std::vector candidateSet; }; } // namespace neighbor diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index 85ec99a6a16..de6c882108f 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -43,9 +43,10 @@ QDAFN::QDAFN(const MatType& referenceSet, // Loop over each projection and find the top m elements. sIndices.set_size(m, l); sValues.set_size(m, l); - candidateSet.set_size(referenceSet.n_rows, m, l); + candidateSet.resize(l); for (size_t i = 0; i < l; ++i) { + candidateSet[i].set_size(referenceSet.n_rows, m); arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend"); // Grab the top m elements. @@ -53,7 +54,7 @@ QDAFN::QDAFN(const MatType& referenceSet, { sIndices(j, i) = sortedIndices[j]; sValues(j, i) = projections(sortedIndices[j], i); - candidateSet.slice(i).col(j) = referenceSet.col(sortedIndices[j]); + candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]); } } } @@ -106,7 +107,7 @@ void QDAFN::Search(const MatType& querySet, // Calculate distance from query point. const double dist = mlpack::metric::EuclideanDistance::Evaluate( - querySet.col(q), candidateSet.slice(p.second).col(tableIndex)); + querySet.col(q), candidateSet[p.second].col(tableIndex)); // Is this neighbor good enough to insert into the results? if (dist > resultsQueue.top().first) From 5ca6936f03af839e3bf58752c1b7125fc6e20960 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 20:50:52 +0900 Subject: [PATCH 26/30] Add approximate furthest neighbor search tutorial. --- doc/tutorials/approx_kfn/approx_kfn.txt | 1025 +++++++++++++++++++++++ 1 file changed, 1025 insertions(+) create mode 100644 doc/tutorials/approx_kfn/approx_kfn.txt diff --git a/doc/tutorials/approx_kfn/approx_kfn.txt b/doc/tutorials/approx_kfn/approx_kfn.txt new file mode 100644 index 00000000000..aa477a0d097 --- /dev/null +++ b/doc/tutorials/approx_kfn/approx_kfn.txt @@ -0,0 +1,1025 @@ +/*! + +@file approx_kfn.txt +@author Ryan Curtin +@brief Tutorial for how to use approximate furthest neighbor search in mlpack. + +@page akfntutorial Approximate furthest neighbor search (mlpack_approx_kfn) tutorial + +@section intro_akfntut Introduction + +\b mlpack implements multiple strategies for approximate furthest neighbor +search in its \c mlpack_approx_kfn and \c mlpack_kfn programs (each program +corresponds to different techniques). This tutorial discusses what problems +these algorithms solve and how to use each of the techniques that \b mlpack +implements. + +\b mlpack implements five approximate furthest neighbor search algorithms: + + - brute-force search (in \c mlpack_kfn) + - single-tree search (in \c mlpack_kfn) + - dual-tree search (in \c mlpack_kfn) + - query-dependent approximate furthest neighbor (QDAFN) (in \c mlpack_approx_kfn) + - DrusillaSelect (in \c mlpack_approx_kfn) + +These methods are described in the following papers: + +@code +@inproceedings{curtin2013tree, + title={Tree-Independent Dual-Tree Algorithms}, + author={Curtin, Ryan R. and March, William B. and Ram, Parikshit and Anderson, + David V. and Gray, Alexander G. and Isbell Jr., Charles L.}, + booktitle={Proceedings of The 30th International Conference on Machine + Learning (ICML '13)}, + pages={1435--1443}, + year={2013} +} +@endcode + +@code +@incollection{pagh2015approximate, + title={Approximate furthest neighbor in high dimensions}, + author={Pagh, Rasmus and Silvestri, Francesco and Sivertsen, Johan and Skala, + Matthew}, + booktitle={Similarity Search and Applications}, + pages={3--14}, + year={2015}, + publisher={Springer} +} +@endcode + +@code +@incollection{curtin2016fast, + title={Fast approximate furthest neighbors with data-dependent candidate + selection}, + author={Curtin, Ryan R., and Gardner, Andrew B.}, + booktitle={Similarity Search and Applications}, + pages={221--235}, + year={2016}, + publisher={Springer} +} +@endcode + +The problem of furthest neighbor search is simple, and is the opposite of the +much-more-studied nearest neighbor search problem. Given a set of reference +points \f$R\f$ (the set in which we are searching), and a set of query points +\f$Q\f$ (the set of points for which we want the furthest neighbor), our goal is +to return the \f$k\f$ furthest neighbors for each query point in \f$Q\f$: + +\f[ +\operatorname{k-argmax}_{p_r \in R} d(p_q, p_r). +\f] + +In order to solve this problem, \b mlpack provides a number of interfaces. + + - two \ref cli_akfntut "simple command-line executables" to calculate + approximate furthest neighbors + - a simple \ref cpp_qdafn_akfntut "C++ class for QDAFN" + - a simple \ref cpp_ds_akfntut "C++ class for DrusillaSelect" + - a simple \ref cpp_kfn_akfntut "C++ class for tree-based and brute-force" + search + +@section toc_akfntut Table of Contents + +A list of all the sections this tutorial contains. + + - \ref intro_akfntut + - \ref toc_akfntut + - \ref which_akfntut + - \ref cli_akfntut + - \ref cli_ex1_akfntut + - \ref cli_ex2_akfntut + - \ref cli_ex3_akfntut + - \ref cli_ex4_akfntut + - \ref cli_ex5_akfntut + - \ref cli_ex6_akfntut + - \ref cli_ex7_akfntut + - \ref cli_ex8_akfntut + - \ref cli_final_akfntut + - \ref cpp_ds_akfntut + - \ref cpp_ex1_ds_akfntut + - \ref cpp_ex2_ds_akfntut + - \ref cpp_ex3_ds_akfntut + - \ref cpp_ex4_ds_akfntut + - \ref cpp_ex5_ds_akfntut + - \ref cpp_qdafn_akfntut + - \ref cpp_ex1_qdafn_akfntut + - \ref cpp_ex2_qdafn_akfntut + - \ref cpp_ex3_qdafn_akfntut + - \ref cpp_ex4_qdafn_akfntut + - \ref cpp_ex5_qdafn_akfntut + - \ref cpp_ns_akfntut + - \ref cpp_ex1_ns_akfntut + - \ref cpp_ex2_ns_akfntut + - \ref cpp_ex3_ns_akfntut + - \ref cpp_ex4_ns_akfntut + - \ref further_doc_akfntut + +@section which_akfntut Which algorithm should be used? + +There are three algorithms for furthest neighbor search that \b mlpack +implements, and each is suited to a different setting. Below is some basic +guidance on what should be used. Note that the question of "which algorithm +should be used" is a very difficult question to answer, so the guidance below is +just that---guidance---and may not be right for a particular problem. + + - \c DrusillaSelect is very fast and will perform extremely well for datasets + with outliers or datasets with structure (like low-dimensional datasets + embedded in high dimensions) + - \c QDAFN is a random approach and therefore should be well-suited for + datasets with little to no structure + - The tree-based approaches (the \c KFN class and the \c mlpack_kfn program) is + best suited for low-dimensional datasets, and is most effective when very + small levels of approximation are desired, or when exact results are desired. + - Dual-tree search is most useful when the query set is large and structured + (like for all-furthest-neighbor search). + - Single-tree search is more useful when the query set is small. + +@section cli_akfntut Command-line 'mlpack_approx_kfn' and 'mlpack_kfn' + +\b mlpack provides two command-line programs to solve approximate furthest +neighbor search: + + - \c mlpack_approx_kfn, for the QDAFN and DrusillaSelect approaches + - \c mlpack_kfn, for exact and approximate tree-based approaches + +These two programs allow a large number of algorithms to be used to find +approximate furthest neighbors. Note that the \c mlpack_kfn program is also +documented by the \ref cli_nstut section of the \ref nstutorial page, as it +shares options with the \c mlpack_knn program. + +Below are several examples of how the \c mlpack_approx_kfn and \c mlpack_kfn +programs might be used. The first examples focus on the \c mlpack_approx_kfn +program, and the last few show how \c mlpack_kfn can be used to produce +approximate results. + +@subsection cli_ex1_akfntut Calculate 5 furthest neighbors with default options + +Here we have a query dataset \c queries.csv and a reference dataset \c refs.csv +and we wish to find the 5 furthest neighbors of every query point in the +reference dataset. We may do that with the \c mlpack_approx_kfn algorithm, +using the default of the \c DrusillaSelect algorithm with default parameters. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 5 -n n.csv -d d.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 5 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 5 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000342s +[INFO ] drusilla_select_search: 0.000780s +[INFO ] loading_data: 0.010689s +[INFO ] saving_data: 0.005585s +[INFO ] total_time: 0.018592s +@endcode + +Convenient timers for parts of the program operation are printed. The results, +saved in \c n.csv and \c d.csv, indicate the furthest neighbors and distances +for each query point. The row of the output file indicates the query point that +the results are for. The neighbors are listed from furthest to nearest; so, the +4th element in the 3rd row of \c d.csv indicates the distance between the 3rd +query point in \c queries.csv and its approximate 4th furthest neighbor. +Similarly, the same element in \c n.csv indicates the index of the approximate +4th furthest neighbor (with respect to \c refs.csv). + +@subsection cli_ex2_akfntut Specifying algorithm parameters for DrusillaSelect + +The \c -p (\c --num_projections) and \c -t (\c --num_tables) parameters affect +the running of the \c DrusillaSelect algorithm and the QDAFN algorithm. +Specifically, larger values for each of these parameters will search more +possible candidate furthest neighbors and produce better results (at the cost of +runtime). More details on how each of these parameters works is available in +the original papers, the \b mlpack source, or the documentation given by +\c --help. + +In the example below, we run \c DrusillaSelect to find 4 furthest neighbors +using 10 tables and 2 points in each table. In this case we have chosen to omit +the \c -n \c n.csv option, meaning that only the output candidate distances will +be written to \c d.csv. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 4 -n n.csv -d d.csv -t 10 -p 2 +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 4 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 4 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 2 +[INFO ] num_tables: 10 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000645s +[INFO ] drusilla_select_search: 0.000551s +[INFO ] loading_data: 0.008518s +[INFO ] saving_data: 0.003734s +[INFO ] total_time: 0.014019s +@endcode + +@subsection cli_ex3_akfntut Using QDAFN instead of DrusillaSelect + +The algorithm to be used for approximate furthest neighbor search can be +specified with the \c --algorithm (\c -a) option to the \c mlpack_approx_kfn +program. Below, we use the QDAFN algorithm instead of the default. We leave +the \c -p and \c -t options at their defaults---even though QDAFN often requires +more tables and points to get the same quality of results. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 3 -n n.csv -d d.csv -a qdafn +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building QDAFN model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 3 furthest neighbors with QDAFN... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: qdafn +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] loading_data: 0.008380s +[INFO ] qdafn_construct: 0.003399s +[INFO ] qdafn_search: 0.000886s +[INFO ] saving_data: 0.002253s +[INFO ] total_time: 0.015465s +@endcode + +@subsection cli_ex4_akfntut Printing results quality with exact distances + +The \c mlpack_approx_kfn program can calculate the quality of the results if the +\c --calculate_error (\c -e) flag is specified. Below we use the program with +its default parameters and calculate the error, which is displayed in the +output. The error is only calculated for the furthest neighbor, not all k; +therefore, in this example we have set \c -k to \c 1. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -v -k 1 -e -q -n n.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 1 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Calculating exact distances... +[INFO ] 28891 node combinations were scored. +[INFO ] 37735 base cases were calculated. +[INFO ] Calculation complete. +[INFO ] Average error: 1.08417. +[INFO ] Maximum error: 1.28712. +[INFO ] Minimum error: 1. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: true +[INFO ] distances_file: "" +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] neighbors_file: "" +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.001476s +[INFO ] drusilla_select_construct: 0.000309s +[INFO ] drusilla_select_search: 0.000495s +[INFO ] loading_data: 0.008462s +[INFO ] total_time: 0.011670s +[INFO ] tree_building: 0.000202s +@endcode + +Note that the output includes three lines indicating the error: + +@code +[INFO ] Average error: 1.08417. +[INFO ] Maximum error: 1.28712. +[INFO ] Minimum error: 1. +@endcode + +In this case, a minimum error of 1 indicates an exact result, and over the +entire query set the algorithm has returned a furthest neighbor candidate with +maximum error 1.28712. + +@subsection cli_ex5_akfntut Using cached exact distances for quality results + +However, for large datasets, calculating the error may take a long time, because +the exact furthest neighbors must be calculated. Therefore, if the exact +furthest neighbor distances are already known, they may be passed in with the +\c --exact_distances_file (\c -x) option in order to avoid the calculation. In +the example below, we assume \c exact.csv contains the exact furthest neighbor +distances. We run the \c qdafn algorithm in this example. + +Note that the \c -e option must be specified for the \c -x option have any +effect. + +@code +$ mlpack_approx_kfn -q queries.csv -r refs.csv -k 1 -e -x exact.csv -n n.csv -v -a qdafn +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building QDAFN model... +[INFO ] Model built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 1 furthest neighbors with QDAFN... +[INFO ] Search complete. +[INFO ] Loading 'exact.csv' as raw ASCII formatted data. Size is 1 x 1000. +[INFO ] Average error: 1.06914. +[INFO ] Maximum error: 1.67407. +[INFO ] Minimum error: 1. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: qdafn +[INFO ] calculate_error: true +[INFO ] distances_file: "" +[INFO ] exact_distances_file: exact.csv +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 1 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] loading_data: 0.010348s +[INFO ] qdafn_construct: 0.000318s +[INFO ] qdafn_search: 0.000793s +[INFO ] saving_data: 0.000259s +[INFO ] total_time: 0.012254s +@endcode + +@subsection cli_ex6_akfntut Using tree-based approximation with mlpack_kfn + +The \c mlpack_kfn algorithm allows specifying a desired approximation level with +the \c --epsilon (\c -e) option. The parameter must be greater than or equal +to 0 and less than 1. A setting of 0 indicates exact search. + +The example below runs dual-tree furthest neighbor search (the default +algorithm) with the approximation parameter set to 0.5. + +@code +$ mlpack_kfn -q queries.csv -r refs.csv -v -k 3 -e 0.5 -n n.csv -d d.csv +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded reference data from 'refs.csv' (3x1000). +[INFO ] Building reference tree... +[INFO ] Tree built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded query data from 'queries.csv' (3x1000). +[INFO ] Searching for 3 neighbors with dual-tree kd-tree search... +[INFO ] 1611 node combinations were scored. +[INFO ] 13938 base cases were calculated. +[INFO ] 1611 node combinations were scored. +[INFO ] 13938 base cases were calculated. +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: dual_tree +[INFO ] distances_file: d.csv +[INFO ] epsilon: 0.5 +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] leaf_size: 20 +[INFO ] naive: false +[INFO ] neighbors_file: n.csv +[INFO ] output_model_file: "" +[INFO ] percentage: 1 +[INFO ] query_file: queries.csv +[INFO ] random_basis: false +[INFO ] reference_file: refs.csv +[INFO ] seed: 0 +[INFO ] single_mode: false +[INFO ] tree_type: kd +[INFO ] true_distances_file: "" +[INFO ] true_neighbors_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.000442s +[INFO ] loading_data: 0.008060s +[INFO ] saving_data: 0.002850s +[INFO ] total_time: 0.012667s +[INFO ] tree_building: 0.000251s +@endcode + +Note that the format of the output files \c d.csv and \c n.csv are the same as +for \c mlpack_approx_kfn. + +@subsection cli_ex7_akfntut Different algorithms with 'mlpack_kfn' + +The \c mlpack_kfn program offers a large number of different algorithms that can +be used. The \c --algorithm (\c -a) may be used to specify three main different +algorithm types: \c naive (brute-force search), \c single_tree (single-tree +search), \c dual_tree (dual-tree search, the default), and \c greedy +("defeatist" greedy search, which goes to one leaf node of the tree then +terminates). The example below uses single-tree search to find approximate +neighbors with epsilon set to 0.1. + +@code +mlpack_kfn -q queries.csv -r refs.csv -v -k 3 -e 0.1 -n n.csv -d d.csv -a single_tree +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded reference data from 'refs.csv' (3x1000). +[INFO ] Building reference tree... +[INFO ] Tree built. +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Loaded query data from 'queries.csv' (3x1000). +[INFO ] Searching for 3 neighbors with single-tree kd-tree search... +[INFO ] 13240 node combinations were scored. +[INFO ] 15924 base cases were calculated. +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: single_tree +[INFO ] distances_file: d.csv +[INFO ] epsilon: 0.1 +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 3 +[INFO ] leaf_size: 20 +[INFO ] naive: false +[INFO ] neighbors_file: n.csv +[INFO ] output_model_file: "" +[INFO ] percentage: 1 +[INFO ] query_file: queries.csv +[INFO ] random_basis: false +[INFO ] reference_file: refs.csv +[INFO ] seed: 0 +[INFO ] single_mode: false +[INFO ] tree_type: kd +[INFO ] true_distances_file: "" +[INFO ] true_neighbors_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] computing_neighbors: 0.000850s +[INFO ] loading_data: 0.007858s +[INFO ] saving_data: 0.003445s +[INFO ] total_time: 0.013084s +[INFO ] tree_building: 0.000250s +@endcode + +@subsection cli_ex8_akfntut Saving a model for later use + +The \c mlpack_approx_kfn and \c mlpack_kfn programs both allow models to be +saved and loaded for future use. The \c --output_model_file (\c -M) option +allows specifying where to save a model, and the \c --input_model_file (\c -m) +option allows a model to be loaded instead of trained. So, if you specify +\c --input_model_file then you do not need to specify \c --reference_file +(\c -r), \c --num_projections (\c -p), or \c --num_tables (\c -t). + +The example below saves a model with 10 projections and 5 tables. Note that +neither \c --query_file (\c -q) nor \c -k are specified; this run only builds +the model and saves it to \c model.bin. + +@code +$ mlpack_approx_kfn -r refs.csv -t 5 -p 10 -v -M model.bin +[INFO ] Loading 'refs.csv' as CSV data. Size is 3 x 1000. +[INFO ] Building DrusillaSelect model... +[INFO ] Model built. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: "" +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: "" +[INFO ] k: 0 +[INFO ] neighbors_file: "" +[INFO ] num_projections: 10 +[INFO ] num_tables: 5 +[INFO ] output_model_file: model.bin +[INFO ] query_file: "" +[INFO ] reference_file: refs.csv +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_construct: 0.000321s +[INFO ] loading_data: 0.004700s +[INFO ] total_time: 0.007320s +@endcode + +Now, with the model saved, we can run approximate furthest neighbor search on a +query set using the saved model: + +@code +$ mlpack_approx_kfn -m model.bin -q queries.csv -k 3 -d d.csv -n n.csv -v +[INFO ] Loading 'queries.csv' as CSV data. Size is 3 x 1000. +[INFO ] Searching for 3 furthest neighbors with DrusillaSelect... +[INFO ] Search complete. +[INFO ] Saving CSV data to 'n.csv'. +[INFO ] Saving CSV data to 'd.csv'. +[INFO ] +[INFO ] Execution parameters: +[INFO ] algorithm: ds +[INFO ] calculate_error: false +[INFO ] distances_file: d.csv +[INFO ] exact_distances_file: "" +[INFO ] help: false +[INFO ] info: "" +[INFO ] input_model_file: model.bin +[INFO ] k: 3 +[INFO ] neighbors_file: n.csv +[INFO ] num_projections: 5 +[INFO ] num_tables: 5 +[INFO ] output_model_file: "" +[INFO ] query_file: queries.csv +[INFO ] reference_file: "" +[INFO ] verbose: true +[INFO ] version: false +[INFO ] +[INFO ] Program timers: +[INFO ] drusilla_select_search: 0.000878s +[INFO ] loading_data: 0.004599s +[INFO ] saving_data: 0.003006s +[INFO ] total_time: 0.009234s +@endcode + +These options work in the same way for both the \c mlpack_approx_kfn and +\c mlpack_kfn programs. + +@subsection cli_final_akfntut Final command-line program notes + +Both the \c mlpack_kfn and \c mlpack_approx_kfn programs contain numerous +options not fully documented in these short examples. You can run each program +with the \c --help (\c -h) option for more information. + +@section cpp_ds_akfntut DrusillaSelect C++ class + +\b mlpack provides a simple \c DrusillaSelect C++ class that can be used inside +of C++ programs to perform approximate furthest neighbor search. The class has +only one template parameter---\c MatType---which specifies the type of matrix to +be use. That means the class can be used with either dense data (of type +\c arma::mat) or sparse data (of type \c arma::sp_mat). + +The following examples show simple usage of this class. + +@subsection cpp_ex1_ds_akfntut Approximate furthest neighbors with defaults + +The code below builds a \c DrusillaSelect model with default options on the +matrix \c dataset, then queries for the approximate furthest neighbor of every +point in the \c queries matrix. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with defaults. +DrusillaSelect<> ds(dataset); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +ds.Search(queries, 1, neighbors, distances); +@endcode + +At the end of this code, both the \c distances and \c neighbors matrices will +have number of columns equal to the number of columns in the \c queries matrix. +So, each column of the \c distances and \c neighbors matrices are the distances +or neighbors of the corresponding column in the \c queries matrix. + +@subsection cpp_ex2_ds_akfntut Custom numbers of tables and projections + +The following example constructs a DrusillaSelect model with 10 tables and 5 +projections. Once that is done it performs the same task as the previous +example. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with custom parameters. +DrusillaSelect<> ds(dataset, 10, 5); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +ds.Search(queries, 1, neighbors, distances); +@endcode + +@subsection cpp_ex3_ds_akfntut Accessing the candidate set + +The \c DrusillaSelect algorithm merely scans the reference set and extracts a +number of points that will be queried in a brute-force fashion when the +\c Search() method is called. We can access this set with the \c CandidateSet() +method. The code below prints the fifth point of the candidate set. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with custom parameters. +DrusillaSelect<> ds(dataset, 10, 5); + +// Print the fifth point of the candidate set. +std::cout << ds.CandidateSet().col(4).t(); +@endcode + +@subsection cpp_ex4_ds_akfntut Retraining on a new reference set + +It is possible to retrain a \c DrusillaSelect model with new parameters or with +a new reference set. This is functionally equivalent to creating a new model. +The example code below creates a first \c DrusillaSelect model using 3 tables +and 10 projections, and then retrains this with the same reference set using 10 +tables and 3 projections. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with initial parameters. +DrusillaSelect<> ds(dataset, 3, 10); + +// Now retrain with different parameters. +ds.Train(dataset, 10, 3); +@endcode + +@subsection cpp_ex5_ds_akfntut Running on sparse data + +We can set the template parameter for \c DrusillaSelect to \c arma::sp_mat in +order to perform furthest neighbor search on sparse data. This code below +creates a \c DrusillaSelect model using 4 tables and 6 projections with sparse +input data, then searches for 3 approximate furthest neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::sp_mat dataset; +// The query dataset. +extern arma::sp_mat querySet; + +// Construct the model on sparse data. +DrusillaSelect ds(dataset, 4, 6); + +// Search on query data. +arma::Mat neighbors; +arma::mat distances; +ds.Search(querySet, 3, neighbors, distances); +@endcode + +@section cpp_qdafn_akfntut QDAFN C++ class + +\b mlpack also provides a standalone simple \c QDAFN class for furthest neighbor +search. The API for this class is virtually identical to the \c DrusillaSelect +class, and also has one template parameter to specify the type of matrix to be +used (dense or sparse or other). + +The following subsections demonstrate usage of the \c QDAFN class in the same +way as the previous section's examples for \c DrusillaSelect. + +@subsection cpp_ex1_qdafn_akfntut Approximate furthest neighbors with defaults + +The code below builds a \c QDAFN model with default options on the +matrix \c dataset, then queries for the approximate furthest neighbor of every +point in the \c queries matrix. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with defaults. +QDAFN<> qd(dataset); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +qd.Search(queries, 1, neighbors, distances); +@endcode + +At the end of this code, both the \c distances and \c neighbors matrices will +have number of columns equal to the number of columns in the \c queries matrix. +So, each column of the \c distances and \c neighbors matrices are the distances +or neighbors of the corresponding column in the \c queries matrix. + +@subsection cpp_ex2_qdafn_akfntut Custom numbers of tables and projections + +The following example constructs a QDAFN model with 15 tables and 30 +projections. Once that is done it performs the same task as the previous +example. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat queries; + +// Construct the model with custom parameters. +QDAFN<> qdafn(dataset, 15, 30); + +// Query the model, putting output into the following two matrices. +arma::mat distances; +arma::Mat neighbors; +qdafn.Search(queries, 1, neighbors, distances); +@endcode + +@subsection cpp_ex3_qdafn_akfntut Accessing the candidate set + +The \c QDAFN algorithm scans the reference set, extracting points that have been +projected onto random directions. Each random direction corresponds to a single +table. The \c QDAFN class stores these points as a vector of matrices, which +can be accessed with the \c CandidateSet() method. The code below prints the +fifth point of the candidate set of the third table. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with custom parameters. +QDAFN<> qdafn(dataset, 10, 5); + +// Print the fifth point of the candidate set. +std::cout << ds.CandidateSet(2).col(4).t(); +@endcode + +@subsection cpp_ex4_qdafn_akfntut Retraining on a new reference set + +It is possible to retrain a \c QDAFN model with new parameters or with +a new reference set. This is functionally equivalent to creating a new model. +The example code below creates a first \c QDAFN model using 10 tables +and 40 projections, and then retrains this with the same reference set using 15 +tables and 25 projections. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; + +// Construct the model with initial parameters. +QDAFN<> qdafn(dataset, 3, 10); + +// Now retrain with different parameters. +qdafn.Train(dataset, 10, 3); +@endcode + +@subsection cpp_ex5_qdafn_akfntut Running on sparse data + +We can set the template parameter for \c QDAFN to \c arma::sp_mat in +order to perform furthest neighbor search on sparse data. This code below +creates a \c QDAFN model using 20 tables and 60 projections with sparse +input data, then searches for 3 approximate furthest neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::sp_mat dataset; +// The query dataset. +extern arma::sp_mat querySet; + +// Construct the model on sparse data. +QDAFN qdafn(dataset, 20, 60); + +// Search on query data. +arma::Mat neighbors; +arma::mat distances; +qdafn.Search(querySet, 3, neighbors, distances); +@endcode + +@section cpp_ns_akfntut KFN C++ class + +The extensive \c NeighborSearch class also provides a way to search for +approximate furthest neighbors using a different, tree-based technique. For +full documentation on this class, see the +\ref nstutorial "NeighborSearch tutorial". The \c KFN class is a convenient +typedef of the \c NeighborSearch class that can be used to perform the furthest +neighbors task with kd-trees. + +In the following subsections, the \c KFN class is used in short code examples. + +@subsection cpp_ex1_ns_akfntut Simple furthest neighbors example + +The \c KFN class has construction semantics similar to \c DrusillaSelect and +\c QDAFN. The example below constructs a \c KFN object (which will build the +tree on the reference set), but note that the third parameter to the constructor +allows us to specify our desired level of approximation. In this example we +choose epsilon = 0.05. Then, the code searches for 3 approximate furthest +neighbors. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference dataset. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object, performing the default dual-tree search with +// approximation level epsilon = 0.05. +KFN kfn(dataset, KFN::DUAL_TREE_MODE, 0.05); + +// Search for approximate furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 3, neighbors, distances); +@endcode + +@subsection cpp_ex2_ns_akfntut Retraining on a new reference set + +Like the \c QDAFN and \c DrusillaSelect classes, the \c KFN class is capable of +retraining on a new reference set. The code below demonstrates this. + +@code +#include + +using namespace mlpack::neighbor; + +// The original reference set we train on. +extern arma::mat dataset; +// The new reference set we retrain on. +extern arma::mat newDataset; + +// Construct the object with approximation level 0.1. +KFN kfn(dataset, DUAL_TREE_MODE, 0.1); + +// Retrain on the new reference set. +kfn.Train(newDataset); +@endcode + +@subsection cpp_ex3_ns_akfntut Searching in single-tree mode + +The particular mode to be used in search can be specified in the constructor. +In this example, we use single-tree search (as opposed to the default of +dual-tree search). + +@code +#include + +using namespace mlpack::neighbor; + +// The reference set. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object with approximation level 0.25 and in single tree search +// mode. +KFN kfn(dataset, SINGLE_TREE_MODE, 0.25); + +// Search for 5 approximate furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 5, neighbors, distances); +@endcode + +@subsection cpp_ex4_ns_akfntut Searching in brute-force mode + +If desired, brute-force search ("naive search") can be used to find the furthest +neighbors; however, the result will not be approximate---it will be exact (since +every possibility will be considered). The code below performs exact furthest +neighbor search by using the \c KFN class in brute-force mode. + +@code +#include + +using namespace mlpack::neighbor; + +// The reference set. +extern arma::mat dataset; +// The query set. +extern arma::mat querySet; + +// Construct the object in brute-force mode. We can leave the approximation +// parameter to its default (0) since brute-force will provide exact results. +KFN kfn(dataset, NAIVE_MODE); + +// Perform the search for 2 furthest neighbors. +arma::Mat neighbors; +arma::mat distances; +kfn.Search(querySet, 2, neighbors, distances); +@endcode + +@section further_doc_akfntut Further documentation + +For further documentation on the approximate furthest neighbor facilities +offered by \b mlpack, consult the following documentation: + + - \ref nstutorial + - \ref mlpack::neighbor::QDAFN "QDAFN class documentation" + - \ref mlpack::neighbor::DrusillaSelect "DrusillaSelect class documentation" + - \ref mlpack::neighbor::NeighborSearch "NeighborSearch class documentation" + +*/ From d01b20fe87591e51421308dde5340816193429e3 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 21:30:30 +0900 Subject: [PATCH 27/30] Add tests for sparse operation and fix sparse bugs. --- .../approx_kfn/drusilla_select_impl.hpp | 13 ++++++++----- src/mlpack/tests/drusilla_select_test.cpp | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp index 95953745ea5..942063b6c08 100644 --- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp +++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp @@ -76,12 +76,15 @@ void DrusillaSelect::Train( candidateSet.set_size(referenceSet.n_rows, l * m); candidateIndices.set_size(l * m); - arma::vec dataMean = arma::mean(referenceSet, 1); + arma::vec dataMean(arma::mean(referenceSet, 1)); arma::vec norms(referenceSet.n_cols); - arma::mat refCopy = referenceSet.each_col() - dataMean; + MatType refCopy(referenceSet.n_rows, referenceSet.n_cols); for (size_t i = 0; i < refCopy.n_cols; ++i) - norms[i] = arma::norm(refCopy.col(i) - dataMean); + { + refCopy.col(i) = referenceSet.col(i) - dataMean; + norms[i] = arma::norm(refCopy.col(i)); + } // Find the top m points for each of the l projections... for (size_t i = 0; i < l; ++i) @@ -90,7 +93,7 @@ void DrusillaSelect::Train( arma::uword maxIndex; norms.max(maxIndex); - arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)); + arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex))); const size_t n_nonzero = (size_t) arma::sum(norms > 0); // Calculate distortion and offset and make scores. @@ -176,7 +179,7 @@ void DrusillaSelect::Search(const MatType& querySet, // TreeType. metric::EuclideanDistance metric; NeighborSearchRules> + tree::KDTree> rules(candidateSet, querySet, k, metric, 0, false); for (size_t q = 0; q < querySet.n_cols; ++q) diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp index b60a1ad1281..cce2704a15d 100644 --- a/src/mlpack/tests/drusilla_select_test.cpp +++ b/src/mlpack/tests/drusilla_select_test.cpp @@ -143,4 +143,23 @@ BOOST_AUTO_TEST_CASE(SerializationTest) } } +// Make sure we can create the object with a sparse matrix. +BOOST_AUTO_TEST_CASE(SparseTest) +{ + arma::sp_mat dataset; + dataset.sprandu(50, 1000, 0.3); + + DrusillaSelect ds(dataset, 5, 10); + + // Run a search. + arma::mat distances; + arma::Mat neighbors; + ds.Search(dataset, 3, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3); + BOOST_REQUIRE_EQUAL(distances.n_cols, 1000); + BOOST_REQUIRE_EQUAL(distances.n_rows, 3); +} + BOOST_AUTO_TEST_SUITE_END(); From 15f4b073adc1410d182cc94d5fded590331eff71 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 21:32:24 +0900 Subject: [PATCH 28/30] Add comprehensive tests for QDAFN. There is a bug now, but I have to push this to be able to solve it so I can get to a system that has working gdb. --- src/mlpack/methods/approx_kfn/qdafn.hpp | 16 +++ src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 9 ++ src/mlpack/tests/qdafn_test.cpp | 101 +++++++++++++++++++ 3 files changed, 126 insertions(+) diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp index f7949db6e7d..6ba8b81b5db 100644 --- a/src/mlpack/methods/approx_kfn/qdafn.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn.hpp @@ -49,6 +49,19 @@ class QDAFN const size_t l, const size_t m); + /** + * Train the QDAFN model on the given reference set, optionally setting new + * parameters for the number of projections/tables (l) and the number of + * elements stored for each projection/table (m). + * + * @param referenceSet Reference set to train on. + * @param l Number of projections. + * @param m Number of elements to store for each projection. + */ + void Train(const MatType& referenceSet, + const size_t l = 0, + const size_t m = 0); + /** * Search for the k furthest neighbors of the given query set. (The query set * can contain just one point, that is okay.) The results will be stored in @@ -64,6 +77,9 @@ class QDAFN template void Serialize(Archive& ar, const unsigned int /* version */); + //! Get the number of projections. + size_t NumProjections() const { return candidateSet.size(); } + //! Get the candidate set for the given projection table. const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; } //! Modify the candidate set for the given projection table. Careful! diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index de6c882108f..475538c81f3 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -27,6 +27,15 @@ QDAFN::QDAFN(const MatType& referenceSet, const size_t m) : l(l), m(m) +{ + Train(referenceSet); +} + +// Train the object. +template +void QDAFN::Train(const MatType& referenceSet, + const size_t l, + const size_t m) { // Build tables. This is done by drawing random points from a Gaussian // distribution as the vectors we project onto. The Gaussian should have zero diff --git a/src/mlpack/tests/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp index ea64b526852..332b7c7e81b 100644 --- a/src/mlpack/tests/qdafn_test.cpp +++ b/src/mlpack/tests/qdafn_test.cpp @@ -102,4 +102,105 @@ BOOST_AUTO_TEST_CASE(QDAFNUniformSet) BOOST_REQUIRE_GE(successes, 700); } +/** + * Test re-training method. + */ +BOOST_AUTO_TEST_CASE(RetrainTest) +{ + arma::mat dataset = arma::randu(25, 500); + arma::mat newDataset = arma::randu(15, 600); + + QDAFN<> qdafn(dataset, 20, 60); + + qdafn.Train(newDataset, 10, 50); + + BOOST_REQUIRE_EQUAL(qdafn.NumProjections(), 10); + for (size_t i = 0; i < 10; ++i) + { + BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_rows, 15); + BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_cols, 50); + } +} + +/** + * Test serialization of QDAFN. + */ +BOOST_AUTO_TEST_CASE(SerializationTest) +{ + // Use a random dataset. + arma::mat dataset = arma::randu(15, 300); + + QDAFN<> qdafn(dataset, 10, 50); + + arma::mat fakeDataset1 = arma::randu(10, 200); + arma::mat fakeDataset2 = arma::randu(50, 500); + QDAFN<> qdafnXml(fakeDataset1, 5, 10); + QDAFN<> qdafnText(6, 50); + QDAFN<> qdafnBinary(7, 15); + qdafnBinary.Train(fakeDataset2); + + // Serialize the objects. + SerializeObjectAll(qdafn, qdafnXml, qdafnText, qdafnBinary); + + // Check that the tables are all the same. + BOOST_REQUIRE_EQUAL(qdafnXml.NumProjections(), qdafn.NumProjections()); + BOOST_REQUIRE_EQUAL(qdafnText.NumProjections(), qdafn.NumProjections()); + BOOST_REQUIRE_EQUAL(qdafnBinary.NumProjections(), qdafn.NumProjections()); + + for (size_t i = 0; i < qdafn.NumProjections(); ++i) + { + BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_rows, + qdafn.CandidateSet(i).n_rows); + + BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_cols, + qdafn.CandidateSet(i).n_cols); + + for (size_t j = 0; j < qdafn.CandidateSet(i).n_elem; ++j) + { + if (std::abs(qdafn.CandidateSet(i)[j]) < 1e-5) + { + BOOST_REQUIRE_SMALL(qdafnXml.CandidateSet(i)[j], 1e-5); + BOOST_REQUIRE_SMALL(qdafnText.CandidateSet(i)[j], 1e-5); + BOOST_REQUIRE_SMALL(qdafnBinary.CandidateSet(i)[j], 1e-5); + } + else + { + const double value = qdafn.CandidateSet(i)[j]; + BOOST_REQUIRE_CLOSE(qdafnXml.CandidateSet(i)[j], value, 1e-5); + BOOST_REQUIRE_CLOSE(qdafnText.CandidateSet(i)[j], value, 1e-5); + BOOST_REQUIRE_CLOSE(qdafnBinary.CandidateSet(i)[j], value, 1e-5); + } + } + } +} + +// Make sure QDAFN works with sparse data. +BOOST_AUTO_TEST_CASE(SparseTest) +{ + arma::sp_mat dataset; + dataset.sprandu(200, 1000, 0.3); + + // Create a sparse version. + QDAFN sparse(dataset, 15, 50); + + // Make sure the results are of the right shape. It's hard to test anything + // more than that because we don't have easy-to-check performance guarantees. + arma::Mat neighbors; + arma::mat distances; + sparse.Search(dataset, 3, neighbors, distances); + + BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3); + BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000); + BOOST_REQUIRE_EQUAL(distances.n_rows, 3); + BOOST_REQUIRE_EQUAL(distances.n_cols, 1000); +} + BOOST_AUTO_TEST_SUITE_END(); From 6376680fa1463574bd151424b8973edf4a69a3dd Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 21:49:12 +0900 Subject: [PATCH 29/30] Fix test bugs. --- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index 475538c81f3..4e2104e9c3e 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -34,9 +34,14 @@ QDAFN::QDAFN(const MatType& referenceSet, // Train the object. template void QDAFN::Train(const MatType& referenceSet, - const size_t l, - const size_t m) + const size_t lIn, + const size_t mIn) { + if (lIn != 0) + l = lIn; + if (mIn != 0) + m = mIn; + // Build tables. This is done by drawing random points from a Gaussian // distribution as the vectors we project onto. The Gaussian should have zero // mean and unit variance. From 86e5edf60e2066a464f9e8862278ffd63af75741 Mon Sep 17 00:00:00 2001 From: Ryan Curtin Date: Sun, 30 Oct 2016 21:50:06 +0900 Subject: [PATCH 30/30] Add error checking to constructor. --- src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp index 4e2104e9c3e..8d64f9578bb 100644 --- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp +++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp @@ -18,7 +18,13 @@ namespace neighbor { // Non-training constructor. template -QDAFN::QDAFN(const size_t l, const size_t m) : l(l), m(m) { } +QDAFN::QDAFN(const size_t l, const size_t m) : l(l), m(m) +{ + if (l > 0) + throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!"); + if (m > 0) + throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!"); +} // Constructor. template @@ -28,6 +34,11 @@ QDAFN::QDAFN(const MatType& referenceSet, l(l), m(m) { + if (l > 0) + throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!"); + if (m > 0) + throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!"); + Train(referenceSet); }