diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java index 33ea8085191c3..32044e30a9aab 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java @@ -9,11 +9,33 @@ package org.elasticsearch.index.codec.vectors.diskbbq; -public record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { +public record CentroidAssignments( + int numCentroids, + float[][] centroids, + int[] assignments, + int[] overspillAssignments, + float[] globalCentroid +) { - public CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { - this(centroids.length, centroids, assignments, overspillAssignments); + public CentroidAssignments(int dims, float[][] centroids, int[] assignments, int[] overspillAssignments) { + this(centroids.length, centroids, assignments, overspillAssignments, computeGlobalCentroid(dims, centroids)); assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 : "assignments and overspillAssignments must have the same length"; + + } + + private static float[] computeGlobalCentroid(int dims, float[][] centroids) { + final float[] globalCentroid = new float[dims]; + // TODO: push this logic into vector util? + for (float[] centroid : centroids) { + assert centroid.length == dims; + for (int j = 0; j < centroid.length; j++) { + globalCentroid[j] += centroid[j]; + } + } + for (int j = 0; j < globalCentroid.length; j++) { + globalCentroid[j] /= centroids.length; + } + return globalCentroid; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java index 68d6acc248a07..125144138f952 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java @@ -523,47 +523,34 @@ public int size() { return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength); } + @Override + public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState) + throws IOException { + return calculateCentroids(fieldInfo, floatVectorValues); + } + /** * Calculate the centroids for the given field. * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments * * @param fieldInfo merging field info * @param floatVectorValues the float vector values to merge - * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed * @throws IOException if an I/O error occurs */ @Override - public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) - throws IOException { - + public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException { // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids - CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster); - float[][] centroids = centroidAssignments.centroids(); // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative // preliminary tests suggest recall is good using only centroids but need to do further evaluation - // TODO: push this logic into vector util? - for (float[] centroid : centroids) { - for (int j = 0; j < centroid.length; j++) { - globalCentroid[j] += centroid[j]; - } - } - for (int j = 0; j < globalCentroid.length; j++) { - globalCentroid[j] /= centroids.length; - } - + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); + float[][] centroids = kMeansResult.centroids(); if (logger.isDebugEnabled()) { logger.debug("final centroid count: {}", centroids.length); } - return centroidAssignments; - } - - static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { - KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); - float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); - return new CentroidAssignments(centroids, assignments, soarAssignments); + return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java index 87b01f3800068..ee4aaca10227b 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java @@ -141,7 +141,9 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException; + + public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState) throws IOException; public record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {} @@ -191,11 +193,10 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { writeMeta(fieldWriter.fieldInfo, 0, 0, 0, 0, 0, null); continue; } - final float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; // build a float vector values with random access final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); // build centroids - final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues, globalCentroid); + final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues); // wrap centroids with a supplier final CentroidSupplier centroidSupplier = CentroidSupplier.fromArray(centroidAssignments.centroids()); // write posting lists @@ -211,6 +212,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { ); final long postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids + final float[] globalCentroid = centroidAssignments.globalCentroid(); final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, ivfCentroids); final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; @@ -377,7 +379,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final int numCentroids; final int[] assignments; final int[] overspillAssignments; - final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; + final float[] calculatedGlobalCentroid; String centroidTempName = null; IndexOutput centroidTemp = null; success = false; @@ -387,7 +389,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws CentroidAssignments centroidAssignments = calculateCentroids( fieldInfo, getFloatVectorValues(fieldInfo, docs, vectors, numVectors), - calculatedGlobalCentroid + mergeState ); // write the centroids to a temporary file so we are not holding them on heap final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); @@ -397,6 +399,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws } numCentroids = centroidAssignments.numCentroids(); assignments = centroidAssignments.assignments(); + calculatedGlobalCentroid = centroidAssignments.globalCentroid(); overspillAssignments = centroidAssignments.overspillAssignments(); success = true; } finally { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java index 5b19f62f00167..70f140294b380 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java @@ -511,47 +511,34 @@ public int size() { return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength); } + @Override + public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, MergeState mergeState) + throws IOException { + return calculateCentroids(fieldInfo, floatVectorValues); + } + /** * Calculate the centroids for the given field. * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments * * @param fieldInfo merging field info * @param floatVectorValues the float vector values to merge - * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed * @throws IOException if an I/O error occurs */ @Override - public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) - throws IOException { - + public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues) throws IOException { // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids - CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster); - float[][] centroids = centroidAssignments.centroids(); // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative // preliminary tests suggest recall is good using only centroids but need to do further evaluation - // TODO: push this logic into vector util? - for (float[] centroid : centroids) { - for (int j = 0; j < centroid.length; j++) { - globalCentroid[j] += centroid[j]; - } - } - for (int j = 0; j < globalCentroid.length; j++) { - globalCentroid[j] /= centroids.length; - } - + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); + float[][] centroids = kMeansResult.centroids(); if (logger.isDebugEnabled()) { logger.debug("final centroid count: {}", centroids.length); } - return centroidAssignments; - } - - static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { - KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); - float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); - return new CentroidAssignments(centroids, assignments, soarAssignments); + return new CentroidAssignments(fieldInfo.getVectorDimension(), centroids, assignments, soarAssignments); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)