Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Multiprobe LSH #691

Merged
merged 29 commits into from Jun 30, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8dd409d
Implementation of Multiprobe LSH, version 1
mentekid Jun 8, 2016
8264347
Merge branch 'lsh-computeRecall' into lsh-multiprobe
mentekid Jun 8, 2016
f24ac2c
Merge branch 'lsh-deterministictest' into lsh-multiprobe
mentekid Jun 8, 2016
27510fe
There is a bug in ReturnIndicesFromTable
mentekid Jun 8, 2016
a48c842
There is a bug in ReturnIndicesFromTable
mentekid Jun 8, 2016
4adaf8f
Fixes two minor bugs causing major headaches
mentekid Jun 8, 2016
8bc5ced
Adds a first multiprobe test
mentekid Jun 9, 2016
a048f46
Merge branch 'lsh-computeRecall' into lsh-multiprobe
mentekid Jun 9, 2016
d332ea1
Add code that replaces multiprobe codes with zeros for bottleneck pro…
mentekid Jun 13, 2016
00377a8
Fixes minor typo that caused multiprobe test to not happen
mentekid Jun 13, 2016
f0638a0
Adds a deterministic test for multiprobe
mentekid Jun 14, 2016
6d11521
Removes some redundant code, fixes comments
mentekid Jun 14, 2016
fa7f62d
Fixes style, adds documentation
mentekid Jun 17, 2016
ae81ee5
Fixes style issues, optimizes code a bit
mentekid Jun 22, 2016
840efe9
Merged changes from #690 and #675
mentekid Jun 22, 2016
2af985c
Fixes bug caused merge. Changes code in lsh_test.cpp that caused test…
mentekid Jun 23, 2016
75dead3
Uses arma::Row<char> instead of std::vector for perturbation sets
mentekid Jun 23, 2016
89b3c7b
Fixes bug in perturbationValid
mentekid Jun 23, 2016
79b954a
Replaces arma::Row<bool> with std::vector<bool> to conserve space
mentekid Jun 23, 2016
e2596c5
Workaround to avoid copy of hashMat
mentekid Jun 23, 2016
5d603b2
More style fixes
mentekid Jun 24, 2016
039af23
Typo fix
mentekid Jun 24, 2016
71eda99
Removes MultiprobeDeterministicTest because it is not correct.
mentekid Jun 24, 2016
2aa839c
Makes Perturbation functions members of LSHSearch
mentekid Jun 28, 2016
cc6a5c2
Fixes a lot of style issues
mentekid Jun 28, 2016
4ba6f6d
Re-adds Deterministic Multiprobe Test
mentekid Jun 28, 2016
71edfbb
Merge branch 'master' into lsh-multiprobe
mentekid Jun 29, 2016
d418c09
Merge branch 'master' into lsh-multiprobe
mentekid Jun 29, 2016
f1e11fd
Fixes style issues in LSH Tests and LSH Class
mentekid Jun 29, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/mlpack/methods/lsh/lsh_search.hpp
Expand Up @@ -154,13 +154,15 @@ class LSHSearch
* available without having to build hashing for every table size.
* By default, this is set to zero in which case all tables are
* considered.
* @param T The number of additional probing bins to examine with multiprobe
* LSH. If T = 0, classic single-probe LSH is run (default).
*/
void Search(const arma::mat& querySet,
const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
const size_t numTablesToSearch = 0,
size_t T = 0);
const size_t T = 0);

/**
* Compute the nearest neighbors and store the output in the given matrices.
Expand Down
122 changes: 48 additions & 74 deletions src/mlpack/methods/lsh/lsh_search_impl.hpp
Expand Up @@ -340,25 +340,11 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceIndex, distance);
}


// Compare class for <double, size_t> pair, used in GetAdditionalProbingBins.
class CompareGreater
{
public:
bool operator()(
std::pair<double, size_t> p1,
std::pair<double, size_t> p2){
//only compare the double values
return p1.first > p2.first;
}
};

//Returns the score of a perturbation vector generated by perturbation set A.
//The score of a pertubation set (vector) is the sum of scores of the
//participating actions.
inline double perturbationScore(
const std::vector<size_t> &A,
const arma::vec &scores)
inline double perturbationScore(const std::vector<size_t>& A,
const arma::vec& scores)
{
double score = 0.0;
for (size_t i = 0; i < A.size(); ++i)
Expand All @@ -368,7 +354,7 @@ inline double perturbationScore(

// Inline function used by GetAdditionalProbingBins. The vector shift operation
// replaces the largest element of a vector A with (largest element) + 1.
inline void perturbationShift(std::vector<size_t> &A)
inline void perturbationShift(std::vector<size_t>& A)
{
size_t max_pos = 0;
size_t max = A[0];
Expand All @@ -386,56 +372,49 @@ inline void perturbationShift(std::vector<size_t> &A)
// Inline function used by GetAdditionalProbingBins. The vector expansion
// operation adds the element [1 + (largest_element)] to a vector A, where
// largest_element is the largest element of A.
inline void perturbationExpand(std::vector<size_t> &A)
inline void perturbationExpand(std::vector<size_t>& A)
{
size_t max = A[0];
for (size_t i = 1; i < A.size(); ++i)
if (A[i] > max)
max = A[i];
A.push_back(max+1);
A.push_back(max + 1);
}

// Return true if perturbation set A is valid. A perturbation set is invalid if
// it contains two (or more) actions for the same dimension or dimensions that
// are larger than the queryCode's dimensions.
inline bool perturbationValid(
const std::vector<size_t> &A,
const size_t numProj)
inline bool perturbationValid(const std::vector<size_t>& A,
const size_t numProj)
{
// Stack allocation and initialization to 0 (bool check[numProj] = {0}) made
// some compilers complain. We use new just to be safe.
bool *check = new bool[numProj]();
// some compilers complain, and std::vector might even be compressed (depends
// on implementation) so this saves some space.
std::vector<bool> check(numProj);

for (size_t i = 0; i < A.size(); ++i)
{
// Check that we only use valid dimensions. If not, vector is not valid.
if ( A[i] >= 2*numProj)
{
delete []check;
if (A[i] >= 2 * numProj)
return false;
}

// Check that we only see each dimension once. If not, vector is not valid.
if (check[A[i] % numProj ] == 0)
check[A[i] % numProj ] = 1;
else
{
delete []check;
return false;
}
}
delete []check;
return true;
}


// Compute additional probing bins for a query
template<typename SortPolicy>
void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
const arma::vec &queryCode,
const arma::vec &queryCodeNotFloored,
const arma::vec& queryCode,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spacing seems odd here; if we can't fit the arguments on the same line as the function name, then we should just indent them four spaces from the start of the line.

const arma::vec& queryCodeNotFloored,
const size_t T,
arma::mat &additionalProbingBins) const
arma::mat& additionalProbingBins) const
{

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra line :)

// No additional bins requested. Our work is done.
Expand All @@ -461,15 +440,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// calculate scores = distances^2
arma::vec scores(2 * numProj);
scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
scores.rows(numProj, 2 * numProj - 1) = arma::pow(limHigh, 2);
scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);

// actions vector shows what transformation to apply to a coordinate
arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]

actions.rows(0, numProj - 1) = // first numProj rows
-1 * arma::ones< arma::Col<short int> > (numProj); // -1s

actions.rows(numProj, 2 * numProj - 1) = // last numProj rows
actions.rows(numProj, (2 * numProj) - 1) = // last numProj rows
arma::ones< arma::Col<short int> > (numProj); // 1s


Expand All @@ -491,7 +470,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// find location and value of smallest element of scores vector
double minscore = scores[0];
size_t minloc = 0;
for (size_t s = 1; s < 2 * numProj; ++s)
for (size_t s = 1; s < (2 * numProj); ++s)
{
if (minscore > scores[s])
{
Expand All @@ -518,7 +497,7 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(

double minscore2 = scores[0];
size_t minloc2 = 0;
for (size_t s = 0; s < 2 * numProj; ++s) // here we can't start from 1
for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1
{
if ( minscore2 > scores[s] && s != minloc) //second smallest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra space after opening parenthesis.

{
Expand All @@ -530,8 +509,8 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
// add or subtract 1 to create second-lowest scoring vector
additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
return;

}
// General case: more than 2 perturbation vectors require use of minheap.

// sort everything in increasing order
arma::Col<long long unsigned int> sortidx = arma::sort_index(scores);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why long long unsigned int? Either size_t or possibly arma::uword (I prefer the former but sometimes it is not possible) should work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - I had trouble when using size_t, the compiler didn't like it. I replaced it with arma::uword. I'm not sure why size_t is not allowed...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was mentioned somewhere else, so forgive any redundancy, but probably arma::uword should be used here instead of long long unsigned int. You can just use arma::uvec since that is a typedef of arma::Col<arma::uword>.

Expand Down Expand Up @@ -561,23 +540,22 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
std::vector< std::vector<size_t> > perturbationSets;
perturbationSets.push_back(Ao); // storage of perturbation sets


// define a priority queue with CompareGreater as a minheap
std::priority_queue<
std::pair<double, size_t>, // contents: pairs of (score, index)
std::vector< // container: vector of pairs
std::pair<double, size_t>
>,
mlpack::neighbor::CompareGreater // comparator of pairs(compare scores)
std::greater< std::pair<double, size_t> > // comparator of pairs(compare scores)
> minHeap; // our minheap

// Start by adding the lowest scoring set to the minheap
std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
minHeap.push(pair0);
// std::pair<double, size_t> pair0( perturbationScore(Ao, scores), 0 );
// minHeap.push(pair0);
minHeap.push( std::make_pair(perturbationScore(Ao, scores), 0) );

double prevScore = 0; // store score of smallest inserted vector (for assert)
// loop invariable: after pvec iterations, additionalProbingBins contains pvec
// valid codes of the highest-scoring bins
// valid codes of the lowest-scoring bins (bins most likely to contain
// neighbors of the query).
for (size_t pvec = 0; pvec < T; ++pvec)
{
std::vector<size_t> Ai;
Expand All @@ -587,17 +565,17 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
Ai = perturbationSets[ minHeap.top().second ];
minHeap.pop(); // .top() returns, .pop() removes


// modify Ai (shift)
std::vector<size_t> As = Ai;
perturbationShift(As);
if ( perturbationValid(As, numProj) )
{
perturbationSets.push_back(As); // add shifted set to sets
std::pair<double, size_t> shifted(
perturbationScore(As, scores),
perturbationSets.size() - 1); // (score, position) pair for shift
minHeap.push(shifted);
minHeap.push(
std::make_pair(
perturbationScore(As, scores),
perturbationSets.size() - 1)
);
}

// modify Ai (expand)
Expand All @@ -606,20 +584,15 @@ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
if ( perturbationValid(Ae, numProj) )
{
perturbationSets.push_back(Ae); // add expanded set to sets
std::pair<double, size_t> expanded(
perturbationScore(Ae, scores),
perturbationSets.size() - 1); // (score, position) pair for expand
minHeap.push(expanded);
minHeap.push(
std::make_pair(
perturbationScore(Ae, scores),
perturbationSets.size() - 1)
);
}


}while (! perturbationValid(Ai, numProj) );//Discard invalid perturbations

// a valid perturbation must have higher score than previous valid ones,
// meaning the bin it corresponds to is less likely to hold neighbors
assert ( perturbationScore(Ai, scores) >= prevScore );
prevScore = perturbationScore(Ai, scores);

// add perturbation vector to probing sequence if valid
for (size_t i = 0; i < Ai.size(); ++i)
additionalProbingBins(positions(Ai[i]), pvec) += actions(Ai[i]);
Expand Down Expand Up @@ -658,18 +631,20 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
allProjInTables = arma::floor(queryCodesNotFloored/hashWidth);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put a space on either side of the division operator /.



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra line

// Compute the hash value of each key of the query into a bucket of the
// 'secondHashTable' using the 'secondHashWeights'.
arma::rowvec hashVec = secondHashWeights.t() * allProjInTables;

// mod and floor hashVec to compute 2nd-level codes
arma::Row<size_t> hashVec =
arma::conv_to< arma::Row<size_t> >::
from( secondHashWeights.t() * allProjInTables ); // typecast to floor
// mod to compute 2nd-level codes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I feel really pedantic saying this but if this comment can be a complete sentence it would be more clear. :)

for (size_t i = 0; i < hashVec.n_elem; i++)
hashVec[i] = (double) ((size_t) hashVec[i] % secondHashSize);
hashVec[i] = (hashVec[i] % secondHashSize);

Log::Assert(hashVec.n_elem == numTablesToSearch);

// Compute hashVectors of additional probing bins
arma::mat hashMat;
arma::Mat<size_t> hashMat;
if (T > 0)
{
hashMat.zeros(T, numTablesToSearch);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be T + 1 rows, since T is the number of additional bins that are probed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I concatenate a (T x numTablesToSearch) matrix with a (1 x numTablesToSearch) vector, won't I get a T+1 row matrix?

I understood from armadillo documnetation that join_vert resizes hashMat. Am I throwing the last vector away like that?

Expand All @@ -678,18 +653,17 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
{
// construct this table's probing sequence of length T
arma::mat additionalProbingBins;
//arma::vec dummyBins;
//dummyBins.zeros(T, 1);
GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
queryCodesNotFloored.unsafe_col(i),
T,
additionalProbingBins);

// map the probing bin to second hash table bins
hashMat.col(i) = additionalProbingBins.t() * secondHashWeights;
//hashMat.col(i) = dummyBins;
hashMat.col(i) =
arma::conv_to< arma::Col<size_t> >::
from(additionalProbingBins.t() * secondHashWeights); // typecast floor
for (size_t p = 0; p < T; ++p)
hashMat(p, i) = (double) ((size_t) hashMat(p, i) % secondHashSize);
hashMat(p, i) = (hashMat(p, i) % secondHashSize);
}

// top row of hashMat is primary bins for each table
Expand Down Expand Up @@ -772,7 +746,7 @@ void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
{
for (size_t p = 0; p < T + 1; ++p)
{
size_t hashInd = (size_t) hashMat(p, i); // Find the query's bucket.
size_t hashInd = hashMat(p, i); // Find the query's bucket.

if (bucketContentSize[hashInd] > 0)
{
Expand Down Expand Up @@ -801,7 +775,7 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
const size_t numTablesToSearch,
size_t T)
const size_t T)
{
// Ensure the dimensionality of the query set is correct.
if (querySet.n_rows != referenceSet->n_rows)
Expand Down