Skip to content

Commit

Permalink
feat(api): support list of strings and single strings in the across
Browse files Browse the repository at this point in the history
… selector
  • Loading branch information
cpcloud authored and kszucs committed Jun 20, 2023
1 parent 374e14b commit a6b60e7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
13 changes: 9 additions & 4 deletions ibis/selectors.py
Expand Up @@ -399,22 +399,25 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:

@public
def across(
selector: Selector,
selector: Selector | Iterable[str] | str,
func: Deferred
| Callable[[ir.Value], ir.Value]
| Mapping[str | None, Deferred | Callable[[ir.Value], ir.Value]],
names: str | Callable[[str, str | None], str] | None = None,
) -> Across:
"""Applies the same data transformation function across multiple columns.
"""Applies data transformations across multiple columns.
Parameters
----------
selector
An expression that selects columns on which the transformation function will be applied.
An expression that selects columns on which the transformation function
will be applied, an iterable of `str` column names or a single `str`
column name.
func
A function (or a dictionary of functions) to use to transform the data.
names
A lambda function or a format string to name the columns created by the transformation function.
A lambda function or a format string to name the columns created by the
transformation function.
Returns
-------
Expand Down Expand Up @@ -455,6 +458,8 @@ def across(
if names is None:
names = lambda col, fn: "_".join(filter(None, (col, fn)))
funcs = frozendict(func if isinstance(func, Mapping) else {None: func})
if not isinstance(selector, Selector):
selector = c(*util.promote_list(selector))
return Across(selector=selector, funcs=funcs, names=names)


Expand Down
12 changes: 12 additions & 0 deletions ibis/tests/expr/test_selectors.py
Expand Up @@ -301,6 +301,18 @@ def test_across_group_by_agg_with_grouped_selectors(penguins, expr_func):
assert expr.equals(expected)


def test_across_list(penguins):
expr = penguins.agg(s.across(["species", "island"], lambda c: c.count()))
expected = penguins.agg(species=_.species.count(), island=_.island.count())
assert expr.equals(expected)


def test_across_str(penguins):
expr = penguins.agg(s.across("species", lambda c: c.count()))
expected = penguins.agg(species=_.species.count())
assert expr.equals(expected)


def test_if_all(penguins):
expr = penguins.filter(s.if_all(s.numeric() & ~s.c("year"), _ > 5))
expected = penguins.filter(
Expand Down

0 comments on commit a6b60e7

Please sign in to comment.