Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,33 @@

package org.elasticsearch.index.codec.vectors.diskbbq;

public record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) {
public record CentroidAssignments(
int numCentroids,
float[][] centroids,
int[] assignments,
int[] overspillAssignments,
float[] globalCentroid
) {

public CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) {
this(centroids.length, centroids, assignments, overspillAssignments);
public CentroidAssignments(int dims, float[][] centroids, int[] assignments, int[] overspillAssignments) {
this(centroids.length, centroids, assignments, overspillAssignments, computeGlobalCentroid(dims, centroids));
assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0
: "assignments and overspillAssignments must have the same length";

}

private static float[] computeGlobalCentroid(int dims, float[][] centroids) {
final float[] globalCentroid = new float[dims];
// TODO: push this logic into vector util?
for (float[] centroid : centroids) {
assert centroid.length == dims;
for (int j = 0; j < centroid.length; j++) {
globalCentroid[j] += centroid[j];
}
}
for (int j = 0; j < globalCentroid.length; j++) {
globalCentroid[j] /= centroids.length;
}
return globalCentroid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -523,47 +523,34 @@ public int size() {
return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength);
}

@Override
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
throws IOException {
return calculateCentroids(fieldInfo, floatVectorValues);
}

/**
* Calculate the centroids for the given field.
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
* @throws IOException if an I/O error occurs
*/
@Override
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
throws IOException {

public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException {
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster);
float[][] centroids = centroidAssignments.centroids();
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
// TODO: push this logic into vector util?
for (float[] centroid : centroids) {
for (int j = 0; j < centroid.length; j++) {
globalCentroid[j] += centroid[j];
}
}
for (int j = 0; j < globalCentroid.length; j++) {
globalCentroid[j] /= centroids.length;
}

KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
if (logger.isDebugEnabled()) {
logger.debug("final centroid count: {}", centroids.length);
}
return centroidAssignments;
}

static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
return new CentroidAssignments(centroids, assignments, soarAssignments);
return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments);
}

static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
return rawVectorDelegate;
}

public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException;

public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
throws IOException;

public record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {}
Expand Down Expand Up @@ -191,11 +193,10 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
writeMeta(fieldWriter.fieldInfo, 0, 0, 0, 0, 0, null);
continue;
}
final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
// build centroids
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid);
final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues);
// wrap centroids with a supplier
final CentroidSupplier centroidSupplier = CentroidSupplier.fromArray(centroidAssignments.centroids());
// write posting lists
Expand All @@ -211,6 +212,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
);
final long postingListLength = ivfClusters.getFilePointer() - postingListOffset;
// write centroids
final float[] globalCentroid = centroidAssignments.globalCentroid();
final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, ivfCentroids);
final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
Expand Down Expand Up @@ -377,7 +379,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
final int numCentroids;
final int[] assignments;
final int[] overspillAssignments;
final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
final float[] calculatedGlobalCentroid;
String centroidTempName = null;
IndexOutput centroidTemp = null;
success = false;
Expand All @@ -387,7 +389,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
CentroidAssignments centroidAssignments = calculateCentroids(
fieldInfo,
getFloatVectorValues(fieldInfo, docs, vectors, numVectors),
calculatedGlobalCentroid
mergeState
);
// write the centroids to a temporary file so we are not holding them on heap
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
Expand All @@ -397,6 +399,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
}
numCentroids = centroidAssignments.numCentroids();
assignments = centroidAssignments.assignments();
calculatedGlobalCentroid = centroidAssignments.globalCentroid();
overspillAssignments = centroidAssignments.overspillAssignments();
success = true;
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,47 +511,34 @@ public int size() {
return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength);
}

@Override
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState)
throws IOException {
return calculateCentroids(fieldInfo, floatVectorValues);
}

/**
* Calculate the centroids for the given field.
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
* @throws IOException if an I/O error occurs
*/
@Override
public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid)
throws IOException {

public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException {
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster);
float[][] centroids = centroidAssignments.centroids();
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
// TODO: push this logic into vector util?
for (float[] centroid : centroids) {
for (int j = 0; j < centroid.length; j++) {
globalCentroid[j] += centroid[j];
}
}
for (int j = 0; j < globalCentroid.length; j++) {
globalCentroid[j] /= centroids.length;
}

KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
if (logger.isDebugEnabled()) {
logger.debug("final centroid count: {}", centroids.length);
}
return centroidAssignments;
}

static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException {
KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();
return new CentroidAssignments(centroids, assignments, soarAssignments);
return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments);
}

static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
Expand Down