Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138372.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138372
summary: ES|QL - KNN function options support k and visit_percentage parameters
area: "ES|QL"
type: enhancement
issues: []

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ azure | [240.0, 255.0, 255.0]

knnWithNonPushableConjunction
required_capability: knn_function_v5
required_capability: knn_function_options_k_visit_percentage

from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false
| where knn(rgb_vector, [128,128,0], {"k": 100}) and composed_name == false
| sort _score desc, color asc
| keep color, composed_name
| limit 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ public void testKnnDefaults() {
}
}

public void testKnnOptions() {
public void testKnnKOverridesLimit() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 0.0f);

var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s)
| WHERE knn(vector, %s, {"k": 5, "min_candidates": 20})
| KEEP id, _score, vector
| SORT _score DESC
| LIMIT 5
| LIMIT 10
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,10 @@ public enum Cap {
*/
PROMQL_V0(Build.current().isSnapshot()),

/**
* KNN function adds support for k and visit_percentage options
*/
KNN_FUNCTION_OPTIONS_K_VISIT_PERCENTAGE,
// Last capability should still have a comma for fewer merge conflicts when adding new ones :)
// This comment prevents the semicolon from being on the previous capability when Spotless formats the file.
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
import static java.util.Map.entry;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VISIT_PERCENTAGE_FIELD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
Expand All @@ -64,16 +66,18 @@ public class Knn extends SingleFieldFullTextFunction implements OptionalArgument

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);

// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
private final transient Integer k;
// Implicit k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
private final transient Integer implicitK;
// Expressions to be used as prefilters in knn query
private final List<Expression> filterExpressions;

public static final String MIN_CANDIDATES_OPTION = "min_candidates";

public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
entry(K_FIELD.getPreferredName(), INTEGER),
entry(MIN_CANDIDATES_OPTION, INTEGER),
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
entry(VISIT_PERCENTAGE_FIELD.getPreferredName(), FLOAT),
entry(BOOST_FIELD.getPreferredName(), FLOAT),
entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
);
Expand Down Expand Up @@ -102,6 +106,15 @@ public Knn(
@MapParam(
name = "options",
params = {
@MapParam.MapParamEntry(
name = "k",
type = "integer",
valueHint = { "10" },
description = "The number of nearest neighbors to return from each shard. "
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
+ "This value must be less than or equal to num_candidates. "
+ "This value is automatically set with any LIMIT applied to the function."
),
@MapParam.MapParamEntry(
name = "boost",
type = "float",
Expand All @@ -116,7 +129,17 @@ public Knn(
description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. "
+ " KNN may use a higher number of candidates in case the query can't use a approximate results. "
+ "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
+ "Defaults to 1.5 * LIMIT used for the query."
+ "Defaults to 1.5 * k (or LIMIT) used for the query."
),
@MapParam.MapParamEntry(
name = "visit_percentage",
type = "float",
valueHint = { "10" },
description = "The percentage of vectors to explore per shard while doing knn search with bbq_disk. "
+ "Must be between 0 and 100. 0 will default to using num_candidates for calculating the percent visited. "
+ "Increasing visit_percentage tends to improve the accuracy of the final results. "
+ "If visit_percentage is set for bbq_disk, num_candidates is ignored. "
+ "Defaults to ~1% per shard for every 1 million vectors"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this wording and where this default is coming from, @benwtrent can you verify?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't default to anything. We shouldn't default it to anything. We dynamically set it according to num_candidates, which is dynamically determined via k (if not provided), which is required for knn search to work.

),
@MapParam.MapParamEntry(
name = "similarity",
Expand Down Expand Up @@ -146,12 +169,12 @@ public Knn(
Expression field,
Expression query,
Expression options,
Integer k,
Integer implicitK,
QueryBuilder queryBuilder,
List<Expression> filterExpressions
) {
super(source, field, query, options, expressionList(field, query, options), queryBuilder);
this.k = k;
this.implicitK = implicitK;
this.filterExpressions = filterExpressions;
}

Expand All @@ -165,15 +188,15 @@ private static List<Expression> expressionList(Expression field, Expression quer
return result;
}

public Integer k() {
return k;
public Integer implicitK() {
return implicitK;
}

public List<Expression> filterExpressions() {
return filterExpressions;
}

public Knn replaceK(Integer k) {
public Knn withImplicitK(Integer k) {
Check.notNull(k, "k must not be null");
return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
}
Expand All @@ -191,7 +214,7 @@ public List<Number> queryAsObject() {

@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder, filterExpressions());
}

@Override
Expand All @@ -207,7 +230,7 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {

@Override
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
assert k() != null : "Knn function must have a k value set before translation";
assert implicitK() != null : "Knn function must have a k value set before translation";
var fieldAttribute = fieldAsFieldAttribute(field());

Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
Expand All @@ -226,7 +249,10 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
}
}

return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
Map<String, Object> options = queryOptions();
Integer explicitK = (Integer) options.get(K_FIELD.getPreferredName());

return new KnnQuery(source(), fieldName, queryAsFloats, explicitK != null ? explicitK : implicitK(), options, filterQueries);
}

private float[] queryAsFloats() {
Expand All @@ -239,7 +265,7 @@ private float[] queryAsFloats() {
}

public Expression withFilters(List<Expression> filterExpressions) {
return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder(), filterExpressions);
}

private Map<String, Object> queryOptions() throws InvalidArgumentException {
Expand All @@ -264,7 +290,7 @@ protected QueryBuilder evaluatorQueryBuilder() {
@Override
public void postOptimizationVerification(Failures failures) {
// Check that a k has been set
if (k() == null) {
if (implicitK() == null) {
failures.add(
Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find")
);
Expand All @@ -278,15 +304,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
newChildren.get(0),
newChildren.get(1),
newChildren.size() > 2 ? newChildren.get(2) : null,
k(),
implicitK(),
queryBuilder(),
filterExpressions()
);
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
return NodeInfo.create(this, Knn::new, field(), query(), options(), implicitK(), queryBuilder(), filterExpressions());
}

@Override
Expand Down Expand Up @@ -334,12 +360,14 @@ public boolean equals(Object o) {
// ignore options when comparing two Knn functions
if (o == null || getClass() != o.getClass()) return false;
Knn knn = (Knn) o;
return super.equals(knn) && Objects.equals(k(), knn.k()) && Objects.equals(filterExpressions(), knn.filterExpressions());
return super.equals(knn)
&& Objects.equals(implicitK(), knn.implicitK())
&& Objects.equals(filterExpressions(), knn.filterExpressions());
}

@Override
public int hashCode() {
return Objects.hash(field(), query(), queryBuilder(), k(), filterExpressions());
return Objects.hash(field(), query(), queryBuilder(), implicitK(), filterExpressions());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
return condition.transformDown(exp -> {
if (exp instanceof Knn knn) {
return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
return knn.withImplicitK((Integer) limit.limit().fold(ctx.foldCtx()));
}
return exp;
});
Expand Down
Loading
Loading