Skip to content

Commit

Permalink
refactor WindowedTable class and add group_by method
Browse files Browse the repository at this point in the history
  • Loading branch information
Chloe He committed May 31, 2024
1 parent ae7bd82 commit fc5d2fd
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
1 change: 0 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def bind(self, *args, **kwargs):
args = ()
else:
args = util.promote_list(args[0])

# bind positional arguments
values = []
for arg in args:
Expand Down
55 changes: 37 additions & 18 deletions ibis/expr/types/temporal_windows.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,76 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from public import public

import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.collections import FrozenOrderedDict # noqa: TCH001
from ibis.common.grounds import Concrete
from ibis.expr.operations.relations import Unaliased # noqa: TCH001
from ibis.expr.types.relations import unwrap_aliases

if TYPE_CHECKING:
from collections.abc import Sequence


@public
class WindowedTable:
class WindowedTable(Concrete):
"""An intermediate table expression to hold windowing information."""

def __init__(self, parent: ir.Table, time_col: ops.Column):
parent: ir.Table
time_col: ops.Column
window_type: Literal["tumble", "hop"] | None = None
window_size: ir.IntervalScalar | None = None
window_slide: ir.IntervalScalar | None = None
window_offset: ir.IntervalScalar | None = None
groups: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None
metrics: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None

def __init__(self, time_col: ops.Column, **kwargs):
if time_col is None:
raise com.IbisInputError(
"Window aggregations require `time_col` as an argument"
)
self.parent = parent
self.time_col = time_col
super().__init__(time_col=time_col, **kwargs)

def tumble(
self,
size: ir.IntervalScalar,
offset: ir.IntervalScalar | None = None,
) -> WindowedTable:
self.window_type = "tumble"
self.window_slide = None
self.window_size = size
self.window_offset = offset
return self
return self.copy(window_type="tumble", window_size=size, window_offset=offset)

def hop(
self,
size: ir.IntervalScalar,
slide: ir.IntervalScalar,
offset: ir.IntervalScalar | None = None,
) -> WindowedTable:
self.window_type = "hop"
self.window_size = size
self.window_slide = slide
self.window_offset = offset
return self
return self.copy(
window_type="hop",
window_size=size,
window_slide=slide,
window_offset=offset,
)

def aggregate(
self,
metrics: Sequence[ir.Scalar] | None = (),
by: Sequence[ir.Value] | None = (),
by: str | ir.Value | Sequence[str] | Sequence[ir.Value] | None = (),
**kwargs: ir.Value,
) -> ir.Table:
groups = self.parent.bind(by)
by = self.parent.bind(by)
metrics = self.parent.bind(metrics, **kwargs)

groups = unwrap_aliases(groups)
by = unwrap_aliases(by)
metrics = unwrap_aliases(metrics)

groups = dict(self.groups) if self.groups is not None else {}
groups.update(by)

return ops.WindowAggregate(
self.parent,
self.window_type,
Expand All @@ -72,3 +83,11 @@ def aggregate(
).to_expr()

agg = aggregate

def group_by(
self, *by: str | ir.Value | Sequence[str] | Sequence[ir.Value]
) -> WindowedTable:
by = tuple(v for v in by if v is not None)
groups = self.parent.bind(*by)
groups = unwrap_aliases(groups)
return self.copy(groups=groups)
30 changes: 28 additions & 2 deletions ibis/tests/expr/test_temporal_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
],
ids=["tumble", "hop"],
)
def test_window_by_agg_schema(table, method):
@pytest.mark.parametrize("by", ["g", _.g, ["g"]])
def test_window_by_agg_schema(table, method, by):
expr = method(table.window_by(time_col=table.i))
expr = expr.agg(by=["g"], a_sum=_.a.sum())
expr = expr.agg(by=by, a_sum=_.a.sum())
expected_schema = ibis.schema(
{
"window_start": "timestamp",
Expand All @@ -36,3 +37,28 @@ def test_window_by_agg_schema(table, method):
def test_window_by_with_non_timestamp_column(table):
with pytest.raises(com.IbisInputError):
table.window_by(time_col=table.a)


@pytest.mark.parametrize(
"method",
[
methodcaller("tumble", size=ibis.interval(minutes=15)),
methodcaller(
"hop", size=ibis.interval(minutes=15), slide=ibis.interval(minutes=1)
),
],
ids=["tumble", "hop"],
)
@pytest.mark.parametrize("by", ["g", _.g, ["g"]])
def test_window_by_group_by_agg(table, method, by):
expr = method(table.window_by(time_col=table.i))
expr = expr.group_by(by).agg(a_sum=_.a.sum())
expected_schema = ibis.schema(
{
"window_start": "timestamp",
"window_end": "timestamp",
"g": "string",
"a_sum": "int64",
}
)
assert expr.schema() == expected_schema

0 comments on commit fc5d2fd

Please sign in to comment.