diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperator.java index 300668163c39e..16109fca0a939 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperator.java @@ -13,12 +13,13 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.DoubleVectorBlock; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.Warnings; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xcontent.XContentBuilder; @@ -28,6 +29,7 @@ import java.util.Collection; import java.util.Deque; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -40,11 +42,26 @@ * */ public class LinearScoreEvalOperator implements Operator { - public record Factory(int discriminatorPosition, int scorePosition, LinearConfig linearConfig) implements OperatorFactory { + public record Factory( + int discriminatorPosition, + int scorePosition, + LinearConfig linearConfig, + String sourceText, + int sourceLine, + int sourceColumn + ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { - return new LinearScoreEvalOperator(discriminatorPosition, scorePosition, linearConfig); + return new LinearScoreEvalOperator( + driverContext, + discriminatorPosition, + scorePosition, + linearConfig, + sourceText, + sourceLine, + sourceColumn + ); } @Override @@ -74,11 +91,30 @@ public String describe() { private long rowsReceived = 0; private long rowsEmitted = 0; - public LinearScoreEvalOperator(int discriminatorPosition, int scorePosition, LinearConfig config) { + private final String sourceText; + private final int sourceLine; + private final int sourceColumn; + private Warnings warnings; + private final DriverContext driverContext; + + public LinearScoreEvalOperator( + DriverContext driverContext, + int discriminatorPosition, + int scorePosition, + LinearConfig config, + String sourceText, + int sourceLine, + int sourceColumn + ) { this.scorePosition = scorePosition; this.discriminatorPosition = discriminatorPosition; this.config = config; this.normalizer = createNormalizer(config.normalizer()); + this.driverContext = driverContext; + + this.sourceText = sourceText; + this.sourceLine = sourceLine; + this.sourceColumn = sourceColumn; finished = false; inputPages = new ArrayDeque<>(); @@ -123,25 +159,54 @@ private void createOutputPages() { private void processInputPage(Page inputPage) { BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition); - DoubleVectorBlock initialScoreBlock = inputPage.getBlock(scorePosition); + DoubleBlock initialScoreBlock = inputPage.getBlock(scorePosition); Page newPage = null; Block scoreBlock = null; - DoubleVector.Builder scores = null; + DoubleBlock.Builder scores = null; try { - scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount()); + scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount()); for (int i = 0; i < inputPage.getPositionCount(); i++) { - String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString(); + Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i); + + if (discriminatorValue == null) { + warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores")); + scores.appendNull(); + continue; + } else if (discriminatorValue instanceof List) { + warnings().registerException( + new IllegalArgumentException("group column contains multivalued entries; assigning null scores") + ); + scores.appendNull(); + continue; + } + String discriminator = ((BytesRef) discriminatorValue).utf8ToString(); var weight = config.weights().get(discriminator) == null ? 1.0 : config.weights().get(discriminator); - double score = initialScoreBlock.getDouble(i); + initialScoreBlock.doesHaveMultivaluedFields(); + + Object scoreValue = BlockUtils.toJavaObject(initialScoreBlock, i); + if (scoreValue == null) { + warnings().registerException(new IllegalArgumentException("score column has null values; assigning null scores")); + scores.appendNull(); + continue; + } else if (scoreValue instanceof List) { + warnings().registerException( + new IllegalArgumentException("score column contains multivalued entries; assigning null scores") + ); + scores.appendNull(); + continue; + } + + double score = (double) scoreValue; + scores.appendDouble(weight * normalizer.normalize(score, discriminator)); } - scoreBlock = scores.build().asBlock(); + scoreBlock = scores.build(); newPage = inputPage.appendBlock(scoreBlock); int[] projections = new int[newPage.getBlockCount() - 1]; @@ -270,23 +335,43 @@ private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) { }; } - private interface Normalizer { - double normalize(double score, String discriminator); + private abstract static class Normalizer { + abstract double normalize(double score, String discriminator); - void preprocess(Collection inputPages, int scorePosition, int discriminatorPosition); + abstract void preprocess(double score, String discriminator); + + void finalizePreprocess() {}; + + void preprocess(Collection inputPages, int scorePosition, int discriminatorPosition) { + for (Page inputPage : inputPages) { + DoubleBlock scoreBlock = inputPage.getBlock(scorePosition); + BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition); + + for (int i = 0; i < inputPage.getPositionCount(); i++) { + Object scoreValue = BlockUtils.toJavaObject(scoreBlock, i); + Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i); + + if (scoreValue instanceof Double score && discriminatorValue instanceof BytesRef discriminator) { + preprocess(score, discriminator.utf8ToString()); + } + } + } + + finalizePreprocess(); + } } - private class NoneNormalizer implements Normalizer { + private static class NoneNormalizer extends Normalizer { @Override public double normalize(double score, String discriminator) { return score; } @Override - public void preprocess(Collection inputPages, int scorePosition, int discriminatorPosition) {} + void preprocess(double score, String discriminator) {} } - private class L2NormNormalizer implements Normalizer { + private static class L2NormNormalizer extends Normalizer { private final Map l2Norms = new HashMap<>(); @Override @@ -297,24 +382,17 @@ public double normalize(double score, String discriminator) { } @Override - public void preprocess(Collection inputPages, int scorePosition, int discriminatorPosition) { - for (Page inputPage : inputPages) { - DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition); - BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition); - - for (int i = 0; i < inputPage.getPositionCount(); i++) { - double score = scoreBlock.getDouble(i); - String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString(); - - l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score); - } - } + void preprocess(double score, String discriminator) { + l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score); + } + @Override + void finalizePreprocess() { l2Norms.replaceAll((k, v) -> Math.sqrt(v)); } } - private class MinMaxNormalizer implements Normalizer { + private static class MinMaxNormalizer extends Normalizer { private final Map minScores = new HashMap<>(); private final Map maxScores = new HashMap<>(); @@ -334,19 +412,17 @@ public double normalize(double score, String discriminator) { } @Override - public void preprocess(Collection inputPages, int scorePosition, int discriminatorPosition) { - for (Page inputPage : inputPages) { - DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition); - BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition); - - for (int i = 0; i < inputPage.getPositionCount(); i++) { - double score = scoreBlock.getDouble(i); - String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString(); + void preprocess(double score, String discriminator) { + minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score)); + maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score)); + } + } - minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score)); - maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score)); - } - } + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText); } + + return warnings; } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperator.java index 519f04bcf60b7..9c4ca7d928715 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperator.java @@ -9,15 +9,18 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.Warnings; import org.elasticsearch.core.Releasables; import java.util.HashMap; +import java.util.List; /** * Updates the score column with new scores using the RRF formula. @@ -27,10 +30,25 @@ */ public class RrfScoreEvalOperator extends AbstractPageMappingOperator { - public record Factory(int discriminatorPosition, int scorePosition, RrfConfig rrfConfig) implements OperatorFactory { + public record Factory( + int discriminatorPosition, + int scorePosition, + RrfConfig rrfConfig, + String sourceText, + int sourceLine, + int sourceColumn + ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { - return new RrfScoreEvalOperator(discriminatorPosition, scorePosition, rrfConfig); + return new RrfScoreEvalOperator( + driverContext, + discriminatorPosition, + scorePosition, + rrfConfig, + sourceText, + sourceLine, + sourceColumn + ); } @Override @@ -48,37 +66,62 @@ public String describe() { private final int scorePosition; private final int discriminatorPosition; private final RrfConfig config; + private Warnings warnings; + private final DriverContext driverContext; + private final String sourceText; + private final int sourceLine; + private final int sourceColumn; private HashMap counters = new HashMap<>(); - public RrfScoreEvalOperator(int discriminatorPosition, int scorePosition, RrfConfig config) { + public RrfScoreEvalOperator( + DriverContext driverContext, + int discriminatorPosition, + int scorePosition, + RrfConfig config, + String sourceText, + int sourceLine, + int sourceColumn + ) { this.scorePosition = scorePosition; this.discriminatorPosition = discriminatorPosition; this.config = config; + this.driverContext = driverContext; + this.sourceText = sourceText; + this.sourceLine = sourceLine; + this.sourceColumn = sourceColumn; } @Override protected Page process(Page page) { - BytesRefBlock discriminatorBlock = (BytesRefBlock) page.getBlock(discriminatorPosition); - - DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount()); + BytesRefBlock discriminatorBlock = page.getBlock(discriminatorPosition); + DoubleBlock.Builder scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount()); for (int i = 0; i < page.getPositionCount(); i++) { - String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString(); - - int rank = counters.getOrDefault(discriminator, 1); - counters.put(discriminator, rank + 1); - - var weight = config.weights().getOrDefault(discriminator, 1.0); - - scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight); + Object value = BlockUtils.toJavaObject(discriminatorBlock, i); + + if (value == null) { + warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores")); + scores.appendNull(); + } else if (value instanceof List) { + warnings().registerException( + new IllegalArgumentException("group column contains multivalued entries; assigning null scores") + ); + scores.appendNull(); + } else { + String discriminator = ((BytesRef) value).utf8ToString(); + int rank = counters.getOrDefault(discriminator, 1); + var weight = config.weights().getOrDefault(discriminator, 1.0); + scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight); + counters.put(discriminator, rank + 1); + } } Page newPage = null; Block scoreBlock = null; try { - scoreBlock = scores.build().asBlock(); + scoreBlock = scores.build(); newPage = page.appendBlock(scoreBlock); int[] projections = new int[newPage.getBlockCount() - 1]; @@ -105,4 +148,12 @@ protected Page process(Page page) { public String toString() { return "RrfScoreEvalOperator"; } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText); + } + + return warnings; + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/FuseOperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/FuseOperatorTestCase.java index de42c8106212c..eb17d3aef99bc 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/FuseOperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/FuseOperatorTestCase.java @@ -79,7 +79,7 @@ protected Page createPage(int positionOffset, int length) { if (b == scorePosition) { try (var builder = blockFactory.newDoubleBlockBuilder(length)) { for (int i = 0; i < length; i++) { - builder.appendDouble(randomDouble()); + builder.appendDouble(randomDoubleBetween(-1000, 1000, true)); } blocks[b] = builder.build(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperatorTests.java index 4fdea6353332b..7c6b685758162 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/LinearScoreEvalOperatorTests.java @@ -15,6 +15,8 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class LinearScoreEvalOperatorTests extends FuseOperatorTestCase { private LinearConfig config; @@ -28,13 +30,26 @@ public void setup() { protected void assertSimpleOutput(List input, List results) { assertOutput(input, results, (discriminator, actualScore, initialScore) -> { var weight = config.weights().getOrDefault(discriminator, 1.0); - assertEquals(actualScore, initialScore * weight, 0.00); + if (config.normalizer() == LinearConfig.Normalizer.NONE) { + assertEquals(actualScore, initialScore * weight, 0.00); + } else if (config.normalizer() == LinearConfig.Normalizer.MINMAX) { + // for min_max, we know the normalized scores will be between 0..1 + // when we apply the weight, the scores should be between 0 and weight + assertThat(actualScore, lessThanOrEqualTo(weight)); + assertThat(actualScore, greaterThanOrEqualTo(0.0)); + } else { + // for l2_norm, we could be dealing with negative scores + // in this case the normalized scores will be between -1 and 1. + // when we apply the weight, the scores should be between -weight and weight + assertThat(actualScore, lessThanOrEqualTo(weight)); + assertThat(actualScore, greaterThanOrEqualTo(-weight)); + } }); } @Override protected Operator.OperatorFactory simple(SimpleOptions options) { - return new LinearScoreEvalOperator.Factory(discriminatorPosition, scorePosition, config); + return new LinearScoreEvalOperator.Factory(discriminatorPosition, scorePosition, config, null, 0, 0); } @Override @@ -64,6 +79,6 @@ protected Matcher expectedToStringOfSimple() { } private LinearConfig randomConfig() { - return new LinearConfig(LinearConfig.Normalizer.NONE, randomWeights()); + return new LinearConfig(randomFrom(LinearConfig.Normalizer.values()), randomWeights()); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperatorTests.java index 7f0347897a814..87907cd003493 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/fuse/RrfScoreEvalOperatorTests.java @@ -40,7 +40,7 @@ protected void assertSimpleOutput(List input, List results) { @Override protected Operator.OperatorFactory simple(SimpleOptions options) { - return new RrfScoreEvalOperator.Factory(discriminatorPosition, scorePosition, config); + return new RrfScoreEvalOperator.Factory(discriminatorPosition, scorePosition, config, null, 0, 0); } @Override diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/fuse.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/fuse.csv-spec index 304f9fa8009b2..fe7788beea1b8 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/fuse.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/fuse.csv-spec @@ -4,7 +4,7 @@ simpleFuse required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon FROM employees METADATA _id, _index, _score @@ -23,7 +23,7 @@ _score:double | _fork:keyword | emp_no:integer fuseWithMatchAndScore required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon FROM books METADATA _id, _index, _score @@ -46,7 +46,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithDisjunctionAndPostFilter required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon FROM books METADATA _id, _index, _score @@ -69,7 +69,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithStats required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon FROM books METADATA _id, _index, _score @@ -89,7 +89,7 @@ count_fork:long | _fork:keyword fuseWithMultipleForkBranches required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon FROM books METADATA _id, _index, _score @@ -116,7 +116,7 @@ _score:double | author:keyword | title:keyword | _fork fuseWithSemanticSearch required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -137,7 +137,7 @@ _fork:keyword | _score:double | _id:keyword | semantic_text_field:keyword fuseWithSimpleRrf required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -161,7 +161,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithRrfAndRankConstant required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -185,7 +185,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithRrfAndWeights required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -209,7 +209,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithRrfRankConstantAndWeights required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -233,7 +233,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithRrfAndScoreColumn required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -257,7 +257,7 @@ my_score:double | _fork:keyword | _id:keyword fuseWithRrfAndDiscriminatorColumn required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -283,7 +283,7 @@ _score:double | new_fork:keyword | _id:keyword fuseWithRrfAndKeyColumns required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -309,7 +309,7 @@ _score:double | _fork:keyword | new_id:keyword fuseWithRrfAllOptionsScoreKeyAndGroupColumns required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -337,7 +337,7 @@ new_score:double | new_fork:keyword | new_id:keyword fuseWithSimpleLinear required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -362,7 +362,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearAndL2Norm required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -387,7 +387,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearAndMinMax required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -411,7 +411,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearAndWeights required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -435,7 +435,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearAndPartialWeights required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -459,7 +459,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearWeightsAndMinMax required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -484,7 +484,7 @@ _score:double | _fork:keyword | _id:keyword fuseWithLinearAndScoreColumn required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -509,7 +509,7 @@ my_score:double | _fork:keyword | _id:keyword fuseWithLinearAndDiscriminatorColumn required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -535,7 +535,7 @@ _score:double | new_fork:keyword | _id:keyword fuseWithLinearAndKeyColumns required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -560,7 +560,7 @@ _score:double | _fork:keyword | new_id:keyword fuseWithLinearAllOptionsScoreGroupAndKeyColumns required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: semantic_text_field_caps required_capability: metadata_score @@ -589,7 +589,7 @@ new_score:double | new_fork:keyword | new_id:keyword fuseWithRowAndRRF -required_capability: fuse_v5 +required_capability: fuse_v6 ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", _fork = "foo" | MV_EXPAND my_score @@ -612,7 +612,7 @@ id_4.0 | 0.01538 fuseWithRowLinearAndWeights -required_capability: fuse_v5 +required_capability: fuse_v6 ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", _fork = "foo" | MV_EXPAND my_score @@ -634,7 +634,7 @@ id_0.0 | 0.0 fuseWithRowLinearAndMinMax -required_capability: fuse_v5 +required_capability: fuse_v6 ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", _fork = "foo" | MV_EXPAND my_score @@ -657,7 +657,7 @@ id_0.0 | 0.0 fuseWithRowLinearAndL2Norm -required_capability: fuse_v5 +required_capability: fuse_v6 ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", _fork = "foo" | MV_EXPAND my_score @@ -677,3 +677,115 @@ id_2.0 | 0.36515 id_1.0 | 0.18257 id_0.0 | 0.0 ; + +fuseWithRowLinearAndMultiValueGroupColumn + +required_capability: fuse_v6 + +ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", my_fork = "foo" +| MV_EXPAND my_score +| EVAL _id = CONCAT("id_", my_score::string), my_fork = CASE(my_score == 0, ["foo", "bar"], my_fork) +| SORT my_score +| LIMIT 10 +| FUSE LINEAR SCORE BY my_score GROUP BY my_fork +| EVAL my_score = round(my_score, 5) +| SORT my_score DESC +| KEEP _id, my_score, my_fork +; + +warning:Line 6:3: evaluation of [FUSE LINEAR SCORE BY my_score GROUP BY my_fork] failed, treating result as null. Only first 20 failures recorded. +warning:Line 6:3: java.lang.IllegalArgumentException: group column contains multivalued entries; assigning null scores + +_id:keyword | my_score:double | my_fork:keyword +id_0.0 | null | [foo, bar] +id_4.0 | 4.0 | foo +id_3.0 | 3.0 | foo +id_2.0 | 2.0 | foo +id_1.0 | 1.0 | foo +; + +fuseWithRowLinearAndMultiValueScoreColumn + +required_capability: fuse_v6 + +ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", my_fork = "foo" +| MV_EXPAND my_score +| EVAL _id = CONCAT("id_", my_score::string), my_score = CASE(my_score == 0, [1, 2]::double, my_score) +| SORT my_score +| LIMIT 10 +| FUSE LINEAR SCORE BY my_score GROUP BY my_fork +| EVAL my_score = round(my_score, 5) +| SORT my_score DESC +| KEEP _id, my_score, my_fork +; + +warning:Line 6:3: evaluation of [FUSE LINEAR SCORE BY my_score GROUP BY my_fork] failed, treating result as null. Only first 20 failures recorded. +warning:Line 6:3: java.lang.IllegalArgumentException: score column contains multivalued entries; assigning null scores + +_id:keyword | my_score:double | my_fork:keyword +id_0.0 | null | foo +id_4.0 | 4.0 | foo +id_3.0 | 3.0 | foo +id_2.0 | 2.0 | foo +id_1.0 | 1.0 | foo +; + +fuseWithRowRRFAndMultiValueGroupColumn + +required_capability: fuse_v6 + +ROW my_score = [0, 1, 2, 3, 4]::double, _index = "my_index", my_fork = "foo" +| MV_EXPAND my_score +| EVAL _id = CONCAT("id_", my_score::string), my_fork = CASE(my_score == 0, ["foo", "bar"], my_fork) +| SORT my_score +| LIMIT 10 +| FUSE RRF SCORE BY my_score GROUP BY my_fork +| EVAL my_score = round(my_score, 5) +| SORT my_score DESC +| KEEP _id, my_score, my_fork +; + +warning:Line 6:3: evaluation of [FUSE RRF SCORE BY my_score GROUP BY my_fork] failed, treating result as null. Only first 20 failures recorded. +warning:Line 6:3: java.lang.IllegalArgumentException: group column contains multivalued entries; assigning null scores + +_id:keyword | my_score:double | my_fork:keyword +id_0.0 | null | [foo, bar] +id_1.0 | 0.01639 | foo +id_2.0 | 0.01613 | foo +id_3.0 | 0.01587 | foo +id_4.0 | 0.01563 | foo +; + +fuseWithRowLinearL2NormAndZeroScores + +required_capability: fuse_v6 + +ROW _id = ["a", "b", "c"], _score = 0.0, _index = "my_index", my_fork = "foo" +| MV_EXPAND _id +| FUSE LINEAR GROUP BY my_fork WITH { "normalizer": "l2_norm" } +| SORT _id +| KEEP _id, _score +; + +_id:keyword | _score:double +a | 0.0 +b | 0.0 +c | 0.0 +; + +fuseWithRowLinearMinMaxAndZeroScores + +required_capability: fuse_v6 + +ROW _id = ["a", "b", "c"], _score = 0.0, _index = "my_index", my_fork = "foo" +| MV_EXPAND _id +| FUSE LINEAR GROUP BY my_fork WITH { "normalizer": "minmax" } +| SORT _id +| KEEP _id, _score +; + +_id:keyword | _score:double +a | 0.0 +b | 0.0 +c | 0.0 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec index 145095a1fe4f0..e7eb7dd3c7021 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/rerank.csv-spec @@ -204,7 +204,7 @@ book_no:keyword | title:text | author reranker after FUSE required_capability: fork_v9 -required_capability: fuse_v5 +required_capability: fuse_v6 required_capability: match_operator_colon required_capability: rerank diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseIT.java index 953678f5a2cf9..bced9d352ca39 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseIT.java @@ -28,7 +28,7 @@ protected Collection> nodePlugins() { @Before public void setupIndex() { - assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); createAndPopulateIndex(); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseWithInvalidLicenseIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseWithInvalidLicenseIT.java index 8d1157318e4ef..802a200cda2ad 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseWithInvalidLicenseIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/FuseWithInvalidLicenseIT.java @@ -31,7 +31,7 @@ protected Collection> nodePlugins() { @Before public void setupIndex() { - assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); var indexName = "test"; var client = client().admin().indices(); var CreateRequest = client.prepareCreate(indexName) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 971aa81eb1074..a1b63f3500885 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1367,7 +1367,7 @@ public enum Cap { /** * FUSE command */ - FUSE_V5(Build.current().isSnapshot()), + FUSE_V6(Build.current().isSnapshot()), /** * Support improved behavior for LIKE operator when used with index fields. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 09acb8dcdc041..532a5cbb5e64a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -330,33 +330,35 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti private PhysicalOperation planFuseScoreEvalExec(FuseScoreEvalExec fuse, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(fuse.child(), context); + Layout layout = source.layout; - int scorePosition = -1; - int discriminatorPosition = -1; - int pos = 0; - - for (Attribute attr : fuse.child().output()) { - if (attr.name().equals(fuse.discriminator().name())) { - discriminatorPosition = pos; - } - if (attr.name().equals(fuse.score().name())) { - scorePosition = pos; - } - - pos += 1; - } - - if (scorePosition == -1) { - throw new IllegalStateException("can't find score attribute position"); - } - if (discriminatorPosition == -1) { - throw new IllegalStateException("can't find discriminator attribute position"); - } + int scorePosition = layout.get(fuse.score().id()).channel(); + int discriminatorPosition = layout.get(fuse.discriminator().id()).channel(); if (fuse.fuseConfig() instanceof RrfConfig rrfConfig) { - return source.with(new RrfScoreEvalOperator.Factory(discriminatorPosition, scorePosition, rrfConfig), source.layout); + return source.with( + new RrfScoreEvalOperator.Factory( + discriminatorPosition, + scorePosition, + rrfConfig, + fuse.sourceText(), + fuse.sourceLocation().getLineNumber(), + fuse.sourceLocation().getColumnNumber() + ), + source.layout + ); } else if (fuse.fuseConfig() instanceof LinearConfig linearConfig) { - return source.with(new LinearScoreEvalOperator.Factory(discriminatorPosition, scorePosition, linearConfig), source.layout); + return source.with( + new LinearScoreEvalOperator.Factory( + discriminatorPosition, + scorePosition, + linearConfig, + fuse.sourceText(), + fuse.sourceLocation().getLineNumber(), + fuse.sourceLocation().getColumnNumber() + ), + source.layout + ); } throw new EsqlIllegalArgumentException("unknown FUSE score method [" + fuse.fuseConfig() + "]"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f312576baedde..4305b227f298e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -3625,7 +3625,7 @@ public void testForkError() { } public void testValidFuse() { - assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); LogicalPlan plan = analyze(""" from test metadata _id, _index, _score @@ -3649,7 +3649,7 @@ public void testValidFuse() { } public void testFuseError() { - assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); var e = expectThrows(VerificationException.class, () -> analyze(""" from test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index c416248eb4506..bdf824443d0de 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2569,7 +2569,7 @@ public void testInvalidTBucketCalls() { } public void testFuse() { - assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); String queryPrefix = "from test metadata _score, _index, _id | fork (where true) (where true)"; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index f29df6d9d4252..bc8f9bda86fbb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -4218,7 +4218,7 @@ static Alias alias(String name, Expression value) { } public void testValidFuse() { - assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); LogicalPlan plan = statement(""" FROM foo* METADATA _id, _index, _score @@ -4318,7 +4318,7 @@ public void testValidFuse() { } public void testInvalidFuse() { - assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE requires corresponding capability", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); String queryPrefix = "from test metadata _score, _index, _id | fork (where true) (where true)"; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/FieldNameUtilsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/FieldNameUtilsTests.java index 4185a2da4e84d..260887c9db2ea 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/FieldNameUtilsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/FieldNameUtilsTests.java @@ -2225,7 +2225,7 @@ public void testForkRef4() { } public void testRerankerAfterFuse() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM books METADATA _id, _index, _score @@ -2240,7 +2240,7 @@ public void testRerankerAfterFuse() { } public void testSimpleFuse() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM employees METADATA _id, _index, _score @@ -2253,7 +2253,7 @@ public void testSimpleFuse() { } public void testFuseWithMatchAndScore() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM books METADATA _id, _index, _score @@ -2267,7 +2267,7 @@ public void testFuseWithMatchAndScore() { } public void testFuseWithDisjunctionAndPostFilter() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM books METADATA _id, _index, _score @@ -2282,7 +2282,7 @@ public void testFuseWithDisjunctionAndPostFilter() { } public void testFuseWithStats() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM books METADATA _id, _index, _score @@ -2295,7 +2295,7 @@ public void testFuseWithStats() { } public void testFuseWithMultipleForkBranches() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM books METADATA _id, _index, _score @@ -2312,7 +2312,7 @@ public void testFuseWithMultipleForkBranches() { } public void testFuseWithSemanticSearch() { - assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V5.isEnabled()); + assumeTrue("FUSE required", EsqlCapabilities.Cap.FUSE_V6.isEnabled()); assertTrue("FORK required", EsqlCapabilities.Cap.FORK_V9.isEnabled()); assertFieldNames(""" FROM semantic_text METADATA _id, _score, _index