diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 1903c8fd9..a66b178dc 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -535,6 +535,19 @@ def trim(self, s, search): return strip_call(search) +class ReplaceOperation(Operation): + """The replace operator (replace occurrences of pattern in a string)""" + + def __init__(self): + super().__init__(self.replace) + + def replace(self, s, pat, repl): + if is_frame(s): + s = s.str + + return s.replace(pat, repl) + + class OverlayOperation(Operation): """The overlay operator (replace string according to positions)""" @@ -965,6 +978,7 @@ class RexCallPlugin(BaseRexPlugin): "substr": SubStringOperation(), "substring": SubStringOperation(), "initcap": TensorScalarOperation(lambda x: x.str.title(), lambda x: x.title()), + "replace": ReplaceOperation(), # date/time operations "extract": ExtractOperation(), "localtime": Operation(lambda *args: pd.Timestamp.now()), diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 655ff69de..b7d455fe3 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -522,7 +522,9 @@ def test_string_functions(c, gpu): SUBSTR(a, 3, 6) AS s, INITCAP(a) AS t, INITCAP(UPPER(a)) AS u, - INITCAP(LOWER(a)) AS v + INITCAP(LOWER(a)) AS v, + REPLACE(a, 'r', 'l') as w, + REPLACE('Another String', 'th', 'b') as x FROM {input_table} """ @@ -555,6 +557,8 @@ def test_string_functions(c, gpu): "t": ["A Normal String"], "u": ["A Normal String"], "v": ["A Normal String"], + "w": ["a nolmal stling"], + "x": ["Anober String"], } ) diff --git a/tests/unit/test_call.py b/tests/unit/test_call.py index 0075c5cb5..05b116af8 100644 --- a/tests/unit/test_call.py +++ b/tests/unit/test_call.py @@ -182,6 +182,9 @@ def test_string_operations(): assert ops_mapping["substring"](a, 2) == " normal string" assert ops_mapping["substring"](a, 2, 2) == " n" assert ops_mapping["initcap"](a) == "A Normal String" + assert ops_mapping["replace"](a, "nor", "") == "a mal string" + assert ops_mapping["replace"](a, "normal", "new") == "a new string" + assert ops_mapping["replace"]("hello", "", "w") == "whwewlwlwow" def test_dates():