diff --git a/opteryx/functions/__init__.py b/opteryx/functions/__init__.py index 5334bc13..e2bfc3e7 100644 --- a/opteryx/functions/__init__.py +++ b/opteryx/functions/__init__.py @@ -175,7 +175,6 @@ def _coalesce(*args): "LENGTH": _iterate_single_parameter(get_len), # LENGTH(str) -> int "UPPER": compute.utf8_upper, # UPPER(str) -> str "LOWER": compute.utf8_lower, # LOWER(str) -> str - "TRIM": compute.utf8_trim_whitespace, # TRIM(str) -> str "LEFT": string_functions.string_slicer_left, "RIGHT": string_functions.string_slicer_right, "REVERSE": compute.utf8_reverse, @@ -187,6 +186,9 @@ def _coalesce(*args): "ENDS_WITH": string_functions.ends_w, "SUBSTRING": string_functions.substring, "POSITION": _iterate_double_parameter(string_functions.position), + "TRIM": string_functions.trim, + "LTRIM": string_functions.ltrim, + "RTRIM": string_functions.rtrim, # HASHING & ENCODING "HASH": _iterate_single_parameter(lambda x: format(CityHash64(str(x)), "X")), diff --git a/opteryx/functions/string_functions.py b/opteryx/functions/string_functions.py index 3c4bc9a2..3af3e622 100644 --- a/opteryx/functions/string_functions.py +++ b/opteryx/functions/string_functions.py @@ -252,3 +252,21 @@ def position(sub, string): Returns the starting position of the first instance of substring in string. Positions start with 1. If not found, 0 is returned. """ return string.find(sub) + 1 + + +def trim(*args): + if len(args) == 1: + return compute.utf8_trim_whitespace(args[0]) + return compute.utf8_trim(args[0], args[1][0]) + + +def ltrim(*args): + if len(args) == 1: + return compute.utf8_ltrim_whitespace(args[0]) + return compute.utf8_ltrim(args[0], args[1][0]) + + +def rtrim(*args): + if len(args) == 1: + return compute.utf8_rtrim_whitespace(args[0]) + return compute.utf8_rtrim(args[0], args[1][0]) diff --git a/opteryx/managers/planner/logical/builders.py b/opteryx/managers/planner/logical/builders.py index 3bbdc95f..6b94650c 100644 --- a/opteryx/managers/planner/logical/builders.py +++ b/opteryx/managers/planner/logical/builders.py @@ -523,8 +523,10 @@ def array_agg(branch, alias=None, key=None): expression = build(branch["expr"]) order = None if branch["order_by"]: + order = custom_builders.extract_order( + {"Query": {"order_by": [branch["order_by"]]}} + ) raise UnsupportedSyntaxError("`ORDER BY` not supported in `ARRAY_AGG`.") - # order = custom_builders.extract_order({"Query": {"order_by": [branch["order_by"]]}}) limit = None if branch["limit"]: limit = int(build(branch["limit"]).value) @@ -537,6 +539,29 @@ def array_agg(branch, alias=None, key=None): ) +def trim_string(branch, alias=None, key=None): + who = build(branch["trim_what"]) + what = build(branch["expr"]) + where = branch["trim_where"] or "Both" + + function = "TRIM" + if where == "Leading": + function = "LTRIM" + if where == "Trailing": + function = "RTRIM" + + parameters = [what] + if who is not None: + parameters.append(who) + + return ExpressionTreeNode( + NodeType.FUNCTION, + value=function, + parameters=parameters, + alias=alias, + ) + + def unsupported(branch, alias=None, key=None): """raise an error""" raise SqlError(key) @@ -601,6 +626,7 @@ def build(value, alias: list = None, key=None): "SimilarTo": pattern_match, "Substring": substring, "Tuple": tuple_literal, + "Trim": trim_string, "TryCast": try_cast, "TypedString": typed_string, "UnaryOp": unary_op, diff --git a/tests/sql_battery/test_battery_sql92.py b/tests/sql_battery/test_battery_sql92.py index 898f505d..689c8ff7 100644 --- a/tests/sql_battery/test_battery_sql92.py +++ b/tests/sql_battery/test_battery_sql92.py @@ -128,14 +128,14 @@ ("SELECT 'foo' || 'bar'", "E021-07"), ("SELECT LOWER ( 'foo' )", "E021-08"), ("SELECT UPPER ( 'foo' )", "E021-08"), -# ("SELECT TRIM ( 'foo' )", "E021-09"), -# ("SELECT TRIM ( 'foo' FROM 'foo' )", "E021-09"), -# ("SELECT TRIM ( BOTH 'foo' FROM 'foo' )", "E021-09"), + ("SELECT TRIM ( 'foo' )", "E021-09"), + ("SELECT TRIM ( 'foo' FROM 'foo' )", "E021-09"), + ("SELECT TRIM ( BOTH 'foo' FROM 'foo' )", "E021-09"), # ("SELECT TRIM ( BOTH FROM 'foo' )", "E021-09"), -# ("SELECT TRIM ( FROM 'foo' )", "E021-09"), -# ("SELECT TRIM ( LEADING 'foo' FROM 'foo' )", "E021-09"), + # ("SELECT TRIM ( FROM 'foo' )", "E021-09"), + ("SELECT TRIM ( LEADING 'foo' FROM 'foo' )", "E021-09"), # ("SELECT TRIM ( LEADING FROM 'foo' )", "E021-09"), -# ("SELECT TRIM ( TRAILING 'foo' FROM 'foo' )", "E021-09"), + ("SELECT TRIM ( TRAILING 'foo' FROM 'foo' )", "E021-09"), # ("SELECT TRIM ( TRAILING FROM 'foo' )", "E021-09"), ("SELECT POSITION ( 'foo' IN 'bar' )", "E021-11"), # ("SELECT POSITION ( 'foo' IN 'bar' USING CHARACTERS )", "E021-11"), diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 28ef4838..82bff37d 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -666,6 +666,10 @@ ("SELECT COUNT(*), place FROM (SELECT CASE id WHEN 3 THEN 'Earth' WHEN 1 THEN 'Mercury' END as place FROM $planets) GROUP BY place HAVING place IS NULL;", 1, 2, None), ("SELECT COUNT(*), place FROM (SELECT CASE id WHEN 3 THEN 'Earth' WHEN 1 THEN 'Mercury' ELSE 'Elsewhere' END as place FROM $planets) GROUP BY place HAVING place IS NULL;", 0, 2, None), + ("SELECT TRIM(LEADING 'E' FROM name) FROM $planets;", 9, 1, None), + ("SELECT * FROM $planets WHERE TRIM(TRAILING 'arth' FROM name) = 'E'", 1, 20, None), + ("SELECT * FROM $planets WHERE TRIM(TRAILING 'ahrt' FROM name) = 'E'", 1, 20, None), + # virtual dataset doesn't exist ("SELECT * FROM $RomanGods", None, None, DatasetNotFoundError), # disk dataset doesn't exist diff --git a/tests/sql_battery/tests/feature_tests.run_tests b/tests/sql_battery/tests/feature_tests.run_tests index 1221806d..61fef943 100644 --- a/tests/sql_battery/tests/feature_tests.run_tests +++ b/tests/sql_battery/tests/feature_tests.run_tests @@ -35,4 +35,11 @@ SELECT CASE WHEN id = 3 THEN 'Earth' WHEN id = 1 THEN 'Mercury' ELSE 'Elsewhere' SELECT CASE WHEN id = 3 THEN 'Earth' WHEN id = 1 THEN 'Mercury' END FROM $planets; SELECT * FROM $astronauts WHERE death_date > current_time - interval '50' YEAR; -SELECT * FROM $astronauts WHERE birth_date < current_time + interval '50' YEAR; \ No newline at end of file +SELECT * FROM $astronauts WHERE birth_date < current_time + interval '50' YEAR; + +SELECT LTRIM(' ABC'); +SELECT TRIM(LEADING '_' FROM '___ABC'); +SELECT TRIM('_' FROM '__init__'); +SELECT TRIM(BOTH '_' FROM '____dunder_'); +SELECT TRIM(TRAILING '__' FROM '_dunder_'); +SELECT RTRIM(' dunder ');