From e043a63cf17b698e4c4d6aacbf954d08554bb697 Mon Sep 17 00:00:00 2001 From: Jonathan Shook Date: Fri, 19 Sep 2025 21:32:52 +0000 Subject: [PATCH 01/12] support OnHeapGraphReconstruction get incremental graph generation working add serialization for NeighborsCache add ord-mapping logic from PQ vectors to RAVV add ord mapping from RAVV to graph creation Signed-off-by: Samuel Herman --- .../jvector/graph/ConcurrentNeighborMap.java | 6 +- .../jvector/graph/GraphIndexBuilder.java | 52 +++- .../jvector/graph/OnHeapGraphIndex.java | 76 ++++- .../graph/disk/NeighborsScoreCache.java | 117 ++++++++ .../graph/similarity/BuildScoreProvider.java | 19 +- .../DefaultSearchScoreProvider.java | 16 ++ .../jvector/quantization/PQVectors.java | 21 +- .../jvector/graph/OnHeapGraphIndexTest.java | 264 ++++++++++++++++++ .../similarity/BuildScoreProviderTest.java | 72 +++++ 9 files changed, 633 insertions(+), 10 deletions(-) create mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda756..ad6d137a0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) { } } - private static class NeighborIterator implements NodesIterator { + public static class NeighborIterator implements NodesIterator { private final NodeArray neighbors; private int i; @@ -374,5 +374,9 @@ public boolean hasNext() { public int nextInt() { return neighbors.getNode(i++); } + + public NodeArray merge(NodeArray other) { + return NodeArray.merge(neighbors, other); + } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 2cff2de4a..f1bd45d5e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -20,6 +20,8 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.SearchResult.NodeScore; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -30,6 +32,7 @@ import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.agrona.collections.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -338,6 +341,50 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.rng = new Random(0); } + /** + * Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk + * copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object + * + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * organized by levels and nodes. + * @param beamWidth the width of the beam used during the graph building process. + * @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * @param refineFinalGraph whether to perform a refinement step on the final graph structure. + * @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building. + * @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building. + * + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex onDiskGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + this.scoreProvider = buildScoreProvider; + this.neighborOverflow = neighborOverflow; + this.dimension = onDiskGraphIndex.getDimension(); + this.alpha = alpha; + this.addHierarchy = addHierarchy; + this.refineFinalGraph = refineFinalGraph; + this.beamWidth = beamWidth; + this.simdExecutor = simdExecutor; + this.parallelExecutor = parallelExecutor; + + this.graph = OnHeapGraphIndex.convertToHeap(onDiskGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha); + + this.searchers = ExplicitThreadLocal.withInitial(() -> { + var gs = new GraphSearcher(graph); + gs.usePruning(false); + return gs; + }); + + // in scratch, we store candidates in reverse order: worse candidates are first + this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + + this.rng = new Random(0); + } + // used by Cassandra when it fine-tunes the PQ codebook public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, @@ -750,7 +797,7 @@ public synchronized long removeDeletedNodes() { return memorySize; } - private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray concurrent) { + private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { // if either natural or concurrent is empty, skip the merge NodeArray toMerge; if (concurrent.size() == 0) { @@ -761,7 +808,7 @@ private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray toMerge = NodeArray.merge(natural, concurrent); } // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones - graph.addEdges(level, nodeId, toMerge, neighborOverflow); + graph.addEdges(layer, nodeId, toMerge, neighborOverflow); } private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) { @@ -876,6 +923,7 @@ private void loadV4(RandomAccessReader in) throws IOException { graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); } + @Deprecated private void loadV3(RandomAccessReader in, int size) throws IOException { if (graph.size() != 0) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 7ddbf7897..3cef67cc6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -25,6 +25,8 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.BitSet; @@ -33,6 +35,11 @@ import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.util.SparseIntMap; import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; +import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.util.*; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; import java.io.DataOutput; @@ -40,7 +47,10 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; @@ -441,7 +451,7 @@ public boolean hasNext() { } } - private class FrozenView implements View { + public class FrozenView implements View { @Override public NodesIterator getNeighborsIterator(int level, int node) { return OnHeapGraphIndex.this.getNeighborsIterator(level, node); @@ -598,4 +608,68 @@ private void ensureCapacity(int node) { } } } + + /** + * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, + * along with other configuration details, from disk-based storage to heap-based storage. + * + * @param diskIndex the disk-based index to be converted + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * organized by levels and nodes. + * @param bsp The build score provider to be used for + * @param overflowRatio usually 1.2f + * @param alpha usually 1.2f + * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory + * @throws IOException if an I/O error occurs during the conversion process + */ + public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, + NeighborsScoreCache perLevelNeighborsScoreCache, + BuildScoreProvider bsp, + float overflowRatio, + float alpha) throws IOException { + + // Create a new OnHeapGraphIndex with the appropriate configuration + List maxDegrees = new ArrayList<>(); + for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { + maxDegrees.add(diskIndex.getDegree(level)); + } + + OnHeapGraphIndex heapIndex = new OnHeapGraphIndex( + maxDegrees, + overflowRatio, // overflow ratio + new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage + ); + + // Copy all nodes and their connections from disk to heap + try (var view = diskIndex.getView()) { + // Copy nodes level by level + for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { + final NodesIterator nodesIterator = diskIndex.getNodes(level); + final Map levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level); + if (levelNeighborsScoreCache == null) { + throw new IllegalStateException("No neighbors score cache found for level " + level); + } + if (nodesIterator.size() != levelNeighborsScoreCache.size()) { + throw new IllegalStateException("Neighbors score cache size mismatch for level " + level + + ". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size()); + } + + while (nodesIterator.hasNext()) { + int nodeId = nodesIterator.next(); + + // Copy neighbors + final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy(); + + // Add the node with its neighbors + heapIndex.addNode(level, nodeId, neighbors); + heapIndex.markComplete(new NodeAtLevel(level, nodeId)); + } + } + + // Set the entry point + heapIndex.updateEntryNode(view.entryNode()); + } + + return heapIndex; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java new file mode 100644 index 000000000..55fdcf082 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java @@ -0,0 +1,117 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.IndexWriter; +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.*; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Cache containing pre-computed neighbor scores, organized by levels and nodes. + *

+ * This cache bridges the gap between {@link OnDiskGraphIndex} and {@link OnHeapGraphIndex}: + *

    + *
  • {@link OnDiskGraphIndex} stores only neighbor IDs (not scores) for space efficiency
  • + *
  • {@link OnHeapGraphIndex} requires neighbor scores for pruning operations
  • + *
+ *

+ * When converting from disk to heap representation, this cache avoids expensive score + * recomputation by providing pre-calculated neighbor scores for all graph levels. + * + * @see OnHeapGraphIndex#convertToHeap(OnDiskGraphIndex, NeighborsScoreCache, BuildScoreProvider, float, float) + * + * This is particularly useful when merging new nodes into an existing graph. + * @see GraphIndexBuilder#buildAndMergeNewNodes(OnDiskGraphIndex, NeighborsScoreCache, RandomAccessVectorValues, BuildScoreProvider, int, int[], int, float, float, boolean) + */ +public class NeighborsScoreCache { + private final Map> perLevelNeighborsScoreCache; + + public NeighborsScoreCache(OnHeapGraphIndex graphIndex) throws IOException { + try (OnHeapGraphIndex.FrozenView view = graphIndex.getFrozenView()) { + final Map> perLevelNeighborsScoreCache = new HashMap<>(graphIndex.getMaxLevel() + 1); + for (int level = 0; level <= graphIndex.getMaxLevel(); level++) { + final Map levelNeighborsScores = new HashMap<>(graphIndex.size(level) + 1); + final NodesIterator nodesIterator = graphIndex.getNodes(level); + while (nodesIterator.hasNext()) { + final int nodeId = nodesIterator.nextInt(); + + ConcurrentNeighborMap.NeighborIterator neighborIterator = (ConcurrentNeighborMap.NeighborIterator) view.getNeighborsIterator(level, nodeId); + final NodeArray neighbours = neighborIterator.merge(new NodeArray(neighborIterator.size())); + levelNeighborsScores.put(nodeId, neighbours); + } + + perLevelNeighborsScoreCache.put(level, levelNeighborsScores); + } + + this.perLevelNeighborsScoreCache = perLevelNeighborsScoreCache; + } + } + + public NeighborsScoreCache(RandomAccessReader in) throws IOException { + final int numberOfLevels = in.readInt(); + perLevelNeighborsScoreCache = new HashMap<>(numberOfLevels); + for (int i = 0; i < numberOfLevels; i++) { + final int level = in.readInt(); + final int numberOfNodesInLevel = in.readInt(); + final Map levelNeighborsScores = new HashMap<>(numberOfNodesInLevel); + for (int j = 0; j < numberOfNodesInLevel; j++) { + final int nodeId = in.readInt(); + final int numberOfNeighbors = in.readInt(); + final NodeArray nodeArray = new NodeArray(numberOfNeighbors); + for (int k = 0; k < numberOfNeighbors; k++) { + final int neighborNodeId = in.readInt(); + final float neighborScore = in.readFloat(); + nodeArray.insertSorted(neighborNodeId, neighborScore); + } + levelNeighborsScores.put(nodeId, nodeArray); + } + perLevelNeighborsScoreCache.put(level, levelNeighborsScores); + } + } + + public void write(IndexWriter out) throws IOException { + out.writeInt(perLevelNeighborsScoreCache.size()); // write the number of levels + for (Map.Entry> levelNeighborsScores : perLevelNeighborsScoreCache.entrySet()) { + final int level = levelNeighborsScores.getKey(); + out.writeInt(level); + out.writeInt(levelNeighborsScores.getValue().size()); // write the number of nodes in the level + // Write the neighborhoods for each node in the level + for (Map.Entry nodeArrayEntry : levelNeighborsScores.getValue().entrySet()) { + final int nodeId = nodeArrayEntry.getKey(); + out.writeInt(nodeId); + final NodeArray nodeArray = nodeArrayEntry.getValue(); + out.writeInt(nodeArray.size()); // write the number of neighbors for the node + // Write the nodeArray(neighbors) + for (int i = 0; i < nodeArray.size(); i++) { + out.writeInt(nodeArray.getNode(i)); + out.writeFloat(nodeArray.getScore(i)); + } + } + } + } + + public Map getNeighborsScoresInLevel(int level) { + return perLevelNeighborsScoreCache.get(level); + } + + +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index b8ec5fa5f..4656a4fcc 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -25,6 +25,8 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import java.util.stream.IntStream; + /** * Encapsulates comparing node distances for GraphIndexBuilder. */ @@ -83,8 +85,17 @@ public interface BuildScoreProvider { /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * + * Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function. */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { + return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction); + } + + /** + * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + */ + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. var vectors = ravv.threadLocalSupplier(); @@ -113,22 +124,22 @@ public VectorFloat approximateCentroid() { @Override public SearchScoreProvider searchProviderFor(VectorFloat vector) { var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc); } @Override public SearchScoreProvider searchProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); return searchProviderFor(v); } @Override public SearchScoreProvider diversityProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(v, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc); } }; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java index 0754b39d7..de46762b2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java @@ -78,4 +78,20 @@ public float similarityTo(int node2) { }; return new DefaultSearchScoreProvider(sf); } + + /** + * A SearchScoreProvider for a single-pass search based on exact similarity. + * Generally only suitable when your RandomAccessVectorValues is entirely in-memory, + * e.g. during construction. + */ + public static DefaultSearchScoreProvider exact(VectorFloat v, int[] graphToRavvOrdMap ,VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { + // don't use ESF.reranker, we need thread safety here + var sf = new ScoreFunction.ExactScoreFunction() { + @Override + public float similarityTo(int node2) { + return vsf.compare(v, ravv.getVector(graphToRavvOrdMap[node2])); + } + }; + return new DefaultSearchScoreProvider(sf); + } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f66a2c6e4..73e59b20f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -77,6 +77,8 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel * and split into chunks to avoid exceeding the maximum array size. * + * This is a helper method for the special case where the ordinals mapping in the graph and the RAVV/PQVectors are the same. + * * @param pq the ProductQuantization to use * @param vectorCount the number of vectors to encode * @param ravv the RandomAccessVectorValues to encode @@ -84,6 +86,21 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * @return the PQVectors instance */ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { + return encodeAndBuild(pq, vectorCount, IntStream.range(0, vectorCount).toArray(), ravv, simdExecutor); + } + + /** + * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel + * and split into chunks to avoid exceeding the maximum array size. + * + * @param pq the ProductQuantization to use + * @param vectorCount the number of vectors to encode + * @param ravv the RandomAccessVectorValues to encode + * @param simdExecutor the ForkJoinPool to use for SIMD operations + * @param ordinalsMapping the graph ordinals to RAVV mapping + * @return the PQVectors instance + */ + public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, int[] ordinalsMapping, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { int compressedDimension = pq.compressedVectorSize(); PQLayout layout = new PQLayout(vectorCount,compressedDimension); final ByteSequence[] chunks = new ByteSequence[layout.totalChunks]; @@ -98,13 +115,13 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect // The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams // and then we are guaranteed safe publication because we join the thread after completion. var ravvCopy = ravv.threadLocalSupplier(); - simdExecutor.submit(() -> IntStream.range(0, ravv.size()) + simdExecutor.submit(() -> IntStream.range(0, ordinalsMapping.length) .parallel() .forEach(ordinal -> { // Retrieve the slice and mutate it. var localRavv = ravvCopy.get(); var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount()); - var vector = localRavv.getVector(ordinal); + var vector = localRavv.getVector(ordinalsMapping[ordinal]); if (vector != null) pq.encodeTo(vector, slice); else diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java new file mode 100644 index 000000000..b4efb543d --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -0,0 +1,264 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; +import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.logging.log4j.Logger; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class OnHeapGraphIndexTest extends RandomizedTest { + private final static Logger log = org.apache.logging.log4j.LogManager.getLogger(OnHeapGraphIndexTest.class); + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int NUM_BASE_VECTORS = 100; + private static final int NUM_NEW_VECTORS = 100; + private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS; + private static final int DIMENSION = 16; + private static final int M = 8; + private static final int BEAM_WIDTH = 100; + private static final float ALPHA = 1.2f; + private static final float NEIGHBOR_OVERFLOW = 1.2f; + private static final boolean ADD_HIERARCHY = false; + private static final int TOP_K = 10; + + private Path testDirectory; + + private ArrayList> baseVectors; + private ArrayList> newVectors; + private ArrayList> allVectors; + private RandomAccessVectorValues baseVectorsRavv; + private RandomAccessVectorValues newVectorsRavv; + private RandomAccessVectorValues allVectorsRavv; + private VectorFloat queryVector; + private int[] groundTruthBaseVectors; + private int[] groundTruthAllVectors; + private BuildScoreProvider baseBuildScoreProvider; + private BuildScoreProvider newBuildScoreProvider; + private BuildScoreProvider allBuildScoreProvider; + private OnHeapGraphIndex baseGraphIndex; + private OnHeapGraphIndex newGraphIndex; + private OnHeapGraphIndex allGraphIndex; + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + baseVectors = new ArrayList<>(NUM_BASE_VECTORS); + newVectors = new ArrayList<>(NUM_NEW_VECTORS); + allVectors = new ArrayList<>(NUM_ALL_VECTORS); + for (int i = 0; i < NUM_BASE_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + baseVectors.add(vector); + allVectors.add(vector); + } + for (int i = 0; i < NUM_NEW_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + newVectors.add(vector); + allVectors.add(vector); + } + + // wrap the raw vectors in a RandomAccessVectorValues + baseVectorsRavv = new ListRandomAccessVectorValues(baseVectors, DIMENSION); + newVectorsRavv = new ListRandomAccessVectorValues(newVectors, DIMENSION); + allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION); + + queryVector = createRandomVector(DIMENSION); + groundTruthBaseVectors = getGroundTruth(baseVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + + // score provider using the raw, in-memory vectors + baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + newBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(newVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider, + baseVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider, + allVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + + baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv); + allGraphIndex = allGraphIndexBuilder.build(allVectorsRavv); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + + /** + * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex} + * Make sure that both graphs are equivalent + * @throws IOException + */ + @Test + public void testReconstructionOfOnHeapGraphIndex() throws IOException { + var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var neighborsScoreCacheOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + NeighborsScoreCache.class.getSimpleName()); + log.info("Writing graph to {}", graphOutputPath); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); + + log.info("Writing neighbors score cache to {}", neighborsScoreCacheOutputPath); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + try (SimpleWriter writer = new SimpleWriter(neighborsScoreCacheOutputPath.toAbsolutePath())) { + neighborsScoreCache.write(writer); + } + + log.info("Reading neighbors score cache from {}", neighborsScoreCacheOutputPath); + final NeighborsScoreCache neighborsScoreCacheRead; + try (var readerSupplier = new SimpleMappedReader.Supplier(neighborsScoreCacheOutputPath.toAbsolutePath())) { + neighborsScoreCacheRead = new NeighborsScoreCache(readerSupplier.get()); + } + + try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { + TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + try (var onDiskView = onDiskGraph.getView()) { + validateVectors(onDiskView, baseVectorsRavv); + } + + OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.convertToHeap(onDiskGraph, neighborsScoreCacheRead, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA); + TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); + TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex); + + } + } + + /** + * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes} + * Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex} + */ + @Test + public void testIncrementalInsertionFromOnDiskIndex() throws IOException { + var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + log.info("Writing graph to {}", outputPath); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { + TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + // We will create a trivial 1:1 mapping between the new graph and the ravv + final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); + OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + + // Verify that the recall is similar + float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + float recallFromAllGraphIndex = calculateRecall(allGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + Assert.assertEquals(recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.01f); + } + } + + public static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVectorValues ravv) { + for (int i = 0; i < view.size(); i++) { + assertEquals("Incorrect vector at " + i, ravv.getVector(i), view.getVector(i)); + } + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } + + /** + * Get the ground truth for a query vector + * @param ravv the vectors to search + * @param queryVector the query vector + * @param topK the number of results to return + * @param similarityFunction the similarity function to use + + * @return the ground truth + */ + private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat queryVector, int topK, VectorSimilarityFunction similarityFunction) { + var exactResults = new ArrayList(); + for (int i = 0; i < ravv.size(); i++) { + float similarityScore = similarityFunction.compare(queryVector, ravv.getVector(i)); + exactResults.add(new SearchResult.NodeScore(i, similarityScore)); + } + exactResults.sort((a, b) -> Float.compare(b.score, a.score)); + return exactResults.stream().limit(topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + } + + private static float calculateRecall(OnHeapGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { + try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){ + SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector); + var searchResults = graphSearcher.search(ssp, k, Bits.ALL); + var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet()); + return calculateRecall(predicted, groundTruth, k); + } + } + /** + * Calculate the recall for a set of predicted results + * @param predicted the predicted results + * @param groundTruth the ground truth + * @param k the number of results to consider + * @return the recall + */ + private static float calculateRecall(Set predicted, int[] groundTruth, int k) { + int hits = 0; + int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); + + for (int i = 0; i < actualK; i++) { + for (int j = 0; j < actualK; j++) { + if (predicted.contains(groundTruth[j])) { + hits++; + break; + } + } + } + + return ((float) hits) / (float) actualK; + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java new file mode 100644 index 000000000..4942b8efb --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.similarity; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class BuildScoreProviderTest { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** + * Test that the ordinal mapping is correctly applied when creating search and diversity score providers. + */ + @Test + public void testOrdinalMapping() { + final VectorSimilarityFunction vsf = VectorSimilarityFunction.DOT_PRODUCT; + + // Create test vectors + final List> vectors = new ArrayList<>(); + vectors.add(vts.createFloatVector(new float[]{1.0f, 0.0f})); + vectors.add(vts.createFloatVector(new float[]{0.0f, 1.0f})); + vectors.add(vts.createFloatVector(new float[]{-1.0f, 0.0f})); + var ravv = new ListRandomAccessVectorValues(vectors, 2); + + // Create non-identity mapping: graph node 0 -> ravv ordinal 2, graph node 1 -> ravv ordinal 0, graph node 2 -> ravv ordinal 1 + int[] graphToRavvOrdMap = {2, 0, 1}; + + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, graphToRavvOrdMap, vsf); + + // Test that searchProviderFor(graphNode) uses the correct RAVV ordinal + var ssp0 = bsp.searchProviderFor(0); // should use ravv ordinal 2 (vector [-1, 0]) + var ssp1 = bsp.searchProviderFor(1); // should use ravv ordinal 0 (vector [1, 0]) + var ssp2 = bsp.searchProviderFor(2); // should use ravv ordinal 1 (vector [0, 1]) + + // Verify by computing similarity between graph nodes + // Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), ssp0.exactScoreFunction().similarityTo(1), 1e-6f); + + // Graph node 1 (vector 0:[1, 0]) vs graph node 0 (vector 2:[-1, 0]) + assertEquals(vsf.compare(vectors.get(0), vectors.get(2)), ssp1.exactScoreFunction().similarityTo(0), 1e-6f); + + // Graph node 2 (vector 1:[0, 1]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(1), vectors.get(0)), ssp2.exactScoreFunction().similarityTo(1), 1e-6f); + + // Test diversityProviderFor uses same mapping, Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + var dsp0 = bsp.diversityProviderFor(0); + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), dsp0.exactScoreFunction().similarityTo(1), 1e-6f); + } +} \ No newline at end of file From 279b5aae1e55cf042f1f9549cdf75b770a4c6a8c Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Tue, 7 Oct 2025 08:17:38 -0700 Subject: [PATCH 02/12] rebase fix Signed-off-by: Samuel Herman --- .../jvector/graph/GraphIndexBuilder.java | 54 +++++++++++++++++++ .../jvector/graph/OnHeapGraphIndex.java | 6 ++- .../jvector/graph/OnHeapGraphIndexTest.java | 14 ++--- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index f1bd45d5e..99a8c85fb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -957,4 +957,58 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { graph.updateEntryNode(new NodeAtLevel(0, entryNode)); graph.setDegrees(List.of(maxDegree)); } + + /** + * Convenience method to build a new graph from an existing one, with the addition of new nodes. + * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. + * + * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, + * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start + * @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids + * @param beamWidth the width of the beam used during the graph building process. + * @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * + * @return the in-memory representation of the graph index. + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + public static ImmutableGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskGraphIndex, + NeighborsScoreCache perLevelNeighborsScoreCache, + RandomAccessVectorValues newVectors, + BuildScoreProvider buildScoreProvider, + int startingNodeOffset, + int[] graphToRavvOrdMap, + int beamWidth, + float overflowRatio, + float alpha, + boolean addHierarchy) throws IOException { + + + + try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider, + onDiskGraphIndex, + perLevelNeighborsScoreCache, + beamWidth, + overflowRatio, + alpha, + addHierarchy, + true, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool())) { + + var vv = newVectors.threadLocalSupplier(); + + // parallel graph construction from the merge documents Ids + PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { + builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord])); + })).join(); + + builder.cleanup(); + return builder.getGraph(); + } + } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 3cef67cc6..576c19d78 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -303,6 +303,10 @@ public View getView() { } } + public FrozenView getFrozenView() { + return new FrozenView(); + } + public ThreadSafeGrowableBitSet getDeletedNodes() { return deletedNodes; } @@ -661,7 +665,7 @@ public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy(); // Add the node with its neighbors - heapIndex.addNode(level, nodeId, neighbors); + heapIndex.connectNode(level, nodeId, neighbors); heapIndex.markComplete(new NodeAtLevel(level, nodeId)); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java index b4efb543d..385588283 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -77,9 +77,9 @@ public class OnHeapGraphIndexTest extends RandomizedTest { private BuildScoreProvider baseBuildScoreProvider; private BuildScoreProvider newBuildScoreProvider; private BuildScoreProvider allBuildScoreProvider; - private OnHeapGraphIndex baseGraphIndex; - private OnHeapGraphIndex newGraphIndex; - private OnHeapGraphIndex allGraphIndex; + private ImmutableGraphIndex baseGraphIndex; + private ImmutableGraphIndex newGraphIndex; + private ImmutableGraphIndex allGraphIndex; @Before public void setup() throws IOException { @@ -149,7 +149,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); log.info("Writing neighbors score cache to {}", neighborsScoreCacheOutputPath); - final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache((OnHeapGraphIndex) baseGraphIndex); try (SimpleWriter writer = new SimpleWriter(neighborsScoreCacheOutputPath.toAbsolutePath())) { neighborsScoreCache.write(writer); } @@ -182,14 +182,14 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { public void testIncrementalInsertionFromOnDiskIndex() throws IOException { var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); log.info("Writing graph to {}", outputPath); - final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache(baseGraphIndex); + final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache((OnHeapGraphIndex) baseGraphIndex); TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); // We will create a trivial 1:1 mapping between the new graph and the ravv final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); - OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); // Verify that the recall is similar float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); @@ -231,7 +231,7 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat nodeScore.node).toArray(); } - private static float calculateRecall(OnHeapGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { + private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){ SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector); var searchResults = graphSearcher.search(ssp, k, Bits.ALL); From 8cdb9c94faafe1c56c9c6a8607186933d3d4bb0d Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Tue, 7 Oct 2025 08:45:38 -0700 Subject: [PATCH 03/12] switch interface for convert graph Signed-off-by: Samuel Herman --- .../jbellis/jvector/graph/OnHeapGraphIndex.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 576c19d78..9b1194b1c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -617,7 +617,7 @@ private void ensureCapacity(int node) { * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, * along with other configuration details, from disk-based storage to heap-based storage. * - * @param diskIndex the disk-based index to be converted + * @param immutableGraphIndex the disk-based index to be converted * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, * organized by levels and nodes. * @param bsp The build score provider to be used for @@ -626,7 +626,7 @@ private void ensureCapacity(int node) { * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory * @throws IOException if an I/O error occurs during the conversion process */ - public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, + public static OnHeapGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, BuildScoreProvider bsp, float overflowRatio, @@ -634,8 +634,8 @@ public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, // Create a new OnHeapGraphIndex with the appropriate configuration List maxDegrees = new ArrayList<>(); - for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { - maxDegrees.add(diskIndex.getDegree(level)); + for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { + maxDegrees.add(immutableGraphIndex.getDegree(level)); } OnHeapGraphIndex heapIndex = new OnHeapGraphIndex( @@ -645,10 +645,10 @@ public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex, ); // Copy all nodes and their connections from disk to heap - try (var view = diskIndex.getView()) { + try (var view = immutableGraphIndex.getView()) { // Copy nodes level by level - for (int level = 0; level <= diskIndex.getMaxLevel(); level++) { - final NodesIterator nodesIterator = diskIndex.getNodes(level); + for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { + final NodesIterator nodesIterator = immutableGraphIndex.getNodes(level); final Map levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level); if (levelNeighborsScoreCache == null) { throw new IllegalStateException("No neighbors score cache found for level " + level); From a670cc9009fbad2d5998cb2142fb72b65be184f2 Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Tue, 7 Oct 2025 08:56:14 -0700 Subject: [PATCH 04/12] remove explicit mentioning of OnDiskGraphIndex in builder Signed-off-by: Samuel Herman --- .../jbellis/jvector/graph/GraphIndexBuilder.java | 16 ++++++++-------- .../jvector/graph/ImmutableGraphIndex.java | 3 +++ .../jbellis/jvector/graph/OnHeapGraphIndex.java | 10 +++++++++- .../jvector/graph/disk/OnDiskGraphIndex.java | 1 + 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 99a8c85fb..9be45a7d6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -327,7 +327,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); + this.graph = new OnHeapGraphIndex(dimension, maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); gs.usePruning(false); @@ -346,7 +346,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object * * @param buildScoreProvider the provider responsible for calculating build scores. - * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param immutableGraphIndex the on-disk representation of the graph index to be processed and converted. * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, * organized by levels and nodes. * @param beamWidth the width of the beam used during the graph building process. @@ -359,10 +359,10 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ - public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex onDiskGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, ImmutableGraphIndex immutableGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { this.scoreProvider = buildScoreProvider; this.neighborOverflow = neighborOverflow; - this.dimension = onDiskGraphIndex.getDimension(); + this.dimension = immutableGraphIndex.getDimension(); this.alpha = alpha; this.addHierarchy = addHierarchy; this.refineFinalGraph = refineFinalGraph; @@ -370,7 +370,7 @@ public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = OnHeapGraphIndex.convertToHeap(onDiskGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha); + this.graph = OnHeapGraphIndex.convertToHeap(immutableGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha); this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); @@ -962,7 +962,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * Convenience method to build a new graph from an existing one, with the addition of new nodes. * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. * - * @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted. + * @param immutableGraphIndex the immutable (usually on-disk) representation of the graph index to be processed and converted. * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph * @param buildScoreProvider the provider responsible for calculating build scores. @@ -976,7 +976,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * @return the in-memory representation of the graph index. * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ - public static ImmutableGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskGraphIndex, + public static ImmutableGraphIndex buildAndMergeNewNodes(ImmutableGraphIndex immutableGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, RandomAccessVectorValues newVectors, BuildScoreProvider buildScoreProvider, @@ -990,7 +990,7 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskG try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider, - onDiskGraphIndex, + immutableGraphIndex, perLevelNeighborsScoreCache, beamWidth, overflowRatio, diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 088f9a1af..1b08876a4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -61,6 +61,9 @@ default int size() { */ NodesIterator getNodes(int level); + /** Return the dimension of the vectors in the graph */ + int getDimension(); + /** * Return a View with which to navigate the graph. Views are not threadsafe -- that is, * only one search at a time should be run per View. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 9b1194b1c..c79848f3f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -77,6 +77,7 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private final CompletionTracker completions; private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0); private final AtomicInteger maxNodeId = new AtomicInteger(-1); + private final int dimension; // Maximum number of neighbors (edges) per node per layer final List maxDegrees; @@ -86,9 +87,10 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private volatile boolean allMutationsCompleted = false; - OnHeapGraphIndex(List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { + OnHeapGraphIndex(int dimension, List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); + this.dimension = dimension; setDegrees(maxDegrees); entryPoint = new AtomicReference<>(); this.completions = new CompletionTracker(1024); @@ -235,6 +237,11 @@ public NodesIterator getNodes(int level) { layers.get(level).size()); } + @Override + public int getDimension() { + return dimension; + } + @Override public IntStream nodeStream(int level) { var layer = layers.get(level); @@ -639,6 +646,7 @@ public static OnHeapGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphI } OnHeapGraphIndex heapIndex = new OnHeapGraphIndex( + immutableGraphIndex.getDimension(), maxDegrees, overflowRatio, // overflow ratio new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index a597aa78f..7365be63e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -225,6 +225,7 @@ public Set getFeatureSet() { return features.keySet(); } + @Override public int getDimension() { return dimension; } From 4cb1353c8b2dd13feec6383bde864b839bd54114 Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Tue, 7 Oct 2025 09:05:06 -0700 Subject: [PATCH 05/12] add dimension Signed-off-by: Samuel Herman --- .../test/java/io/github/jbellis/jvector/TestUtil.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 21de0fede..127106bf4 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -280,6 +280,11 @@ public NodesIterator getNodes(int level) { return new NodesIterator.ArrayNodesIterator(IntStream.range(0, n).toArray(), n); } + @Override + public int getDimension() { + throw new NotImplementedException(); + } + @Override public View getView() { return new FullyConnectedGraphIndexView(); @@ -399,6 +404,11 @@ public NodesIterator getNodes(int level) { return new NodesIterator.ArrayNodesIterator(IntStream.range(0, sz).toArray(), sz); } + @Override + public int getDimension() { + throw new NotImplementedException(); + } + @Override public View getView() { return new RandomlyConnectedGraphIndexView(); From c1a45c6af16624c312b76aed6c1aee3202e7206c Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Tue, 7 Oct 2025 16:03:46 -0700 Subject: [PATCH 06/12] Remove NeighborsCache. Refactoring to respect interface boundaries --- .../jvector/graph/ConcurrentNeighborMap.java | 6 +- .../jvector/graph/GraphIndexBuilder.java | 78 ++++++------ .../jvector/graph/GraphIndexConverter.java | 73 +++++++++++ .../jvector/graph/ImmutableGraphIndex.java | 3 - .../jvector/graph/OnHeapGraphIndex.java | 90 +------------- .../graph/disk/NeighborsScoreCache.java | 117 ------------------ .../jvector/graph/disk/OnDiskGraphIndex.java | 5 - .../graph/similarity/BuildScoreProvider.java | 1 + .../io/github/jbellis/jvector/TestUtil.java | 10 -- .../jvector/graph/OnHeapGraphIndexTest.java | 26 +--- 10 files changed, 120 insertions(+), 289 deletions(-) create mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java delete mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index ad6d137a0..891fda756 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) { } } - public static class NeighborIterator implements NodesIterator { + private static class NeighborIterator implements NodesIterator { private final NodeArray neighbors; private int i; @@ -374,9 +374,5 @@ public boolean hasNext() { public int nextInt() { return neighbors.getNode(i++); } - - public NodeArray merge(NodeArray other) { - return NodeArray.merge(neighbors, other); - } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9be45a7d6..aeca2e522 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -20,8 +20,6 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.SearchResult.NodeScore; -import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; -import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -32,7 +30,6 @@ import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import org.agrona.collections.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -327,7 +324,8 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = new OnHeapGraphIndex(dimension, maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); + this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); + this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); gs.usePruning(false); @@ -346,9 +344,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object * * @param buildScoreProvider the provider responsible for calculating build scores. - * @param immutableGraphIndex the on-disk representation of the graph index to be processed and converted. - * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, - * organized by levels and nodes. + * @param mutableGraphIndex a mutable graph index. * @param beamWidth the width of the beam used during the graph building process. * @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit. * @param alpha the weight factor for balancing score computations. @@ -359,10 +355,20 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ - public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, ImmutableGraphIndex immutableGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + if (neighborOverflow < 1.0f) { + throw new IllegalArgumentException("neighborOverflow must be >= 1.0"); + } + if (alpha <= 0) { + throw new IllegalArgumentException("alpha must be positive"); + } + this.scoreProvider = buildScoreProvider; this.neighborOverflow = neighborOverflow; - this.dimension = immutableGraphIndex.getDimension(); + this.dimension = dimension; this.alpha = alpha; this.addHierarchy = addHierarchy; this.refineFinalGraph = refineFinalGraph; @@ -370,7 +376,7 @@ public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, ImmutableGraphIn this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = OnHeapGraphIndex.convertToHeap(immutableGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha); + this.graph = mutableGraphIndex; this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); @@ -497,7 +503,7 @@ public void cleanup() { // clean up overflowed neighbor lists parallelExecutor.submit(() -> { IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> { - for (int layer = 0; layer <= graph.getMaxLevel(); layer++) { + for (int level = 0; level <= graph.getMaxLevel(); level++) { graph.enforceDegree(id); } }); @@ -797,7 +803,7 @@ public synchronized long removeDeletedNodes() { return memorySize; } - private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { + private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray concurrent) { // if either natural or concurrent is empty, skip the merge NodeArray toMerge; if (concurrent.size() == 0) { @@ -808,7 +814,7 @@ private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray toMerge = NodeArray.merge(natural, concurrent); } // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones - graph.addEdges(layer, nodeId, toMerge, neighborOverflow); + graph.addEdges(level, nodeId, toMerge, neighborOverflow); } private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) { @@ -923,7 +929,6 @@ private void loadV4(RandomAccessReader in) throws IOException { graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); } - @Deprecated private void loadV3(RandomAccessReader in, int size) throws IOException { if (graph.size() != 0) { @@ -963,7 +968,6 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. * * @param immutableGraphIndex the immutable (usually on-disk) representation of the graph index to be processed and converted. - * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph * @param buildScoreProvider the provider responsible for calculating build scores. * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start @@ -977,28 +981,28 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ public static ImmutableGraphIndex buildAndMergeNewNodes(ImmutableGraphIndex immutableGraphIndex, - NeighborsScoreCache perLevelNeighborsScoreCache, - RandomAccessVectorValues newVectors, - BuildScoreProvider buildScoreProvider, - int startingNodeOffset, - int[] graphToRavvOrdMap, - int beamWidth, - float overflowRatio, - float alpha, - boolean addHierarchy) throws IOException { - - - - try (GraphIndexBuilder builder = new GraphIndexBuilder(buildScoreProvider, - immutableGraphIndex, - perLevelNeighborsScoreCache, - beamWidth, - overflowRatio, - alpha, - addHierarchy, - true, - PhysicalCoreExecutor.pool(), - ForkJoinPool.commonPool())) { + RandomAccessVectorValues newVectors, + BuildScoreProvider buildScoreProvider, + int startingNodeOffset, + int[] graphToRavvOrdMap, + int beamWidth, + float overflowRatio, + float alpha, + boolean addHierarchy) throws IOException { + + try (var graph = GraphIndexConverter.convertToHeap(immutableGraphIndex, buildScoreProvider, overflowRatio, alpha)) { + GraphIndexBuilder builder = new GraphIndexBuilder( + buildScoreProvider, + newVectors.dimension(), + graph, + beamWidth, + overflowRatio, + alpha, + addHierarchy, + true, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool() + ); var vv = newVectors.threadLocalSupplier(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java new file mode 100644 index 000000000..684db71b2 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java @@ -0,0 +1,73 @@ +package io.github.jbellis.jvector.graph; + +import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.lang.Math.max; + +class GraphIndexConverter { + /** + * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, + * along with other configuration details, from disk-based storage to heap-based storage. + * + * @param immutableGraphIndex the disk-based index to be converted + * @param bsp The build score provider to be used for + * @param overflowRatio usually 1.2f + * @param alpha usually 1.2f + * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory + * @throws IOException if an I/O error occurs during the conversion process + */ + public static MutableGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphIndex, + BuildScoreProvider bsp, + float overflowRatio, + float alpha) throws IOException { + + // Create a new OnHeapGraphIndex with the appropriate configuration + List maxDegrees = new ArrayList<>(); + for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { + maxDegrees.add(immutableGraphIndex.getDegree(level)); + } + + MutableGraphIndex index = new OnHeapGraphIndex( + maxDegrees, + overflowRatio, // overflow ratio + new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage + ); + + // Copy all nodes and their connections from disk to heap + try (var view = immutableGraphIndex.getView()) { + // Copy nodes level by level + for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { + final NodesIterator nodesIterator = immutableGraphIndex.getNodes(level); + + while (nodesIterator.hasNext()) { + int nodeId = nodesIterator.next(); + + var sf = bsp.searchProviderFor(nodeId).scoreFunction(); + + var neighborsIterator = view.getNeighborsIterator(level, nodeId); + + NodeArray nodeArray = new NodeArray(neighborsIterator.size()); + while(neighborsIterator.hasNext()) { + int neighbor = neighborsIterator.nextInt(); + float score = sf.similarityTo(neighbor); + nodeArray.addInOrder(neighbor, score); + } + + // Add the node with its neighbors + index.connectNode(level, nodeId, nodeArray); + index.markComplete(new ImmutableGraphIndex.NodeAtLevel(level, nodeId)); + } + } + + // Set the entry point + index.updateEntryNode(view.entryNode()); + } + + return index; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 1b08876a4..088f9a1af 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -61,9 +61,6 @@ default int size() { */ NodesIterator getNodes(int level); - /** Return the dimension of the vectors in the graph */ - int getDimension(); - /** * Return a View with which to navigate the graph. Views are not threadsafe -- that is, * only one search at a time should be run per View. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index c79848f3f..7ddbf7897 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -25,8 +25,6 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; -import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; -import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.BitSet; @@ -35,11 +33,6 @@ import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.util.SparseIntMap; import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; -import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; -import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; -import io.github.jbellis.jvector.util.*; -import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; import java.io.DataOutput; @@ -47,10 +40,7 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.NoSuchElementException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; @@ -77,7 +67,6 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private final CompletionTracker completions; private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0); private final AtomicInteger maxNodeId = new AtomicInteger(-1); - private final int dimension; // Maximum number of neighbors (edges) per node per layer final List maxDegrees; @@ -87,10 +76,9 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private volatile boolean allMutationsCompleted = false; - OnHeapGraphIndex(int dimension, List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { + OnHeapGraphIndex(List maxDegrees, double overflowRatio, DiversityProvider diversityProvider) { this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); - this.dimension = dimension; setDegrees(maxDegrees); entryPoint = new AtomicReference<>(); this.completions = new CompletionTracker(1024); @@ -237,11 +225,6 @@ public NodesIterator getNodes(int level) { layers.get(level).size()); } - @Override - public int getDimension() { - return dimension; - } - @Override public IntStream nodeStream(int level) { var layer = layers.get(level); @@ -310,10 +293,6 @@ public View getView() { } } - public FrozenView getFrozenView() { - return new FrozenView(); - } - public ThreadSafeGrowableBitSet getDeletedNodes() { return deletedNodes; } @@ -462,7 +441,7 @@ public boolean hasNext() { } } - public class FrozenView implements View { + private class FrozenView implements View { @Override public NodesIterator getNeighborsIterator(int level, int node) { return OnHeapGraphIndex.this.getNeighborsIterator(level, node); @@ -619,69 +598,4 @@ private void ensureCapacity(int node) { } } } - - /** - * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, - * along with other configuration details, from disk-based storage to heap-based storage. - * - * @param immutableGraphIndex the disk-based index to be converted - * @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores, - * organized by levels and nodes. - * @param bsp The build score provider to be used for - * @param overflowRatio usually 1.2f - * @param alpha usually 1.2f - * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory - * @throws IOException if an I/O error occurs during the conversion process - */ - public static OnHeapGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphIndex, - NeighborsScoreCache perLevelNeighborsScoreCache, - BuildScoreProvider bsp, - float overflowRatio, - float alpha) throws IOException { - - // Create a new OnHeapGraphIndex with the appropriate configuration - List maxDegrees = new ArrayList<>(); - for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { - maxDegrees.add(immutableGraphIndex.getDegree(level)); - } - - OnHeapGraphIndex heapIndex = new OnHeapGraphIndex( - immutableGraphIndex.getDimension(), - maxDegrees, - overflowRatio, // overflow ratio - new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage - ); - - // Copy all nodes and their connections from disk to heap - try (var view = immutableGraphIndex.getView()) { - // Copy nodes level by level - for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { - final NodesIterator nodesIterator = immutableGraphIndex.getNodes(level); - final Map levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level); - if (levelNeighborsScoreCache == null) { - throw new IllegalStateException("No neighbors score cache found for level " + level); - } - if (nodesIterator.size() != levelNeighborsScoreCache.size()) { - throw new IllegalStateException("Neighbors score cache size mismatch for level " + level + - ". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size()); - } - - while (nodesIterator.hasNext()) { - int nodeId = nodesIterator.next(); - - // Copy neighbors - final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy(); - - // Add the node with its neighbors - heapIndex.connectNode(level, nodeId, neighbors); - heapIndex.markComplete(new NodeAtLevel(level, nodeId)); - } - } - - // Set the entry point - heapIndex.updateEntryNode(view.entryNode()); - } - - return heapIndex; - } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java deleted file mode 100644 index 55fdcf082..000000000 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NeighborsScoreCache.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.graph.disk; - -import io.github.jbellis.jvector.disk.IndexWriter; -import io.github.jbellis.jvector.disk.RandomAccessReader; -import io.github.jbellis.jvector.graph.*; -import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -/** - * Cache containing pre-computed neighbor scores, organized by levels and nodes. - *

- * This cache bridges the gap between {@link OnDiskGraphIndex} and {@link OnHeapGraphIndex}: - *

    - *
  • {@link OnDiskGraphIndex} stores only neighbor IDs (not scores) for space efficiency
  • - *
  • {@link OnHeapGraphIndex} requires neighbor scores for pruning operations
  • - *
- *

- * When converting from disk to heap representation, this cache avoids expensive score - * recomputation by providing pre-calculated neighbor scores for all graph levels. - * - * @see OnHeapGraphIndex#convertToHeap(OnDiskGraphIndex, NeighborsScoreCache, BuildScoreProvider, float, float) - * - * This is particularly useful when merging new nodes into an existing graph. - * @see GraphIndexBuilder#buildAndMergeNewNodes(OnDiskGraphIndex, NeighborsScoreCache, RandomAccessVectorValues, BuildScoreProvider, int, int[], int, float, float, boolean) - */ -public class NeighborsScoreCache { - private final Map> perLevelNeighborsScoreCache; - - public NeighborsScoreCache(OnHeapGraphIndex graphIndex) throws IOException { - try (OnHeapGraphIndex.FrozenView view = graphIndex.getFrozenView()) { - final Map> perLevelNeighborsScoreCache = new HashMap<>(graphIndex.getMaxLevel() + 1); - for (int level = 0; level <= graphIndex.getMaxLevel(); level++) { - final Map levelNeighborsScores = new HashMap<>(graphIndex.size(level) + 1); - final NodesIterator nodesIterator = graphIndex.getNodes(level); - while (nodesIterator.hasNext()) { - final int nodeId = nodesIterator.nextInt(); - - ConcurrentNeighborMap.NeighborIterator neighborIterator = (ConcurrentNeighborMap.NeighborIterator) view.getNeighborsIterator(level, nodeId); - final NodeArray neighbours = neighborIterator.merge(new NodeArray(neighborIterator.size())); - levelNeighborsScores.put(nodeId, neighbours); - } - - perLevelNeighborsScoreCache.put(level, levelNeighborsScores); - } - - this.perLevelNeighborsScoreCache = perLevelNeighborsScoreCache; - } - } - - public NeighborsScoreCache(RandomAccessReader in) throws IOException { - final int numberOfLevels = in.readInt(); - perLevelNeighborsScoreCache = new HashMap<>(numberOfLevels); - for (int i = 0; i < numberOfLevels; i++) { - final int level = in.readInt(); - final int numberOfNodesInLevel = in.readInt(); - final Map levelNeighborsScores = new HashMap<>(numberOfNodesInLevel); - for (int j = 0; j < numberOfNodesInLevel; j++) { - final int nodeId = in.readInt(); - final int numberOfNeighbors = in.readInt(); - final NodeArray nodeArray = new NodeArray(numberOfNeighbors); - for (int k = 0; k < numberOfNeighbors; k++) { - final int neighborNodeId = in.readInt(); - final float neighborScore = in.readFloat(); - nodeArray.insertSorted(neighborNodeId, neighborScore); - } - levelNeighborsScores.put(nodeId, nodeArray); - } - perLevelNeighborsScoreCache.put(level, levelNeighborsScores); - } - } - - public void write(IndexWriter out) throws IOException { - out.writeInt(perLevelNeighborsScoreCache.size()); // write the number of levels - for (Map.Entry> levelNeighborsScores : perLevelNeighborsScoreCache.entrySet()) { - final int level = levelNeighborsScores.getKey(); - out.writeInt(level); - out.writeInt(levelNeighborsScores.getValue().size()); // write the number of nodes in the level - // Write the neighborhoods for each node in the level - for (Map.Entry nodeArrayEntry : levelNeighborsScores.getValue().entrySet()) { - final int nodeId = nodeArrayEntry.getKey(); - out.writeInt(nodeId); - final NodeArray nodeArray = nodeArrayEntry.getValue(); - out.writeInt(nodeArray.size()); // write the number of neighbors for the node - // Write the nodeArray(neighbors) - for (int i = 0; i < nodeArray.size(); i++) { - out.writeInt(nodeArray.getNode(i)); - out.writeFloat(nodeArray.getScore(i)); - } - } - } - } - - public Map getNeighborsScoresInLevel(int level) { - return perLevelNeighborsScoreCache.get(level); - } - - -} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 7365be63e..161fb0f07 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -225,11 +225,6 @@ public Set getFeatureSet() { return features.keySet(); } - @Override - public int getDimension() { - return dimension; - } - @Override public int size(int level) { return layerInfo.get(level).size; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index 4656a4fcc..0ffdf72eb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -94,6 +94,7 @@ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues rav /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * graphToRavvOrdMap maps graph node IDs to ravv ordinals. */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 127106bf4..21de0fede 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -280,11 +280,6 @@ public NodesIterator getNodes(int level) { return new NodesIterator.ArrayNodesIterator(IntStream.range(0, n).toArray(), n); } - @Override - public int getDimension() { - throw new NotImplementedException(); - } - @Override public View getView() { return new FullyConnectedGraphIndexView(); @@ -404,11 +399,6 @@ public NodesIterator getNodes(int level) { return new NodesIterator.ArrayNodesIterator(IntStream.range(0, sz).toArray(), sz); } - @Override - public int getDimension() { - throw new NotImplementedException(); - } - @Override public View getView() { return new RandomlyConnectedGraphIndexView(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java index 385588283..91bac66d4 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -20,8 +20,6 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; -import io.github.jbellis.jvector.disk.SimpleWriter; -import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; @@ -46,7 +44,6 @@ import java.util.stream.IntStream; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class OnHeapGraphIndexTest extends RandomizedTest { @@ -72,13 +69,10 @@ public class OnHeapGraphIndexTest extends RandomizedTest { private RandomAccessVectorValues newVectorsRavv; private RandomAccessVectorValues allVectorsRavv; private VectorFloat queryVector; - private int[] groundTruthBaseVectors; private int[] groundTruthAllVectors; private BuildScoreProvider baseBuildScoreProvider; - private BuildScoreProvider newBuildScoreProvider; private BuildScoreProvider allBuildScoreProvider; private ImmutableGraphIndex baseGraphIndex; - private ImmutableGraphIndex newGraphIndex; private ImmutableGraphIndex allGraphIndex; @Before @@ -104,12 +98,10 @@ public void setup() throws IOException { allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION); queryVector = createRandomVector(DIMENSION); - groundTruthBaseVectors = getGroundTruth(baseVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); // score provider using the raw, in-memory vectors baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); - newBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(newVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider, baseVectorsRavv.dimension(), @@ -144,22 +136,9 @@ public void tearDown() { @Test public void testReconstructionOfOnHeapGraphIndex() throws IOException { var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); - var neighborsScoreCacheOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + NeighborsScoreCache.class.getSimpleName()); log.info("Writing graph to {}", graphOutputPath); TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); - log.info("Writing neighbors score cache to {}", neighborsScoreCacheOutputPath); - final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache((OnHeapGraphIndex) baseGraphIndex); - try (SimpleWriter writer = new SimpleWriter(neighborsScoreCacheOutputPath.toAbsolutePath())) { - neighborsScoreCache.write(writer); - } - - log.info("Reading neighbors score cache from {}", neighborsScoreCacheOutputPath); - final NeighborsScoreCache neighborsScoreCacheRead; - try (var readerSupplier = new SimpleMappedReader.Supplier(neighborsScoreCacheOutputPath.toAbsolutePath())) { - neighborsScoreCacheRead = new NeighborsScoreCache(readerSupplier.get()); - } - try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath()); var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); @@ -167,7 +146,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { validateVectors(onDiskView, baseVectorsRavv); } - OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.convertToHeap(onDiskGraph, neighborsScoreCacheRead, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA); + MutableGraphIndex reconstructedOnHeapGraphIndex = GraphIndexConverter.convertToHeap(onDiskGraph, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA); TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex); @@ -182,14 +161,13 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { public void testIncrementalInsertionFromOnDiskIndex() throws IOException { var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); log.info("Writing graph to {}", outputPath); - final NeighborsScoreCache neighborsScoreCache = new NeighborsScoreCache((OnHeapGraphIndex) baseGraphIndex); TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); // We will create a trivial 1:1 mapping between the new graph and the ravv final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); - ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); // Verify that the recall is similar float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); From 3764116ae6b3003548fb26caffbb4a2117c89101 Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 8 Oct 2025 14:18:49 -0700 Subject: [PATCH 07/12] Implementation that serializes/deserializes the OnHeapGraphIndex --- .../jvector/graph/GraphIndexBuilder.java | 16 ++- .../jvector/graph/GraphIndexConverter.java | 73 ---------- .../jvector/graph/MutableGraphIndex.java | 8 +- .../jvector/graph/OnHeapGraphIndex.java | 135 +++++++++++++----- .../jvector/graph/OnHeapGraphIndexTest.java | 32 ++++- 5 files changed, 142 insertions(+), 122 deletions(-) delete mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index aeca2e522..2158034a3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -355,7 +355,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ - private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException { + private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) { if (beamWidth <= 0) { throw new IllegalArgumentException("beamWidth must be positive"); } @@ -509,7 +509,7 @@ public void cleanup() { }); }).join(); - graph.allMutationsCompleted(); + graph.setAllMutationsCompleted(); } private void improveConnections(int node) { @@ -878,6 +878,9 @@ public void load(RandomAccessReader in) throws IOException { loadV3(in, size); } else { version = in.readInt(); + if (version != 4) { + throw new IOException("Unsupported version: " + version); + } loadV4(in); } } @@ -967,7 +970,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * Convenience method to build a new graph from an existing one, with the addition of new nodes. * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. * - * @param immutableGraphIndex the immutable (usually on-disk) representation of the graph index to be processed and converted. + * @param in a reader from which to read the on-heap graph. * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph * @param buildScoreProvider the provider responsible for calculating build scores. * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start @@ -980,7 +983,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * @return the in-memory representation of the graph index. * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ - public static ImmutableGraphIndex buildAndMergeNewNodes(ImmutableGraphIndex immutableGraphIndex, + public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, RandomAccessVectorValues newVectors, BuildScoreProvider buildScoreProvider, int startingNodeOffset, @@ -990,7 +993,10 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(ImmutableGraphIndex immu float alpha, boolean addHierarchy) throws IOException { - try (var graph = GraphIndexConverter.convertToHeap(immutableGraphIndex, buildScoreProvider, overflowRatio, alpha)) { + var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha); + + try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) { + GraphIndexBuilder builder = new GraphIndexBuilder( buildScoreProvider, newVectors.dimension(), diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java deleted file mode 100644 index 684db71b2..000000000 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexConverter.java +++ /dev/null @@ -1,73 +0,0 @@ -package io.github.jbellis.jvector.graph; - -import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; -import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static java.lang.Math.max; - -class GraphIndexConverter { - /** - * Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors, - * along with other configuration details, from disk-based storage to heap-based storage. - * - * @param immutableGraphIndex the disk-based index to be converted - * @param bsp The build score provider to be used for - * @param overflowRatio usually 1.2f - * @param alpha usually 1.2f - * @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory - * @throws IOException if an I/O error occurs during the conversion process - */ - public static MutableGraphIndex convertToHeap(ImmutableGraphIndex immutableGraphIndex, - BuildScoreProvider bsp, - float overflowRatio, - float alpha) throws IOException { - - // Create a new OnHeapGraphIndex with the appropriate configuration - List maxDegrees = new ArrayList<>(); - for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { - maxDegrees.add(immutableGraphIndex.getDegree(level)); - } - - MutableGraphIndex index = new OnHeapGraphIndex( - maxDegrees, - overflowRatio, // overflow ratio - new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage - ); - - // Copy all nodes and their connections from disk to heap - try (var view = immutableGraphIndex.getView()) { - // Copy nodes level by level - for (int level = 0; level <= immutableGraphIndex.getMaxLevel(); level++) { - final NodesIterator nodesIterator = immutableGraphIndex.getNodes(level); - - while (nodesIterator.hasNext()) { - int nodeId = nodesIterator.next(); - - var sf = bsp.searchProviderFor(nodeId).scoreFunction(); - - var neighborsIterator = view.getNeighborsIterator(level, nodeId); - - NodeArray nodeArray = new NodeArray(neighborsIterator.size()); - while(neighborsIterator.hasNext()) { - int neighbor = neighborsIterator.nextInt(); - float score = sf.similarityTo(neighbor); - nodeArray.addInOrder(neighbor, score); - } - - // Add the node with its neighbors - index.connectNode(level, nodeId, nodeArray); - index.markComplete(new ImmutableGraphIndex.NodeAtLevel(level, nodeId)); - } - } - - // Set the entry point - index.updateEntryNode(view.entryNode()); - } - - return index; - } -} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java index 2e88e6dd4..1e30fcd2a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java @@ -166,5 +166,11 @@ interface MutableGraphIndex extends ImmutableGraphIndex { * Signals that all mutations have been completed and the graph will not be mutated any further. * Should be called by the builder after all mutations are completed (during cleanup). */ - void allMutationsCompleted(); + void setAllMutationsCompleted(); + + /** + * Signals that all mutations have been completed and the graph will not be mutated any further. + * Should be called by the builder after all mutations are completed (during cleanup). + */ + boolean allMutationsCompleted(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 7ddbf7897..976aa1bb7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -24,6 +24,7 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; import io.github.jbellis.jvector.util.Accountable; @@ -37,9 +38,10 @@ import java.io.DataOutput; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; @@ -367,10 +369,14 @@ public void setDegrees(List layerDegrees) { } @Override - public void allMutationsCompleted() { + public void setAllMutationsCompleted() { allMutationsCompleted = true; } + @Override + public boolean allMutationsCompleted() { + return allMutationsCompleted; + } /** * A concurrent View of the graph that is safe to search concurrently with updates and with other @@ -491,43 +497,98 @@ public String toString() { * Saves the graph to the given DataOutput for reloading into memory later */ @Deprecated - public void save(DataOutput out) { - if (deletedNodes.cardinality() > 0) { - throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); - } - - try (var view = getView()) { - out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number - out.writeInt(4); // The version - - // Write graph-level properties. - out.writeInt(layers.size()); - assert view.entryNode().level == getMaxLevel(); - out.writeInt(view.entryNode().node); - - for (int level = 0; level < layers.size(); level++) { - out.writeInt(size(level)); - out.writeInt(getDegree(level)); - - // Save neighbors from the layer. - var baseLayer = layers.get(level); - baseLayer.forEach((nodeId, neighbors) -> { - try { - NodesIterator iterator = neighbors.iterator(); - out.writeInt(nodeId); - out.writeInt(iterator.size()); - for (int n = 0; n < iterator.size(); n++) { - out.writeInt(iterator.nextInt()); - } - assert !iterator.hasNext(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); + public void save(DataOutput out) throws IOException { + if (allMutationsCompleted()) { + throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first"); + } + + out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number + out.writeInt(4); // The version + + // Write graph-level properties. + out.writeInt(layers.size()); + for (int level = 0; level < layers.size(); level++) { + out.writeInt(getDegree(level)); + } + + var entryNode = entryPoint.get(); + assert entryNode.level == getMaxLevel(); + out.writeInt(entryNode.node); + + for (int level = 0; level < layers.size(); level++) { + out.writeInt(size(level)); + + // Save neighbors from the layer. + var it = nodeStream(level).iterator(); + while (it.hasNext()) { + int nodeId = it.nextInt(); + var neighbors = layers.get(level).get(nodeId); + out.writeInt(nodeId); + out.writeInt(neighbors.size()); + + for (int n = 0; n < neighbors.size(); n++) { + out.writeInt(neighbors.getNode(n)); + out.writeFloat(neighbors.getScore(n)); + } + } + } + } + + /** + * Saves the graph to the given DataOutput for reloading into memory later + */ + @Deprecated + public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException { + int magic = in.readInt(); // the magic number + if (magic != OnHeapGraphIndex.MAGIC) { + throw new IOException("Unsupported magic number: " + magic); + } + + int version = in.readInt(); // The version + if (version != 4) { + throw new IOException("Unsupported version: " + version); + } + + // Write graph-level properties. + int layerCount = in.readInt(); + var layerDegrees = new ArrayList(layerCount); + for (int level = 0; level < layerCount; level++) { + layerDegrees.add(in.readInt()); + } + + int entryNode = in.readInt(); + + var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider); + + Map nodeLevelMap = new HashMap<>(); + + for (int level = 0; level < layerCount; level++) { + int layerSize = in.readInt(); + + for (int i = 0; i < layerSize; i++) { + int nodeId = in.readInt(); + int nNeighbors = in.readInt(); + + var ca = new NodeArray(nNeighbors); + for (int j = 0; j < nNeighbors; j++) { + int neighbor = in.readInt(); + float score = in.readFloat(); + ca.addInOrder(neighbor, score); + } + graph.connectNode(level, nodeId, ca); + nodeLevelMap.put(nodeId, level); } - } catch (IOException e) { - throw new UncheckedIOException(e); } + + for (var k : nodeLevelMap.keySet()) { + NodeAtLevel nal = new NodeAtLevel(nodeLevelMap.get(k), k); + graph.markComplete(nal); + } + + graph.setDegrees(layerDegrees); + graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + + return graph; } /** diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java index 91bac66d4..c2cf9fc90 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -20,7 +20,9 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.util.Bits; @@ -136,9 +138,22 @@ public void tearDown() { @Test public void testReconstructionOfOnHeapGraphIndex() throws IOException { var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); + log.info("Writing graph to {}", graphOutputPath); TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndex).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + MutableGraphIndex reconstructedOnHeapGraphIndex; + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { + reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA)); + } + try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath()); var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); @@ -146,10 +161,8 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { validateVectors(onDiskView, baseVectorsRavv); } - MutableGraphIndex reconstructedOnHeapGraphIndex = GraphIndexConverter.convertToHeap(onDiskGraph, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA); TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex); - } } @@ -160,14 +173,21 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { @Test public void testIncrementalInsertionFromOnDiskIndex() throws IOException { var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); + log.info("Writing graph to {}", outputPath); TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); - try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); - var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { - TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndex).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { // We will create a trivial 1:1 mapping between the new graph and the ravv final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); - ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(readerSupplier.get(), allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); // Verify that the recall is similar float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); From 093691ededc5cbe2858501c54a70766635305f7f Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Wed, 8 Oct 2025 14:23:27 -0700 Subject: [PATCH 08/12] Label OnHeapGraphIndex.save and OnHeapGraphIndex.load as experimental. --- .../io/github/jbellis/jvector/graph/OnHeapGraphIndex.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 976aa1bb7..9d95ce098 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -24,6 +24,7 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.annotations.Experimental; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; @@ -496,7 +497,7 @@ public String toString() { /** * Saves the graph to the given DataOutput for reloading into memory later */ - @Deprecated + @Experimental public void save(DataOutput out) throws IOException { if (allMutationsCompleted()) { throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first"); @@ -537,7 +538,7 @@ public void save(DataOutput out) throws IOException { /** * Saves the graph to the given DataOutput for reloading into memory later */ - @Deprecated + @Experimental public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException { int magic = in.readInt(); // the magic number if (magic != OnHeapGraphIndex.MAGIC) { From 572066040b6ff825b09828dcb310b2b3271361d8 Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Thu, 9 Oct 2025 08:04:45 -0700 Subject: [PATCH 09/12] Add experimental tag to GraphIndexBuilder.buildAndMergeNewNodes --- .../java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 2158034a3..0d94d6af2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.annotations.Experimental; import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; @@ -983,6 +984,7 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * @return the in-memory representation of the graph index. * @throws IOException if an I/O error occurs during the graph loading or conversion process. */ + @Experimental public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, RandomAccessVectorValues newVectors, BuildScoreProvider buildScoreProvider, From 1e3847f7bec1bd90264d2e63139a220b9b1925b4 Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Thu, 9 Oct 2025 08:46:26 -0700 Subject: [PATCH 10/12] Bug fixes to make tests pass --- .../github/jbellis/jvector/graph/GraphIndexBuilder.java | 8 ++++++-- .../io/github/jbellis/jvector/graph/OnHeapGraphIndex.java | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 0d94d6af2..a43a31ed7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -893,15 +893,18 @@ private void loadV4(RandomAccessReader in) throws IOException { } int layerCount = in.readInt(); - int entryNode = in.readInt(); var layerDegrees = new ArrayList(layerCount); + for (int level = 0; level < layerCount; level++) { + layerDegrees.add(in.readInt()); + } + + int entryNode = in.readInt(); Map nodeLevelMap = new HashMap<>(); // Read layer info for (int level = 0; level < layerCount; level++) { int layerSize = in.readInt(); - layerDegrees.add(in.readInt()); for (int i = 0; i < layerSize; i++) { int nodeId = in.readInt(); int nNeighbors = in.readInt(); @@ -917,6 +920,7 @@ private void loadV4(RandomAccessReader in) throws IOException { var ca = new NodeArray(nNeighbors); for (int j = 0; j < nNeighbors; j++) { int neighbor = in.readInt(); + float score = in.readFloat(); ca.addInOrder(neighbor, sf.similarityTo(neighbor)); } graph.connectNode(level, nodeId, ca); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 9d95ce098..6142adef3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -499,7 +499,7 @@ public String toString() { */ @Experimental public void save(DataOutput out) throws IOException { - if (allMutationsCompleted()) { + if (!allMutationsCompleted()) { throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first"); } From 819f1a9cccc4e9e67fcaff765a15ae33a4ac169f Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Thu, 9 Oct 2025 10:13:27 -0700 Subject: [PATCH 11/12] Added deprecated tags to OnHeapGraphIndex.load and OnHeapGraphIndex.save --- .../java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 6142adef3..29999bfde 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -498,6 +498,7 @@ public String toString() { * Saves the graph to the given DataOutput for reloading into memory later */ @Experimental + @Deprecated public void save(DataOutput out) throws IOException { if (!allMutationsCompleted()) { throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first"); @@ -539,6 +540,7 @@ public void save(DataOutput out) throws IOException { * Saves the graph to the given DataOutput for reloading into memory later */ @Experimental + @Deprecated public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException { int magic = in.readInt(); // the magic number if (magic != OnHeapGraphIndex.MAGIC) { From 89c5bf31de6d01cb77685e9715dcc9d1287c1f96 Mon Sep 17 00:00:00 2001 From: Mariano Tepper Date: Fri, 10 Oct 2025 10:58:49 -0700 Subject: [PATCH 12/12] Fix documentation of MutableGraphIndex.allMutationsCompleted --- .../io/github/jbellis/jvector/graph/MutableGraphIndex.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java index 1e30fcd2a..36ec49a16 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java @@ -169,8 +169,7 @@ interface MutableGraphIndex extends ImmutableGraphIndex { void setAllMutationsCompleted(); /** - * Signals that all mutations have been completed and the graph will not be mutated any further. - * Should be called by the builder after all mutations are completed (during cleanup). + * Returns true if all mutations have been completed. This is signaled by calling setAllMutationsCompleted. */ boolean allMutationsCompleted(); }