From 7d9e975bce7be0ae434e8bd3c6bc07253e55eb88 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 15:06:50 +0300 Subject: [PATCH 01/21] iter --- .../xpack/esql/core/type/DataType.java | 2 +- .../vector/VectorSimilarityFunctionsIT.java | 27 ++- .../vector/VectorSimilarityFunction.java | 160 ++++++++++++++---- 3 files changed, 140 insertions(+), 49 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index b3b36ac0fa160..2bf9021b65ca3 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -309,7 +309,7 @@ public enum DataType { AGGREGATE_METRIC_DOUBLE(builder().esType("aggregate_metric_double").estimatedSize(Double.BYTES * 3 + Integer.BYTES)), /** - * Fields with this type are dense vectors, represented as an array of double values. + * Fields with this type are dense vectors, represented as an array of float values. */ DENSE_VECTOR(builder().esType("dense_vector").estimatedSize(4096)); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index e886e2baf06c4..70ebbb1633c40 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -24,9 +24,6 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; -import org.elasticsearch.xpack.esql.expression.function.vector.Hamming; -import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm; -import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm; import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction; import org.junit.Before; @@ -48,18 +45,18 @@ public static Iterable parameters() throws Exception { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare }); } - if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); - } - if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); - } - if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); - } - if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); - } + // if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + // params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); + // } + // if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + // params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); + // } + // if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + // params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); + // } + // if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + // params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); + // } return params; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 1e9727398d458..c2cf72eb01dc0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -16,10 +16,12 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.EsqlClientException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -27,6 +29,9 @@ import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.function.Function; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -80,8 +85,9 @@ public Object fold(FoldContext ctx) { @Override public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { return new SimilarityEvaluatorFactory( - toEvaluator.apply(left()), - toEvaluator.apply(right()), + left(), + right(), + toEvaluator::apply, getSimilarityFunction(), getClass().getSimpleName() + "Evaluator" ); @@ -92,19 +98,33 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe */ protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); + @SuppressWarnings("unchecked") private record SimilarityEvaluatorFactory( - EvalOperator.ExpressionEvaluator.Factory left, - EvalOperator.ExpressionEvaluator.Factory right, + Expression leftExpression, + Expression rightExpression, + Function toEvaluator, SimilarityEvaluatorFunction similarityFunction, String evaluatorName ) implements EvalOperator.ExpressionEvaluator.Factory { @Override public EvalOperator.ExpressionEvaluator get(DriverContext context) { + VectorValueProvider left; + VectorValueProvider right; + if (leftExpression instanceof Literal && leftExpression.dataType() == DENSE_VECTOR) { + left = new VectorValueProvider((ArrayList) ((Literal) leftExpression).value(), null); + } else { + left = new VectorValueProvider(null, toEvaluator.apply(leftExpression).get(context)); + } + if (rightExpression instanceof Literal && rightExpression.dataType() == DENSE_VECTOR) { + right = new VectorValueProvider((ArrayList) ((Literal) rightExpression).value(), null); + } else { + right = new VectorValueProvider(null, toEvaluator.apply(rightExpression).get(context)); + } // TODO check whether to use this custom evaluator or reuse / define an existing one return new SimilarityEvaluator( - left.get(context), - right.get(context), + left, + right, similarityFunction, evaluatorName, context.blockFactory() @@ -113,13 +133,96 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) { @Override public String toString() { - return evaluatorName() + "[left=" + left + ", right=" + right + "]"; + return evaluatorName() + "[left=" + "left" + ", right=" + "right" + "]"; + } + } + + private static class VectorValueProvider implements Releasable { + + private final float[] constantVector; + private final EvalOperator.ExpressionEvaluator expressionEvaluator; + private FloatBlock block; + float[] scratch; + + VectorValueProvider(ArrayList constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) { + if(constantVector != null) { + this.constantVector = new float[constantVector.size()]; + for (int i = 0; i < constantVector.size(); i++) { + this.constantVector[i] = constantVector.get(i); + } + }else { + this.constantVector = null; + } + this.expressionEvaluator = expressionEvaluator; + } + + private void eval(Page page) { + if(expressionEvaluator != null) { + block = (FloatBlock) expressionEvaluator.eval(page); + } + } + + private float[] getVector(int position) { + if (constantVector != null) { + return constantVector; + } else if (block != null) { + if (block.isNull(position)) { + return null; + } + if (scratch == null) { + int dims = block.getValueCount(position); + if (dims > 0) { + scratch = new float[dims]; + } + } + if (scratch != null) { + readFloatArray(block, block.getFirstValueIndex(position), scratch.length, scratch); + } + return scratch; + } else { + throw new EsqlClientException("VectorValueProvider not initialized"); + } + } + + private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { + for (int i = 0; i < dimensions; i++) { + scratch[i] = block.getFloat(position + i); + } + } + + public int getDimensions(){ + if (constantVector != null) { + return constantVector.length; + } else if (block != null) { + for (int p = 0; p < block.getPositionCount(); p++) { + int dims = block.getValueCount(p); + if (dims > 0) { + return dims; + } + } + return 0; + } else { + throw new EsqlClientException("VectorValueProvider not initialized"); + } + } + + public long baseRamBytesUsed() { + return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null ? 0 : expressionEvaluator.baseRamBytesUsed()); + } + + public void close() { + Releasables.close(block, expressionEvaluator); + } + + @Override + public String toString() { + return "constant_vector=" + Arrays.toString(constantVector) + ", expressionEvaluator=[" + expressionEvaluator + "]"; } } private record SimilarityEvaluator( - EvalOperator.ExpressionEvaluator left, - EvalOperator.ExpressionEvaluator right, + VectorValueProvider left, + VectorValueProvider right, SimilarityEvaluatorFunction similarityFunction, String evaluatorName, BlockFactory blockFactory @@ -129,26 +232,23 @@ private record SimilarityEvaluator( @Override public Block eval(Page page) { - try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) { + try { + left.eval(page); + right.eval(page); + + int dimensions = left.getDimensions(); int positionCount = page.getPositionCount(); - int dimensions = 0; - // Get the first non-empty vector to calculate the dimension - for (int p = 0; p < positionCount; p++) { - if (leftBlock.getValueCount(p) != 0) { - dimensions = leftBlock.getValueCount(p); - break; - } - } if (dimensions == 0) { return blockFactory.newConstantNullBlock(positionCount); } - float[] leftScratch = new float[dimensions]; - float[] rightScratch = new float[dimensions]; - try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount * dimensions)) { + try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) { for (int p = 0; p < positionCount; p++) { - int dimsLeft = leftBlock.getValueCount(p); - int dimsRight = rightBlock.getValueCount(p); + float[] leftVector = left.getVector(p); + float[] rightvector = right.getVector(p); + + int dimsLeft = leftVector == null ? 0 : leftVector.length; + int dimsRight = rightvector == null ? 0 : rightvector.length; if (dimsLeft == 0 || dimsRight == 0) { // A null value on the left or right vector. Similarity is null @@ -161,19 +261,13 @@ public Block eval(Page page) { dimsRight ); } - readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); - readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); - float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); + float result = similarityFunction.calculateSimilarity(leftVector, rightvector); builder.appendDouble(result); } return builder.build(); } - } - } - - private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { - for (int i = 0; i < dimensions; i++) { - scratch[i] = block.getFloat(position + i); + } finally { + Releasables.close(left, right); } } @@ -189,7 +283,7 @@ public String toString() { @Override public void close() { - Releasables.close(left, right); +// Releasables.close(left, right); } } } From abe7f01a57c4360ecee9407719a267591f5f61ab Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 16:52:05 +0300 Subject: [PATCH 02/21] iter --- .../vector/VectorSimilarityFunction.java | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index c2cf72eb01dc0..1eb5629af6539 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -122,13 +122,7 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) { right = new VectorValueProvider(null, toEvaluator.apply(rightExpression).get(context)); } // TODO check whether to use this custom evaluator or reuse / define an existing one - return new SimilarityEvaluator( - left, - right, - similarityFunction, - evaluatorName, - context.blockFactory() - ); + return new SimilarityEvaluator(left, right, similarityFunction, evaluatorName, context.blockFactory()); } @Override @@ -145,20 +139,20 @@ private static class VectorValueProvider implements Releasable { float[] scratch; VectorValueProvider(ArrayList constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) { - if(constantVector != null) { + if (constantVector != null) { this.constantVector = new float[constantVector.size()]; for (int i = 0; i < constantVector.size(); i++) { this.constantVector[i] = constantVector.get(i); } - }else { + } else { this.constantVector = null; } this.expressionEvaluator = expressionEvaluator; } private void eval(Page page) { - if(expressionEvaluator != null) { - block = (FloatBlock) expressionEvaluator.eval(page); + if (expressionEvaluator != null) { + block = (FloatBlock) expressionEvaluator.eval(page); } } @@ -190,7 +184,7 @@ private static void readFloatArray(FloatBlock block, int position, int dimension } } - public int getDimensions(){ + public int getDimensions() { if (constantVector != null) { return constantVector.length; } else if (block != null) { @@ -207,11 +201,20 @@ public int getDimensions(){ } public long baseRamBytesUsed() { - return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null ? 0 : expressionEvaluator.baseRamBytesUsed()); + return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null + ? 0 + : expressionEvaluator.baseRamBytesUsed()); + } + + public void finishBlock() { + if (block != null) { + block.close(); + block = null; + } } public void close() { - Releasables.close(block, expressionEvaluator); + Releasables.close(expressionEvaluator); } @Override @@ -267,7 +270,8 @@ public Block eval(Page page) { return builder.build(); } } finally { - Releasables.close(left, right); + left.finishBlock(); + right.finishBlock(); } } @@ -283,7 +287,7 @@ public String toString() { @Override public void close() { -// Releasables.close(left, right); + Releasables.close(left, right); } } } From 989235e44f82800b149e2898aee464b20a244514 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 16:52:55 +0300 Subject: [PATCH 03/21] iter --- .../vector/VectorSimilarityFunctionsIT.java | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index 70ebbb1633c40..38c5431fc4579 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -24,6 +24,9 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.expression.function.vector.Hamming; +import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm; +import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm; import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction; import org.junit.Before; @@ -45,18 +48,18 @@ public static Iterable parameters() throws Exception { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare }); } - // if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - // params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); - // } - // if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - // params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); - // } - // if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - // params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); - // } - // if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - // params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); - // } + if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); + } + if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); + } + if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); + } + if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); + } return params; } From aeebc6e39a08b48fb948603be9df87395f3c7f05 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 16:53:22 +0300 Subject: [PATCH 04/21] iter --- .../vector/VectorSimilarityFunctionsIT.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index 38c5431fc4579..e886e2baf06c4 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -48,18 +48,18 @@ public static Iterable parameters() throws Exception { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare }); } - if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); - } - if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); - } - if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); - } - if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); - } + if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare }); + } + if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity }); + } + if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity }); + } + if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + params.add(new Object[] { "v_hamming", (SimilarityEvaluatorFunction) Hamming::calculateSimilarity }); + } return params; } From 240dd53a87287d9caecad067375cb7e26436cd0b Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 17:02:27 +0300 Subject: [PATCH 05/21] iter --- .../function/vector/VectorSimilarityFunction.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 1eb5629af6539..de6be8cd79489 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -248,10 +248,10 @@ public Block eval(Page page) { try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) { for (int p = 0; p < positionCount; p++) { float[] leftVector = left.getVector(p); - float[] rightvector = right.getVector(p); + float[] rightVector = right.getVector(p); int dimsLeft = leftVector == null ? 0 : leftVector.length; - int dimsRight = rightvector == null ? 0 : rightvector.length; + int dimsRight = rightVector == null ? 0 : rightVector.length; if (dimsLeft == 0 || dimsRight == 0) { // A null value on the left or right vector. Similarity is null @@ -264,7 +264,7 @@ public Block eval(Page page) { dimsRight ); } - float result = similarityFunction.calculateSimilarity(leftVector, rightvector); + float result = similarityFunction.calculateSimilarity(leftVector, rightVector); builder.appendDouble(result); } return builder.build(); From 4a759bab2dc6b056010c32b578e296db55657276 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 26 Sep 2025 17:37:26 +0300 Subject: [PATCH 06/21] iter --- .../function/vector/VectorSimilarityFunction.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index de6be8cd79489..8a6a57a209c2a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -170,17 +170,17 @@ private float[] getVector(int position) { } } if (scratch != null) { - readFloatArray(block, block.getFirstValueIndex(position), scratch.length, scratch); + readFloatArray(block, block.getFirstValueIndex(position), scratch); } return scratch; } else { - throw new EsqlClientException("VectorValueProvider not initialized"); + throw new EsqlClientException("VectorValueProvider not properly initialized. Both [constantVector] and [block] are null."); } } - private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { - for (int i = 0; i < dimensions; i++) { - scratch[i] = block.getFloat(position + i); + private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) { + for (int i = 0; i < scratch.length; i++) { + scratch[i] = block.getFloat(firstValueIndex + i); } } @@ -196,7 +196,7 @@ public int getDimensions() { } return 0; } else { - throw new EsqlClientException("VectorValueProvider not initialized"); + throw new EsqlClientException("[VectorValueProvider] not initialized"); } } @@ -210,6 +210,7 @@ public void finishBlock() { if (block != null) { block.close(); block = null; + scratch = null; } } From 0301c5b360a5d238b582e0829ddd0ad2e954d9c0 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 29 Sep 2025 12:25:03 +0300 Subject: [PATCH 07/21] iter --- .../vector/VectorSimilarityFunction.java | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 8a6a57a209c2a..a0fa7f956c53b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -109,34 +109,54 @@ private record SimilarityEvaluatorFactory( @Override public EvalOperator.ExpressionEvaluator get(DriverContext context) { - VectorValueProvider left; - VectorValueProvider right; + VectorValueProvider.Builder left = new VectorValueProvider.Builder(); + VectorValueProvider.Builder right = new VectorValueProvider.Builder(); if (leftExpression instanceof Literal && leftExpression.dataType() == DENSE_VECTOR) { - left = new VectorValueProvider((ArrayList) ((Literal) leftExpression).value(), null); + left.constantVector((ArrayList) ((Literal) leftExpression).value()); } else { - left = new VectorValueProvider(null, toEvaluator.apply(leftExpression).get(context)); + left.expressionEvaluator(toEvaluator.apply(leftExpression).get(context)); } if (rightExpression instanceof Literal && rightExpression.dataType() == DENSE_VECTOR) { - right = new VectorValueProvider((ArrayList) ((Literal) rightExpression).value(), null); + right.constantVector((ArrayList) ((Literal) rightExpression).value()); } else { - right = new VectorValueProvider(null, toEvaluator.apply(rightExpression).get(context)); + right.expressionEvaluator(toEvaluator.apply(rightExpression).get(context)); } // TODO check whether to use this custom evaluator or reuse / define an existing one - return new SimilarityEvaluator(left, right, similarityFunction, evaluatorName, context.blockFactory()); + return new SimilarityEvaluator(left.build(), right.build(), similarityFunction, evaluatorName, context.blockFactory()); } @Override public String toString() { - return evaluatorName() + "[left=" + "left" + ", right=" + "right" + "]"; + return evaluatorName() + "[left=" + leftExpression + ", right=" + rightExpression + "]"; } } private static class VectorValueProvider implements Releasable { + private static final class Builder { + private ArrayList constantVector; + private EvalOperator.ExpressionEvaluator expressionEvaluator; + + void constantVector(ArrayList constantVector) { + this.constantVector = constantVector; + } + + void expressionEvaluator(EvalOperator.ExpressionEvaluator expressionEvaluator) { + this.expressionEvaluator = expressionEvaluator; + } + + VectorValueProvider build() { + if (false == (constantVector == null ^ expressionEvaluator == null)) { + throw new IllegalArgumentException("One of [constantVector] or [expressionEvaluator] must be set, but not both."); + } + return new VectorValueProvider(constantVector, expressionEvaluator); + } + } + private final float[] constantVector; private final EvalOperator.ExpressionEvaluator expressionEvaluator; private FloatBlock block; - float[] scratch; + private float[] scratch; VectorValueProvider(ArrayList constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) { if (constantVector != null) { @@ -174,7 +194,7 @@ private float[] getVector(int position) { } return scratch; } else { - throw new EsqlClientException("VectorValueProvider not properly initialized. Both [constantVector] and [block] are null."); + throw new EsqlClientException("[" + getClass() + "] not properly initialized. Both [constantVector] and [block] are null."); } } @@ -201,12 +221,13 @@ public int getDimensions() { } public long baseRamBytesUsed() { - return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null - ? 0 - : expressionEvaluator.baseRamBytesUsed()); + return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + + (expressionEvaluator == null ? 0 : expressionEvaluator.baseRamBytesUsed()); } - public void finishBlock() { + // Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be + // filled by the next page through `eval` + void finishBlock() { if (block != null) { block.close(); block = null; From cafafcb85be5e73f9507edc508573cdf8dfe1e5b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 29 Sep 2025 09:56:29 +0000 Subject: [PATCH 08/21] [CI] Auto commit changes from spotless --- .../expression/function/vector/VectorSimilarityFunction.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index a0fa7f956c53b..5afc50065ce28 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -221,8 +221,9 @@ public int getDimensions() { } public long baseRamBytesUsed() { - return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) - + (expressionEvaluator == null ? 0 : expressionEvaluator.baseRamBytesUsed()); + return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null + ? 0 + : expressionEvaluator.baseRamBytesUsed()); } // Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be From 9f12dbe0dec94f714ac8160d55aa81899d6475af Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 29 Sep 2025 14:43:28 +0300 Subject: [PATCH 09/21] iter --- .../vector/VectorSimilarityFunction.java | 90 ++++++++++++------- ...tractVectorSimilarityFunctionTestCase.java | 4 +- 2 files changed, 59 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index a0fa7f956c53b..b0f4e5e25f0dc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -31,7 +31,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.function.Function; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -83,11 +82,23 @@ public Object fold(FoldContext ctx) { } @Override + @SuppressWarnings("unchecked") public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { + VectorValueProvider.Builder leftVectorProviderBuilder = new VectorValueProvider.Builder(); + VectorValueProvider.Builder rightVectorProviderBuilder = new VectorValueProvider.Builder(); + if (left() instanceof Literal && left().dataType() == DENSE_VECTOR) { + leftVectorProviderBuilder.constantVector((ArrayList) ((Literal) left()).value()); + } else { + leftVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(left())); + } + if (right() instanceof Literal && right().dataType() == DENSE_VECTOR) { + rightVectorProviderBuilder.constantVector((ArrayList) ((Literal) right()).value()); + } else { + rightVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(right())); + } return new SimilarityEvaluatorFactory( - left(), - right(), - toEvaluator::apply, + leftVectorProviderBuilder, + rightVectorProviderBuilder, getSimilarityFunction(), getClass().getSimpleName() + "Evaluator" ); @@ -98,36 +109,28 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe */ protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); - @SuppressWarnings("unchecked") private record SimilarityEvaluatorFactory( - Expression leftExpression, - Expression rightExpression, - Function toEvaluator, + VectorValueProvider.Builder leftVectorProviderBuilder, + VectorValueProvider.Builder rightVectorProviderBuilder, SimilarityEvaluatorFunction similarityFunction, String evaluatorName ) implements EvalOperator.ExpressionEvaluator.Factory { @Override public EvalOperator.ExpressionEvaluator get(DriverContext context) { - VectorValueProvider.Builder left = new VectorValueProvider.Builder(); - VectorValueProvider.Builder right = new VectorValueProvider.Builder(); - if (leftExpression instanceof Literal && leftExpression.dataType() == DENSE_VECTOR) { - left.constantVector((ArrayList) ((Literal) leftExpression).value()); - } else { - left.expressionEvaluator(toEvaluator.apply(leftExpression).get(context)); - } - if (rightExpression instanceof Literal && rightExpression.dataType() == DENSE_VECTOR) { - right.constantVector((ArrayList) ((Literal) rightExpression).value()); - } else { - right.expressionEvaluator(toEvaluator.apply(rightExpression).get(context)); - } // TODO check whether to use this custom evaluator or reuse / define an existing one - return new SimilarityEvaluator(left.build(), right.build(), similarityFunction, evaluatorName, context.blockFactory()); + return new SimilarityEvaluator( + leftVectorProviderBuilder.build(context), + rightVectorProviderBuilder.build(context), + similarityFunction, + evaluatorName, + context.blockFactory() + ); } @Override public String toString() { - return evaluatorName() + "[left=" + leftExpression + ", right=" + rightExpression + "]"; + return evaluatorName() + "[left=" + leftVectorProviderBuilder + ", right=" + rightVectorProviderBuilder + "]"; } } @@ -135,21 +138,35 @@ private static class VectorValueProvider implements Releasable { private static final class Builder { private ArrayList constantVector; - private EvalOperator.ExpressionEvaluator expressionEvaluator; + private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory; void constantVector(ArrayList constantVector) { this.constantVector = constantVector; } - void expressionEvaluator(EvalOperator.ExpressionEvaluator expressionEvaluator) { - this.expressionEvaluator = expressionEvaluator; + void expressionEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) { + this.expressionEvaluatorFactory = expressionEvaluatorFactory; } - VectorValueProvider build() { - if (false == (constantVector == null ^ expressionEvaluator == null)) { - throw new IllegalArgumentException("One of [constantVector] or [expressionEvaluator] must be set, but not both."); + VectorValueProvider build(DriverContext context) { + if (false == (constantVector == null ^ expressionEvaluatorFactory == null)) { + throw new IllegalArgumentException( + "One of [constantVector] or [expressionEvaluatorFactory] must be set, but not both." + ); } - return new VectorValueProvider(constantVector, expressionEvaluator); + return new VectorValueProvider( + constantVector, + expressionEvaluatorFactory == null ? null : expressionEvaluatorFactory.get(context) + ); + } + + @Override + public String toString() { + return "constantVector=" + + (constantVector == null ? "[null]" : constantVector) + + ", expressionEvaluator=[" + + expressionEvaluatorFactory + + "]"; } } @@ -194,7 +211,9 @@ private float[] getVector(int position) { } return scratch; } else { - throw new EsqlClientException("[" + getClass() + "] not properly initialized. Both [constantVector] and [block] are null."); + throw new EsqlClientException( + "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." + ); } } @@ -216,13 +235,16 @@ public int getDimensions() { } return 0; } else { - throw new EsqlClientException("[VectorValueProvider] not initialized"); + throw new EsqlClientException( + "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." + ); } } public long baseRamBytesUsed() { - return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) - + (expressionEvaluator == null ? 0 : expressionEvaluator.baseRamBytesUsed()); + return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null + ? 0 + : expressionEvaluator.baseRamBytesUsed()); } // Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be @@ -241,7 +263,7 @@ public void close() { @Override public String toString() { - return "constant_vector=" + Arrays.toString(constantVector) + ", expressionEvaluator=[" + expressionEvaluator + "]"; + return "constantVector=[" + Arrays.toString(constantVector) + "], expressionEvaluator=[" + expressionEvaluator + "]"; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index 6b0faaaf6d53e..0e4ab70b7bce4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -42,7 +42,9 @@ protected static Iterable similarityParameters( VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction ) { - final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]"; + final String evaluatorName = className + + "Evaluator" + + "[left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; List suppliers = new ArrayList<>(); From 1fc70a2f1ad84c35d875c84c223441b7d70396c8 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 29 Sep 2025 14:58:29 +0300 Subject: [PATCH 10/21] checkstyle --- .../vector/AbstractVectorSimilarityFunctionTestCase.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index 0e4ab70b7bce4..ea8986acce173 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -43,8 +43,8 @@ protected static Iterable similarityParameters( ) { final String evaluatorName = className - + "Evaluator" - + "[left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; + + "Evaluator [left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " + + "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; List suppliers = new ArrayList<>(); From 7b5b760570bf71bfb7990829a312acf539bfa6c7 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 29 Sep 2025 12:10:43 +0000 Subject: [PATCH 11/21] [CI] Auto commit changes from spotless --- .../vector/AbstractVectorSimilarityFunctionTestCase.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index ea8986acce173..f00fe4df5c529 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -43,8 +43,8 @@ protected static Iterable similarityParameters( ) { final String evaluatorName = className - + "Evaluator [left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " + - "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; + + "Evaluator [left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " + + "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; List suppliers = new ArrayList<>(); From 9d51231eb120fb837667b5f5b6b7292296aa400b Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 29 Sep 2025 16:03:28 +0300 Subject: [PATCH 12/21] checkstyle --- .../vector/AbstractVectorSimilarityFunctionTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index f00fe4df5c529..c078ae55c17bb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -43,7 +43,7 @@ protected static Iterable similarityParameters( ) { final String evaluatorName = className - + "Evaluator [left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " + + "Evaluator[left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " + "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; List suppliers = new ArrayList<>(); From 504c92eb28344db67829fff0dca44221347e4a46 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 29 Sep 2025 16:19:32 +0300 Subject: [PATCH 13/21] checkstyle --- .../function/vector/VectorSimilarityFunction.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index b0f4e5e25f0dc..05a3988320d2c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -263,7 +263,11 @@ public void close() { @Override public String toString() { - return "constantVector=[" + Arrays.toString(constantVector) + "], expressionEvaluator=[" + expressionEvaluator + "]"; + return "constantVector=" + + (constantVector == null ? "[null]" : Arrays.toString(constantVector)) + + ", expressionEvaluator=[" + + expressionEvaluator + + "]"; } } From 59cdda97d136c2b52bf1cf072e1d4738a00ece79 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 30 Sep 2025 16:16:48 +0300 Subject: [PATCH 14/21] iter --- .../vector/VectorSimilarityFunction.java | 326 ++++++++++-------- ...tractVectorSimilarityFunctionTestCase.java | 4 +- 2 files changed, 179 insertions(+), 151 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 05a3988320d2c..766ad83cfe710 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -84,17 +85,17 @@ public Object fold(FoldContext ctx) { @Override @SuppressWarnings("unchecked") public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { - VectorValueProvider.Builder leftVectorProviderBuilder = new VectorValueProvider.Builder(); - VectorValueProvider.Builder rightVectorProviderBuilder = new VectorValueProvider.Builder(); - if (left() instanceof Literal && left().dataType() == DENSE_VECTOR) { - leftVectorProviderBuilder.constantVector((ArrayList) ((Literal) left()).value()); + VectorValueProviderBuilder leftVectorProviderBuilder; + VectorValueProviderBuilder rightVectorProviderBuilder; + if (left() instanceof Literal) { + leftVectorProviderBuilder = new VectorValueProviderBuilder((ArrayList) ((Literal) left()).value(), null); } else { - leftVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(left())); + leftVectorProviderBuilder = new VectorValueProviderBuilder(null, toEvaluator.apply(left())); } - if (right() instanceof Literal && right().dataType() == DENSE_VECTOR) { - rightVectorProviderBuilder.constantVector((ArrayList) ((Literal) right()).value()); + if (right() instanceof Literal) { + rightVectorProviderBuilder = new VectorValueProviderBuilder((ArrayList) ((Literal) right()).value(), null); } else { - rightVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(right())); + rightVectorProviderBuilder = new VectorValueProviderBuilder(null, toEvaluator.apply(right())); } return new SimilarityEvaluatorFactory( leftVectorProviderBuilder, @@ -110,8 +111,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); private record SimilarityEvaluatorFactory( - VectorValueProvider.Builder leftVectorProviderBuilder, - VectorValueProvider.Builder rightVectorProviderBuilder, + VectorValueProviderBuilder leftVectorProviderBuilder, + VectorValueProviderBuilder rightVectorProviderBuilder, SimilarityEvaluatorFunction similarityFunction, String evaluatorName ) implements EvalOperator.ExpressionEvaluator.Factory { @@ -134,143 +135,6 @@ public String toString() { } } - private static class VectorValueProvider implements Releasable { - - private static final class Builder { - private ArrayList constantVector; - private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory; - - void constantVector(ArrayList constantVector) { - this.constantVector = constantVector; - } - - void expressionEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) { - this.expressionEvaluatorFactory = expressionEvaluatorFactory; - } - - VectorValueProvider build(DriverContext context) { - if (false == (constantVector == null ^ expressionEvaluatorFactory == null)) { - throw new IllegalArgumentException( - "One of [constantVector] or [expressionEvaluatorFactory] must be set, but not both." - ); - } - return new VectorValueProvider( - constantVector, - expressionEvaluatorFactory == null ? null : expressionEvaluatorFactory.get(context) - ); - } - - @Override - public String toString() { - return "constantVector=" - + (constantVector == null ? "[null]" : constantVector) - + ", expressionEvaluator=[" - + expressionEvaluatorFactory - + "]"; - } - } - - private final float[] constantVector; - private final EvalOperator.ExpressionEvaluator expressionEvaluator; - private FloatBlock block; - private float[] scratch; - - VectorValueProvider(ArrayList constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) { - if (constantVector != null) { - this.constantVector = new float[constantVector.size()]; - for (int i = 0; i < constantVector.size(); i++) { - this.constantVector[i] = constantVector.get(i); - } - } else { - this.constantVector = null; - } - this.expressionEvaluator = expressionEvaluator; - } - - private void eval(Page page) { - if (expressionEvaluator != null) { - block = (FloatBlock) expressionEvaluator.eval(page); - } - } - - private float[] getVector(int position) { - if (constantVector != null) { - return constantVector; - } else if (block != null) { - if (block.isNull(position)) { - return null; - } - if (scratch == null) { - int dims = block.getValueCount(position); - if (dims > 0) { - scratch = new float[dims]; - } - } - if (scratch != null) { - readFloatArray(block, block.getFirstValueIndex(position), scratch); - } - return scratch; - } else { - throw new EsqlClientException( - "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." - ); - } - } - - private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) { - for (int i = 0; i < scratch.length; i++) { - scratch[i] = block.getFloat(firstValueIndex + i); - } - } - - public int getDimensions() { - if (constantVector != null) { - return constantVector.length; - } else if (block != null) { - for (int p = 0; p < block.getPositionCount(); p++) { - int dims = block.getValueCount(p); - if (dims > 0) { - return dims; - } - } - return 0; - } else { - throw new EsqlClientException( - "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." - ); - } - } - - public long baseRamBytesUsed() { - return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null - ? 0 - : expressionEvaluator.baseRamBytesUsed()); - } - - // Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be - // filled by the next page through `eval` - void finishBlock() { - if (block != null) { - block.close(); - block = null; - scratch = null; - } - } - - public void close() { - Releasables.close(expressionEvaluator); - } - - @Override - public String toString() { - return "constantVector=" - + (constantVector == null ? "[null]" : Arrays.toString(constantVector)) - + ", expressionEvaluator=[" - + expressionEvaluator - + "]"; - } - } - private record SimilarityEvaluator( VectorValueProvider left, VectorValueProvider right, @@ -318,8 +182,8 @@ public Block eval(Page page) { return builder.build(); } } finally { - left.finishBlock(); - right.finishBlock(); + left.finish(); + right.finish(); } } @@ -338,4 +202,168 @@ public void close() { Releasables.close(left, right); } } + + private record VectorValueProviderBuilder( + ArrayList vector, + EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory + ) { + VectorValueProvider build(DriverContext context) { + if (vector != null) { + return new ConstantVectorProvider(vector); + } else { + return new ExpressionVectorProvider(expressionEvaluatorFactory.get(context)); + } + } + + @Override + public String toString() { + if (vector != null) { + return "ConstantVectorProvider[vector=" + Arrays.toString(vector.toArray()) + "]"; + } else { + return "ExpressionVectorProvider[expressionEvaluator=[" + expressionEvaluatorFactory + "]]"; + } + } + } + + interface VectorValueProvider extends Releasable { + + void eval(Page page); + + float[] getVector(int position); + + int getDimensions(); + + void finish(); + + long baseRamBytesUsed(); + } + + private static class ConstantVectorProvider implements VectorValueProvider { + + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantVectorProvider.class); + + private final float[] vector; + + ConstantVectorProvider(List vector) { + assert vector != null; + this.vector = new float[vector.size()]; + for (int i = 0; i < vector.size(); i++) { + this.vector[i] = vector.get(i); + } + } + + @Override + public void eval(Page page) { + // no-op + } + + @Override + public float[] getVector(int position) { + return vector; + } + + @Override + public int getDimensions() { + return vector.length; + } + + @Override + public void finish() { + // no-op + } + + @Override + public void close() { + // no-op + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED + RamUsageEstimator.shallowSizeOf(vector); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "[vector=" + (vector == null ? "[null]" : Arrays.toString(vector)) + "]"; + } + } + + private static class ExpressionVectorProvider implements VectorValueProvider { + + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ExpressionVectorProvider.class); + + private final EvalOperator.ExpressionEvaluator expressionEvaluator; + private FloatBlock block; + private float[] scratch; + + ExpressionVectorProvider(EvalOperator.ExpressionEvaluator expressionEvaluator) { + this.expressionEvaluator = expressionEvaluator; + } + + @Override + public void eval(Page page) { + if (expressionEvaluator != null) { + block = (FloatBlock) expressionEvaluator.eval(page); + } + } + + @Override + public float[] getVector(int position) { + if (block.isNull(position)) { + return null; + } + if (scratch == null) { + int dims = block.getValueCount(position); + if (dims > 0) { + scratch = new float[dims]; + } + } + if (scratch != null) { + readFloatArray(block, block.getFirstValueIndex(position), scratch); + } + return scratch; + } + + @Override + public int getDimensions() { + for (int p = 0; p < block.getPositionCount(); p++) { + int dims = block.getValueCount(p); + if (dims > 0) { + return dims; + } + } + return 0; + } + + @Override + public void finish() { + if (block != null) { + block.close(); + block = null; + scratch = null; + } + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED + expressionEvaluator.baseRamBytesUsed() + (block == null ? 0 : block.ramBytesUsed()) + + (scratch == null ? 0 : RamUsageEstimator.shallowSizeOf(scratch)); + } + + @Override + public void close() { + Releasables.close(expressionEvaluator); + } + + private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) { + for (int i = 0; i < scratch.length; i++) { + scratch[i] = block.getFloat(firstValueIndex + i); + } + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + "[expressionEvaluator=[" + expressionEvaluator + "]]"; + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index c078ae55c17bb..ace11049057b1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -43,8 +43,8 @@ protected static Iterable similarityParameters( ) { final String evaluatorName = className - + "Evaluator[left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], " - + "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]"; + + "Evaluator[left=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=0]]]," + + " right=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=1]]]]"; List suppliers = new ArrayList<>(); From 6da8fb49ec842d0958875a9fb2a188befb37915d Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 30 Sep 2025 16:19:29 +0300 Subject: [PATCH 15/21] iter --- .../expression/function/vector/VectorSimilarityFunction.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 766ad83cfe710..6c7bea8f5d932 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -284,7 +284,7 @@ public long baseRamBytesUsed() { @Override public String toString() { - return this.getClass().getSimpleName() + "[vector=" + (vector == null ? "[null]" : Arrays.toString(vector)) + "]"; + return this.getClass().getSimpleName() + "[vector=" + Arrays.toString(vector) + "]"; } } @@ -297,6 +297,7 @@ private static class ExpressionVectorProvider implements VectorValueProvider { private float[] scratch; ExpressionVectorProvider(EvalOperator.ExpressionEvaluator expressionEvaluator) { + assert expressionEvaluator != null; this.expressionEvaluator = expressionEvaluator; } From 42dcf15660f5efe0f3442adfee5e4f8a5ab14fdb Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 30 Sep 2025 16:21:25 +0300 Subject: [PATCH 16/21] iter --- .../expression/function/vector/VectorSimilarityFunction.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 6c7bea8f5d932..86a9a91c565e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -303,9 +303,7 @@ private static class ExpressionVectorProvider implements VectorValueProvider { @Override public void eval(Page page) { - if (expressionEvaluator != null) { - block = (FloatBlock) expressionEvaluator.eval(page); - } + block = (FloatBlock) expressionEvaluator.eval(page); } @Override From 4640bb8565cb3b7d1ad13d2afe86ab506e010c5d Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 1 Oct 2025 11:46:31 +0300 Subject: [PATCH 17/21] iter --- .../expression/function/vector/VectorSimilarityFunction.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 86a9a91c565e4..b5d287018dde4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -218,9 +218,9 @@ VectorValueProvider build(DriverContext context) { @Override public String toString() { if (vector != null) { - return "ConstantVectorProvider[vector=" + Arrays.toString(vector.toArray()) + "]"; + return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(vector.toArray()) + "]"; } else { - return "ExpressionVectorProvider[expressionEvaluator=[" + expressionEvaluatorFactory + "]]"; + return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]"; } } } From 8d72e5c586dd9914a7063b35fc3dcc675be377be Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 1 Oct 2025 12:08:31 +0300 Subject: [PATCH 18/21] adding test cases for toString when one vector is a literal --- ...tractVectorSimilarityFunctionTestCase.java | 57 ++++++++++++++++++- .../vector/CosineSimilarityTests.java | 5 ++ .../vector/DotProductSimilarityTests.java | 5 ++ .../vector/HammingSimilarityTests.java | 5 ++ .../vector/L1NormSimilarityTests.java | 5 ++ .../vector/L2NormSimilarityTests.java | 5 ++ 6 files changed, 80 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index ace11049057b1..ac5a05ddd1040 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -9,7 +9,9 @@ import com.carrotsearch.randomizedtesting.annotations.Name; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.junit.Before; @@ -32,6 +34,8 @@ public void checkCapability() { assumeTrue("Similarity function is not enabled", capability().isEnabled()); } + public abstract String getBaseEvaluatorName(); + /** * Get the capability of the vector similarity function to check */ @@ -43,8 +47,10 @@ protected static Iterable similarityParameters( ) { final String evaluatorName = className - + "Evaluator[left=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=0]]]," - + " right=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=1]]]]"; + + "Evaluator[" + + "left=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=0]]], " + + "right=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=1]]]" + + "]"; List suppliers = new ArrayList<>(); @@ -69,4 +75,51 @@ protected static Iterable similarityParameters( return parameterSuppliersFromTypedData(suppliers); } + + public final void testEvaluatorToStringWhenOneVectorIsLiteral() { + Expression literal = buildLiteralExpression(testCase).children().getFirst(); + Expression field = buildFieldExpression(testCase).children().getLast(); + var expression = build(testCase.getSource(), List.of(literal, field)); + if (testCase.getExpectedTypeError() != null) { + assertTypeResolutionFailure(expression); + return; + } + assumeTrue("Can't build evaluator", testCase.canBuildEvaluator()); + var factory = evaluator(expression); + final String evaluatorName = getBaseEvaluatorName() + "Evaluator" + + "[left=ConstantVectorProvider[vector=" + + testCase.getData().getFirst().getValue() + + "]," + + " right=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=0]]]]"; + + try (EvalOperator.ExpressionEvaluator ev = factory.get(driverContext())) { + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } + assertThat(ev.toString(), equalTo(evaluatorName)); + } + } + + public final void testFactoryToStringWhenOneVectorIsLiteral() { + + Expression literal = buildLiteralExpression(testCase).children().getFirst(); + Expression field = buildFieldExpression(testCase).children().getLast(); + var expression = build(testCase.getSource(), List.of(literal, field)); + if (testCase.getExpectedTypeError() != null) { + assertTypeResolutionFailure(expression); + return; + } + assumeTrue("Can't build evaluator", testCase.canBuildEvaluator()); + var factory = evaluator(expression); + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } + final String evaluatorName = getBaseEvaluatorName() + "Evaluator" + + "[left=ConstantVectorProvider[vector=" + + testCase.getData().getFirst().getValue() + + "]," + + " right=ExpressionVectorProvider[expressionEvaluator=[Attribute[channel=0]]]]"; + + assertThat(factory.toString(), equalTo(evaluatorName)); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java index 32ba95ee0af27..e257a0a262646 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java @@ -26,6 +26,11 @@ public CosineSimilarityTests(@Name("TestCase") Supplier parameters() { return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProductSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProductSimilarityTests.java index 20e275d1280b1..091b497c5acee 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProductSimilarityTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/DotProductSimilarityTests.java @@ -26,6 +26,11 @@ public DotProductSimilarityTests(@Name("TestCase") Supplier parameters() { return similarityParameters(DotProduct.class.getSimpleName(), DotProduct.SIMILARITY_FUNCTION); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/HammingSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/HammingSimilarityTests.java index 203c0171dc5f4..b41f8e4ec32d5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/HammingSimilarityTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/HammingSimilarityTests.java @@ -26,6 +26,11 @@ public HammingSimilarityTests(@Name("TestCase") Supplier parameters() { return similarityParameters(Hamming.class.getSimpleName(), Hamming.SIMILARITY_FUNCTION); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java index 149f5a1cd650a..e8e7f49dbadc7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java @@ -26,6 +26,11 @@ public L1NormSimilarityTests(@Name("TestCase") Supplier parameters() { return similarityParameters(L1Norm.class.getSimpleName(), L1Norm.SIMILARITY_FUNCTION); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java index 2dc7929ffa05e..fee60152c3893 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java @@ -26,6 +26,11 @@ public L2NormSimilarityTests(@Name("TestCase") Supplier parameters() { return similarityParameters(L2Norm.class.getSimpleName(), L2Norm.SIMILARITY_FUNCTION); From c25f0bd7a8b1a368ac2e77386b29e36bb73e2f67 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 1 Oct 2025 09:16:08 +0000 Subject: [PATCH 19/21] [CI] Auto commit changes from spotless --- .../vector/AbstractVectorSimilarityFunctionTestCase.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index ac5a05ddd1040..bfdda32889d66 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -86,7 +86,8 @@ public final void testEvaluatorToStringWhenOneVectorIsLiteral() { } assumeTrue("Can't build evaluator", testCase.canBuildEvaluator()); var factory = evaluator(expression); - final String evaluatorName = getBaseEvaluatorName() + "Evaluator" + final String evaluatorName = getBaseEvaluatorName() + + "Evaluator" + "[left=ConstantVectorProvider[vector=" + testCase.getData().getFirst().getValue() + "]," @@ -114,7 +115,8 @@ public final void testFactoryToStringWhenOneVectorIsLiteral() { if (testCase.getExpectedBuildEvaluatorWarnings() != null) { assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); } - final String evaluatorName = getBaseEvaluatorName() + "Evaluator" + final String evaluatorName = getBaseEvaluatorName() + + "Evaluator" + "[left=ConstantVectorProvider[vector=" + testCase.getData().getFirst().getValue() + "]," From ee019f2fb8a76926ef3de2503bc9a892f29144d6 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 2 Oct 2025 12:55:48 +0300 Subject: [PATCH 20/21] addressing PR comments - adding common factory interface --- .../vector/VectorSimilarityFunction.java | 74 ++++++++++--------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index b5d287018dde4..d4e808cddfea6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -85,21 +85,21 @@ public Object fold(FoldContext ctx) { @Override @SuppressWarnings("unchecked") public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { - VectorValueProviderBuilder leftVectorProviderBuilder; - VectorValueProviderBuilder rightVectorProviderBuilder; + VectorValueProviderFactory leftVectorProviderFactory; + VectorValueProviderFactory rightVectorProviderFactory; if (left() instanceof Literal) { - leftVectorProviderBuilder = new VectorValueProviderBuilder((ArrayList) ((Literal) left()).value(), null); + leftVectorProviderFactory = new ConstantVectorProvider.Factory((ArrayList) ((Literal) left()).value()); } else { - leftVectorProviderBuilder = new VectorValueProviderBuilder(null, toEvaluator.apply(left())); + leftVectorProviderFactory = new ExpressionVectorProvider.Factory(toEvaluator.apply(left())); } if (right() instanceof Literal) { - rightVectorProviderBuilder = new VectorValueProviderBuilder((ArrayList) ((Literal) right()).value(), null); + rightVectorProviderFactory = new ConstantVectorProvider.Factory((ArrayList) ((Literal) right()).value()); } else { - rightVectorProviderBuilder = new VectorValueProviderBuilder(null, toEvaluator.apply(right())); + rightVectorProviderFactory = new ExpressionVectorProvider.Factory(toEvaluator.apply(right())); } return new SimilarityEvaluatorFactory( - leftVectorProviderBuilder, - rightVectorProviderBuilder, + leftVectorProviderFactory, + rightVectorProviderFactory, getSimilarityFunction(), getClass().getSimpleName() + "Evaluator" ); @@ -111,8 +111,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); private record SimilarityEvaluatorFactory( - VectorValueProviderBuilder leftVectorProviderBuilder, - VectorValueProviderBuilder rightVectorProviderBuilder, + VectorValueProviderFactory leftVectorProviderFactory, + VectorValueProviderFactory rightVectorProviderFactory, SimilarityEvaluatorFunction similarityFunction, String evaluatorName ) implements EvalOperator.ExpressionEvaluator.Factory { @@ -121,8 +121,8 @@ private record SimilarityEvaluatorFactory( public EvalOperator.ExpressionEvaluator get(DriverContext context) { // TODO check whether to use this custom evaluator or reuse / define an existing one return new SimilarityEvaluator( - leftVectorProviderBuilder.build(context), - rightVectorProviderBuilder.build(context), + leftVectorProviderFactory.build(context), + rightVectorProviderFactory.build(context), similarityFunction, evaluatorName, context.blockFactory() @@ -131,7 +131,7 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) { @Override public String toString() { - return evaluatorName() + "[left=" + leftVectorProviderBuilder + ", right=" + rightVectorProviderBuilder + "]"; + return evaluatorName() + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]"; } } @@ -203,28 +203,6 @@ public void close() { } } - private record VectorValueProviderBuilder( - ArrayList vector, - EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory - ) { - VectorValueProvider build(DriverContext context) { - if (vector != null) { - return new ConstantVectorProvider(vector); - } else { - return new ExpressionVectorProvider(expressionEvaluatorFactory.get(context)); - } - } - - @Override - public String toString() { - if (vector != null) { - return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(vector.toArray()) + "]"; - } else { - return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]"; - } - } - } - interface VectorValueProvider extends Releasable { void eval(Page page); @@ -238,8 +216,23 @@ interface VectorValueProvider extends Releasable { long baseRamBytesUsed(); } + interface VectorValueProviderFactory { + VectorValueProvider build(DriverContext context); + } + private static class ConstantVectorProvider implements VectorValueProvider { + record Factory(List vector) implements VectorValueProviderFactory { + public VectorValueProvider build(DriverContext context) { + return new ConstantVectorProvider(vector); + } + + @Override + public String toString() { + return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(vector.toArray()) + "]"; + } + } + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantVectorProvider.class); private final float[] vector; @@ -290,6 +283,17 @@ public String toString() { private static class ExpressionVectorProvider implements VectorValueProvider { + record Factory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) implements VectorValueProviderFactory { + public VectorValueProvider build(DriverContext context) { + return new ExpressionVectorProvider(expressionEvaluatorFactory.get(context)); + } + + @Override + public String toString() { + return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]"; + } + } + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ExpressionVectorProvider.class); private final EvalOperator.ExpressionEvaluator expressionEvaluator; From 9f47ab2b027dd4159cd5e457253fa2a13f2d4d91 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 2 Oct 2025 13:16:53 +0300 Subject: [PATCH 21/21] extracting method for generating the appropriate factory --- .../vector/VectorSimilarityFunction.java | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index d4e808cddfea6..42d215cc3d201 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -83,20 +83,9 @@ public Object fold(FoldContext ctx) { } @Override - @SuppressWarnings("unchecked") public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { - VectorValueProviderFactory leftVectorProviderFactory; - VectorValueProviderFactory rightVectorProviderFactory; - if (left() instanceof Literal) { - leftVectorProviderFactory = new ConstantVectorProvider.Factory((ArrayList) ((Literal) left()).value()); - } else { - leftVectorProviderFactory = new ExpressionVectorProvider.Factory(toEvaluator.apply(left())); - } - if (right() instanceof Literal) { - rightVectorProviderFactory = new ConstantVectorProvider.Factory((ArrayList) ((Literal) right()).value()); - } else { - rightVectorProviderFactory = new ExpressionVectorProvider.Factory(toEvaluator.apply(right())); - } + VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory(left(), toEvaluator); + VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory(right(), toEvaluator); return new SimilarityEvaluatorFactory( leftVectorProviderFactory, rightVectorProviderFactory, @@ -105,6 +94,18 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe ); } + @SuppressWarnings("unchecked") + private static VectorValueProviderFactory getVectorValueProviderFactory( + Expression expression, + EvaluatorMapper.ToEvaluator toEvaluator + ) { + if (expression instanceof Literal) { + return new ConstantVectorProvider.Factory((ArrayList) ((Literal) expression).value()); + } else { + return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression)); + } + } + /** * Returns the similarity function to be used for evaluating the similarity between two vectors. */