-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate recall benchmark functionality
- Loading branch information
Showing
2 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
224 changes: 224 additions & 0 deletions
224
src/main/java/com/github/jbellis/jvector/example/Bench.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.github.jbellis.jvector.example; | ||
|
||
import com.github.jbellis.jvector.graph.*; | ||
import com.github.jbellis.jvector.vector.VectorEncoding; | ||
import com.github.jbellis.jvector.vector.VectorSimilarityFunction; | ||
import com.github.jbellis.jvector.vector.VectorUtil; | ||
import io.jhdf.HdfFile; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Paths; | ||
import java.util.*; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.atomic.LongAdder; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.IntStream; | ||
|
||
/** | ||
* Tests GraphIndexes against vectors from various datasets | ||
*/ | ||
public class Bench { | ||
private static void testRecall(int M, int efConstruction, List<Integer> efSearchOptions, DataSet ds) | ||
{ | ||
var ravv = new ListRandomAccessVectorValues(ds.baseVectors, ds.baseVectors.get(0).length); | ||
var topK = ds.groundTruth.get(0).size(); | ||
|
||
// build the graphs on multiple threads | ||
var start = System.nanoTime(); | ||
var builder = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, ds.similarityFunction, M, efConstruction, 1.5f, 1.4f); | ||
var index = builder.build(); | ||
long buildNanos = System.nanoTime() - start; | ||
|
||
int queryRuns = 10; | ||
for (int overquery : efSearchOptions) { | ||
start = System.nanoTime(); | ||
var pqr = performQueries(ds, ravv, index, topK, topK * overquery, queryRuns); | ||
var recall = ((double) pqr.topKFound) / (queryRuns * ds.queryVectors.size() * topK); | ||
System.out.format("Index M=%d ef=%d: top %d/%d recall %.4f, build %.2fs, query %.2fs. %s nodes visited%n", | ||
M, efConstruction, topK, overquery, recall, buildNanos / 1_000_000_000.0, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited); | ||
} | ||
} | ||
|
||
private static float normOf(float[] baseVector) { | ||
float norm = 0; | ||
for (float v : baseVector) { | ||
norm += v * v; | ||
} | ||
return (float) Math.sqrt(norm); | ||
} | ||
|
||
private record ResultSummary(int topKFound, int nodesVisited) { } | ||
|
||
private static long topKCorrect(int topK, int[] resultNodes, Set<Integer> gt) { | ||
int count = Math.min(resultNodes.length, topK); | ||
// stream the first count results into a Set | ||
var resultSet = Arrays.stream(resultNodes, 0, count) | ||
.boxed() | ||
.collect(Collectors.toSet()); | ||
assert resultSet.size() == count : String.format("%s duplicate results out of %s", count - resultSet.size(), count); | ||
return resultSet.stream().filter(gt::contains).count(); | ||
} | ||
|
||
private static long topKCorrect(int topK, NeighborQueue nn, Set<Integer> gt) { | ||
var a = new int[nn.size()]; | ||
for (int j = a.length - 1; j >= 0; j--) { | ||
a[j] = nn.pop(); | ||
} | ||
return topKCorrect(topK, a, gt); | ||
} | ||
|
||
private static ResultSummary performQueries(DataSet ds, ListRandomAccessVectorValues ravv, GraphIndex index, int topK, int efSearch, int queryRuns) { | ||
assert efSearch >= topK; | ||
LongAdder topKfound = new LongAdder(); | ||
LongAdder nodesVisited = new LongAdder(); | ||
for (int k = 0; k < queryRuns; k++) { | ||
IntStream.range(0, ds.queryVectors.size()).parallel().forEach(i -> { | ||
var queryVector = ds.queryVectors.get(i); | ||
NeighborQueue nn; | ||
nn = GraphSearcher.search(queryVector, efSearch, ravv, VectorEncoding.FLOAT32, ds.similarityFunction, index, null, Integer.MAX_VALUE); | ||
var gt = ds.groundTruth.get(i); | ||
var n = topKCorrect(topK, nn, gt); | ||
topKfound.add(n); | ||
nodesVisited.add(nn.visitedCount()); | ||
}); | ||
} | ||
return new ResultSummary((int) topKfound.sum(), (int) nodesVisited.sum()); | ||
} | ||
|
||
record DataSet(VectorSimilarityFunction similarityFunction, List<float[]> baseVectors, List<float[]> queryVectors, List<Set<Integer>> groundTruth) { } | ||
|
||
private static DataSet load(String pathStr) { | ||
// infer the similarity | ||
VectorSimilarityFunction similarityFunction; | ||
if (pathStr.contains("angular")) { | ||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; | ||
} else if (pathStr.contains("euclidean")) { | ||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; | ||
} else { | ||
throw new IllegalArgumentException("Unknown similarity function -- expected angular or euclidean for " + pathStr); | ||
} | ||
|
||
// read the data | ||
float[][] baseVectors; | ||
float[][] queryVectors; | ||
int[][] groundTruth; | ||
try (HdfFile hdf = new HdfFile(Paths.get(pathStr))) { | ||
baseVectors = (float[][]) hdf.getDatasetByPath("train").getData(); | ||
queryVectors = (float[][]) hdf.getDatasetByPath("test").getData(); | ||
groundTruth = (int[][]) hdf.getDatasetByPath("neighbors").getData(); | ||
} | ||
|
||
List<float[]> scrubbedBaseVectors; | ||
List<float[]> scrubbedQueryVectors; | ||
List<Set<Integer>> gtSet; | ||
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) { | ||
// verify that vectors are normalized and sane | ||
scrubbedBaseVectors = new ArrayList<>(baseVectors.length); | ||
scrubbedQueryVectors = new ArrayList<>(queryVectors.length); | ||
gtSet = new ArrayList<>(groundTruth.length); | ||
// remove zero vectors, noting that this will change the indexes of the ground truth answers | ||
Map<Integer, Integer> rawToScrubbed = new HashMap<>(); | ||
{ | ||
int j = 0; | ||
for (int i = 0; i < baseVectors.length; i++) { | ||
float[] v = baseVectors[i]; | ||
if (Math.abs(normOf(v)) > 1e-5) { | ||
scrubbedBaseVectors.add(v); | ||
rawToScrubbed.put(i, j++); | ||
} | ||
} | ||
} | ||
for (int i = 0; i < queryVectors.length; i++) { | ||
float[] v = queryVectors[i]; | ||
if (Math.abs(normOf(v)) > 1e-5) { | ||
scrubbedQueryVectors.add(v); | ||
var gt = new HashSet<Integer>(); | ||
for (int j = 0; j < groundTruth[i].length; j++) { | ||
gt.add(rawToScrubbed.get(groundTruth[i][j])); | ||
} | ||
gtSet.add(gt); | ||
} | ||
} | ||
// now that the zero vectors are removed, we can normalize | ||
if (Math.abs(normOf(baseVectors[0]) - 1.0) > 1e-5) { | ||
normalizeAll(scrubbedBaseVectors); | ||
normalizeAll(scrubbedQueryVectors); | ||
} | ||
assert scrubbedQueryVectors.size() == gtSet.size(); | ||
} else { | ||
scrubbedBaseVectors = Arrays.asList(baseVectors); | ||
scrubbedQueryVectors = Arrays.asList(queryVectors); | ||
gtSet = new ArrayList<>(groundTruth.length); | ||
for (int[] gt : groundTruth) { | ||
var gtSetForQuery = new HashSet<Integer>(); | ||
for (int i : gt) { | ||
gtSetForQuery.add(i); | ||
} | ||
gtSet.add(gtSetForQuery); | ||
} | ||
} | ||
|
||
System.out.format("%n%s: %d base and %d query vectors loaded, dimensions %d%n", | ||
pathStr, scrubbedBaseVectors.size(), scrubbedQueryVectors.size(), scrubbedBaseVectors.get(0).length); | ||
|
||
return new DataSet(similarityFunction, scrubbedBaseVectors, scrubbedQueryVectors, gtSet); | ||
} | ||
|
||
private static void normalizeAll(Iterable<float[]> vectors) { | ||
for (float[] v : vectors) { | ||
VectorUtil.l2normalize(v); | ||
} | ||
} | ||
|
||
public static void main(String[] args) throws ExecutionException, InterruptedException { | ||
System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory()); | ||
var files = List.of( | ||
"hdf5/nytimes-256-angular.hdf5", | ||
"hdf5/glove-100-angular.hdf5", | ||
"hdf5/glove-200-angular.hdf5", | ||
"hdf5/sift-128-euclidean.hdf5"); | ||
var mGrid = List.of(8, 12, 16, 24, 32, 48, 64); | ||
var efConstructionGrid = List.of(60, 80, 100, 120, 160, 200, 400, 600, 800); | ||
var efSearchFactor = List.of(1, 2, 4); | ||
// large files not yet supported | ||
// "hdf5/deep-image-96-angular.hdf5", | ||
// "hdf5/gist-960-euclidean.hdf5"); | ||
for (var f : files) { | ||
gridSearch(f, mGrid, efConstructionGrid, efSearchFactor); | ||
} | ||
|
||
// tiny dataset, don't waste time building a huge index | ||
files = List.of("hdf5/fashion-mnist-784-euclidean.hdf5"); | ||
mGrid = List.of(8, 12, 16, 24); | ||
efConstructionGrid = List.of(40, 60, 80, 100, 120, 160); | ||
for (var f : files) { | ||
gridSearch(f, mGrid, efConstructionGrid, efSearchFactor); | ||
} | ||
} | ||
|
||
private static void gridSearch(String f, List<Integer> mGrid, List<Integer> efConstructionGrid, List<Integer> efSearchFactor) throws ExecutionException, InterruptedException { | ||
var ds = load(f); | ||
for (int M : mGrid) { | ||
for (int beamWidth : efConstructionGrid) { | ||
testRecall(M, beamWidth, efSearchFactor, ds); | ||
} | ||
} | ||
} | ||
} |