Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Nov 17, 2022
1 parent 8534e07 commit bff5acf
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 23 deletions.
2 changes: 2 additions & 0 deletions opteryx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def connect(*args, **kwargs):
os.nice(-20 + nice_value)
print(f"Process priority set to {os.nice(0)}.")
except PermissionError:
if nice_value == 0:
nice_value = "0 (normal)"
print(f"Cannot update process priority. Currently set to {nice_value}.")

# Log resource usage
Expand Down
15 changes: 13 additions & 2 deletions opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def format_expression(root):
return str(root.value)
# INTERAL IDENTIFIERS
if node_type & INTERNAL_TYPE == INTERNAL_TYPE:
if node_type in (NodeType.FUNCTION, NodeType.AGGREGATOR):
if node_type in (
NodeType.FUNCTION,
NodeType.AGGREGATOR,
NodeType.COMPLEX_AGGREGATOR,
):
if root.value == "CASE":
con = [format_expression(a) for a in root.parameters[0].value]
vals = [format_expression(a) for a in root.parameters[1].value]
Expand All @@ -66,6 +70,11 @@ def format_expression(root):
+ "".join([f"WHERE {c} THEN {v} " for c, v in zip(con, vals)])
+ "END"
)
if root.value == "ARRAY_AGG":
distinct = "DISTINCT " if root.parameters[1] else ""
order = f" ORDER BY {root.parameters[2]}" if root.parameters[2] else ""
limit = f" LIMIT {root.parameters[3]}" if root.parameters[3] else ""
return f"{root.value.upper()}({distinct}{format_expression(root.parameters[0])}{order}{limit})"
return f"{root.value.upper()}({','.join([format_expression(e) for e in root.parameters])})"
if node_type == NodeType.WILDCARD:
return "*"
Expand Down Expand Up @@ -128,7 +137,9 @@ class NodeType(int, Enum):
SUBQUERY: int = 114
NESTED: int = 130
AGGREGATOR:int = 146
EXPRESSION_LIST:int = 162 # 1010 0010
COMPLEX_AGGREGATOR: int = 162
EXPRESSION_LIST:int = 178 # 1011 0010


# LITERAL TYPES
# nnnn0100
Expand Down
22 changes: 22 additions & 0 deletions opteryx/managers/planner/logical/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,27 @@ def case_when(value, alias: list = None, key=None):
)


def array_agg(branch, alias=None, key=None):
from opteryx.managers.planner.logical import custom_builders

distinct = branch["distinct"]
expression = build(branch["expr"])
order = None
if branch["order_by"]:
raise UnsupportedSyntaxError("`ORDER BY` not supported in `ARRAY_AGG`.")
# order = custom_builders.extract_order({"Query": {"order_by": [branch["order_by"]]}})
limit = None
if branch["limit"]:
limit = int(build(branch["limit"]).value)

return ExpressionTreeNode(
token_type=NodeType.COMPLEX_AGGREGATOR,
value="ARRAY_AGG",
parameters=(expression, distinct, order, limit),
alias=alias,
)


def unsupported(branch, alias=None, key=None):
"""raise an error"""
raise SqlError(key)
Expand Down Expand Up @@ -544,6 +565,7 @@ def build(value, alias: list = None, key=None):

# parts to build the literal parts of a query
BUILDERS = {
"ArrayAgg": array_agg,
"Between": between,
"BinaryOp": binary_op,
"Boolean": literal_boolean,
Expand Down
13 changes: 1 addition & 12 deletions opteryx/managers/planner/logical/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,11 @@ def select_query(ast, properties):
_projection = builders.build(ast["Query"]["body"]["Select"]["projection"])
_groups = builders.build(ast["Query"]["body"]["Select"]["group_by"])
if _groups or get_all_nodes_of_type(
_projection, select_nodes=(NodeType.AGGREGATOR,)
_projection, select_nodes=(NodeType.AGGREGATOR, NodeType.COMPLEX_AGGREGATOR)
):
_aggregates = _projection.copy()
if isinstance(_aggregates, dict):
raise SqlError("GROUP BY cannot be used with SELECT *")
if not any(
a.token_type == NodeType.AGGREGATOR
for a in _aggregates
if isinstance(a, ExpressionTreeNode)
):
wildcard = ExpressionTreeNode(NodeType.WILDCARD)
_aggregates.append(
ExpressionTreeNode(
NodeType.AGGREGATOR, value="COUNT", parameters=[wildcard]
)
)
plan.add_operator(
"agg",
operators.AggregateNode(properties, aggregates=_aggregates, groups=_groups),
Expand Down
40 changes: 33 additions & 7 deletions opteryx/operators/aggregate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@
"ALL": "all",
"ANY": "any",
"APPROXIMATE_MEDIAN": "approximate_median",
"ARRAY_AGG": "hash_list",
"COUNT": "count", # counts only non nulls
"COUNT_DISTINCT": "count_distinct",
"DISTINCT": "distinct",
"LIST": "hash_list",
"DISTINCT": "distinct", # fated
"LIST": "hash_list", # fated
"MAX": "max",
"MAXIMUM": "max", # alias
"MEAN": "mean",
Expand Down Expand Up @@ -121,10 +122,13 @@ def _build_aggs(aggregators, columns):
for root in aggregators:

for aggregator in get_all_nodes_of_type(
root, select_nodes=(NodeType.AGGREGATOR,)
root, select_nodes=(NodeType.AGGREGATOR, NodeType.COMPLEX_AGGREGATOR)
):

if aggregator.token_type == NodeType.AGGREGATOR:
if aggregator.token_type in (
NodeType.AGGREGATOR,
NodeType.COMPLEX_AGGREGATOR,
):
field_node = aggregator.parameters[0]
display_name = format_expression(field_node)
exists = columns.get_column_from_alias(display_name)
Expand Down Expand Up @@ -152,9 +156,14 @@ def _build_aggs(aggregators, columns):
f"Invalid identifier or literal provided in aggregator function `{display_name}`"
)
function = AGGREGATORS.get(aggregator.value)
if aggregator.value == "ARRAY_AGG":
# if the array agg is distinct, base off that function instead
if aggregator.parameters[1]:
function = "distinct"
aggs.append((field_name, function, count_options))
column_map[
f"{aggregator.value.upper()}({display_field})"
format_expression(aggregator)
# f"{aggregator.value.upper()}({display_field})"
] = f"{field_name}_{function}".replace("_hash_", "_")

return column_map, aggs
Expand All @@ -171,7 +180,7 @@ def _non_group_aggregates(aggregates, table, columns):

for aggregate in aggregates:

if aggregate.token_type == NodeType.AGGREGATOR:
if aggregate.token_type in (NodeType.AGGREGATOR, NodeType.COMPLEX_AGGREGATOR):

column_node = aggregate.parameters[0]
if column_node.token_type == NodeType.LITERAL_NUMERIC:
Expand All @@ -190,7 +199,7 @@ def _non_group_aggregates(aggregates, table, columns):
# pyarrow.compute module
if not hasattr(pyarrow.compute, aggregate_function_name):
raise UnsupportedSyntaxError(
f"Aggregate {aggregate.value} can only be used with GROUP BY"
f"Aggregate `{aggregate.value}` can only be used with GROUP BY"
)
aggregate_function = getattr(pyarrow.compute, aggregate_function_name)
aggregate_column_value = aggregate_function(raw_column_values).as_py()
Expand Down Expand Up @@ -307,6 +316,23 @@ def execute(self) -> Iterable:
groups = table.group_by(group_by_columns)
groups = groups.aggregate(aggs)

# do the secondary activities on ARRAY_AGG
for agg in [a for a in self._aggregates if a.value == "ARRAY_AGG"]:
_, _, order, limit = agg.parameters
if order or limit:
# rip the column out of the table
column_name = column_map[format_expression(agg)]
column_def = groups.field(column_name)
column = groups.column(column_name).to_pylist()
groups = groups.drop([column_name])
# order
if order:
pass
if limit:
column = [c[:limit] for c in column]
# put the new column into the table
groups = groups.append_column(column_def, [column])

# name the aggregate fields
for friendly_name, agg_name in column_map.items():
columns.add_column(agg_name)
Expand Down
1 change: 1 addition & 0 deletions opteryx/operators/projection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, properties: QueryProperties, **config):
elif attribute.token_type in (
NodeType.FUNCTION,
NodeType.AGGREGATOR,
NodeType.COMPLEX_AGGREGATOR,
NodeType.BINARY_OPERATOR,
NodeType.COMPARISON_OPERATOR,
) or (attribute.token_type & LITERAL_TYPE == LITERAL_TYPE):
Expand Down
42 changes: 41 additions & 1 deletion opteryx/third_party/distogram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,46 @@ def __init__(self, bin_count: int = 100, weighted_diff: bool = False):
self.min_diff: Optional[float] = None
self.weighted_diff: bool = weighted_diff

## all class methods below here have been added for Opteryx
def dump(self): # pragma: no cover
import orjson

return orjson.dumps(
{
"bin_count": self.bin_count,
"bins": self.bins,
"min": self.min,
"max": self.max,
"diffs": self.diffs,
"min_diff": self.min_diff,
"weighted_diff": self.weighted_diff,
}
)

def __add__(self, operand): # pragma: no cover
dgram = merge(self, operand)
# merge estimates min and max, so set them manually
dgram.min = min(self.min, operand.min)
dgram.max = max(self.max, operand.max)
return dgram


# added for opteryx
def load(dic): # pragma: no cover
if not isinstance(dic, dict):
import orjson

dic = orjson.loads(dic)
dgram = Distogram()
dgram.bin_count = dic["bin_count"]
dgram.bins = dic["bins"]
dgram.min = dic["min"]
dgram.max = dic["max"]
dgram.diffs = dic["diffs"]
dgram.min_diff = dic["min_diff"]
dgram.weighted_diff = dic["weighted_diff"]
return dgram


def _linspace(start: float, stop: float, num: int) -> List[float]: # pragma: no cover
if num == 1:
Expand Down Expand Up @@ -284,7 +324,7 @@ def count(h: Distogram) -> float: # pragma: no cover
Returns:
The number of elements in the distribution.
"""
return sum((f for _, f in h.bins))
return sum(f for _, f in h.bins)


def bounds(h: Distogram) -> Tuple[float, float]: # pragma: no cover
Expand Down
14 changes: 13 additions & 1 deletion tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
import opteryx

from opteryx.connectors import DiskConnector
from opteryx.exceptions import SqlError, DatasetNotFoundError

from opteryx.exceptions import DatasetNotFoundError
from opteryx.exceptions import SqlError
from opteryx.exceptions import UnsupportedSyntaxError

# fmt:off
STATEMENTS = [
Expand Down Expand Up @@ -647,6 +650,15 @@
("SHOW STORES LIKE 'apple'", None, None, SqlError),
("SELECT name FROM $astronauts WHERE LEFT(name, POSITION(' ' IN name) - 1) = 'Andrew'", 3, 1, None),
("SELECT name FROM $astronauts WHERE LEFT(name, POSITION(' ' IN name)) = 'Andrew '", 3, 1, None),

("SELECT ARRAY_AGG(name) from $satellites GROUP BY planetId", 7, 1, None),
("SELECT ARRAY_AGG(DISTINCT name) from $satellites GROUP BY planetId", 7, 1, None),
("SELECT ARRAY_AGG(name ORDER BY name) from $satellites GROUP BY TRUE", None, None, UnsupportedSyntaxError),
("SELECT ARRAY_AGG(name LIMIT 1) from $satellites GROUP BY planetId", 7, 1, None),
("SELECT ARRAY_AGG(DISTINCT name LIMIT 1) from $satellites GROUP BY planetId", 7, 1, None),
("SELECT COUNT(*), ARRAY_AGG(name) from $satellites GROUP BY planetId", 7, 2, None),
("SELECT planetId, COUNT(*), ARRAY_AGG(name) from $satellites GROUP BY planetId", 7, 3, None),
("SELECT ARRAY_AGG(DISTINCT LEFT(name, 1)) from $satellites GROUP BY planetId", 7, 1, None),

("SELECT COUNT(*), place FROM (SELECT CASE id WHEN 3 THEN 'Earth' WHEN 1 THEN 'Mercury' ELSE 'Elsewhere' END as place FROM $planets) GROUP BY place;", 3, 2, None),
("SELECT COUNT(*), place FROM (SELECT CASE id WHEN 3 THEN 'Earth' WHEN 1 THEN 'Mercury' END as place FROM $planets) GROUP BY place HAVING place IS NULL;", 1, 2, None),
Expand Down

0 comments on commit bff5acf

Please sign in to comment.