New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LSHSearch Parallelization #700
Changes from 2 commits
72999dd
95d417f
afcc881
5db5423
abef504
4cbd43e
a60ff91
6152527
7cf77cd
2ca48c6
a6aca41
3d536c7
c04b073
65983d1
b95a3ce
0d38271
3af80c3
b02e2f3
a1e9c28
ad8e6d3
c4c8ff9
074d726
f982ca5
1fb998f
b92d465
2fee61e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -293,6 +293,7 @@ void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances, | |
neighbors(pos, queryIndex) = neighbor; | ||
} | ||
|
||
/* | ||
// Base case where the query set is the reference set. (So, we can't return | ||
// ourselves as the nearest neighbor.) | ||
template<typename SortPolicy> | ||
|
@@ -319,8 +320,13 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex, | |
|
||
// SortDistance() returns (size_t() - 1) if we shouldn't add it. | ||
if (insertPosition != (size_t() - 1)) | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
{ | ||
#pragma omp critical | ||
{ | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
} | ||
} | ||
} | ||
|
||
// Base case for bichromatic search. | ||
|
@@ -345,10 +351,79 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex, | |
|
||
// SortDistance() returns (size_t() - 1) if we shouldn't add it. | ||
if (insertPosition != (size_t() - 1)) | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
{ | ||
#pragma omp critical | ||
{ | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
} | ||
} | ||
} | ||
*/ | ||
|
||
// Base case where the query set is the reference set. (So, we can't return | ||
// ourselves as the nearest neighbor.) | ||
template<typename SortPolicy> | ||
inline force_inline | ||
void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex, | ||
const arma::uvec& referenceIndices, | ||
arma::Mat<size_t>& neighbors, | ||
arma::mat& distances) const | ||
{ | ||
for (size_t j = 0; j < referenceIndices.n_elem; ++j) | ||
{ | ||
const size_t referenceIndex = referenceIndices[j]; | ||
// If the points are the same, skip this point. | ||
if (queryIndex == referenceIndex) | ||
continue; | ||
|
||
const double distance = metric::EuclideanDistance::Evaluate( | ||
referenceSet->unsafe_col(queryIndex), | ||
referenceSet->unsafe_col(referenceIndex)); | ||
|
||
// If this distance is better than any of the current candidates, the | ||
// SortDistance() function will give us the position to insert it into. | ||
arma::vec queryDist = distances.unsafe_col(queryIndex); | ||
arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex); | ||
size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices, | ||
distance); | ||
|
||
// SortDistance() returns (size_t() - 1) if we shouldn't add it. | ||
if (insertPosition != (size_t() - 1)) | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
} | ||
} | ||
|
||
// Base case for bichromatic search. | ||
template<typename SortPolicy> | ||
inline force_inline | ||
void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex, | ||
const arma::uvec& referenceIndices, | ||
const arma::mat& querySet, | ||
arma::Mat<size_t>& neighbors, | ||
arma::mat& distances) const | ||
{ | ||
for (size_t j = 0; j < referenceIndices.n_elem; ++j) | ||
{ | ||
const size_t referenceIndex = referenceIndices[j]; | ||
const double distance = metric::EuclideanDistance::Evaluate( | ||
querySet.unsafe_col(queryIndex), | ||
referenceSet->unsafe_col(referenceIndex)); | ||
|
||
// If this distance is better than any of the current candidates, the | ||
// SortDistance() function will give us the position to insert it into. | ||
arma::vec queryDist = distances.unsafe_col(queryIndex); | ||
arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex); | ||
size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices, | ||
distance); | ||
|
||
// SortDistance() returns (size_t() - 1) if we shouldn't add it. | ||
if (insertPosition != (size_t() - 1)) | ||
InsertNeighbor(distances, neighbors, queryIndex, insertPosition, | ||
referenceIndex, distance); | ||
} | ||
} | ||
template<typename SortPolicy> | ||
template<typename VecType> | ||
void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | ||
|
@@ -416,45 +491,21 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | |
arma::Col<size_t> refPointsConsidered; | ||
refPointsConsidered.zeros(referenceSet->n_cols); | ||
|
||
// Define the number of threads used to process this. | ||
size_t numThreadsUsed = std::min(maxThreads, numTablesToSearch); | ||
|
||
// Parallelization: By default nested parallelism is off, so this won't be | ||
// parallel. The user might turn nested parallelism on if (for example) they | ||
// have a query-by-query processing scheme and so processing more than one | ||
// query at the same time doesn't make sense for them. | ||
|
||
#pragma omp parallel for \ | ||
num_threads (numThreadsUsed) \ | ||
shared (hashVec, refPointsConsidered) \ | ||
schedule(dynamic) | ||
for (size_t i = 0; i < numTablesToSearch; ++i) | ||
for (long long int i = 0; i < numTablesToSearch; ++i) | ||
{ | ||
|
||
const size_t hashInd = (size_t) hashVec[i]; | ||
const size_t tableRow = bucketRowInHashTable[hashInd]; | ||
|
||
// Pick the indices in the bucket corresponding to 'hashInd'. | ||
if (tableRow != secondHashSize) | ||
{ | ||
for (size_t j = 0; j < bucketContentSize[tableRow]; j++) | ||
{ | ||
#pragma omp atomic | ||
refPointsConsidered[secondHashTable[tableRow](j)]++; | ||
} | ||
} | ||
} | ||
|
||
// Only keep reference points found in at least one bucket. If OpenMP is | ||
// found, do it in parallel | ||
#ifdef OPENMP_FOUND | ||
// TODO: change this to our own function? | ||
referenceIndices = arma::find(refPointsConsidered > 0); | ||
return; | ||
#else | ||
referenceIndices = arma::find(refPointsConsidered > 0); | ||
return; | ||
#endif | ||
// Only keep reference points found in at least one bucket. | ||
referenceIndices = arma::find(refPointsConsidered > 0); | ||
return; | ||
} | ||
else | ||
{ | ||
|
@@ -467,45 +518,20 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | |
// Retrieve candidates. | ||
size_t start = 0; | ||
|
||
// Define the number of threads used to process this. | ||
size_t numThreadsUsed = std::min(maxThreads, numTablesToSearch); | ||
|
||
// Parallelization: By default nested parallelism is off, so this won't be | ||
// parallel. The user might turn nested parallelism on if (for example) they | ||
// have a query-by-query processing scheme and so processing more than one | ||
// query at the same time doesn't make sense for them. | ||
|
||
#pragma omp parallel for \ | ||
num_threads (numThreadsUsed) \ | ||
shared (hashVec, refPointsConsideredSmall, start) \ | ||
schedule(dynamic) | ||
for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables | ||
for (long long int i = 0; i < numTablesToSearch; ++i) // For all tables | ||
{ | ||
const size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket. | ||
const size_t tableRow = bucketRowInHashTable[hashInd]; | ||
|
||
// Store all secondHashTable points in the candidates set. | ||
if (tableRow != secondHashSize) | ||
{ | ||
for (size_t j = 0; j < bucketContentSize[tableRow]; ++j) | ||
{ | ||
#pragma omp critical | ||
{ | ||
refPointsConsideredSmall(start++) = secondHashTable[tableRow][j]; | ||
} | ||
} | ||
} | ||
refPointsConsideredSmall(start++) = secondHashTable[tableRow][j]; | ||
} | ||
|
||
// Only keep unique candidates. If OpenMP is found, do it in parallel. | ||
#ifdef OPENMP_FOUND | ||
// TODO: change this to our own function? | ||
referenceIndices = arma::unique(refPointsConsideredSmall); | ||
return; | ||
#else | ||
referenceIndices = arma::unique(refPointsConsideredSmall); | ||
return; | ||
#endif | ||
// Keep only one copy of each candidate. | ||
referenceIndices = arma::unique(refPointsConsideredSmall); | ||
return; | ||
} | ||
} | ||
|
||
|
@@ -557,8 +583,9 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet, | |
num_threads ( numThreadsUsed )\ | ||
shared(avgIndicesReturned, resultingNeighbors, distances) \ | ||
schedule(dynamic) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two questions---
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The problem with static scheduling is it doesn't leave room for work-stealing. Since queries get unequal sizes of candidate sets, in static scheduling some threads will finish their chunks quickly and then be useless. In dynamic scheduling, the compiler will detect slackers and give them more work to do.
Yes I think I can simplify the code more now that we're not doing nested parallelism. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. About static vs dynamic scheduling, I ran some tests: Sift100k
phy
Corel
Miniboone
In the first 3, I'd say dynamic is slightly faster. It's hard to tell for Miniboone because the standard deviation is much larger than the difference. I'll run covertype and pokerhand in a while when my PC is not used. |
||
// Go through every query point. | ||
for (size_t i = 0; i < querySet.n_cols; i++) | ||
// Go through every query point. Use long int because some compilers complain | ||
// for openMP unsigned index variables. | ||
for (long long int i = 0; i < querySet.n_cols; i++) | ||
{ | ||
|
||
// Hash every query into every hash table and eventually into the | ||
|
@@ -574,9 +601,17 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet, | |
|
||
// Sequentially go through all the candidates and save the best 'k' | ||
// candidates. | ||
/* | ||
numTheadsUsed = std::min( (arma::uword) maxThreads, refIndices.n_elem); | ||
#pragma omp parallel for\ | ||
num_threads( numThreadsUsed )\ | ||
shared(refIndices, resultingNeighbors, distances, querySet)\ | ||
schedule(dynamic) | ||
for (size_t j = 0; j < refIndices.n_elem; j++) | ||
BaseCase(i, (size_t) refIndices[j], querySet, resultingNeighbors, | ||
distances); | ||
*/ | ||
BaseCase(i, refIndices, querySet, resultingNeighbors, distances); | ||
} | ||
|
||
Timer::Stop("computing_neighbors"); | ||
|
@@ -613,8 +648,9 @@ Search(const size_t k, | |
num_threads ( numThreadsUsed )\ | ||
shared(avgIndicesReturned, resultingNeighbors, distances) \ | ||
schedule(dynamic) | ||
// Go through every query point. | ||
for (size_t i = 0; i < referenceSet->n_cols; i++) | ||
// Go through every query point. Use long int because some compilers complain | ||
// for openMP unsigned index variables. | ||
for (long long int i = 0; i < referenceSet->n_cols; i++) | ||
{ | ||
// Hash every query into every hash table and eventually into the | ||
// 'secondHashTable' to obtain the neighbor candidates. | ||
|
@@ -629,8 +665,17 @@ Search(const size_t k, | |
|
||
// Sequentially go through all the candidates and save the best 'k' | ||
// candidates. | ||
|
||
/* | ||
numTheadsUsed = std::min( (arma::uword) maxThreads, refIndices.n_elem); | ||
#pragma omp parallel for\ | ||
num_threads( numThreadsUsed )\ | ||
shared(refIndices, resultingNeighbors, distances)\ | ||
schedule(dynamic) | ||
for (size_t j = 0; j < refIndices.n_elem; j++) | ||
BaseCase(i, (size_t) refIndices[j], resultingNeighbors, distances); | ||
*/ | ||
BaseCase(i, refIndices, resultingNeighbors, distances); | ||
|
||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The restriction to
long long
here to support Windows really irks me, and as far as I can tell there is no reasonable and portable way to get something that's the same size assize_t
but signed (there isssize_t
but its support on Windows is unclear to me). I have half a mind to just require OpenMP 3.0 support to get rid of this stupid restriction, which would disable OpenMP support with Visual Studio. I am not sure I am bothered by that; OpenMP 3 is almost a decade old at this point and the Visual Studio team still doesn't have support for it, so I am not sure I want to keep restricting us to such a legacy version. Windows users, if they need parallelism, can always switch to using MinGW or ICC or something like that. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's pretty reasonable. I can go back to using
size_t
and ask to detect version 3 and up.