From 6455899eb6e4f88cbb099f5da089114d006571cb Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 13 Oct 2025 15:56:14 +0100 Subject: [PATCH] Add TS, FUSE and INLINE STATS ES|QL commands and recently added functions (#3096) * Various fixes in the ES|QL query builder docstrings * add TS, FUSE and INLINE STATS commands * add new ES|QL functions (cherry picked from commit e62eaf231820c30eb02d1e5ff52e6cfc7fcb1e90) --- elasticsearch/esql/esql.py | 261 +++++++++++++++++++++++++++----- elasticsearch/esql/functions.py | 88 +++++++++++ test_elasticsearch/test_esql.py | 164 ++++++++++++++++++++ 3 files changed, 475 insertions(+), 38 deletions(-) diff --git a/elasticsearch/esql/esql.py b/elasticsearch/esql/esql.py index 6643ddc67..501148e39 100644 --- a/elasticsearch/esql/esql.py +++ b/elasticsearch/esql/esql.py @@ -18,7 +18,7 @@ import json import re from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..dsl.document_base import DocumentBase, InstrumentedExpression, InstrumentedField @@ -78,6 +78,22 @@ def show(item: str) -> "Show": """ return Show(item) + @staticmethod + def ts(*indices: IndexType) -> "TS": + """The ``TS`` source command is similar to ``FROM``, but for time series indices. + + :param indices: A list of indices, data streams or aliases. Supports wildcards and date math. + + Examples:: + + query = ( + ESQL.ts("metrics") + .where("@timestamp >= now() - 1 day") + .stats("SUM(AVG_OVER_TIME(memory_usage)").by("host", "TBUCKET(1 hour)") + ) + """ + return TS(*indices) + @staticmethod def branch() -> "Branch": """This method can only be used inside a ``FORK`` command to create each branch. @@ -284,14 +300,14 @@ def eval(self, *columns: ExpressionType, **named_columns: ExpressionType) -> "Ev def fork( self, - fork1: "ESQLBase", - fork2: Optional["ESQLBase"] = None, - fork3: Optional["ESQLBase"] = None, - fork4: Optional["ESQLBase"] = None, - fork5: Optional["ESQLBase"] = None, - fork6: Optional["ESQLBase"] = None, - fork7: Optional["ESQLBase"] = None, - fork8: Optional["ESQLBase"] = None, + fork1: "Branch", + fork2: Optional["Branch"] = None, + fork3: Optional["Branch"] = None, + fork4: Optional["Branch"] = None, + fork5: Optional["Branch"] = None, + fork6: Optional["Branch"] = None, + fork7: Optional["Branch"] = None, + fork8: Optional["Branch"] = None, ) -> "Fork": """The ``FORK`` processing command creates multiple execution branches to operate on the same input data and combines the results in a single output table. @@ -314,6 +330,51 @@ def fork( raise ValueError("a query can only have one fork") return Fork(self, fork1, fork2, fork3, fork4, fork5, fork6, fork7, fork8) + def fuse(self, method: Optional[str] = None) -> "Fuse": + """The ``FUSE`` processing command merges rows from multiple result sets and assigns + new relevance scores. + + :param method: Defaults to ``RRF``. Can be one of ``RRF`` (for Reciprocal Rank Fusion) + or ``LINEAR`` (for linear combination of scores). Designates which + method to use to assign new relevance scores. + + Examples:: + + query1 = ( + ESQL.from_("books").metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse() + ) + query2 = ( + ESQL.from_("books").metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear") + ) + query3 = ( + ESQL.from_("books").metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear").by("title", "description") + ) + query4 = ( + ESQL.from_("books").metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear").with_(normalizer="minmax") + ) + """ + return Fuse(self, method) + def grok(self, input: FieldType, pattern: str) -> "Grok": """``GROK`` enables you to extract structured data out of a string. @@ -348,6 +409,58 @@ def grok(self, input: FieldType, pattern: str) -> "Grok": """ return Grok(self, input, pattern) + def inline_stats( + self, *expressions: ExpressionType, **named_expressions: ExpressionType + ) -> "Stats": + """The ``INLINE STATS`` processing command groups rows according to a common value + and calculates one or more aggregated values over the grouped rows. + + The command is identical to ``STATS`` except that it preserves all the columns from + the input table. + + :param expressions: A list of expressions, given as positional arguments. + :param named_expressions: A list of expressions, given as keyword arguments. The + argument names are used for the returned aggregated values. + + Note that only one of ``expressions`` and ``named_expressions`` must be provided. + + Examples:: + + query1 = ( + ESQL.from_("employees") + .keep("emp_no", "languages", "salary") + .inline_stats(max_salary=functions.max(E("salary"))).by("languages") + ) + query2 = ( + ESQL.from_("employees") + .keep("emp_no", "languages", "salary") + .inline_stats(max_salary=functions.max(E("salary"))) + ) + query3 = ( + ESQL.from_("employees") + .where("still_hired") + .keep("emp_no", "languages", "salary", "hire_date") + .eval(tenure=functions.date_diff("year", E("hire_date"), "2025-09-18T00:00:00")) + .drop("hire_date") + .inline_stats( + avg_salary=functions.avg(E("salary")), + count=functions.count(E("*")), + ) + .by("languages", "tenure") + ) + query4 = ( + ESQL.from_("employees") + .keep("emp_no", "salary") + .inline_stats( + avg_lt_50=functions.round(functions.avg(E("salary"))).where(E("salary") < 50000), + avg_lt_60=functions.round(functions.avg(E("salary"))).where(E("salary") >= 50000, E("salary") < 60000), + avg_gt_60=functions.round(functions.avg(E("salary"))).where(E("salary") >= 60000), + ) + ) + + """ + return InlineStats(self, *expressions, **named_expressions) + def keep(self, *columns: FieldType) -> "Keep": """The ``KEEP`` processing command enables you to specify what columns are returned and the order in which they are returned. @@ -377,7 +490,7 @@ def limit(self, max_number_of_rows: int) -> "Limit": return Limit(self, max_number_of_rows) def lookup_join(self, lookup_index: IndexType) -> "LookupJoin": - """`LOOKUP JOIN` enables you to add data from another index, AKA a 'lookup' index, + """``LOOKUP JOIN`` enables you to add data from another index, AKA a 'lookup' index, to your ES|QL query results, simplifying data enrichment and analysis workflows. :param lookup_index: The name of the lookup index. This must be a specific index @@ -411,7 +524,7 @@ def lookup_join(self, lookup_index: IndexType) -> "LookupJoin": return LookupJoin(self, lookup_index) def mv_expand(self, column: FieldType) -> "MvExpand": - """The `MV_EXPAND` processing command expands multivalued columns into one row per + """The ``MV_EXPAND`` processing command expands multivalued columns into one row per value, duplicating other columns. :param column: The multivalued column to expand. @@ -449,7 +562,7 @@ def rerank(self, *query: ExpressionType, **named_query: ExpressionType) -> "Rera :param named_query: The query text used to rerank the documents, given as a keyword argument. The argument name is used for the column name. If the query is given as a positional argument, the - results will be stored in a column named `_score`. If the + results will be stored in a column named ``_score``. If the specified column already exists, it will be overwritten with the new results. @@ -540,7 +653,7 @@ def stats( :param named_expressions: A list of expressions, given as keyword arguments. The argument names are used for the returned aggregated values. - Note that only one of `expressions` and `named_expressions` must be provided. + Note that only one of ``expressions`` and ``named_expressions`` must be provided. Examples:: @@ -596,7 +709,7 @@ def stats( def where(self, *expressions: ExpressionType) -> "Where": """The ``WHERE`` processing command produces a table that contains all the rows - from the input table for which the provided condition evaluates to `true`. + from the input table for which the provided condition evaluates to ``true``. :param expressions: A list of boolean expressions, given as positional arguments. These expressions are combined with an ``AND`` logical operator. @@ -629,13 +742,15 @@ class From(ESQLBase): in a single expression. """ + command_name = "FROM" + def __init__(self, *indices: IndexType): super().__init__() self._indices = indices self._metadata_fields: Tuple[FieldType, ...] = tuple() def metadata(self, *fields: FieldType) -> "From": - """Continuation of the ``FROM`` source command. + """Continuation of the ``FROM`` and ``TS`` source commands. :param fields: metadata fields to retrieve, given as positional arguments. """ @@ -644,7 +759,7 @@ def metadata(self, *fields: FieldType) -> "From": def _render_internal(self) -> str: indices = [self._format_index(index) for index in self._indices] - s = f'{self.__class__.__name__.upper()} {", ".join(indices)}' + s = f'{self.command_name} {", ".join(indices)}' if self._metadata_fields: s = ( s @@ -692,6 +807,17 @@ def _render_internal(self) -> str: return f"SHOW {self._format_id(self._item)}" +class TS(From): + """Implementation of the ``TS`` source command. + + This class inherits from :class:`ESQLBase `, + to make it possible to chain all the commands that belong to an ES|QL query + in a single expression. + """ + + command_name = "TS" + + class Branch(ESQLBase): """Implementation of a branch inside a ``FORK`` processing command. @@ -720,21 +846,22 @@ def __init__(self, parent: ESQLBase, value: FieldType): self._pvalue_name: Optional[str] = None def on(self, key: FieldType) -> "ChangePoint": - """Continuation of the `CHANGE_POINT` command. + """Continuation of the ``CHANGE_POINT`` command. :param key: The column with the key to order the values by. If not specified, - `@timestamp` is used. + ``@timestamp`` is used. """ self._key = key return self def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint": - """Continuation of the `CHANGE_POINT` command. + """Continuation of the ``CHANGE_POINT`` command. :param type_name: The name of the output column with the change point type. - If not specified, `type` is used. + If not specified, ``type`` is used. :param pvalue_name: The name of the output column with the p-value that indicates - how extreme the change point is. If not specified, `pvalue` is used. + how extreme the change point is. If not specified, ``pvalue`` + is used. """ self._type_name = type_name self._pvalue_name = pvalue_name @@ -771,10 +898,10 @@ def __init__( self._inference_id: Optional[str] = None def with_(self, inference_id: str) -> "Completion": - """Continuation of the `COMPLETION` command. + """Continuation of the ``COMPLETION`` command. :param inference_id: The ID of the inference endpoint to use for the task. The - inference endpoint must be configured with the completion + inference endpoint must be configured with the ``completion`` task type. """ self._inference_id = inference_id @@ -863,7 +990,7 @@ def on(self, match_field: FieldType) -> "Enrich": :param match_field: The match field. ``ENRICH`` uses its value to look for records in the enrich index. If not specified, the match will be performed on the column with the same name as the - `match_field` defined in the enrich policy. + ``match_field`` defined in the enrich policy. """ self._match_field = match_field return self @@ -953,14 +1080,14 @@ class Fork(ESQLBase): def __init__( self, parent: ESQLBase, - fork1: ESQLBase, - fork2: Optional[ESQLBase] = None, - fork3: Optional[ESQLBase] = None, - fork4: Optional[ESQLBase] = None, - fork5: Optional[ESQLBase] = None, - fork6: Optional[ESQLBase] = None, - fork7: Optional[ESQLBase] = None, - fork8: Optional[ESQLBase] = None, + fork1: "Branch", + fork2: Optional["Branch"] = None, + fork3: Optional["Branch"] = None, + fork4: Optional["Branch"] = None, + fork5: Optional["Branch"] = None, + fork6: Optional["Branch"] = None, + fork7: Optional["Branch"] = None, + fork8: Optional["Branch"] = None, ): super().__init__(parent) self._branches = [fork1, fork2, fork3, fork4, fork5, fork6, fork7, fork8] @@ -977,6 +1104,39 @@ def _render_internal(self) -> str: return f"FORK {cmds}" +class Fuse(ESQLBase): + """Implementation of the ``FUSE`` processing command. + + This class inherits from :class:`ESQLBase `, + to make it possible to chain all the commands that belong to an ES|QL query + in a single expression. + """ + + def __init__(self, parent: ESQLBase, method: Optional[str] = None): + super().__init__(parent) + self.method = method + self.by_columns: List[FieldType] = [] + self.options: Dict[str, Any] = {} + + def by(self, *columns: FieldType) -> "Fuse": + self.by_columns += list(columns) + return self + + def with_(self, **options: Any) -> "Fuse": + self.options = options + return self + + def _render_internal(self) -> str: + method = f" {self.method.upper()}" if self.method else "" + by = ( + " " + " ".join([f"BY {column}" for column in self.by_columns]) + if self.by_columns + else "" + ) + with_ = " WITH " + json.dumps(self.options) if self.options else "" + return f"FUSE{method}{by}{with_}" + + class Grok(ESQLBase): """Implementation of the ``GROK`` processing command. @@ -1040,7 +1200,7 @@ def __init__(self, parent: ESQLBase, lookup_index: IndexType): self._field: Optional[FieldType] = None def on(self, field: FieldType) -> "LookupJoin": - """Continuation of the `LOOKUP_JOIN` command. + """Continuation of the ``LOOKUP JOIN`` command. :param field: The field to join on. This field must exist in both your current query results and in the lookup index. If the field contains multi-valued @@ -1117,14 +1277,19 @@ def __init__( self._inference_id: Optional[str] = None def on(self, *fields: str) -> "Rerank": + """Continuation of the ``RERANK`` command. + + :param fields: One or more fields to use for reranking. These fields should + contain the text that the reranking model will evaluate. + """ self._fields = fields return self def with_(self, inference_id: str) -> "Rerank": - """Continuation of the `COMPLETION` command. + """Continuation of the ``RERANK`` command. :param inference_id: The ID of the inference endpoint to use for the task. The - inference endpoint must be configured with the completion + inference endpoint must be configured with the ``rerank`` task type. """ self._inference_id = inference_id @@ -1195,6 +1360,8 @@ class Stats(ESQLBase): in a single expression. """ + command_name = "STATS" + def __init__( self, parent: ESQLBase, @@ -1210,6 +1377,12 @@ def __init__( self._grouping_expressions: Optional[Tuple[ExpressionType, ...]] = None def by(self, *grouping_expressions: ExpressionType) -> "Stats": + """Continuation of the ``STATS`` and ``INLINE STATS`` commands. + + :param grouping_expressions: Expressions that output the values to group by. + If their names coincide with one of the computed + columns, that column will be ignored. + """ self._grouping_expressions = grouping_expressions return self @@ -1221,13 +1394,25 @@ def _render_internal(self) -> str: ] else: exprs = [f"{self._format_expr(expr)}" for expr in self._expressions] - expression_separator = ",\n " + indent = " " * (len(self.command_name) + 3) + expression_separator = f",\n{indent}" by = ( "" if self._grouping_expressions is None - else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}' + else f'\n{indent}BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}' ) - return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}' + return f'{self.command_name} {expression_separator.join([f"{expr}" for expr in exprs])}{by}' + + +class InlineStats(Stats): + """Implementation of the ``INLINE STATS`` processing command. + + This class inherits from :class:`ESQLBase `, + to make it possible to chain all the commands that belong to an ES|QL query + in a single expression. + """ + + command_name = "INLINE STATS" class Where(ESQLBase): diff --git a/elasticsearch/esql/functions.py b/elasticsearch/esql/functions.py index 162d7b95e..46cf5c9d9 100644 --- a/elasticsearch/esql/functions.py +++ b/elasticsearch/esql/functions.py @@ -38,6 +38,20 @@ def abs(number: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"ABS({_render(number)})") +def absent(field: ExpressionType) -> InstrumentedExpression: + """Returns true if the input expression yields no non-null values within the + current aggregation context. + + :param field: Expression that outputs values to be checked for absence. + """ + return InstrumentedExpression(f"ABSENT({_render(field)})") + + +def absent_over_time(field: ExpressionType) -> InstrumentedExpression: + """Calculates the absence of a field in the output result over time range.""" + return InstrumentedExpression(f"ABSENT_OVER_TIME({_render(field)})") + + def acos(number: ExpressionType) -> InstrumentedExpression: """Returns the arccosine of `n` as an angle, expressed in radians. @@ -364,6 +378,11 @@ def exp(number: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"EXP({_render(number)})") +def first(value: ExpressionType, sort: ExpressionType) -> InstrumentedExpression: + """Calculates the earliest value of a field.""" + return InstrumentedExpression(f"FIRST({_render(value)}, {_render(sort)})") + + def first_over_time(field: ExpressionType) -> InstrumentedExpression: """The earliest value of a field, where recency determined by the `@timestamp` field. @@ -463,6 +482,11 @@ def kql(query: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"KQL({_render(query)})") +def last(value: ExpressionType, sort: ExpressionType) -> InstrumentedExpression: + """Calculates the latest value of a field.""" + return InstrumentedExpression(f"LAST({_render(value)}, {_render(sort)})") + + def last_over_time(field: ExpressionType) -> InstrumentedExpression: """The latest value of a field, where recency determined by the `@timestamp` field. @@ -697,6 +721,18 @@ def mv_concat(string: ExpressionType, delim: ExpressionType) -> InstrumentedExpr return InstrumentedExpression(f"MV_CONCAT({_render(string)}, {_render(delim)})") +def mv_contains( + superset: ExpressionType, subset: ExpressionType +) -> InstrumentedExpression: + """Checks if all values yielded by the second multivalue expression are present in the + values yielded by the first multivalue expression. Returns a boolean. Null values are + treated as an empty set. + """ + return InstrumentedExpression( + f"MV_CONTAINS({_render(superset)}, {_render(subset)})" + ) + + def mv_count(field: ExpressionType) -> InstrumentedExpression: """Converts a multivalued expression into a single valued column containing a count of the number of values. @@ -894,6 +930,18 @@ def pow(base: ExpressionType, exponent: ExpressionType) -> InstrumentedExpressio return InstrumentedExpression(f"POW({_render(base)}, {_render(exponent)})") +def present(field: ExpressionType) -> InstrumentedExpression: + """Returns true if the input expression yields any non-null values within the current + aggregation context. Otherwise it returns false. + """ + return InstrumentedExpression(f"PRESENT({_render(field)})") + + +def present_over_time(field: ExpressionType) -> InstrumentedExpression: + """Calculates the presence of a field in the output result over time range.""" + return InstrumentedExpression(f"PRESENT_OVER_TIME({_render(field)})") + + def qstr( query: ExpressionType, options: ExpressionType = None ) -> InstrumentedExpression: @@ -1452,6 +1500,11 @@ def sum(number: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"SUM({_render(number)})") +def sum_over_time(field: ExpressionType) -> InstrumentedExpression: + """Calculates the sum over time value of a field.""" + return InstrumentedExpression(f"SUM({_render(field)})") + + def tan(angle: ExpressionType) -> InstrumentedExpression: """Returns the tangent of an angle. @@ -1483,6 +1536,17 @@ def term(field: ExpressionType, query: ExpressionType) -> InstrumentedExpression return InstrumentedExpression(f"TERM({_render(field)}, {_render(query)})") +def text_embedding( + text: ExpressionType, inference_id: ExpressionType +) -> InstrumentedExpression: + """Generates dense vector embeddings from text input using a specified inference endpoint. + Use this function to generate query vectors for KNN searches against your vectorized data + or others dense vector based operations.""" + return InstrumentedExpression( + f"TEXT_EMBEDDING({_render(text)}, {_render(inference_id)})" + ) + + def top( field: ExpressionType, limit: ExpressionType, order: ExpressionType ) -> InstrumentedExpression: @@ -1596,6 +1660,22 @@ def to_double(field: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"TO_DOUBLE({_render(field)})") +def to_geohash(field: ExpressionType) -> InstrumentedExpression: + """Converts an input value to a geohash value. A string will only be successfully + converted if it respects the geohash format, as described for the geohash grid + aggregation. + """ + return InstrumentedExpression(f"TO_GEOHASH({_render(field)})") + + +def to_geohex(field: ExpressionType) -> InstrumentedExpression: + """Converts an input value to a geohex value. A string will only be successfully + converted if it respects the geohex format, as described for the geohex grid + aggregation. + """ + return InstrumentedExpression(f"TO_GEOHEX({_render(field)})") + + def to_geopoint(field: ExpressionType) -> InstrumentedExpression: """Converts an input value to a `geo_point` value. A string will only be successfully converted if it respects the WKT Point format. @@ -1616,6 +1696,14 @@ def to_geoshape(field: ExpressionType) -> InstrumentedExpression: return InstrumentedExpression(f"TO_GEOSHAPE({_render(field)})") +def to_geotile(field: ExpressionType) -> InstrumentedExpression: + """Converts an input value to a geotile value. A string will only be successfully + converted if it respects the geotile format, as described for the geotile grid + aggregation. + """ + return InstrumentedExpression(f"TO_GEOTILE({_render(field)})") + + def to_integer(field: ExpressionType) -> InstrumentedExpression: """Converts an input value to an integer value. If the input parameter is of a date type, its value will be interpreted as milliseconds since the diff --git a/test_elasticsearch/test_esql.py b/test_elasticsearch/test_esql.py index e33f288da..06073da5e 100644 --- a/test_elasticsearch/test_esql.py +++ b/test_elasticsearch/test_esql.py @@ -55,6 +55,22 @@ def test_show(): assert query.render() == "SHOW INFO" +def test_ts(): + query = ( + ESQL.ts("metrics") + .where("@timestamp >= now() - 1 day") + .stats("SUM(AVG_OVER_TIME(memory_usage))") + .by("host", "TBUCKET(1 hour)") + ) + assert ( + query.render() + == """TS metrics +| WHERE @timestamp >= now() - 1 day +| STATS SUM(AVG_OVER_TIME(memory_usage)) + BY host, TBUCKET(1 hour)""" + ) + + def test_change_point(): query = ( ESQL.row(key=list(range(1, 26))) @@ -276,6 +292,78 @@ def test_fork(): ) +def test_fuse(): + query = ( + ESQL.from_("books") + .metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse() + ) + assert ( + query.render() + == """FROM books METADATA _id, _index, _score +| FORK ( WHERE title:"Shakespeare" | SORT _score DESC ) + ( WHERE semantic_title:"Shakespeare" | SORT _score DESC ) +| FUSE""" + ) + + query = ( + ESQL.from_("books") + .metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear") + ) + assert ( + query.render() + == """FROM books METADATA _id, _index, _score +| FORK ( WHERE title:"Shakespeare" | SORT _score DESC ) + ( WHERE semantic_title:"Shakespeare" | SORT _score DESC ) +| FUSE LINEAR""" + ) + + query = ( + ESQL.from_("books") + .metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear") + .by("title", "description") + ) + assert ( + query.render() + == """FROM books METADATA _id, _index, _score +| FORK ( WHERE title:"Shakespeare" | SORT _score DESC ) + ( WHERE semantic_title:"Shakespeare" | SORT _score DESC ) +| FUSE LINEAR BY title BY description""" + ) + + query = ( + ESQL.from_("books") + .metadata("_id", "_index", "_score") + .fork( + ESQL.branch().where('title:"Shakespeare"').sort("_score DESC"), + ESQL.branch().where('semantic_title:"Shakespeare"').sort("_score DESC"), + ) + .fuse("linear") + .with_(normalizer="minmax") + ) + assert ( + query.render() + == """FROM books METADATA _id, _index, _score +| FORK ( WHERE title:"Shakespeare" | SORT _score DESC ) + ( WHERE semantic_title:"Shakespeare" | SORT _score DESC ) +| FUSE LINEAR WITH {"normalizer": "minmax"}""" + ) + + def test_grok(): query = ( ESQL.row(a="2023-01-23T12:15:00.000Z 127.0.0.1 some.email@foo.com 42") @@ -322,6 +410,82 @@ def test_grok(): ) +def test_inline_stats(): + query = ( + ESQL.from_("employees") + .keep("emp_no", "languages", "salary") + .inline_stats(max_salary=functions.max(E("salary"))) + .by("languages") + ) + assert ( + query.render() + == """FROM employees +| KEEP emp_no, languages, salary +| INLINE STATS max_salary = MAX(salary) + BY languages""" + ) + + query = ( + ESQL.from_("employees") + .keep("emp_no", "languages", "salary") + .inline_stats(max_salary=functions.max(E("salary"))) + ) + assert ( + query.render() + == """FROM employees +| KEEP emp_no, languages, salary +| INLINE STATS max_salary = MAX(salary)""" + ) + + query = ( + ESQL.from_("employees") + .where("still_hired") + .keep("emp_no", "languages", "salary", "hire_date") + .eval(tenure=functions.date_diff("year", E("hire_date"), "2025-09-18T00:00:00")) + .drop("hire_date") + .inline_stats( + avg_salary=functions.avg(E("salary")), + count=functions.count(E("*")), + ) + .by("languages", "tenure") + ) + assert ( + query.render() + == """FROM employees +| WHERE still_hired +| KEEP emp_no, languages, salary, hire_date +| EVAL tenure = DATE_DIFF("year", hire_date, "2025-09-18T00:00:00") +| DROP hire_date +| INLINE STATS avg_salary = AVG(salary), + count = COUNT(*) + BY languages, tenure""" + ) + + query = ( + ESQL.from_("employees") + .keep("emp_no", "salary") + .inline_stats( + avg_lt_50=functions.round(functions.avg(E("salary"))).where( + E("salary") < 50000 + ), + avg_lt_60=functions.round(functions.avg(E("salary"))).where( + E("salary") >= 50000, E("salary") < 60000 + ), + avg_gt_60=functions.round(functions.avg(E("salary"))).where( + E("salary") >= 60000 + ), + ) + ) + assert ( + query.render() + == """FROM employees +| KEEP emp_no, salary +| INLINE STATS avg_lt_50 = ROUND(AVG(salary)) WHERE salary < 50000, + avg_lt_60 = ROUND(AVG(salary)) WHERE (salary >= 50000) AND (salary < 60000), + avg_gt_60 = ROUND(AVG(salary)) WHERE salary >= 60000""" + ) + + def test_keep(): query = ESQL.from_("employees").keep("emp_no", "first_name", "last_name", "height") assert (