diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md b/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md
index df15bde7deb55..9b6d20b551e7a 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md
@@ -4,7 +4,7 @@
```esql
from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0], 10)
+| where knn(rgb_vector, [0, 120, 0])
| sort _score desc, color asc
```
diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md
index 1e87271707676..f38a8e8d84584 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md
@@ -2,12 +2,12 @@
**Supported function named parameters**
-`num_candidates`
-: (integer) The number of nearest neighbor candidates to consider per shard while doing knn search. Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * k
-
`boost`
: (float) Floating point number used to decrease or increase the relevance scores of the query.Defaults to 1.0.
+`min_candidates`
+: (integer) The minimum number of nearest neighbor candidates to consider per shard while doing knn search. KNN may use a higher number of candidates in case the query can't use a approximate results. Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * LIMIT used for the query.
+
`rescore_oversample`
: (double) Applies the specified oversampling for rescoring quantized vectors. See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details.
diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md
index e33acabbd014f..fb1b98a1e8a7a 100644
--- a/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md
+++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md
@@ -8,9 +8,6 @@
`query`
: Vector value to find top nearest neighbours for.
-`k`
-: The number of nearest neighbors to return from each shard. Elasticsearch collects k results from each shard, then merges them to find the global top results. This value must be less than or equal to num_candidates.
-
`options`
: (Optional) kNN additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params). See [knn query](/reference/query-languages/query-dsl/query-dsl-match-query.md#query-dsl-knn-query) for more information.
diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg
index 6e20dbc217206..75a104a7cdcfa 100644
--- a/docs/reference/query-languages/esql/images/functions/knn.svg
+++ b/docs/reference/query-languages/esql/images/functions/knn.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
index d347891393dcf..f4b77305a200b 100644
--- a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
+++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
@@ -5,7 +5,7 @@
"description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
"signatures" : [ ],
"examples" : [
- "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc"
+ "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc, color asc"
],
"preview" : true,
"snapshot_only" : true
diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
index f32319b080dbb..bea09b0bf50de 100644
--- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
+++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
@@ -5,6 +5,6 @@ Finds the k nearest vectors to a query vector, as measured by a similarity metri
```esql
from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0], 10)
+| where knn(rgb_vector, [0, 120, 0])
| sort _score desc, color asc
```
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java
index c7f187c6c4a8f..a4561978bedff 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java
@@ -60,10 +60,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards)
}
public Block executeQuery(Page page) {
- // Lucene based operators retrieve DocVectors as first block
- Block block = page.getBlock(0);
- assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
- DocVector docs = (DocVector) block.asVector();
+ // Search for DocVector block
+ Block docBlock = null;
+ for (int i = 0; i < page.getBlockCount(); i++) {
+ if (page.getBlock(i) instanceof DocBlock) {
+ docBlock = page.getBlock(i);
+ break;
+ }
+ }
+ assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock";
+ DocVector docs = (DocVector) docBlock.asVector();
try {
if (docs.singleSegmentNonDecreasing()) {
return evalSingleSegmentNonDecreasing(docs);
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java
index 2afc885d71124..1c3d522fda5ab 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java
@@ -9,7 +9,6 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.Page;
@@ -46,9 +45,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco
@Override
protected Page process(Page page) {
- assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
- assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
- assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
+ assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition;
+ assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector
+ : "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector();
Block[] blocks = new Block[page.getBlockCount()];
for (int i = 0; i < page.getBlockCount(); i++) {
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
index 2cad34e324fda..7a0e854f63f90 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
@@ -3,11 +3,11 @@
# top-n query at the shard level
knnSearch
-required_capability: knn_function_v3
+required_capability: knn_function_v4
// tag::knn-function[]
from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0], 10)
+| where knn(rgb_vector, [0, 120, 0])
| sort _score desc, color asc
// end::knn-function[]
| keep color, rgb_vector
@@ -30,10 +30,10 @@ chartreuse | [127.0, 255.0, 0.0]
;
knnSearchWithSimilarityOption
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
+| where knn(rgb_vector, [255,192,203], {"similarity": 40})
| sort _score desc, color asc
| keep color, rgb_vector
;
@@ -46,13 +46,14 @@ wheat | [245.0, 222.0, 179.0]
;
knnHybridSearch
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
+| where match(color, "blue") or knn(rgb_vector, [65,105,225])
| where primary == true
| sort _score desc, color asc
| keep color, rgb_vector
+| limit 10
;
color:text | rgb_vector:dense_vector
@@ -68,10 +69,10 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithPrefilter
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors
-| where knn(rgb_vector, [120,180,0], 10) and (match(color, "olive") or match(color, "green"))
+| where knn(rgb_vector, [120,180,0]) and (match(color, "olive") or match(color, "green"))
| sort color asc
| keep color
;
@@ -82,10 +83,10 @@ olive
;
knnWithNegatedPrefilter
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate"))
+| where knn(rgb_vector, [128,128,0]) and not (match(color, "olive") or match(color, "chocolate"))
| sort _score desc, color asc
| keep color, rgb_vector
| LIMIT 10
@@ -105,11 +106,11 @@ orange | [255.0, 165.0, 0.0]
;
knnAfterKeep
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
| keep rgb_vector, color, _score
-| where knn(rgb_vector, [128,255,0], 140)
+| where knn(rgb_vector, [128,255,0])
| sort _score desc, color asc
| keep rgb_vector
| limit 5
@@ -124,11 +125,11 @@ rgb_vector:dense_vector
;
knnAfterDrop
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
| drop primary
-| where knn(rgb_vector, [128,250,0], 140)
+| where knn(rgb_vector, [128,250,0])
| sort _score desc, color asc
| keep color, rgb_vector
| limit 5
@@ -143,11 +144,11 @@ lime | [0.0, 255.0, 0.0]
;
knnAfterEval
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
-| where knn(rgb_vector, [128,128,0], 140)
+| where knn(rgb_vector, [128,128,0])
| sort _score desc, color asc
| keep color, composed_name
| limit 5
@@ -162,12 +163,13 @@ golden rod | true
;
knnWithConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*"
+| where knn(rgb_vector, [255,255,238]) and hex_code like "#FFF*"
| sort _score desc, color asc
| keep color, hex_code, rgb_vector
+| limit 10
;
color:text | hex_code:keyword | rgb_vector:dense_vector
@@ -181,10 +183,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
;
knnWithDisjunctionAndFiltersConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true
+| where (knn(rgb_vector, [0,255,255]) or knn(rgb_vector, [128, 0, 255])) and primary == true
| keep color, rgb_vector, _score
| sort _score desc, color asc
| drop _score
@@ -204,10 +206,10 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithNegationsAndFiltersConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue")))
+| where (knn(rgb_vector, [0,255,255]) and not(primary == true and match(color, "blue")))
| sort _score desc, color asc
| keep color, rgb_vector
| limit 10
@@ -227,11 +229,11 @@ azure | [240.0, 255.0, 255.0]
;
knnWithNonPushableConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
-| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
+| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false
| sort _score desc, color asc
| keep color, composed_name
| limit 10
@@ -251,58 +253,88 @@ maroon | false
;
testKnnWithNonPushableDisjunctions
-required_capability: knn_function_v3
+required_capability: knn_function_v4
from colors metadata _score
-| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
+| where knn(rgb_vector, [128,128,0]) or length(color) > 10
| sort _score desc, color asc
-| keep color
+| keep color
+| limit 10
;
color:text
-olive
-aqua marine
-lemon chiffon
-papaya whip
+olive
+sienna
+chocolate
+peru
+golden rod
+brown
+firebrick
+chartreuse
+gray
+green
;
-testKnnWithNonPushableDisjunctionsOnComplexExpressions
-required_capability: knn_function_v3
+testKnnWithNonPushableDisjunctionsAndMinCandidates
+required_capability: knn_function_v4
from colors metadata _score
-| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
+| where (knn(rgb_vector, [128,128,0], {"min_candidates": 2}) and length(color) > 10) or (knn(rgb_vector, [128,0,128], {"min_candidates": 2}) and primary == true)
| sort _score desc, color asc
| keep color, primary
;
color:text | primary:boolean
-olive | false
-purple | false
-indigo | false
-;
+gray | true
+green | true
+red | true
+black | true
+magenta | true
+yellow | true
+blue | true
+aqua marine | false
+papaya whip | false
+lemon chiffon | false
+white | true
+cyan | true
+;
+
+testKnnWithStats
+required_capability: knn_function_v4
-testKnnInStatsNonPushable
-required_capability: knn_function_v3
-
-from colors
-| where length(color) < 10
-| stats c = count(*) where knn(rgb_vector, [128,128,255], 140)
+from colors metadata _score
+| where knn(rgb_vector, [128,128,0])
+| sort _score desc, color asc
+| limit 15
+| stats c = count(*)
;
-c: long
-50
+c:long
+15
;
-testKnnInStatsWithGrouping
-required_capability: knn_function_v3
-required_capability: full_text_functions_in_stats_where
+testKnnWithRerank
+required_capability: knn_function_v4
+required_capability: rerank
-from colors
-| where length(color) < 10
-| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary
+from colors metadata _score
+| where knn(rgb_vector, [100,120,0])
+| sort _score desc, color asc
+| limit 10
+| rerank rerank_score = "deepest blue" ON color WITH { "inference_id" : "test_reranker" }
+| sort rerank_score desc, color asc
+| keep color
;
-c: long | primary: boolean
-41 | false
-9 | true
+color:text
+gray
+peru
+brown
+green
+olive
+maroon
+sienna
+chocolate
+firebrick
+golden rod
;
diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
index d44a9b458b082..21ec240d9f8f4 100644
--- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
+++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
@@ -74,9 +74,10 @@ public void testKnnDefaults() {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, 10)
+ | WHERE knn(vector, %s)
| KEEP id, _score, vector
| SORT _score DESC
+ | LIMIT 10
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
@@ -113,9 +114,10 @@ public void testKnnOptions() {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, 5)
+ | WHERE knn(vector, %s)
| KEEP id, _score, vector
| SORT _score DESC
+ | LIMIT 5
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
@@ -131,12 +133,12 @@ public void testKnnNonPushedDown() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 0.0f);
- // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, 5) OR id > 100
+ | WHERE knn(vector, %s) OR id > 100
| KEEP id, _score, vector
| SORT _score DESC
+ | LIMIT 5
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
@@ -155,7 +157,7 @@ public void testKnnWithPrefilters() {
// We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10
+ | WHERE knn(vector, %s) AND id > 5 AND id <= 10
| KEEP id, _score, vector
| SORT _score DESC
| LIMIT 5
@@ -178,7 +180,8 @@ public void testKnnWithLookupJoin() {
var query = String.format(Locale.ROOT, """
FROM test
| LOOKUP JOIN test_lookup ON id
- | WHERE KNN(lookup_vector, %s, 5) OR id > 100
+ | WHERE KNN(lookup_vector, %s) OR id > 100
+ | LIMIT 5
""", Arrays.toString(queryVector));
var error = expectThrows(VerificationException.class, () -> run(query));
@@ -193,7 +196,7 @@ public void testKnnWithLookupJoin() {
@Before
public void setup() throws IOException {
- assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+ assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
var indexName = "test";
var client = client().admin().indices();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 19eac8bd9ad03..9a69e9c86fe10 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1291,7 +1291,7 @@ public enum Cap {
/**
* Support knn function
*/
- KNN_FUNCTION_V3(Build.current().isSnapshot()),
+ KNN_FUNCTION_V4(Build.current().isSnapshot()),
/**
* Support for the LIKE operator with a list of wildcards.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index 9b794d9b9b7b5..f4d20dcafd1a0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -505,7 +505,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Score.class, uni(Score::new), Score.NAME),
def(Term.class, bi(Term::new), "term"),
- def(Knn.class, quad(Knn::new), "knn"),
+ def(Knn.class, tri(Knn::new), "knn"),
def(ToGeohash.class, ToGeohash::new, "to_geohash"),
def(ToGeotile.class, ToGeotile::new, "to_geotile"),
def(ToGeohex.class, ToGeohex::new, "to_geohex"),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
index c273da317dec2..c9e23fdd29387 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
@@ -384,18 +384,29 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
int i = 0;
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
- shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
+ shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
}
return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
}
+ /**
+ * Returns the query builder to be used when the function cannot be pushed down to Lucene, but uses a
+ * {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator} instead
+ *
+ * @return the query builder to be used in the {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator}
+ */
+ protected QueryBuilder evaluatorQueryBuilder() {
+ // Use the same query builder as for the translation by default
+ return queryBuilder();
+ }
+
@Override
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
List shardContexts = toScorer.shardContexts();
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
int i = 0;
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
- shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
+ shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
}
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
index 0b64fb43909df..9add14da034b5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
@@ -7,15 +7,17 @@
package org.elasticsearch.xpack.esql.expression.function.vector;
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
+import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
+import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
+import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -54,14 +56,11 @@
import static java.util.Map.entry;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
-import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
@@ -70,20 +69,26 @@
import static org.elasticsearch.xpack.esql.expression.Foldables.TypeResolutionValidator.forPreOptimizationValidation;
import static org.elasticsearch.xpack.esql.expression.Foldables.resolveTypeQuery;
-public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
- private final Logger log = LogManager.getLogger(getClass());
+public class Knn extends FullTextFunction
+ implements
+ OptionalArgument,
+ VectorFunction,
+ PostAnalysisPlanVerificationAware,
+ PostOptimizationVerificationAware {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
private final Expression field;
// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
- private final transient Expression k;
+ private final transient Integer k;
private final Expression options;
// Expressions to be used as prefilters in knn query
private final List filterExpressions;
+ public static final String MIN_CANDIDATES_OPTION = "min_candidates";
+
public static final Map ALLOWED_OPTIONS = Map.ofEntries(
- entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
+ entry(MIN_CANDIDATES_OPTION, INTEGER),
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
entry(BOOST_FIELD.getPreferredName(), FLOAT),
entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
@@ -105,13 +110,6 @@ public Knn(
type = { "dense_vector" },
description = "Vector value to find top nearest neighbours for."
) Expression query,
- @Param(
- name = "k",
- type = { "integer" },
- description = "The number of nearest neighbors to return from each shard. "
- + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
- + "This value must be less than or equal to num_candidates."
- ) Expression k,
@MapParam(
name = "options",
params = {
@@ -123,12 +121,13 @@ public Knn(
+ "Defaults to 1.0."
),
@MapParam.MapParamEntry(
- name = "num_candidates",
+ name = "min_candidates",
type = "integer",
valueHint = { "10" },
- description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
- + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
- + "Defaults to 1.5 * k"
+ description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. "
+ + " KNN may use a higher number of candidates in case the query can't use a approximate results. "
+ + "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
+ + "Defaults to 1.5 * LIMIT used for the query."
),
@MapParam.MapParamEntry(
name = "similarity",
@@ -150,32 +149,29 @@ public Knn(
optional = true
) Expression options
) {
- this(source, field, query, k, options, null, List.of());
+ this(source, field, query, options, null, null, List.of());
}
public Knn(
Source source,
Expression field,
Expression query,
- Expression k,
Expression options,
+ Integer k,
QueryBuilder queryBuilder,
List filterExpressions
) {
- super(source, query, expressionList(field, query, k, options), queryBuilder);
+ super(source, query, expressionList(field, query, options), queryBuilder);
this.field = field;
this.k = k;
this.options = options;
this.filterExpressions = filterExpressions;
}
- private static List expressionList(Expression field, Expression query, Expression k, Expression options) {
+ private static List expressionList(Expression field, Expression query, Expression options) {
List result = new ArrayList<>();
result.add(field);
result.add(query);
- if (k != null) {
- result.add(k);
- }
if (options != null) {
result.add(options);
}
@@ -186,7 +182,7 @@ public Expression field() {
return field;
}
- public Expression k() {
+ public Integer k() {
return k;
}
@@ -205,7 +201,7 @@ public DataType dataType() {
@Override
protected TypeResolution resolveParams() {
- return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS));
+ return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS));
}
private TypeResolution resolveField() {
@@ -225,14 +221,9 @@ private TypeResolution resolveQuery() {
return TypeResolution.TYPE_RESOLVED;
}
- private TypeResolution resolveK() {
- if (k == null) {
- // Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed
- return TypeResolution.TYPE_RESOLVED;
- }
-
- return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD))
- .and(isNotNull(k(), sourceText(), THIRD));
+ public Knn replaceK(Integer k) {
+ Check.notNull(k, "k must not be null");
+ return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
}
public List queryAsObject() {
@@ -246,16 +237,9 @@ public List queryAsObject() {
throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query));
}
- int getKIntValue() {
- if (k() instanceof Literal literal) {
- return (int) (Number) literal.value();
- }
- throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
- }
-
@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
- return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
+ return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
}
@Override
@@ -271,37 +255,39 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
@Override
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
+ assert k() != null : "Knn function must have a k value set before translation";
var fieldAttribute = Match.fieldAsFieldAttribute(field());
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
String fieldName = getNameFromFieldAttribute(fieldAttribute);
- List queryFolded = queryAsObject();
- float[] queryAsFloats = new float[queryFolded.size()];
- for (int i = 0; i < queryFolded.size(); i++) {
- queryAsFloats[i] = queryFolded.get(i).floatValue();
- }
- int kValue = getKIntValue();
-
- Map opts = queryOptions();
- opts.put(K_FIELD.getPreferredName(), kValue);
+ float[] queryAsFloats = queryAsFloats();
List filterQueries = new ArrayList<>();
for (Expression filterExpression : filterExpressions()) {
if (filterExpression instanceof TranslationAware translationAware) {
// We can only translate filter expressions that are translatable. In case any is not translatable,
- // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them
- // when creating an evaluator for the non-pushed down query
+ // Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator
+ // for the non-pushed down query
if (translationAware.translatable(pushdownPredicates) == Translatable.YES) {
filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
}
}
}
- return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
+ return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
+ }
+
+ private float[] queryAsFloats() {
+ List queryFolded = queryAsObject();
+ float[] queryAsFloats = new float[queryFolded.size()];
+ for (int i = 0; i < queryFolded.size(); i++) {
+ queryAsFloats[i] = queryFolded.get(i).floatValue();
+ }
+ return queryAsFloats;
}
public Expression withFilters(List filterExpressions) {
- return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
+ return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
}
private Map queryOptions() throws InvalidArgumentException {
@@ -312,6 +298,17 @@ private Map queryOptions() throws InvalidArgumentException {
return options;
}
+ protected QueryBuilder evaluatorQueryBuilder() {
+ // Either we couldn't push down due to non-pushable filters, or because it's part of a disjuncion.
+ // Uses a nearest neighbors exact query instead of an approximate one
+ var fieldAttribute = Match.fieldAsFieldAttribute(field());
+ Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
+ String fieldName = getNameFromFieldAttribute(fieldAttribute);
+ Map opts = queryOptions();
+
+ return new ExactKnnQueryBuilder(VectorData.fromFloats(queryAsFloats()), fieldName, (Float) opts.get(VECTOR_SIMILARITY_FIELD));
+ }
+
@Override
public BiConsumer postAnalysisPlanVerification() {
return (plan, failures) -> {
@@ -320,14 +317,24 @@ public BiConsumer postAnalysisPlanVerification() {
};
}
+ @Override
+ public void postOptimizationVerification(Failures failures) {
+ // Check that a k has been set
+ if (k() == null) {
+ failures.add(
+ Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find")
+ );
+ }
+ }
+
@Override
public Expression replaceChildren(List newChildren) {
return new Knn(
source(),
newChildren.get(0),
newChildren.get(1),
- newChildren.get(2),
- newChildren.size() > 3 ? newChildren.get(3) : null,
+ newChildren.size() > 2 ? newChildren.get(2) : null,
+ k(),
queryBuilder(),
filterExpressions()
);
@@ -335,7 +342,7 @@ public Expression replaceChildren(List newChildren) {
@Override
protected NodeInfo extends Expression> info() {
- return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions());
+ return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
index f4353c28476d2..ab41201ceb328 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
@@ -27,7 +27,7 @@ private VectorWritables() {
public static List getNamedWritables() {
List entries = new ArrayList<>();
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
entries.add(Knn.ENTRY);
}
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
index dac533f872022..6f550524c5ca5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
@@ -44,6 +44,7 @@
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownJoinPastProject;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushLimitToKnn;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.RemoveStatsOverride;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateNestedExpressionWithEval;
@@ -192,6 +193,7 @@ protected static Batch operators(boolean local) {
new PruneColumns(),
new PruneLiteralsInOrderBy(),
new PushDownAndCombineLimits(),
+ new PushLimitToKnn(),
new PushDownAndCombineFilters(),
new PushDownConjunctionsToKnnPrefilters(),
new PushDownAndCombineSample(),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java
new file mode 100644
index 0000000000000..a8503c300bfbc
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.util.Holder;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
+import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
+import org.elasticsearch.xpack.esql.plan.logical.Filter;
+import org.elasticsearch.xpack.esql.plan.logical.Limit;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.TopN;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+
+/**
+ * Traverses the logical plan and pushes down the limit to the KNN function(s) in filter expressions, so KNN can use
+ * it to set k if not specified.
+ */
+public class PushLimitToKnn extends OptimizerRules.ParameterizedOptimizerRule {
+
+ public PushLimitToKnn() {
+ super(OptimizerRules.TransformDirection.DOWN);
+ }
+
+ @Override
+ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
+ Holder breakerReached = new Holder<>(false);
+ Holder firstLimit = new Holder<>(false);
+ return limit.transformDown(plan -> {
+ if (breakerReached.get()) {
+ // We reached a breaker and don't want to continue processing
+ return plan;
+ }
+ if (plan instanceof Filter filter) {
+ Expression limitAppliedExpression = limitFilterExpressions(filter.condition(), limit, ctx);
+ if (limitAppliedExpression.equals(filter.condition()) == false) {
+ return filter.with(limitAppliedExpression);
+ }
+ } else if (plan instanceof Limit) {
+ // Break if it's not the initial limit
+ breakerReached.set(firstLimit.get());
+ firstLimit.set(true);
+ } else if (plan instanceof TopN || plan instanceof Rerank || plan instanceof Aggregate) {
+ breakerReached.set(true);
+ }
+
+ return plan;
+ });
+ }
+
+ /**
+ * Applies a limit to the filter expressions of a condition. Some filter expressions, such as KNN function,
+ * can be optimized by applying the limit directly to them.
+ */
+ private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
+ return condition.transformDown(exp -> {
+ if (exp instanceof Knn knn) {
+ return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
+ }
+ return exp;
+ });
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java
index b218b897121df..fedddfa8bcaa4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java
@@ -12,6 +12,7 @@
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import java.util.ArrayList;
import java.util.Arrays;
@@ -20,8 +21,6 @@
import java.util.Objects;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
public class KnnQuery extends Query {
@@ -32,9 +31,12 @@ public class KnnQuery extends Query {
private final List filterQueries;
public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
+ private final Integer k;
- public KnnQuery(Source source, String field, float[] query, Map options, List filterQueries) {
+ public KnnQuery(Source source, String field, float[] query, Integer k, Map options, List filterQueries) {
super(source);
+ assert k != null && k > 0 : "k must be a positive integer, but was: " + k;
+ this.k = k;
assert options != null;
this.field = field;
this.query = query;
@@ -44,16 +46,24 @@ public KnnQuery(Source source, String field, float[] query, Map
@Override
protected QueryBuilder asBuilder() {
- Integer k = (Integer) options.get(K_FIELD.getPreferredName());
- Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
RescoreVectorBuilder rescoreVectorBuilder = null;
Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD);
if (oversample != null) {
rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
}
Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
-
- KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
+ Integer minCandidates = (Integer) options.get(Knn.MIN_CANDIDATES_OPTION);
+ int adjustedK = Math.max(k, minCandidates == null ? 0 : minCandidates);
+ minCandidates = minCandidates == null ? null : Math.max(minCandidates, adjustedK);
+
+ KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(
+ field,
+ query,
+ adjustedK,
+ minCandidates,
+ rescoreVectorBuilder,
+ vectorSimilarity
+ );
for (QueryBuilder filter : filterQueries) {
queryBuilder.addFilterQuery(filter);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index 869a851a1fb34..97429ea091053 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -305,7 +305,7 @@ public final void test() throws Throwable {
);
assumeFalse(
"can't use KNN function in csv tests",
- testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName())
+ testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V4.capabilityName())
);
assumeFalse(
"lookup join disabled for csv tests",
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index f26c14db41604..95a7204b5c71f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -2349,20 +2349,19 @@ public void testImplicitCasting() {
public void testDenseVectorImplicitCastingKnn() {
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
- assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+ assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
checkDenseVectorCastingKnn("float_vector");
}
private static void checkDenseVectorCastingKnn(String fieldName) {
var plan = analyze(String.format(Locale.ROOT, """
- from test | where knn(%s, [0.342, 0.164, 0.234], 10)
+ from test | where knn(%s, [0.342, 0.164, 0.234])
""", fieldName), "mapping-dense_vector.json");
var limit = as(plan, Limit.class);
var filter = as(limit.child(), Filter.class);
var knn = as(filter.condition(), Knn.class);
- var field = knn.field();
var queryVector = as(knn.query(), Literal.class);
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index 4e0814d6cc6f5..815ae4bae6b89 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -1268,8 +1268,8 @@ public void testFieldBasedFullTextFunctions() throws Exception {
checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
}
}
@@ -1401,8 +1401,8 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
}
}
@@ -1456,8 +1456,8 @@ public void testFullTextFunctionsDisjunctions() {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
}
}
@@ -1521,8 +1521,8 @@ public void testFullTextFunctionsWithNonBooleanFunctions() {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
}
}
@@ -1592,7 +1592,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)");
}
}
@@ -2189,8 +2189,8 @@ public void testFullTextFunctionOptions() {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
}
}
@@ -2282,10 +2282,9 @@ public void testFullTextFunctionsNullArgs() throws Exception {
checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
checkFullTextFunctionNullArgs("term(title, null)", "second");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first");
- checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second");
- checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
+ checkFullTextFunctionNullArgs("knn(vector, null)", "second");
}
}
@@ -2314,8 +2313,8 @@ public void testFullTextFunctionsInStats() {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
- checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+ checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
index 002c519b001f8..f87e278bd4238 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
@@ -52,7 +52,7 @@ public static Iterable