New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Multiprobe LSH #691
Changes from 1 commit
8dd409d
8264347
f24ac2c
27510fe
a48c842
4adaf8f
8bc5ced
a048f46
d332ea1
00377a8
f0638a0
6d11521
fa7f62d
ae81ee5
840efe9
2af985c
75dead3
89b3c7b
79b954a
e2596c5
5d603b2
039af23
71eda99
2aa839c
cc6a5c2
4ba6f6d
71edfbb
d418c09
f1e11fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -340,25 +340,11 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex, | |
referenceIndex, distance); | ||
} | ||
|
||
|
||
// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins. | ||
class CompareGreater | ||
{ | ||
public: | ||
bool operator()( | ||
std::pair<double, size_t> p1, | ||
std::pair<double, size_t> p2){ | ||
//only compare the double values | ||
return p1.first > p2.first; | ||
} | ||
}; | ||
|
||
//Returns the score of a perturbation vector generated by perturbation set A. | ||
//The score of a pertubation set (vector) is the sum of scores of the | ||
//participating actions. | ||
inline double perturbationScore( | ||
const std::vector<size_t> &A, | ||
const arma::vec &scores) | ||
inline double perturbationScore(const std::vector<size_t>& A, | ||
const arma::vec& scores) | ||
{ | ||
double score = 0.0; | ||
for (size_t i = 0; i < A.size(); ++i) | ||
|
@@ -368,7 +354,7 @@ inline double perturbationScore( | |
|
||
// Inline function used by GetAdditionalProbingBins. The vector shift operation | ||
// replaces the largest element of a vector A with (largest element) + 1. | ||
inline void perturbationShift(std::vector<size_t> &A) | ||
inline void perturbationShift(std::vector<size_t>& A) | ||
{ | ||
size_t max_pos = 0; | ||
size_t max = A[0]; | ||
|
@@ -386,56 +372,49 @@ inline void perturbationShift(std::vector<size_t> &A) | |
// Inline function used by GetAdditionalProbingBins. The vector expansion | ||
// operation adds the element [1 + (largest_element)] to a vector A, where | ||
// largest_element is the largest element of A. | ||
inline void perturbationExpand(std::vector<size_t> &A) | ||
inline void perturbationExpand(std::vector<size_t>& A) | ||
{ | ||
size_t max = A[0]; | ||
for (size_t i = 1; i < A.size(); ++i) | ||
if (A[i] > max) | ||
max = A[i]; | ||
A.push_back(max+1); | ||
A.push_back(max + 1); | ||
} | ||
|
||
// Return true if perturbation set A is valid. A perturbation set is invalid if | ||
// it contains two (or more) actions for the same dimension or dimensions that | ||
// are larger than the queryCode's dimensions. | ||
inline bool perturbationValid( | ||
const std::vector<size_t> &A, | ||
const size_t numProj) | ||
inline bool perturbationValid(const std::vector<size_t>& A, | ||
const size_t numProj) | ||
{ | ||
// Stack allocation and initialization to 0 (bool check[numProj] = {0}) made | ||
// some compilers complain. We use new just to be safe. | ||
bool *check = new bool[numProj](); | ||
// some compilers complain, and std::vector might even be compressed (depends | ||
// on implementation) so this saves some space. | ||
std::vector<bool> check(numProj); | ||
|
||
for (size_t i = 0; i < A.size(); ++i) | ||
{ | ||
// Check that we only use valid dimensions. If not, vector is not valid. | ||
if ( A[i] >= 2*numProj) | ||
{ | ||
delete []check; | ||
if (A[i] >= 2 * numProj) | ||
return false; | ||
} | ||
|
||
// Check that we only see each dimension once. If not, vector is not valid. | ||
if (check[A[i] % numProj ] == 0) | ||
check[A[i] % numProj ] = 1; | ||
else | ||
{ | ||
delete []check; | ||
return false; | ||
} | ||
} | ||
delete []check; | ||
return true; | ||
} | ||
|
||
|
||
// Compute additional probing bins for a query | ||
template<typename SortPolicy> | ||
void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | ||
const arma::vec &queryCode, | ||
const arma::vec &queryCodeNotFloored, | ||
const arma::vec& queryCode, | ||
const arma::vec& queryCodeNotFloored, | ||
const size_t T, | ||
arma::mat &additionalProbingBins) const | ||
arma::mat& additionalProbingBins) const | ||
{ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra line :) |
||
// No additional bins requested. Our work is done. | ||
|
@@ -461,15 +440,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
// calculate scores = distances^2 | ||
arma::vec scores(2 * numProj); | ||
scores.rows(0, numProj - 1) = arma::pow(limLow, 2); | ||
scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2); | ||
scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2); | ||
|
||
// actions vector shows what transformation to apply to a coordinate | ||
arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...] | ||
|
||
actions.rows(0, numProj - 1) = // first numProj rows | ||
-1 * arma::ones< arma::Col<short int> > (numProj); // -1s | ||
|
||
actions.rows(numProj, 2 * numProj - 1) = // last numProj rows | ||
actions.rows(numProj, (2 * numProj) - 1) = // last numProj rows | ||
arma::ones< arma::Col<short int> > (numProj); // 1s | ||
|
||
|
||
|
@@ -491,7 +470,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
// find location and value of smallest element of scores vector | ||
double minscore = scores[0]; | ||
size_t minloc = 0; | ||
for (size_t s = 1; s < 2 * numProj; ++s) | ||
for (size_t s = 1; s < (2 * numProj); ++s) | ||
{ | ||
if (minscore > scores[s]) | ||
{ | ||
|
@@ -518,7 +497,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
|
||
double minscore2 = scores[0]; | ||
size_t minloc2 = 0; | ||
for (size_t s = 0; s < 2 * numProj; ++s) // here we can't start from 1 | ||
for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1 | ||
{ | ||
if ( minscore2 > scores[s] && s != minloc) //second smallest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra space after opening parenthesis. |
||
{ | ||
|
@@ -530,8 +509,8 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
// add or subtract 1 to create second-lowest scoring vector | ||
additionalProbingBins(positions[minloc2], 1) += actions[minloc2]; | ||
return; | ||
|
||
} | ||
// General case: more than 2 perturbation vectors require use of minheap. | ||
|
||
// sort everything in increasing order | ||
arma::Col<long long unsigned int> sortidx = arma::sort_index(scores); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep - I had trouble when using size_t, the compiler didn't like it. I replaced it with arma::uword. I'm not sure why size_t is not allowed... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was mentioned somewhere else, so forgive any redundancy, but probably |
||
|
@@ -561,23 +540,22 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
std::vector< std::vector<size_t> > perturbationSets; | ||
perturbationSets.push_back(Ao); // storage of perturbation sets | ||
|
||
|
||
// define a priority queue with CompareGreater as a minheap | ||
std::priority_queue< | ||
std::pair<double, size_t>, // contents: pairs of (score, index) | ||
std::vector< // container: vector of pairs | ||
std::pair<double, size_t> | ||
>, | ||
mlpack::neighbor::CompareGreater // comparator of pairs(compare scores) | ||
std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores) | ||
> minHeap; // our minheap | ||
|
||
// Start by adding the lowest scoring set to the minheap | ||
std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 ); | ||
minHeap.push(pair0); | ||
// std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 ); | ||
// minHeap.push(pair0); | ||
minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) ); | ||
|
||
double prevScore = 0; // store score of smallest inserted vector (for assert) | ||
// loop invariable: after pvec iterations, additionalProbingBins contains pvec | ||
// valid codes of the highest-scoring bins | ||
// valid codes of the lowest-scoring bins (bins most likely to contain | ||
// neighbors of the query). | ||
for (size_t pvec = 0; pvec < T; ++pvec) | ||
{ | ||
std::vector<size_t> Ai; | ||
|
@@ -587,17 +565,17 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
Ai = perturbationSets[ minHeap.top().second ]; | ||
minHeap.pop(); // .top() returns, .pop() removes | ||
|
||
|
||
// modify Ai (shift) | ||
std::vector<size_t> As = Ai; | ||
perturbationShift(As); | ||
if ( perturbationValid(As, numProj) ) | ||
{ | ||
perturbationSets.push_back(As); // add shifted set to sets | ||
std::pair<double, size_t> shifted( | ||
perturbationScore(As, scores), | ||
perturbationSets.size() - 1); // (score, position) pair for shift | ||
minHeap.push(shifted); | ||
minHeap.push( | ||
std::make_pair( | ||
perturbationScore(As, scores), | ||
perturbationSets.size() - 1) | ||
); | ||
} | ||
|
||
// modify Ai (expand) | ||
|
@@ -606,20 +584,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins( | |
if ( perturbationValid(Ae, numProj) ) | ||
{ | ||
perturbationSets.push_back(Ae); // add expanded set to sets | ||
std::pair<double, size_t> expanded( | ||
perturbationScore(Ae, scores), | ||
perturbationSets.size() - 1); // (score, position) pair for expand | ||
minHeap.push(expanded); | ||
minHeap.push( | ||
std::make_pair( | ||
perturbationScore(Ae, scores), | ||
perturbationSets.size() - 1) | ||
); | ||
} | ||
|
||
|
||
}while (! perturbationValid(Ai, numProj) );//Discard invalid perturbations | ||
|
||
// a valid perturbation must have higher score than previous valid ones, | ||
// meaning the bin it corresponds to is less likely to hold neighbors | ||
assert ( perturbationScore(Ai, scores) >= prevScore ); | ||
prevScore = perturbationScore(Ai, scores); | ||
|
||
// add perturbation vector to probing sequence if valid | ||
for (size_t i = 0; i < Ai.size(); ++i) | ||
additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]); | ||
|
@@ -658,18 +631,20 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | |
queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1); | ||
allProjInTables = arma::floor(queryCodesNotFloored/hashWidth); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should put a space on either side of the division operator |
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra line |
||
// Compute the hash value of each key of the query into a bucket of the | ||
// 'secondHashTable' using the 'secondHashWeights'. | ||
arma::rowvec hashVec = secondHashWeights.t() * allProjInTables; | ||
|
||
// mod and floor hashVec to compute 2nd-level codes | ||
arma::Row<size_t> hashVec = | ||
arma::conv_to< arma::Row<size_t> >:: | ||
from( secondHashWeights.t() * allProjInTables ); // typecast to floor | ||
// mod to compute 2nd-level codes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again I feel really pedantic saying this but if this comment can be a complete sentence it would be more clear. :) |
||
for (size_t i = 0; i < hashVec.n_elem; i++) | ||
hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize); | ||
hashVec[i] = (hashVec[i] % secondHashSize); | ||
|
||
Log::Assert(hashVec.n_elem == numTablesToSearch); | ||
|
||
// Compute hashVectors of additional probing bins | ||
arma::mat hashMat; | ||
arma::Mat<size_t> hashMat; | ||
if (T > 0) | ||
{ | ||
hashMat.zeros(T, numTablesToSearch); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I concatenate a (T x numTablesToSearch) matrix with a (1 x numTablesToSearch) vector, won't I get a T+1 row matrix? I understood from armadillo documnetation that |
||
|
@@ -678,18 +653,17 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | |
{ | ||
// construct this table's probing sequence of length T | ||
arma::mat additionalProbingBins; | ||
//arma::vec dummyBins; | ||
//dummyBins.zeros(T, 1); | ||
GetAdditionalProbingBins(allProjInTables.unsafe_col(i), | ||
queryCodesNotFloored.unsafe_col(i), | ||
T, | ||
additionalProbingBins); | ||
|
||
// map the probing bin to second hash table bins | ||
hashMat.col(i) = additionalProbingBins.t() * secondHashWeights; | ||
//hashMat.col(i) = dummyBins; | ||
hashMat.col(i) = | ||
arma::conv_to< arma::Col<size_t> >:: | ||
from(additionalProbingBins.t() * secondHashWeights); // typecast floor | ||
for (size_t p = 0; p < T; ++p) | ||
hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize); | ||
hashMat(p, i) = (hashMat(p, i) % secondHashSize); | ||
} | ||
|
||
// top row of hashMat is primary bins for each table | ||
|
@@ -772,7 +746,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable( | |
{ | ||
for (size_t p = 0; p < T + 1; ++p) | ||
{ | ||
size_t hashInd = (size_t) hashMat(p, i); // Find the query's bucket. | ||
size_t hashInd = hashMat(p, i); // Find the query's bucket. | ||
|
||
if (bucketContentSize[hashInd] > 0) | ||
{ | ||
|
@@ -801,7 +775,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet, | |
arma::Mat<size_t>& resultingNeighbors, | ||
arma::mat& distances, | ||
const size_t numTablesToSearch, | ||
size_t T) | ||
const size_t T) | ||
{ | ||
// Ensure the dimensionality of the query set is correct. | ||
if (querySet.n_rows != referenceSet->n_rows) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The spacing seems odd here; if we can't fit the arguments on the same line as the function name, then we should just indent them four spaces from the start of the line.