diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index 1f894487788e..db545a88a718 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -145,66 +145,58 @@ def _minimize_spec(start, end, spec): spec.args["end_side"] = None return spec - def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset): + def visit_WindowAggregate( + self, + op, + *, + parent, + window_type, + time_col, + groups, + metrics, + window_size, + window_slide, + window_offset, + ): + if window_type == "tumble": + assert window_slide is None + args = [ - self.v[f"TABLE {table.this.sql(self.dialect)}"], + self.v[f"TABLE {parent.this.sql(self.dialect)}"], # `time_col` has the table _alias_, instead of the table, but it is # required to be bound to the table, this happens because of the # way we construct the op in the tumble API using bind # # perhaps there's a better way to deal with this self.f.descriptor(time_col.this), + window_slide, window_size, - offset, + window_offset, ] - return sg.select( - sge.Column( - this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) - ) - ).from_( - self.f.table(self.f.tumble(*filter(None, args))).as_( - table.alias_or_name, quoted=True - ) - ) + window_func = getattr(self.f, window_type) - def visit_HopWindowingTVF( - self, op, *, table, time_col, window_size, window_slide, offset - ): - args = [ - self.v[f"TABLE {table.this.sql(self.dialect)}"], - self.f.descriptor(time_col.this), - window_slide, - window_size, - offset, - ] - return sg.select( - sge.Column( - this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) - ) - ).from_( - self.f.table(self.f.hop(*filter(None, args))).as_( - table.alias_or_name, quoted=True - ) + # create column references to new columns generated by Flink in the output + window_start = sg.column( + "window_start", table=parent.alias_or_name, quoted=True ) - - def visit_CumulateWindowingTVF( - self, op, *, table, time_col, window_size, window_step, offset - ): - args = [ - self.v[f"TABLE {table.this.sql(self.dialect)}"], - self.f.descriptor(time_col.this), - window_step, - window_size, - offset, - ] - return sg.select( - sge.Column( - this=STAR, table=sg.to_identifier(table.alias_or_name, quoted=True) + window_end = sg.column("window_end", table=parent.alias_or_name, quoted=True) + + return ( + sg.select( + window_start, + window_end, + *self._cleanup_names(groups), + *self._cleanup_names(metrics), + copy=False, + ) + .from_( + self.f.table(window_func(*filter(None, args))).as_( + parent.alias_or_name, quoted=True + ) ) - ).from_( - self.f.table(self.f.cumulate(*filter(None, args))).as_( - table.alias_or_name, quoted=True + .group_by( + *self._generate_groups([window_start, window_end, *groups.values()]) ) ) diff --git a/ibis/backends/flink/tests/conftest.py b/ibis/backends/flink/tests/conftest.py index c0a5c2b05aab..03716387b1c4 100644 --- a/ibis/backends/flink/tests/conftest.py +++ b/ibis/backends/flink/tests/conftest.py @@ -13,6 +13,25 @@ if TYPE_CHECKING: from pyflink.table import StreamTableEnvironment +TEST_TABLES["functional_alltypes"] = ibis.schema( + { + "id": "int32", + "bool_col": "boolean", + "tinyint_col": "int8", + "smallint_col": "int16", + "int_col": "int32", + "bigint_col": "int64", + "float_col": "float32", + "double_col": "float64", + "date_string_col": "string", + "string_col": "string", + "timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's + # watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3 + "year": "int32", + "month": "int32", + } +) + def get_table_env( local_env: bool, @@ -152,24 +171,7 @@ def awards_players_schema(): @pytest.fixture def functional_alltypes_schema(): - return ibis.schema( - { - "id": "int32", - "bool_col": "boolean", - "tinyint_col": "int8", - "smallint_col": "int16", - "int_col": "int32", - "bigint_col": "int64", - "float_col": "float32", - "double_col": "float64", - "date_string_col": "string", - "string_col": "string", - "timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's - # watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3 - "year": "int32", - "month": "int32", - } - ) + return TEST_TABLES["functional_alltypes"] @pytest.fixture @@ -188,3 +190,33 @@ def generate_csv_configs(csv_file): } return generate_csv_configs + + +@pytest.fixture(scope="session") +def functional_alltypes_no_header(tmpdir_factory, data_dir): + file = tmpdir_factory.mktemp("data") / "functional_alltypes.csv" + with ( + open(data_dir / "csv" / "functional_alltypes.csv") as reader, + open(str(file), mode="w") as writer, + ): + reader.readline() # read the first line and discard it + for line in reader: + writer.write(line) + return file + + +@pytest.fixture(scope="session", autouse=True) +def functional_alltypes_with_watermark(con, functional_alltypes_no_header): + # create a streaming table with watermark for testing event-time based ops + t = con.create_table( + "functional_alltypes_with_watermark", + schema=TEST_TABLES["functional_alltypes"], + tbl_properties={ + "connector": "filesystem", + "path": functional_alltypes_no_header, + "format": "csv", + }, + watermark=ibis.watermark("timestamp_col", ibis.interval(seconds=10)), + temp=True, + ) + return t diff --git a/ibis/backends/flink/tests/test_compiler.py b/ibis/backends/flink/tests/test_compiler.py index cdfe9a996d92..49e649877ee8 100644 --- a/ibis/backends/flink/tests/test_compiler.py +++ b/ibis/backends/flink/tests/test_compiler.py @@ -1,13 +1,8 @@ from __future__ import annotations -from operator import methodcaller - import pytest from pytest import param -import ibis -from ibis.common.deferred import _ - def test_sum(simple_table, assert_sql): expr = simple_table.a.sum() @@ -103,48 +98,3 @@ def test_having(simple_table, assert_sql): .aggregate(simple_table.b.sum().name("b_sum")) ) assert_sql(expr) - - -@pytest.mark.parametrize( - "method", - [ - methodcaller("tumble", window_size=ibis.interval(minutes=15)), - methodcaller( - "hop", - window_size=ibis.interval(minutes=15), - window_slide=ibis.interval(minutes=1), - ), - methodcaller( - "cumulate", - window_size=ibis.interval(minutes=1), - window_step=ibis.interval(seconds=10), - ), - ], - ids=["tumble", "hop", "cumulate"], -) -def test_windowing_tvf(simple_table, method, assert_sql): - expr = method(simple_table.window_by(time_col=simple_table.i)) - assert_sql(expr) - - -def test_window_aggregation(simple_table, assert_sql): - expr = ( - simple_table.window_by(time_col=simple_table.i) - .tumble(window_size=ibis.interval(minutes=15)) - .group_by(["window_start", "window_end", "g"]) - .aggregate(mean=_.d.mean()) - ) - assert_sql(expr) - - -def test_window_topn(simple_table, assert_sql): - expr = simple_table.window_by(time_col="i").tumble( - window_size=ibis.interval(seconds=600), - )["a", "b", "c", "d", "g", "window_start", "window_end"] - expr = expr.mutate( - rownum=ibis.row_number().over( - group_by=["window_start", "window_end"], order_by=ibis.desc("g") - ) - ) - expr = expr[expr.rownum <= 3] - assert_sql(expr) diff --git a/ibis/backends/flink/tests/test_join.py b/ibis/backends/flink/tests/test_join.py deleted file mode 100644 index d9ea09a819d5..000000000000 --- a/ibis/backends/flink/tests/test_join.py +++ /dev/null @@ -1,166 +0,0 @@ -from __future__ import annotations - -import tempfile - -import numpy as np -import pandas as pd -import pytest - -import ibis -from ibis.backends.flink.tests.conftest import TestConf as tm -from ibis.backends.tests.errors import Py4JJavaError - - -@pytest.fixture(scope="module") -def left_tmp(): - return tempfile.NamedTemporaryFile() - - -@pytest.fixture(scope="module") -def right_tmp(): - return tempfile.NamedTemporaryFile() - - -@pytest.fixture(scope="module") -def left_tumble(con, left_tmp): - left_pd = pd.DataFrame( - { - "row_time": [ - pd.to_datetime("2020-04-15 12:02"), - pd.to_datetime("2020-04-15 12:06"), - pd.to_datetime("2020-04-15 12:03"), - ], - "num": [1, 2, 3], - "id": ["L1", "L2", "L3"], - } - ) - left_pd.to_csv(left_tmp.name, header=False, index=None) - - left_schema = ibis.schema( - { - "row_time": "timestamp(3)", - "num": "int32", - "id": "string", - } - ) - left = con.create_table( - "left", - schema=left_schema, - tbl_properties={ - "connector": "filesystem", - "path": left_tmp.name, - "format": "csv", - }, - watermark=ibis.watermark( - time_col="row_time", - allowed_delay=ibis.interval(seconds=1), - ), - ) - left_tumble = left.window_by(time_col=left.row_time).tumble( - window_size=ibis.interval(minutes=5) - ) - left_tumble = left_tumble[ - left_tumble - ] # this is required in order to avoid `row_time` being an ambiguous reference - return left_tumble - - -@pytest.fixture(scope="module") -def right_tumble(con, right_tmp): - right_pd = pd.DataFrame( - { - "row_time": [ - pd.to_datetime("2020-04-15 12:01"), - pd.to_datetime("2020-04-15 12:04"), - pd.to_datetime("2020-04-15 12:05"), - ], - "num": [2, 3, 4], - "id": ["R2", "R3", "R4"], - } - ) - right_pd.to_csv(right_tmp.name, header=False, index=None) - - right_schema = ibis.schema( - { - "row_time": "timestamp(3)", - "num": "int32", - "id": "string", - } - ) - right = con.create_table( - "right", - schema=right_schema, - tbl_properties={ - "connector": "filesystem", - "path": right_tmp.name, - "format": "csv", - }, - watermark=ibis.watermark( - time_col="row_time", - allowed_delay=ibis.interval(seconds=1), - ), - ) - right_tumble = right.window_by(time_col=right.row_time).tumble( - window_size=ibis.interval(minutes=5) - ) - right_tumble = right_tumble[ - right_tumble - ] # this is required in order to avoid `row_time` being an ambiguous reference - return right_tumble - - -@pytest.fixture(autouse=True, scope="module") -def remove_temp_files(left_tmp, right_tmp): - yield - left_tmp.close() - right_tmp.close() - - -@pytest.mark.xfail( - raises=(Py4JJavaError, AssertionError), - reason="subquery probably uses too much memory/resources, flink complains about network buffers", - strict=False, -) -def test_outer_join(left_tumble, right_tumble): - expr = left_tumble.join( - right_tumble, - ["num", "window_start", "window_end"], - how="outer", - lname="L_{name}", - rname="R_{name}", - ) - expr = expr[ - "L_num", - "L_id", - "R_num", - "R_id", - ibis.coalesce(expr["L_window_start"], expr["R_window_start"]).name( - "window_start" - ), - ibis.coalesce(expr["L_window_end"], expr["R_window_end"]).name("window_end"), - ] - result_df = expr.to_pandas() - - expected_df = pd.DataFrame.from_dict( - { - "L_num": {0: np.nan, 1: 1.0, 2: 3.0, 3: 2.0, 4: np.nan}, - "L_id": {0: None, 1: "L1", 2: "L3", 3: "L2", 4: None}, - "R_num": {0: 2.0, 1: np.nan, 2: 3.0, 3: np.nan, 4: 4.0}, - "R_id": {0: "R2", 1: None, 2: "R3", 3: None, 4: "R4"}, - "window_start": { - 0: pd.Timestamp("2020-04-15 12:00:00"), - 1: pd.Timestamp("2020-04-15 12:00:00"), - 2: pd.Timestamp("2020-04-15 12:00:00"), - 3: pd.Timestamp("2020-04-15 12:05:00"), - 4: pd.Timestamp("2020-04-15 12:05:00"), - }, - "window_end": { - 0: pd.Timestamp("2020-04-15 12:05:00"), - 1: pd.Timestamp("2020-04-15 12:05:00"), - 2: pd.Timestamp("2020-04-15 12:05:00"), - 3: pd.Timestamp("2020-04-15 12:10:00"), - 4: pd.Timestamp("2020-04-15 12:10:00"), - }, - } - ) - tm.assert_frame_equal(result_df, expected_df) diff --git a/ibis/backends/flink/tests/test_window.py b/ibis/backends/flink/tests/test_window.py index 9e0e480a603f..a8a20e4b5faa 100644 --- a/ibis/backends/flink/tests/test_window.py +++ b/ibis/backends/flink/tests/test_window.py @@ -5,6 +5,7 @@ from pytest import param import ibis +from ibis import _ from ibis.backends.tests.errors import Py4JJavaError @@ -53,13 +54,37 @@ def test_window_invalid_start_end(con, window): con.execute(expr) -def test_range_window(con, simple_table, assert_sql): +def test_range_window(simple_table, assert_sql): expr = simple_table.f.sum().over( range=(-ibis.interval(minutes=500), 0), order_by=simple_table.f ) assert_sql(expr) -def test_rows_window(con, simple_table, assert_sql): +def test_rows_window(simple_table, assert_sql): expr = simple_table.f.sum().over(rows=(-1000, 0), order_by=simple_table.f) assert_sql(expr) + + +def test_tumble_window_by_grouped_agg(con): + t = con.table("functional_alltypes_with_watermark") + expr = ( + t.window_by(time_col=t.timestamp_col) + .tumble(size=ibis.interval(seconds=30)) + .agg(by=["string_col"], avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"] + assert result.shape == (610, 4) + + +def test_tumble_window_by_ungrouped_agg(con): + t = con.table("functional_alltypes_with_watermark") + expr = ( + t.window_by(time_col=t.timestamp_col) + .tumble(size=ibis.interval(seconds=30)) + .agg(avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "avg"] + assert result.shape == (610, 3) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index ee4f74f6a29f..aca6e830faa7 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -2,6 +2,7 @@ import contextlib import os +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -1044,7 +1045,11 @@ def to_kafka( @util.experimental def read_csv_dir( - self, path: str | Path, table_name: str | None = None, **kwargs: Any + self, + path: str | Path, + table_name: str | None = None, + watermark: Watermark | None = None, + **kwargs: Any, ) -> ir.Table: """Register a CSV directory as a table in the current database. @@ -1055,6 +1060,8 @@ def read_csv_dir( table_name An optional name to use for the created table. This defaults to a random generated name. + watermark + Watermark strategy for the table. kwargs Additional keyword arguments passed to PySpark loading function. https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.csv.html @@ -1072,10 +1079,17 @@ def read_csv_dir( spark_df = self._session.read.csv( path, inferSchema=inferSchema, header=header, **kwargs ) + if watermark is not None: + warnings.warn("Watermark is not supported in batch mode") elif self.mode == "streaming": spark_df = self._session.readStream.csv( path, inferSchema=inferSchema, header=header, **kwargs ) + if watermark is not None: + spark_df = spark_df.withWatermark( + watermark.time_col, + _interval_to_string(watermark.allowed_delay), + ) table_name = table_name or util.gen_name("read_csv_dir") spark_df.createOrReplaceTempView(table_name) @@ -1086,6 +1100,8 @@ def read_parquet_dir( self, path: str | Path, table_name: str | None = None, + watermark: Watermark | None = None, + schema: sch.Schema | None = None, **kwargs: Any, ) -> ir.Table: """Register a parquet file as a table in the current database. @@ -1097,6 +1113,10 @@ def read_parquet_dir( table_name An optional name to use for the created table. This defaults to a random generated name. + watermark + Watermark strategy for the table. + schema + Schema of the parquet source. kwargs Additional keyword arguments passed to PySpark. https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.parquet.html @@ -1109,9 +1129,22 @@ def read_parquet_dir( """ path = util.normalize_filename(path) if self.mode == "batch": - spark_df = self._session.read.parquet(path, **kwargs) + spark_df = self._session.read + if schema is not None: + spark_df = spark_df.schema(PySparkSchema.from_ibis(schema)) + spark_df = spark_df.parquet(path, **kwargs) + if watermark is not None: + warnings.warn("Watermark is not supported in batch mode") elif self.mode == "streaming": - spark_df = self._session.readStream.parquet(path, **kwargs) + spark_df = self._session.readStream + if schema is not None: + spark_df = spark_df.schema(PySparkSchema.from_ibis(schema)) + spark_df = spark_df.parquet(path, **kwargs) + if watermark is not None: + spark_df = spark_df.withWatermark( + watermark.time_col, + _interval_to_string(watermark.allowed_delay), + ) table_name = table_name or util.gen_name("read_parquet_dir") spark_df.createOrReplaceTempView(table_name) @@ -1119,7 +1152,11 @@ def read_parquet_dir( @util.experimental def read_json_dir( - self, path: str | Path, table_name: str | None = None, **kwargs: Any + self, + path: str | Path, + table_name: str | None = None, + watermark: Watermark | None = None, + **kwargs: Any, ) -> ir.Table: """Register a JSON file as a table in the current database. @@ -1130,6 +1167,8 @@ def read_json_dir( table_name An optional name to use for the created table. This defaults to a random generated name. + watermark + Watermark strategy for the table. kwargs Additional keyword arguments passed to PySpark loading function. https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.json.html @@ -1143,8 +1182,15 @@ def read_json_dir( path = util.normalize_filename(path) if self.mode == "batch": spark_df = self._session.read.json(path, **kwargs) + if watermark is not None: + warnings.warn("Watermark is not supported in batch mode") elif self.mode == "streaming": spark_df = self._session.readStream.json(path, **kwargs) + if watermark is not None: + spark_df = spark_df.withWatermark( + watermark.time_col, + _interval_to_string(watermark.allowed_delay), + ) table_name = table_name or util.gen_name("read_json_dir") spark_df.createOrReplaceTempView(table_name) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 6cbd7f0796a9..0e7658edc0c6 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -472,9 +472,11 @@ def visit_TableUnnest( if overlaps_with_parent: column_alias_or_name = column.alias_or_name selcols.extend( - sg.column(col, table=parent_alias, quoted=quoted) - if col != column_alias_or_name - else computed_column + ( + sg.column(col, table=parent_alias, quoted=quoted) + if col != column_alias_or_name + else computed_column + ) for col in parent_schema.names ) else: @@ -517,3 +519,75 @@ def visit_TableUnnest( ) ) ) + + def visit_WindowAggregate( + self, + op, + *, + parent, + window_type, + time_col, + groups, + metrics, + window_size, + window_slide, + window_offset, + ): + if window_offset is not None: + raise com.UnsupportedOperationError( + "PySpark streaming does not support windowing with offset." + ) + if window_type == "tumble": + assert window_slide is None + + return ( + sg.select( + # the window column needs to be referred to directly as `window` rather + # than `t0`.`window` + sg.alias( + sge.Dot( + this=sge.Column(this="window"), + expression=sge.Identifier(this="start"), + ), + "window_start", + quoted=True, + ), + sg.alias( + sge.Dot( + this=sge.Column(this="window"), + expression=sge.Identifier(this="end"), + ), + "window_end", + quoted=True, + ), + *self._cleanup_names(groups), + *self._cleanup_names(metrics), + copy=False, + ) + .from_(parent.as_(parent.alias_or_name)) + .group_by( + *groups.values(), + self.f.window( + sg.column(time_col.this, table=parent.alias_or_name, quoted=True), + *filter( + None, + [ + self._format_window_interval(window_size), + self._format_window_interval(window_slide), + ], + ), + ), + copy=False, + ) + ) + + def _format_window_interval(self, expression): + if expression is None: + return None + unit = expression.args.get("unit").sql(dialect=self.dialect) + # skip plural conversion + unit = f" {unit}" if unit else "" + + this = expression.this.this # avoid quoting the interval as a string literal + + return f"{this}{unit}" diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index 7ffce3ac297e..5477471688bb 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -3,6 +3,7 @@ import os from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any +from unittest import mock import numpy as np import pandas as pd @@ -12,6 +13,8 @@ import ibis from ibis import util from ibis.backends.conftest import TEST_TABLES +from ibis.backends.pyspark import Backend +from ibis.backends.pyspark.datatypes import PySparkSchema from ibis.backends.tests.base import BackendTest from ibis.backends.tests.data import json_types, topk, win @@ -189,13 +192,17 @@ def _load_data(self, **_: Any) -> None: s = self.connection._session num_partitions = 4 - sort_cols = {"functional_alltypes": "id"} + watermark_cols = {"functional_alltypes": "timestamp_col"} - for name in TEST_TABLES: + for name, schema in TEST_TABLES.items(): path = str(self.data_dir / "directory" / "parquet" / name) - t = s.readStream.parquet(path).repartition(num_partitions) - if (sort_col := sort_cols.get(name)) is not None: - t = t.sort(sort_col) + t = ( + s.readStream.schema(PySparkSchema.from_ibis(schema)) + .parquet(path) + .repartition(num_partitions) + ) + if (watermark_col := watermark_cols.get(name)) is not None: + t = t.withWatermark(watermark_col, "10 seconds") t.createOrReplaceTempView(name) @classmethod @@ -409,3 +416,29 @@ def temp_table_db(con, temp_database): yield temp_database, name assert name in con.list_tables(database=temp_database), name con.drop_table(name, database=temp_database) + + +@pytest.fixture(scope="session", autouse=True) +def default_session_fixture(): + with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True): + yield + + +def write_to_memory(self, expr, table_name): + if self.mode == "batch": + raise NotImplementedError + df = self._session.sql(expr.compile()) + df.writeStream.format("memory").queryName(table_name).start() + + +@pytest.fixture(autouse=True, scope="function") +def stop_active_jobs(con_streaming): + yield + for sq in con_streaming._session.streams.active: + sq.stop() + sq.awaitTermination() + + +@pytest.fixture +def awards_players_schema(): + return TEST_TABLES["awards_players"] diff --git a/ibis/backends/pyspark/tests/test_import_export.py b/ibis/backends/pyspark/tests/test_import_export.py index 9955359a9ac2..bc57ddf8728c 100644 --- a/ibis/backends/pyspark/tests/test_import_export.py +++ b/ibis/backends/pyspark/tests/test_import_export.py @@ -2,42 +2,13 @@ from operator import methodcaller from time import sleep -from unittest import mock import pandas as pd import pytest -from ibis.backends.conftest import TEST_TABLES -from ibis.backends.pyspark import Backend from ibis.backends.pyspark.datatypes import PySparkSchema -@pytest.fixture(scope="session", autouse=True) -def default_session_fixture(): - with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True): - yield - - -def write_to_memory(self, expr, table_name): - if self.mode == "batch": - raise NotImplementedError - df = self._session.sql(expr.compile()) - df.writeStream.format("memory").queryName(table_name).start() - - -@pytest.fixture(autouse=True, scope="function") -def stop_active_jobs(con_streaming): - yield - for sq in con_streaming._session.streams.active: - sq.stop() - sq.awaitTermination() - - -@pytest.fixture -def awards_players_schema(): - return TEST_TABLES["awards_players"] - - @pytest.mark.parametrize( "method", [ diff --git a/ibis/backends/pyspark/tests/test_window.py b/ibis/backends/pyspark/tests/test_window.py index f7d23e69f02c..b41a381ef3b9 100644 --- a/ibis/backends/pyspark/tests/test_window.py +++ b/ibis/backends/pyspark/tests/test_window.py @@ -1,9 +1,13 @@ from __future__ import annotations +from time import sleep + +import pandas as pd import pandas.testing as tm import pytest import ibis +from ibis import _ pyspark = pytest.importorskip("pyspark") @@ -87,3 +91,46 @@ def test_multiple_windows(t, spark_table, ibis_windows, spark_range): .toPandas() ) tm.assert_frame_equal(result, expected) + + +def test_tumble_window_by_grouped_agg(con_streaming, tmp_path): + t = con_streaming.table("functional_alltypes") + expr = ( + t.window_by(time_col=t.timestamp_col) + .tumble(size=ibis.interval(seconds=30)) + .agg(by=["string_col"], avg=_.float_col.mean()) + ) + path = tmp_path / "out" + con_streaming.to_csv_dir( + expr, + path=path, + options={"checkpointLocation": tmp_path / "checkpoint", "header": True}, + ) + sleep(5) + dfs = [pd.read_csv(f) for f in path.glob("*.csv")] + df = pd.concat([df for df in dfs if not df.empty]) + assert list(df.columns) == ["window_start", "window_end", "string_col", "avg"] + # [NOTE] The expected number of rows here is 7299 because when all the data is ready + # at once, no event is dropped as out of order. On the contrary, Flink discards all + # out-of-order events as late arrivals and only emits 610 windows. + assert df.shape == (7299, 4) + + +def test_tumble_window_by_ungrouped_agg(con_streaming, tmp_path): + t = con_streaming.table("functional_alltypes") + expr = ( + t.window_by(time_col=t.timestamp_col) + .tumble(size=ibis.interval(seconds=30)) + .agg(avg=_.float_col.mean()) + ) + path = tmp_path / "out" + con_streaming.to_csv_dir( + expr, + path=path, + options={"checkpointLocation": tmp_path / "checkpoint", "header": True}, + ) + sleep(5) + dfs = [pd.read_csv(f) for f in path.glob("*.csv")] + df = pd.concat([df for df in dfs if not df.empty]) + assert list(df.columns) == ["window_start", "window_end", "avg"] + assert df.shape == (7299, 3) diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 25d52fc721c9..9edd18361d04 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -211,7 +211,6 @@ def fmt(op, **kwargs): @fmt.register(ops.Relation) -@fmt.register(ops.WindowingTVF) def _relation(op, parent=None, **kwargs): if parent is None: top = f"{op.__class__.__name__}\n" diff --git a/ibis/expr/operations/temporal_windows.py b/ibis/expr/operations/temporal_windows.py index 9687befb170a..ab7b8632f176 100644 --- a/ibis/expr/operations/temporal_windows.py +++ b/ibis/expr/operations/temporal_windows.py @@ -2,68 +2,39 @@ from __future__ import annotations -from typing import Optional +from typing import Literal, Optional from public import public import ibis.expr.datatypes as dt from ibis.common.annotations import attribute +from ibis.common.collections import FrozenOrderedDict from ibis.expr.operations.core import Column, Scalar # noqa: TCH001 -from ibis.expr.operations.relations import Relation +from ibis.expr.operations.relations import Relation, Unaliased from ibis.expr.schema import Schema @public -class WindowingTVF(Relation): - """Generic windowing table-valued function.""" - - # TODO(kszucs): rename to `parent` - table: Relation - time_col: Column[dt.Timestamp] # enforce timestamp column type here +class WindowAggregate(Relation): + parent: Relation + window_type: Literal["tumble", "hop"] + time_col: Unaliased[Column] + groups: FrozenOrderedDict[str, Unaliased[Column]] + metrics: FrozenOrderedDict[str, Unaliased[Scalar]] + window_size: Scalar[dt.Interval] + window_slide: Optional[Scalar[dt.Interval]] = None + window_offset: Optional[Scalar[dt.Interval]] = None @attribute def values(self): - return self.table.fields + return FrozenOrderedDict({**self.groups, **self.metrics}) - @property + @attribute def schema(self): - names = list(self.table.schema.names) - types = list(self.table.schema.types) - - # The return value of windowing TVF is a new relation that includes all columns - # of original relation as well as additional 3 columns named “window_start”, - # “window_end”, “window_time” to indicate the assigned window - - # TODO(kszucs): this looks like an implementation detail leaked from the - # flink backend - names.extend(["window_start", "window_end", "window_time"]) - # window_start, window_end, window_time have type TIMESTAMP(3) in Flink - types.extend([dt.timestamp(scale=3)] * 3) - - return Schema.from_tuples(list(zip(names, types))) - - -@public -class TumbleWindowingTVF(WindowingTVF): - """TUMBLE window table-valued function.""" - - window_size: Scalar[dt.Interval] - offset: Optional[Scalar[dt.Interval]] = None - - -@public -class HopWindowingTVF(WindowingTVF): - """HOP window table-valued function.""" - - window_size: Scalar[dt.Interval] - window_slide: Scalar[dt.Interval] - offset: Optional[Scalar[dt.Interval]] = None - - -@public -class CumulateWindowingTVF(WindowingTVF): - """CUMULATE window table-valued function.""" - - window_size: Scalar[dt.Interval] - window_step: Scalar[dt.Interval] - offset: Optional[Scalar[dt.Interval]] = None + field_pairs = { + "window_start": dt.timestamp, + "window_end": dt.timestamp, + **{k: v.dtype for k, v in self.groups.items()}, + **{k: v.dtype for k, v in self.metrics.items()}, + } + return Schema(field_pairs) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index f577d4177d80..f66d6d64c0c0 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -22,6 +22,7 @@ from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import Value, literal from ibis.expr.types.pretty import to_rich +from ibis.expr.types.temporal import TimestampColumn from ibis.selectors import Selector from ibis.util import deprecated @@ -58,10 +59,10 @@ def _regular_join_method( def f( # noqa: D417 self: ir.Table, right: ir.Table, - predicates: str - | Sequence[ - str | tuple[str | ir.Column, str | ir.Column] | ir.BooleanValue - ] = (), + predicates: ( + str + | Sequence[str | tuple[str | ir.Column, str | ir.Column] | ir.BooleanValue] + ) = (), *, lname: str = "", rname: str = "{name}_right", @@ -220,7 +221,6 @@ def bind(self, *args, **kwargs): args = () else: args = util.promote_list(args[0]) - # bind positional arguments values = [] for arg in args: @@ -2200,10 +2200,12 @@ def select( ) def relabel( self, - substitutions: Mapping[str, str] - | Callable[[str], str | None] - | str - | Literal["snake_case", "ALL_CAPS"], + substitutions: ( + Mapping[str, str] + | Callable[[str], str | None] + | str + | Literal["snake_case", "ALL_CAPS"] + ), ) -> Table: """Deprecated in favor of `Table.rename`.""" if isinstance(substitutions, Mapping): @@ -2212,11 +2214,13 @@ def relabel( def rename( self, - method: str - | Callable[[str], str | None] - | Literal["snake_case", "ALL_CAPS"] - | Mapping[str, str] - | None = None, + method: ( + str + | Callable[[str], str | None] + | Literal["snake_case", "ALL_CAPS"] + | Mapping[str, str] + | None + ) = None, /, **substitutions: str, ) -> Table: @@ -3076,17 +3080,19 @@ def describe( def join( left: Table, right: Table, - predicates: str - | Sequence[ + predicates: ( str - | ir.BooleanColumn - | Literal[True] - | Literal[False] - | tuple[ - str | ir.Column | ir.Deferred, - str | ir.Column | ir.Deferred, + | Sequence[ + str + | ir.BooleanColumn + | Literal[True] + | Literal[False] + | tuple[ + str | ir.Column | ir.Deferred, + str | ir.Column | ir.Deferred, + ] ] - ] = (), + ) = (), how: JoinKind = "inner", *, lname: str = "", @@ -3630,9 +3636,9 @@ def pivot_longer( *, names_to: str | Iterable[str] = "name", names_pattern: str | re.Pattern = r"(.+)", - names_transform: Callable[[str], ir.Value] - | Mapping[str, Callable[[str], ir.Value]] - | None = None, + names_transform: ( + Callable[[str], ir.Value] | Mapping[str, Callable[[str], ir.Value]] | None + ) = None, values_to: str = "value", values_transform: Callable[[ir.Value], ir.Value] | Deferred | None = None, ) -> Table: @@ -4606,23 +4612,19 @@ def relocate( return relocated - def window_by(self, time_col: ir.Value) -> WindowedTable: - """Create a windowing table-valued function (TVF) expression. + def window_by( + self, + time_col: str | ir.Value, + ) -> WindowedTable: + from ibis.expr.types.temporal_windows import WindowedTable - Windowing table-valued functions (TVF) assign rows of a table to windows - based on a time attribute column in the table. + time_col = next(iter(self.bind(time_col))) - Parameters - ---------- - time_col - Column of the table that will be mapped to windows. - - Returns - ------- - WindowedTable - WindowedTable expression. - """ - from ibis.expr.types.temporal_windows import WindowedTable + # validate time_col is a timestamp column + if not isinstance(time_col, TimestampColumn): + raise com.IbisInputError( + f"`time_col` must be a timestamp column, got {time_col.type()}" + ) return WindowedTable(self, time_col) diff --git a/ibis/expr/types/temporal_windows.py b/ibis/expr/types/temporal_windows.py index 74560d0f36a2..170b3b29aa99 100644 --- a/ibis/expr/types/temporal_windows.py +++ b/ibis/expr/types/temporal_windows.py @@ -1,135 +1,93 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from public import public import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.types as ir -from ibis.expr.types.relations import bind +from ibis.common.collections import FrozenOrderedDict # noqa: TCH001 +from ibis.common.grounds import Concrete +from ibis.expr.operations.relations import Unaliased # noqa: TCH001 +from ibis.expr.types.relations import unwrap_aliases if TYPE_CHECKING: - from ibis.expr.types import Table + from collections.abc import Sequence @public -class WindowedTable: +class WindowedTable(Concrete): """An intermediate table expression to hold windowing information.""" - def __init__(self, table: ir.Table, time_col: ir.Value): - self.table = table - self.time_col = next(bind(table, time_col)) - - if self.time_col is None: + parent: ir.Table + time_col: ops.Column + window_type: Literal["tumble", "hop"] | None = None + window_size: ir.IntervalScalar | None = None + window_slide: ir.IntervalScalar | None = None + window_offset: ir.IntervalScalar | None = None + groups: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None + metrics: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None + + def __init__(self, time_col: ops.Column, **kwargs): + if time_col is None: raise com.IbisInputError( "Window aggregations require `time_col` as an argument" ) + super().__init__(time_col=time_col, **kwargs) def tumble( self, - window_size: ir.IntervalScalar, + size: ir.IntervalScalar, offset: ir.IntervalScalar | None = None, - ) -> Table: - """Compute a tumble table valued function. - - Tumbling windows have a fixed size and do not overlap. The size of the windows is - determined by `window_size`, optionally shifted by a duration specified by `offset`. - - Parameters - ---------- - window_size - Width of the tumbling windows. - offset - An optional parameter to specify the offset which window start should be shifted by. - - Returns - ------- - Table - Table expression after applying tumbling table-valued function. - """ - time_col = next(bind(self.table, self.time_col)) - return ops.TumbleWindowingTVF( - table=self.table, - time_col=time_col, - window_size=window_size, - offset=offset, - ).to_expr() + ) -> WindowedTable: + return self.copy(window_type="tumble", window_size=size, window_offset=offset) def hop( self, - window_size: ir.IntervalScalar, - window_slide: ir.IntervalScalar, + size: ir.IntervalScalar, + slide: ir.IntervalScalar, offset: ir.IntervalScalar | None = None, - ): - """Compute a hop table valued function. - - Hopping windows have a fixed size and can be overlapping if the slide is smaller than the - window size (in which case elements can be assigned to multiple windows). Hopping windows - are also known as sliding windows. The size of the windows is determined by `window_size`, - how frequently a hopping window is started is determined by `window_slide`, and windows can - be optionally shifted by a duration specified by `offset`. - - For example, you could have windows of size 10 minutes that slides by 5 minutes. With this, - you get every 5 minutes a window that contains the events that arrived during the last 10 minutes. - - Parameters - ---------- - window_size - Width of the hopping windows. - window_slide - The duration between the start of sequential hopping windows. - offset - An optional parameter to specify the offset which window start should be shifted by. - - Returns - ------- - Table - Table expression after applying hopping table-valued function. - """ - time_col = next(bind(self.table, self.time_col)) - return ops.HopWindowingTVF( - table=self.table, - time_col=time_col, - window_size=window_size, - window_slide=window_slide, - offset=offset, - ).to_expr() - - def cumulate( + ) -> WindowedTable: + return self.copy( + window_type="hop", + window_size=size, + window_slide=slide, + window_offset=offset, + ) + + def aggregate( self, - window_size: ir.IntervalScalar, - window_step: ir.IntervalScalar, - offset: ir.IntervalScalar | None = None, - ): - """Compute a cumulate table valued function. - - Cumulate windows don't have a fixed size and do overlap. Cumulate windows assign elements to windows - that cover rows within an initial interval of step size and expand to one more step size (keep window - start fixed) every step until the max window size. - - For example, you could have a cumulating window for 1 hour step and 1 day max size, and you will get - windows: [00:00, 01:00), [00:00, 02:00), [00:00, 03:00), …, [00:00, 24:00) for every day. - - Parameters - ---------- - window_size - Max width of the cumulating windows. - window_step - A duration specifying the increased window size between the end of sequential cumulating windows. - offset - An optional parameter to specify the offset which window start should be shifted by. - - Returns - ------- - Table - Table expression after applying cumulate table-valued function. - """ - time_col = next(bind(self.table, self.time_col)) - return ops.CumulateWindowingTVF( - table=self.table, - time_col=time_col, - window_size=window_size, - window_step=window_step, - offset=offset, + metrics: Sequence[ir.Scalar] | None = (), + by: str | ir.Value | Sequence[str] | Sequence[ir.Value] | None = (), + **kwargs: ir.Value, + ) -> ir.Table: + by = self.parent.bind(by) + metrics = self.parent.bind(metrics, **kwargs) + + by = unwrap_aliases(by) + metrics = unwrap_aliases(metrics) + + groups = dict(self.groups) if self.groups is not None else {} + groups.update(by) + + return ops.WindowAggregate( + self.parent, + self.window_type, + self.time_col, + groups=groups, + metrics=metrics, + window_size=self.window_size, + window_slide=self.window_slide, + window_offset=self.window_offset, ).to_expr() + + agg = aggregate + + def group_by( + self, *by: str | ir.Value | Sequence[str] | Sequence[ir.Value] + ) -> WindowedTable: + by = tuple(v for v in by if v is not None) + groups = self.parent.bind(*by) + groups = unwrap_aliases(groups) + return self.copy(groups=groups) diff --git a/ibis/tests/expr/test_temporal_windows.py b/ibis/tests/expr/test_temporal_windows.py index 928c26726d84..6df992ac1a18 100644 --- a/ibis/tests/expr/test_temporal_windows.py +++ b/ibis/tests/expr/test_temporal_windows.py @@ -1,76 +1,87 @@ from __future__ import annotations -import datetime +from operator import methodcaller import pytest import ibis import ibis.common.exceptions as com -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -from ibis import selectors as s -from ibis.common.annotations import ValidationError from ibis.common.deferred import _ -def test_tumble_tvf_schema(schema, table): - expr = table.window_by(time_col=table.i).tumble( - window_size=ibis.interval(minutes=15) - ) +@pytest.mark.parametrize( + "method", + [ + methodcaller("tumble", size=ibis.interval(minutes=15)), + methodcaller( + "hop", size=ibis.interval(minutes=15), slide=ibis.interval(minutes=1) + ), + ], + ids=["tumble", "hop"], +) +@pytest.mark.parametrize("by", ["g", _.g, ["g"]]) +def test_window_by_agg_schema(table, method, by): + expr = method(table.window_by(time_col=table.i)) + expr = expr.agg(by=by, a_sum=_.a.sum()) expected_schema = ibis.schema( - schema - + [ - ("window_start", dt.Timestamp(scale=3)), - ("window_end", dt.Timestamp(scale=3)), - ("window_time", dt.Timestamp(scale=3)), - ] + { + "window_start": "timestamp", + "window_end": "timestamp", + "g": "string", + "a_sum": "int64", + } ) assert expr.schema() == expected_schema -@pytest.mark.parametrize("wrong_type_window_size", ["60", 60]) -def test_create_tumble_tvf_with_wrong_scalar_type(table, wrong_type_window_size): - with pytest.raises(ValidationError, match=".* is not coercible to a .*"): - table.window_by(time_col=table.i).tumble(window_size=wrong_type_window_size) - - -def test_create_tumble_tvf_with_nonexistent_time_col(table): - with pytest.raises(com.IbisTypeError, match="Column .* is not found in table"): - table.window_by(time_col=table["nonexistent"]).tumble( - window_size=datetime.timedelta(seconds=60) - ) - - -def test_create_tumble_tvf_with_nonscalar_window_size(schema): - schema.append(("l", "interval")) - table = ibis.table(schema, name="table") - with pytest.raises(ValidationError, match=".* is not coercible to a .*"): - table.window_by(time_col=table.i).tumble(window_size=table.l) - - -def test_create_tumble_tvf_with_non_timestamp_time_col(table): - with pytest.raises(ValidationError, match=".* is not coercible to a .*"): - table.window_by(time_col=table.e).tumble(window_size=ibis.interval(minutes=15)) - - -def test_create_tumble_tvf_with_str_time_col(table): - expr = table.window_by(time_col="i").tumble(window_size=ibis.interval(minutes=15)) - assert isinstance(expr.op(), ops.TumbleWindowingTVF) - assert expr.op().time_col == table.i.op() - - -@pytest.mark.parametrize("deferred", [_["i"], _.i]) -def test_create_tumble_tvf_with_deferred_time_col(table, deferred): - expr = table.window_by(time_col=deferred.resolve(table)).tumble( - window_size=ibis.interval(minutes=15) +def test_window_by_with_non_timestamp_column(table): + with pytest.raises(com.IbisInputError): + table.window_by(time_col=table.a) + + +@pytest.mark.parametrize( + "method", + [ + methodcaller("tumble", size=ibis.interval(minutes=15)), + methodcaller( + "hop", size=ibis.interval(minutes=15), slide=ibis.interval(minutes=1) + ), + ], + ids=["tumble", "hop"], +) +@pytest.mark.parametrize("by", ["g", _.g, ["g"]]) +def test_window_by_grouped_agg(table, method, by): + expr = method(table.window_by(time_col=table.i)) + expr = expr.group_by(by).agg(a_sum=_.a.sum()) + expected_schema = ibis.schema( + { + "window_start": "timestamp", + "window_end": "timestamp", + "g": "string", + "a_sum": "int64", + } ) - assert isinstance(expr.op(), ops.TumbleWindowingTVF) - assert expr.op().time_col == table.i.op() + assert expr.schema() == expected_schema -def test_create_tumble_tvf_with_selector_time_col(table): - expr = table.window_by(time_col=s.c("i")).tumble( - window_size=ibis.interval(minutes=15) +@pytest.mark.parametrize( + "method", + [ + methodcaller("tumble", size=ibis.interval(minutes=15)), + methodcaller( + "hop", size=ibis.interval(minutes=15), slide=ibis.interval(minutes=1) + ), + ], + ids=["tumble", "hop"], +) +def test_window_by_global_agg(table, method): + expr = method(table.window_by(time_col=table.i)) + expr = expr.agg(a_sum=_.a.sum()) + expected_schema = ibis.schema( + { + "window_start": "timestamp", + "window_end": "timestamp", + "a_sum": "int64", + } ) - assert isinstance(expr.op(), ops.TumbleWindowingTVF) - assert expr.op().time_col == table.i.op() + assert expr.schema() == expected_schema