Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
import io.github.jbellis.jvector.util.AbstractLongHeap;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import io.github.jbellis.jvector.util.NumericUtils;
import java.util.PrimitiveIterator;
import org.agrona.collections.Int2ObjectHashMap;

import java.util.Arrays;

import static java.lang.Math.min;

/**
Expand Down Expand Up @@ -90,6 +89,16 @@ public boolean push(int newNode, float newScore) {
return heap.push(encode(newNode, newScore));
}

/**
* Encodes then adds all elements from the given iterator to this heap, in bulk.
*
* @param nodeScoreIterator the node and score pairs to add
* @param count the number of elements to add
*/
public void pushAll(NodeScoreIterator nodeScoreIterator, int count) {
heap.pushAll(new NodeScoreIteratorConverter(nodeScoreIterator, this), count);
}

/**
* Encodes the node ID and its similarity score as long. If two scores are equals,
* the smaller node ID wins.
Expand Down Expand Up @@ -260,6 +269,18 @@ public interface NodeConsumer {
void accept(int node, float score);
}

/** Iterator over node and score pairs. */
public interface NodeScoreIterator {
/** @return true if there are more elements */
boolean hasNext();

/** @return the next node id */
int nextNode();

/** @return the next node score and advance the iterator */
float nextScore();
}

/**
* Copies the other NodeQueue to this one. If its order (MIN_HEAP or MAX_HEAP) is the same as this,
* it is copied verbatim. If it differs, every lement is re-inserted into this.
Expand All @@ -274,4 +295,28 @@ public void copyFrom(NodeQueue other) {
other.foreach(this::push);
}
}

/**
* Converts a NodeScoreIterator to a PrimitiveIterator.OfLong by encoding the node and score as a long.
*/
private static class NodeScoreIteratorConverter implements PrimitiveIterator.OfLong {
private final NodeScoreIterator it;
private final NodeQueue queue;

public NodeScoreIteratorConverter(NodeScoreIterator it, NodeQueue queue) {
this.it = it;
this.queue = queue;
}

@Override
public boolean hasNext() {
return it.hasNext();
}

@Override
public long nextLong() {
// Call to nextScore() advances the iterator
return queue.encode(it.nextNode(), it.nextScore());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package io.github.jbellis.jvector.util;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import java.util.PrimitiveIterator;

/**
* A min heap that stores longs; a primitive priority queue that like all priority queues maintains
Expand Down Expand Up @@ -64,6 +65,14 @@ public AbstractLongHeap(int initialSize) {
*/
public abstract boolean push(long element);

/**
* Adds all elements from the given iterator to this heap, in bulk.
*
* @param elements the elements to add
* @param elementsSize the number of elements to add
*/
public abstract void pushAll(PrimitiveIterator.OfLong elements, int elementsSize);

protected long add(long element) {
size++;
if (size == heap.length) {
Expand All @@ -74,6 +83,38 @@ protected long add(long element) {
return heap[1];
}

/**
* Bulk-adds all elements from the given iterator to this heap, then re-heapifies
* in O(n) time (Floyd's build-heap). For a proof explaining the linear time
* complexity, see <a href="https://stackoverflow.com/a/18742428">this stackoverflow answer</a>.
*
* @param elements the elements to add
* @param elementsSize the number of elements to add
*/
protected void addAll(PrimitiveIterator.OfLong elements, int elementsSize) {
if (!elements.hasNext()) {
return; // nothing to do
}

// 1) Ensure we have enough capacity
int newSize = size + elementsSize;
if (newSize >= heap.length) {
heap = ArrayUtil.grow(heap, newSize);
}

// 2) Copy the new elements directly into the array
while (elements.hasNext()) {
heap[++size] = elements.nextLong();
}

// 3) "Bottom-up" re-heapify:
// Start from the last non-leaf node (size >>> 1) down to the root (1).
// This is Floyd's build-heap algorithm.
for (int i = size >>> 1; i >= 1; i--) {
downHeap(i);
}
}

/**
* Returns the least element of the LongHeap in constant time. It is up to the caller to verify
* that the heap is not empty; no checking is done, and if no elements have been added, 0 is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package io.github.jbellis.jvector.util;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import java.util.PrimitiveIterator;

/**
* An AbstractLongHeap with an adjustable maximum size.
Expand Down Expand Up @@ -67,6 +68,15 @@ public boolean push(long value) {
return true;
}

@Override
public void pushAll(PrimitiveIterator.OfLong elements, int elementsSize)
{
if (elementsSize + size >= maxSize) {
throw new IllegalArgumentException("Cannot add more elements than maxSize");
}
addAll(elements, elementsSize);
}

/**
* Replace the top of the heap with {@code newTop}, and enforce the heap invariant.
* Should be called when the top value changes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

package io.github.jbellis.jvector.util;

import java.util.PrimitiveIterator;

/**
* An AbstractLongHeap that can grow in size (unbounded, except for memory and array size limits).
*/
Expand All @@ -47,4 +49,10 @@ public boolean push(long element) {
add(element);
return true;
}

@Override
public void pushAll(PrimitiveIterator.OfLong elements, int elementsSize)
{
addAll(elements, elementsSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,68 @@ public void testUnboundedQueue() {
assertEquals(maxNode, nn.topNode());
}

@Test
public void testPushAllMinHeap() {
// Build a NodeQueue with a GrowableLongHeap, using MIN_HEAP order
NodeQueue queue = new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MIN_HEAP);

// Let's prepare some node, score pairs
int[] nodes = { 5, 1, 3, 2, 8 };
float[] scores = { 2.2f, -1.0f, 0.5f, 2.1f, -0.9f };

// We'll create a TestNodeScoreIterator with these arrays
TestNodeScoreIterator it = new TestNodeScoreIterator(nodes, scores);

// Bulk-add all pairs in one go
queue.pushAll(it, nodes.length);

// The queue should now contain 5 elements
assertEquals(5, queue.size());

// Because it's a MIN_HEAP, the top (root) should be the "smallest" score
// We have scores: [2.2, -1.0, 0.5, 2.1, -0.9]
// The minimum is -1.0. Let's see which node that corresponds to: node=1
assertEquals(-1.0f, queue.topScore(), 0.000001);
assertEquals(1, queue.topNode());
}

@Test
public void testPushAllMaxHeap() {
// Build a NodeQueue with a GrowableLongHeap, using MAX_HEAP order
NodeQueue queue = new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MAX_HEAP);

// Let's prepare some node, score pairs
int[] nodes = { 10, 20, 30, 40, 50 };
float[] scores = { -2.5f, 1.0f, 0.0f, 1.5f, 3.0f };

// We'll create a TestNodeScoreIterator with these arrays
TestNodeScoreIterator it = new TestNodeScoreIterator(nodes, scores);

// Bulk-add all pairs in one go
queue.pushAll(it, nodes.length);

// The queue should now contain 5 elements
assertEquals(5, queue.size());

// Because it's a MAX_HEAP, the top (root) should be the "largest" score
// The largest among [-2.5, 1.0, 0.0, 1.5, 3.0] is 3.0 => node=50
assertEquals(3.0f, queue.topScore(), 0.000001);
assertEquals(50, queue.topNode());
}

@Test
public void testPushAllBoundedHeapExceedsCapacity() {
assertThrows(IllegalArgumentException.class, () -> {
NodeQueue queue = new NodeQueue(new BoundedLongHeap(2), NodeQueue.Order.MAX_HEAP);
queue.pushAll(new TestNodeScoreIterator(new int[] { 1, 2, 3 }, new float[] { 1, 2, 3 }), 3);
});
NodeQueue queue = new NodeQueue(new BoundedLongHeap(2), NodeQueue.Order.MAX_HEAP);
queue.push(1, 1);
assertThrows(IllegalArgumentException.class, () -> {
queue.pushAll(new TestNodeScoreIterator(new int[] { 1, 2 }, new float[] { 1, 2 }), 2);
});
}

@Test
public void testInvalidArguments() {
assertThrows(IllegalArgumentException.class, () -> new NodeQueue(new GrowableLongHeap(0), NodeQueue.Order.MIN_HEAP));
Expand All @@ -127,4 +189,36 @@ public void testInvalidArguments() {
public void testToString() {
assertEquals("Nodes[0]", new NodeQueue(new GrowableLongHeap(2), NodeQueue.Order.MIN_HEAP).toString());
}

/**
* Simple iterator that yields a fixed array of (node, score) pairs
* for testing the pushAll method.
*/
static class TestNodeScoreIterator implements NodeQueue.NodeScoreIterator {
private final int[] nodes;
private final float[] scores;
private int index = 0;

TestNodeScoreIterator(int[] nodes, float[] scores) {
assert nodes.length == scores.length;
this.nodes = nodes;
this.scores = scores;
}

@Override
public boolean hasNext() {
return index < nodes.length;
}

@Override
public int nextNode() {
return nodes[index];
}

@Override
public float nextScore() {
return scores[index++];
}
}

}