diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp index 2894411cece..4f4040abf55 100644 --- a/src/mlpack/methods/lsh/lsh_main.cpp +++ b/src/mlpack/methods/lsh/lsh_main.cpp @@ -48,6 +48,7 @@ PARAM_STRING("reference_file", "File containing the reference dataset.", "r", ""); PARAM_STRING("distances_file", "File to output distances into.", "d", ""); PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", ""); +PARAM_STRING("true_neighbors_file", "File of real neighbors to compute recall (printed with -v).", "t", ""); // We can load or save models. PARAM_STRING("input_model_file", "File to load LSH model from. (Cannot be " @@ -188,6 +189,25 @@ int main(int argc, char *argv[]) Log::Info << "Neighbors computed." << endl; + // Compute recall, if desired. + if (CLI::HasParam("t")) + { + // read specified filename + const string trueNeighborsFile = + CLI::GetParam("true_neighbors_file"); + + // load the data + arma::Mat trueNeighbors; + data::Load(trueNeighborsFile, trueNeighbors, true); + Log::Info << "Loaded true neighbor indices from '" + << trueNeighborsFile << "'." << endl; + + // Compute Recall and log + double recallPercentage = 100 * allkann.ComputeRecall(neighbors, trueNeighbors); + + Log::Info << "Recall: " << recallPercentage << endl; + } + // Save output, if desired. if (CLI::HasParam("distances_file")) data::Save(distancesFile, distances); diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp index b42bb7a81e0..fbdf6f35fae 100644 --- a/src/mlpack/methods/lsh/lsh_search.hpp +++ b/src/mlpack/methods/lsh/lsh_search.hpp @@ -168,6 +168,14 @@ class LSHSearch arma::mat& distances, const size_t numTablesToSearch = 0); + /** + * Compute the recall (% of neighbors found) given the neighbors returned by + * LSHSearch::Search and a "ground truth" file. Recall in [0, 1] + */ + double ComputeRecall(const arma::Mat& foundNeighbors, + const arma::Mat& realNeighbors); + + /** * Serialize the LSH model. * diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp index ad698e12800..14617c04205 100644 --- a/src/mlpack/methods/lsh/lsh_search_impl.hpp +++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp @@ -571,6 +571,34 @@ Search(const size_t k, std::endl; } +template +double LSHSearch::ComputeRecall( + const arma::Mat& foundNeighbors, + const arma::Mat& realNeighbors) +{ + if (foundNeighbors.n_rows != realNeighbors.n_rows || + foundNeighbors.n_cols != realNeighbors.n_cols) + throw std::invalid_argument("LSHSearch::ComputeRecall(): matrices provided" + " must have equal size"); + + const size_t queries = foundNeighbors.n_cols; + const size_t neighbors= foundNeighbors.n_rows; //k + + // recall is set intersection of found and real neighbors + double found = 0; + for (size_t col = 0; col < queries; ++col) + for (size_t row = 0; row < neighbors; ++row) + for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei) + if (realNeighbors(row, col) == foundNeighbors(nei, col)) + { + found++; + break; + } + + return found/realNeighbors.n_elem; + +} + template template void LSHSearch::Serialize(Archive& ar, diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp index d42566694fb..0988059bf34 100644 --- a/src/mlpack/tests/lsh_test.cpp +++ b/src/mlpack/tests/lsh_test.cpp @@ -326,6 +326,104 @@ BOOST_AUTO_TEST_CASE(LSHTrainTest) BOOST_REQUIRE_EQUAL(distances.n_rows, 3); } +/** + * Test: this verifies ComputeRecall works correctly by providing two identical + * vectors and requiring that Recall is equal to 1. + */ +BOOST_AUTO_TEST_CASE(RecallTestIdentical) +{ + const size_t k = 5; // 5 nearest neighbors + const size_t numQueries = 1; + + // base = [1; 2; 3; 4; 5] + arma::Mat base; + base.set_size(k, numQueries); + base.col(0) = arma::linspace< arma::Col >(1, k, k); + + // q1 = [1; 2; 3; 4; 5]. Expect recall = 1 + arma::Mat q1; + q1.set_size(k, numQueries); + q1.col(0) = arma::linspace< arma::Col >(1, k, k); + + LSHSearch<> lsh; + BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q1), 1); +} + +/** + * Test: this verifies ComputeRecall returns correct values for partially + * correct found neighbors. This is important because this is a good example of + * how the recall and accuracy metrics differ - accuracy in this case would be + * 0, recall should not be + */ +BOOST_AUTO_TEST_CASE(RecallTestPartiallyCorrect) +{ + const size_t k = 5; // 5 nearest neighbors + const size_t numQueries = 1; + + // base = [1; 2; 3; 4; 5] + arma::Mat base; + base.set_size(k, numQueries); + base.col(0) = arma::linspace< arma::Col >(1, k, k); + + // q2 = [2; 3; 4; 6; 7]. Expect recall = 0.6. This is important because this + // is a good example of how recall and accuracy differ. Accuracy here would + // be 0 but recall should not be. + arma::Mat q2; + q2.set_size(k, numQueries); + q2 << + 2 << arma::endr << + 3 << arma::endr << + 4 << arma::endr << + 6 << arma::endr << + 7 << arma::endr; + + LSHSearch<> lsh; + BOOST_REQUIRE_CLOSE(lsh.ComputeRecall(base, q2), 0.6, 0.0001); +} + +/** + * Test: If given a completely wrong vector, ComputeRecall should return 0 + */ +BOOST_AUTO_TEST_CASE(RecallTestIncorrect) +{ + const size_t k = 5; // 5 nearest neighbors + const size_t numQueries = 1; + + // base = [1; 2; 3; 4; 5] + arma::Mat base; + base.set_size(k, numQueries); + base.col(0) = arma::linspace< arma::Col >(1, k, k); + // q3 = [6; 7; 8; 9; 10]. Expected recall = 0 + arma::Mat q3; + q3.set_size(k, numQueries); + q3.col(0) = arma::linspace< arma::Col >(k + 1, 2 * k, k); + + LSHSearch<> lsh; + BOOST_REQUIRE_EQUAL(lsh.ComputeRecall(base, q3), 0); +} + +/** + * Test: If given a vector of wrong shape, ComputeRecall should throw an + * exception + */ +BOOST_AUTO_TEST_CASE(RecallTestException) +{ + const size_t k = 5; // 5 nearest neighbors + const size_t numQueries = 1; + + // base = [1; 2; 3; 4; 5] + arma::Mat base; + base.set_size(k, numQueries); + base.col(0) = arma::linspace< arma::Col >(1, k, k); + // verify that nonsense arguments throw exception + arma::Mat q4; + q4.set_size(2 * k, numQueries); + + LSHSearch<> lsh; + BOOST_REQUIRE_THROW(lsh.ComputeRecall(base, q4), std::invalid_argument); + +} + BOOST_AUTO_TEST_CASE(EmptyConstructorTest) { // If we create an empty LSH model and then call Search(), it should throw an