Skip to content

Commit

Permalink
Merge pull request #238 from mabel-dev/FEATURE/#237
Browse files Browse the repository at this point in the history
Feature/#237
  • Loading branch information
joocer committed Jun 26, 2022
2 parents 633aa7b + 0a448c3 commit 5a47eb6
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/Release Notes/Change Log.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- [[#226](https://github.com/mabel-dev/opteryx/issues/226)] Implement `DATE_TRUNC` function. ([@joocer](https://github.com/joocer))
- [[#230](https://github.com/mabel-dev/opteryx/issues/230)] Allow addressing fields as numbers. ([@joocer](https://github.com/joocer))
- [[#234](https://github.com/mabel-dev/opteryx/issues/234)] Implement `SEARCH` function. ([@joocer](https://github.com/joocer))
- [[#237](https://github.com/mabel-dev/opteryx/issues/237)] Implement `COALESCE` function. ([@joocer](https://github.com/joocer))


**Changed**
Expand Down
1 change: 1 addition & 0 deletions docs/SQL Reference/06 Functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Function | Description | Exampl
------------------- | ------------------------------------------------- | ---------------------------
`BOOLEAN(str)` | Convert input to a Boolean | `BOOLEAN('true') -> True`
`CAST(any AS type)` | Cast any to type, calls `type(any)` | `CAST(state AS BOOLEAN) -> False`
`COALESCE(args)` | Return the first item from args which is not None | `CAST(university, high_school) -> 'Olympia High'`
`GET(list, n)` :fontawesome-solid-asterisk: | Gets the nth element in a list, also `list[n]` | `GET(names, 2) -> 'Joe'`
`GET(struct, a)` :fontawesome-solid-asterisk: | Gets the element called 'a' from a struct, also `struct[a]` | `GET(dict, 'key') -> 'value'`
`HASH(str)` | Calculate the [CityHash](https://opensource.googleblog.com/2011/04/introducing-cityhash.html) (64 bit) of a value | `HASH('hello') -> 'B48BE5A931380CE8'`
Expand Down
1 change: 1 addition & 0 deletions opteryx/engine/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def get_len(obj):
"LIST_CONTAINS_ANY": _iterate_double_parameter(other_functions._list_contains_any),
"LIST_CONTAINS_ALL": _iterate_double_parameter(other_functions._list_contains_all),
"SEARCH": other_functions._search,
"COALESCE": other_functions._coalesce,
# NUMERIC
"ROUND": compute.round,
"FLOOR": compute.floor,
Expand Down
29 changes: 28 additions & 1 deletion opteryx/engine/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# limitations under the License.

import numpy
import pyarrow

from pyarrow import compute

Expand Down Expand Up @@ -45,6 +44,8 @@ def _search(array, item):
return None
if array_type == str:
# return True if the value is in the string
# find_substring returns -1 or an index, we need to convert this to a boolean
# and then to a list of lists for pyarrow
res = compute.find_substring(array, pattern=item, ignore_case=True)
res = ~(res.to_numpy() < 0)
return ([r] for r in res)
Expand All @@ -58,3 +59,29 @@ def _search(array, item):
for record in array
)
return [False] * array.shape[0]


def _coalesce(*args):
def _make_list(arr, length):
if not isinstance(arr, numpy.ndarray):
return [arr] * length

cycles = max([0] + [len(a) for a in args if isinstance(a, numpy.ndarray)])
if cycles == 0:
raise Exception("something has gone wrong")

my_args = list(args)

for i in range(len(args)):
if not isinstance(args[i], numpy.ndarray):
my_args[i] = _make_list(args[i], cycles)

def inner_coalesce(iterable):
for element in iterable:
print(element)
if (element is not None) and (element == element): # nosemgrep
return element
return None

for row in zip(*my_args):
yield [inner_coalesce(row)]
28 changes: 16 additions & 12 deletions opteryx/engine/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def __init__(self, statistics, reader, cache, partition_scheme):
self._cache = cache
self._partition_scheme = partition_scheme

self._start_date = datetime.datetime.utcnow().date()
self._end_date = datetime.datetime.utcnow().date()
self.start_date = datetime.datetime.utcnow().date()
self.end_date = datetime.datetime.utcnow().date()

def __repr__(self):
return "QueryPlanner"
Expand All @@ -98,8 +98,8 @@ def copy(self):
cache=self._cache,
partition_scheme=self._partition_scheme,
)
planner._start_date = self._start_date
planner._end_date = self._end_date
planner.start_date = self.start_date
planner.end_date = self.end_date
return planner

def create_plan(self, sql: str = None, ast: dict = None):
Expand All @@ -108,7 +108,7 @@ def create_plan(self, sql: str = None, ast: dict = None):
import sqloxide

# extract temporal filters, this isn't supported by sqloxide
self._start_date, self._end_date, sql = extract_temporal_filters(sql)
self.start_date, self.end_date, sql = extract_temporal_filters(sql)
# Parse the SQL into a AST
try:
self._ast = sqloxide.parse_sql(sql, dialect="mysql")
Expand All @@ -134,7 +134,7 @@ def _extract_value(self, value):
"""
extract values from a value node
"""
if value is None:
if value is None or value in ("None", "Null"):
return (None, None)
if "SingleQuotedString" in value:
# quoted strings are either VARCHAR or TIMESTAMP
Expand All @@ -153,6 +153,10 @@ def _extract_value(self, value):
[self._extract_value(t["Value"])[0] for t in value["Tuple"]],
TOKEN_TYPES.LIST,
)
if "Value" in value:
if value["Value"] == "Null":
return (None, TOKEN_TYPES.OTHER)
return (value["Value"], TOKEN_TYPES.OTHER)

def _build_dnf_filters(self, filters):

Expand Down Expand Up @@ -615,8 +619,8 @@ def _show_columns_planner(self, ast, statistics):
reader=self._reader,
cache=None, # never read from cache
partition_scheme=self._partition_scheme,
start_date=self._start_date,
end_date=self._end_date,
start_date=self.start_date,
end_date=self.end_date,
),
)
last_node = "reader"
Expand Down Expand Up @@ -681,8 +685,8 @@ def _naive_select_planner(self, ast, statistics):
reader=self._reader,
cache=self._cache,
partition_scheme=self._partition_scheme,
start_date=self._start_date,
end_date=self._end_date,
start_date=self.start_date,
end_date=self.end_date,
hints=hints,
),
)
Expand All @@ -708,8 +712,8 @@ def _naive_select_planner(self, ast, statistics):
reader=self._reader,
cache=self._cache,
partition_scheme=self._partition_scheme,
start_date=self._start_date,
end_date=self._end_date,
start_date=self.start_date,
end_date=self.end_date,
hints=right[3],
)

Expand Down
21 changes: 0 additions & 21 deletions opteryx/engine/planner/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,3 @@ def extract_temporal_filters(sql):
pass

return start_date, end_date, sql


if __name__ == "__main__":

def date_range(
start_date,
end_date,
):

if end_date < start_date: # type:ignore
raise ValueError(
"date_range: end_date must be the same or later than the start_date "
)

for n in range(int((end_date - start_date).days) + 1): # type:ignore
yield start_date + datetime.timedelta(n) # type:ignore

s = datetime.date.today().replace(day=1, month=1)
e = s.replace(year=s.year + 1)
for d in date_range(s, e):
print(d, _subtract_one_month(d))
5 changes: 2 additions & 3 deletions opteryx/utils/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,9 @@ def fuzzy_search(self, column_name):
best_match_column = None
best_match_score = 100

for k, v in self._column_metadata.items():
for alias in v.get("aliases"):
for attributes in self._column_metadata.values():
for alias in attributes.get("aliases"):
my_dist = compare(column_name, alias)
print(alias)
if my_dist > 0 and my_dist < best_match_score:
best_match_score = my_dist
best_match_column = alias
Expand Down
2 changes: 1 addition & 1 deletion opteryx/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
2) we can import it in setup.py for the same reason
"""

__version__ = "0.0.3-beta.20"
__version__ = "0.0.3-beta.21"
3 changes: 3 additions & 0 deletions tests/sql_battery/test_battery_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@
("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('year', birth_date) AS BIRTH_YEAR FROM $astronauts)", 54, 1),
("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('month', birth_date) AS BIRTH_YEAR_MONTH FROM $astronauts)", 247, 1),

("SELECT COALESCE(graduate_major, undergraduate_major, 'high school') as ed FROM $astronauts WHERE ed = 'high school'", 4, 1),
("SELECT COALESCE(graduate_major, undergraduate_major) AS ed, graduate_major, undergraduate_major FROM $astronauts WHERE ed = 'Aeronautical Engineering'", 41, 3),

("SELECT SEARCH(name, 'al'), name FROM $satellites", 177, 2),
("SELECT name FROM $satellites WHERE SEARCH(name, 'al')", 18, 1),
("SELECT SEARCH(missions, 'Apollo 11'), missions FROM $astronauts", 357, 2),
Expand Down

0 comments on commit 5a47eb6

Please sign in to comment.