Skip to content

Commit

Permalink
fix(ux): make string range selectors inclusive
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 4, 2023
1 parent 325140f commit 7071669
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 37 deletions.
54 changes: 29 additions & 25 deletions ibis/expr/selectors.py
Expand Up @@ -50,10 +50,8 @@
from typing import Callable, Iterable, Mapping, Optional, Sequence, Union

from public import public
from typing_extensions import Annotated

import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import attribute
Expand Down Expand Up @@ -402,48 +400,54 @@ def __precomputed_hash__(self) -> int:
return hash((self.__class__, (self.start, self.stop, self.step)))


class RangeSelector(Selector):
key: Union[str, int, Annotated[slice, rlz.coerced_to(HashableSlice)]]

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
key = self.key

if isinstance(key, (str, int)):
return [table[key]]

start = key.start or 0
stop = key.stop or len(table.columns)
step = key.step or 1
class Sliceable(Singleton):
def __getitem__(self, key: str | int | slice) -> Predicate:
def pred(col):
import ibis.expr.analysis as an

schema = table.schema()
table = an.find_first_base_table(col.op())
schema = table.schema
idxs = schema._name_locs
num_names = len(schema)
colname = col.get_name()
colidx = idxs[colname]

if isinstance(start, str):
start = schema._name_locs[start]
if isinstance(key, str):
return key == colname
elif isinstance(key, int):
return key % num_names == colidx
else:
start = key.start or 0
stop = key.stop or num_names
step = key.step or 1

if isinstance(stop, str):
stop = schema._name_locs[stop]
if isinstance(start, str):
start = idxs[start]

return [table[i] for i in range(start, stop, step)]
if isinstance(stop, str):
stop = idxs[stop] + 1

return colidx in range(start, stop, step)

class Sliceable(Singleton):
def __getitem__(self, key: str | int | slice):
return RangeSelector(key=key)
return where(pred)


r = Sliceable()


@public
def first() -> Selector:
def first() -> Predicate:
"""Return the first column of a table."""
return r[0]


@public
def last() -> Selector:
def last() -> Predicate:
"""Return the last column of a table."""
return r[-1]


@public
def all() -> Predicate:
"""Return every column from a table."""
return r[:]
22 changes: 11 additions & 11 deletions ibis/expr/types/relations.py
Expand Up @@ -358,17 +358,17 @@ def __getitem__(self, what):
│ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │
└─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘
>>> t[s.r["bill_length_mm":"body_mass_g"]].head()
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ float64 │ float64 │ int64 │
├────────────────┼───────────────┼───────────────────┤
│ 39.1 │ 18.7 │ 181 │
│ 39.5 │ 17.4 │ 186 │
│ 40.3 │ 18.0 │ 195 │
│ nan │ nan │ ∅ │
│ 36.7 │ 19.3 │ 193 │
└────────────────┴───────────────┴───────────────────┘
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━
┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ body_mass_g ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━
│ float64 │ float64 │ int64 │ int64 │
├────────────────┼───────────────┼───────────────────┼─────────────
│ 39.1 │ 18.7 │ 181 │ 3750 │
│ 39.5 │ 17.4 │ 186 │ 3800 │
│ 40.3 │ 18.0 │ 195 │ 3250 │
│ nan │ nan │ ∅ │ ∅ │
│ 36.7 │ 19.3 │ 193 │ 3450 │
└────────────────┴───────────────┴───────────────────┴─────────────
"""
from ibis.expr.types.generic import Column
from ibis.expr.types.logical import BooleanValue
Expand Down
12 changes: 11 additions & 1 deletion ibis/tests/expr/test_selectors.py
Expand Up @@ -289,6 +289,10 @@ def test_if_any(penguins):
assert expr.equals(expected)


def test_negate_range(penguins):
assert penguins.select(~s.r[3:]).equals(penguins.select(0, 1, 2))


def test_string_range_start(penguins):
assert penguins.select(s.r["island":5]).equals(
penguins.select(penguins.columns[penguins.columns.index("island") : 5])
Expand All @@ -297,7 +301,13 @@ def test_string_range_start(penguins):

def test_string_range_end(penguins):
assert penguins.select(s.r[:"island"]).equals(
penguins.select(penguins.columns[: penguins.columns.index("island")])
penguins.select(penguins.columns[: penguins.columns.index("island") + 1])
)


def test_string_element(penguins):
assert penguins.select(~s.r["island"]).equals(
penguins.select([c for c in penguins.columns if c != "island"])
)


Expand Down

0 comments on commit 7071669

Please sign in to comment.