Skip to content

Commit

Permalink
fix(substitute): allow mappings with None keys
Browse files Browse the repository at this point in the history
We don't need to sort this dictionary anymore.
  • Loading branch information
gforsyth authored and cpcloud committed Nov 21, 2023
1 parent 6e3219f commit 4b28ff1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
16 changes: 16 additions & 0 deletions ibis/backends/tests/test_generic.py
Expand Up @@ -1608,3 +1608,19 @@ def test_sample_with_seed(backend):
df1 = expr.to_pandas()
df2 = expr.to_pandas()
backend.assert_frame_equal(df1, df2)


@pytest.mark.broken(
["dask"], reason="implementation somehow differs from pandas", raises=ValueError
)
def test_substitute(backend):
val = "400"
t = backend.functional_alltypes
expr = (
t.string_col.nullif("1")
.substitute({None: val})
.name("subs")
.value_counts()
.filter(lambda t: t.subs == val)
)
assert expr["subs_count"].execute()[0] == t.count().execute() // 10
15 changes: 11 additions & 4 deletions ibis/expr/types/generic.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence

from public import public
Expand Down Expand Up @@ -683,12 +684,18 @@ def substitute(
│ torg │ 52 │
└────────┴──────────────┘
"""
expr = self.case()
if isinstance(value, dict):
for k, v in sorted(value.items()):
expr = expr.when(k, v)
expr = ibis.case()
try:
null_replacement = value.pop(None)
except KeyError:
pass
else:
expr = expr.when(self.isnull(), null_replacement)
for k, v in value.items():
expr = expr.when(self == k, v)
else:
expr = expr.when(value, replacement)
expr = self.case().when(value, replacement)

return expr.else_(else_ if else_ is not None else self).end()

Expand Down
12 changes: 10 additions & 2 deletions ibis/tests/expr/test_value_exprs.py
Expand Up @@ -792,13 +792,21 @@ def test_substitute_dict():

result = table.foo.substitute(subs)
expected = (
table.foo.case().when("a", "one").when("b", table.bar).else_(table.foo).end()
ibis.case()
.when(table.foo == "a", "one")
.when(table.foo == "b", table.bar)
.else_(table.foo)
.end()
)
assert_equal(result, expected)

result = table.foo.substitute(subs, else_=ibis.NA)
expected = (
table.foo.case().when("a", "one").when("b", table.bar).else_(ibis.NA).end()
ibis.case()
.when(table.foo == "a", "one")
.when(table.foo == "b", table.bar)
.else_(ibis.NA)
.end()
)
assert_equal(result, expected)

Expand Down

0 comments on commit 4b28ff1

Please sign in to comment.