Skip to content

Commit

Permalink
Integrate recall benchmark functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jkni committed Aug 28, 2023
1 parent 080c25d commit 0401011
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
29 changes: 29 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
<artifactId>high-scale-lib</artifactId>
<version>1.0.6</version>
</dependency>
<dependency>
<groupId>io.jhdf</groupId>
<artifactId>jhdf</artifactId>
<version>0.6.10</version>
</dependency>
</dependencies>

<build>
Expand All @@ -72,6 +77,30 @@
</compilerArgs>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>bench</id>
<goals>
<goal>exec</goal>
</goals>
<configuration>
<executable>java</executable>
<arguments>
<argument>-classpath</argument>
<classpath/>
<!--<argument> - -add-modules=jdk.incubator.vector</argument # Reenable once vectorization provider fixed-->
<argument>-Xmx32G</argument>
<argument>-ea</argument>
<argument>com.github.jbellis.jvector.example.Bench</argument>
</arguments>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
224 changes: 224 additions & 0 deletions src/main/java/com/github/jbellis/jvector/example/Bench.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.github.jbellis.jvector.example;

import com.github.jbellis.jvector.graph.*;
import com.github.jbellis.jvector.vector.VectorEncoding;
import com.github.jbellis.jvector.vector.VectorSimilarityFunction;
import com.github.jbellis.jvector.vector.VectorUtil;
import io.jhdf.HdfFile;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* Tests GraphIndexes against vectors from various datasets
*/
public class Bench {
private static void testRecall(int M, int efConstruction, List<Integer> efSearchOptions, DataSet ds)
{
var ravv = new ListRandomAccessVectorValues(ds.baseVectors, ds.baseVectors.get(0).length);
var topK = ds.groundTruth.get(0).size();

// build the graphs on multiple threads
var start = System.nanoTime();
var builder = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, ds.similarityFunction, M, efConstruction, 1.5f, 1.4f);
var index = builder.build();
long buildNanos = System.nanoTime() - start;

int queryRuns = 10;
for (int overquery : efSearchOptions) {
start = System.nanoTime();
var pqr = performQueries(ds, ravv, index, topK, topK * overquery, queryRuns);
var recall = ((double) pqr.topKFound) / (queryRuns * ds.queryVectors.size() * topK);
System.out.format("Index M=%d ef=%d: top %d/%d recall %.4f, build %.2fs, query %.2fs. %s nodes visited%n",
M, efConstruction, topK, overquery, recall, buildNanos / 1_000_000_000.0, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited);
}
}

private static float normOf(float[] baseVector) {
float norm = 0;
for (float v : baseVector) {
norm += v * v;
}
return (float) Math.sqrt(norm);
}

private record ResultSummary(int topKFound, int nodesVisited) { }

private static long topKCorrect(int topK, int[] resultNodes, Set<Integer> gt) {
int count = Math.min(resultNodes.length, topK);
// stream the first count results into a Set
var resultSet = Arrays.stream(resultNodes, 0, count)
.boxed()
.collect(Collectors.toSet());
assert resultSet.size() == count : String.format("%s duplicate results out of %s", count - resultSet.size(), count);
return resultSet.stream().filter(gt::contains).count();
}

private static long topKCorrect(int topK, NeighborQueue nn, Set<Integer> gt) {
var a = new int[nn.size()];
for (int j = a.length - 1; j >= 0; j--) {
a[j] = nn.pop();
}
return topKCorrect(topK, a, gt);
}

private static ResultSummary performQueries(DataSet ds, ListRandomAccessVectorValues ravv, GraphIndex index, int topK, int efSearch, int queryRuns) {
assert efSearch >= topK;
LongAdder topKfound = new LongAdder();
LongAdder nodesVisited = new LongAdder();
for (int k = 0; k < queryRuns; k++) {
IntStream.range(0, ds.queryVectors.size()).parallel().forEach(i -> {
var queryVector = ds.queryVectors.get(i);
NeighborQueue nn;
nn = GraphSearcher.search(queryVector, efSearch, ravv, VectorEncoding.FLOAT32, ds.similarityFunction, index, null, Integer.MAX_VALUE);
var gt = ds.groundTruth.get(i);
var n = topKCorrect(topK, nn, gt);
topKfound.add(n);
nodesVisited.add(nn.visitedCount());
});
}
return new ResultSummary((int) topKfound.sum(), (int) nodesVisited.sum());
}

record DataSet(VectorSimilarityFunction similarityFunction, List<float[]> baseVectors, List<float[]> queryVectors, List<Set<Integer>> groundTruth) { }

private static DataSet load(String pathStr) {
// infer the similarity
VectorSimilarityFunction similarityFunction;
if (pathStr.contains("angular")) {
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
} else if (pathStr.contains("euclidean")) {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
} else {
throw new IllegalArgumentException("Unknown similarity function -- expected angular or euclidean for " + pathStr);
}

// read the data
float[][] baseVectors;
float[][] queryVectors;
int[][] groundTruth;
try (HdfFile hdf = new HdfFile(Paths.get(pathStr))) {
baseVectors = (float[][]) hdf.getDatasetByPath("train").getData();
queryVectors = (float[][]) hdf.getDatasetByPath("test").getData();
groundTruth = (int[][]) hdf.getDatasetByPath("neighbors").getData();
}

List<float[]> scrubbedBaseVectors;
List<float[]> scrubbedQueryVectors;
List<Set<Integer>> gtSet;
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
// verify that vectors are normalized and sane
scrubbedBaseVectors = new ArrayList<>(baseVectors.length);
scrubbedQueryVectors = new ArrayList<>(queryVectors.length);
gtSet = new ArrayList<>(groundTruth.length);
// remove zero vectors, noting that this will change the indexes of the ground truth answers
Map<Integer, Integer> rawToScrubbed = new HashMap<>();
{
int j = 0;
for (int i = 0; i < baseVectors.length; i++) {
float[] v = baseVectors[i];
if (Math.abs(normOf(v)) > 1e-5) {
scrubbedBaseVectors.add(v);
rawToScrubbed.put(i, j++);
}
}
}
for (int i = 0; i < queryVectors.length; i++) {
float[] v = queryVectors[i];
if (Math.abs(normOf(v)) > 1e-5) {
scrubbedQueryVectors.add(v);
var gt = new HashSet<Integer>();
for (int j = 0; j < groundTruth[i].length; j++) {
gt.add(rawToScrubbed.get(groundTruth[i][j]));
}
gtSet.add(gt);
}
}
// now that the zero vectors are removed, we can normalize
if (Math.abs(normOf(baseVectors[0]) - 1.0) > 1e-5) {
normalizeAll(scrubbedBaseVectors);
normalizeAll(scrubbedQueryVectors);
}
assert scrubbedQueryVectors.size() == gtSet.size();
} else {
scrubbedBaseVectors = Arrays.asList(baseVectors);
scrubbedQueryVectors = Arrays.asList(queryVectors);
gtSet = new ArrayList<>(groundTruth.length);
for (int[] gt : groundTruth) {
var gtSetForQuery = new HashSet<Integer>();
for (int i : gt) {
gtSetForQuery.add(i);
}
gtSet.add(gtSetForQuery);
}
}

System.out.format("%n%s: %d base and %d query vectors loaded, dimensions %d%n",
pathStr, scrubbedBaseVectors.size(), scrubbedQueryVectors.size(), scrubbedBaseVectors.get(0).length);

return new DataSet(similarityFunction, scrubbedBaseVectors, scrubbedQueryVectors, gtSet);
}

private static void normalizeAll(Iterable<float[]> vectors) {
for (float[] v : vectors) {
VectorUtil.l2normalize(v);
}
}

public static void main(String[] args) throws ExecutionException, InterruptedException {
System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory());
var files = List.of(
"hdf5/nytimes-256-angular.hdf5",
"hdf5/glove-100-angular.hdf5",
"hdf5/glove-200-angular.hdf5",
"hdf5/sift-128-euclidean.hdf5");
var mGrid = List.of(8, 12, 16, 24, 32, 48, 64);
var efConstructionGrid = List.of(60, 80, 100, 120, 160, 200, 400, 600, 800);
var efSearchFactor = List.of(1, 2, 4);
// large files not yet supported
// "hdf5/deep-image-96-angular.hdf5",
// "hdf5/gist-960-euclidean.hdf5");
for (var f : files) {
gridSearch(f, mGrid, efConstructionGrid, efSearchFactor);
}

// tiny dataset, don't waste time building a huge index
files = List.of("hdf5/fashion-mnist-784-euclidean.hdf5");
mGrid = List.of(8, 12, 16, 24);
efConstructionGrid = List.of(40, 60, 80, 100, 120, 160);
for (var f : files) {
gridSearch(f, mGrid, efConstructionGrid, efSearchFactor);
}
}

private static void gridSearch(String f, List<Integer> mGrid, List<Integer> efConstructionGrid, List<Integer> efSearchFactor) throws ExecutionException, InterruptedException {
var ds = load(f);
for (int M : mGrid) {
for (int beamWidth : efConstructionGrid) {
testRecall(M, beamWidth, efSearchFactor, ds);
}
}
}
}

0 comments on commit 0401011

Please sign in to comment.