Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1479 #1482

Merged
merged 2 commits into from
Feb 27, 2024
Merged

#1479 #1482

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 323
__build__ = 324

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion opteryx/connectors/base/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

MIN_CHUNK_SIZE: int = 500
INITIAL_CHUNK_SIZE: int = 500
DEFAULT_MORSEL_SIZE: int = 1024 * 1024
DEFAULT_MORSEL_SIZE: int = 16 * 1024 * 1024


class BaseConnector:
Expand Down
33 changes: 25 additions & 8 deletions opteryx/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class SqlConnector(BaseConnector, PredicatePushable):
"LtEq": "<=",
"Like": "LIKE",
"NotLike": "NOT LIKE",
"IsTrue": "IS TRUE",
"IsNotTrue": "IS NOT TRUE",
"IsFalse": "IS FALSE",
"IsNotFalse": "IS NOT FALSE",
"IsNull": "IS NULL",
"IsNotNull": "IS NOT NULL",
}

def __init__(self, *args, connection: str = None, engine=None, **kwargs):
Expand Down Expand Up @@ -104,6 +110,11 @@ def __init__(self, *args, connection: str = None, engine=None, **kwargs):
self.schema = None # type: ignore
self.metadata = MetaData()

def can_push(self, operator: Node, types: set = None) -> bool:
if super().can_push(operator, types):
return True
return operator.condition.node_type == NodeType.UNARY_OPERATOR

def read_dataset( # type:ignore
self,
*,
Expand Down Expand Up @@ -134,14 +145,20 @@ def read_dataset( # type:ignore
# Update SQL if we've pushed predicates
parameters: dict = {}
for predicate in predicates:
left_operand = predicate.left
right_operand = predicate.right
operator = self.OPS_XLAT[predicate.value]
if predicate.node_type == NodeType.UNARY_OPERATOR:
operand = predicate.centre.current_name
operator = self.OPS_XLAT[predicate.value]

query_builder.WHERE(f"{operand} {operator}")
else:
left_operand = predicate.left
right_operand = predicate.right
operator = self.OPS_XLAT[predicate.value]

left_value, parameters = _handle_operand(left_operand, parameters)
right_value, parameters = _handle_operand(right_operand, parameters)
left_value, parameters = _handle_operand(left_operand, parameters)
right_value, parameters = _handle_operand(right_operand, parameters)

query_builder.WHERE(f"{left_value} {operator} {right_value}")
query_builder.WHERE(f"{left_value} {operator} {right_value}")

# Use orso as an intermediatary, it's row-based so is well suited to processing
# records coming back from a SQL query, and it has a well-optimized to arrow
Expand All @@ -154,7 +171,7 @@ def read_dataset( # type:ignore
# DEBUG: log ("READ DATASET\n", str(query_builder))
# DEBUG: log ("PARAMETERS\n", parameters)
# Execution Options allows us to handle datasets larger than memory
result = conn.execution_options(stream_results=True, max_row_buffer=500).execute(
result = conn.execution_options(stream_results=True, max_row_buffer=5000).execute(
text(str(query_builder)), parameters=parameters
)

Expand All @@ -171,7 +188,7 @@ def read_dataset( # type:ignore

# Dynamically adjust chunk size based on the data size, we start by downloading
# 500 records to get an idea of the row size, assuming these 500 are
# representative, we work out how many rows fit into 8Mb.
# representative, we work out how many rows fit into 16Mb (check setting).
# Don't keep recalculating, this is not a cheap operation and it's predicting
# the future so isn't going to ever be 100% correct
if self.chunk_size == INITIAL_CHUNK_SIZE and morsel.nbytes() > 0:
Expand Down
109 changes: 70 additions & 39 deletions tests/plan_optimization/test_predicate_pushdown_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import time
import pytest
import opteryx
from opteryx.connectors import SqlConnector
from opteryx.utils.formatter import format_sql
from tests.tools import is_arm, is_mac, is_version, is_windows, skip_if

POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD")
Expand All @@ -22,48 +25,76 @@
connection=f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@trumpet.db.elephantsql.com/{POSTGRES_USER}",
)

test_cases = [
("SELECT * FROM pg.planets WHERE gravity <= 3.7", 3, 3),
("SELECT * FROM pg.planets WHERE name != 'Earth'", 8, 8),
("SELECT * FROM pg.planets WHERE name != 'E\"arth'", 9, 9),
("SELECT * FROM pg.planets WHERE gravity != 3.7", 7, 7),
("SELECT * FROM pg.planets WHERE gravity < 3.7", 1, 1),
("SELECT * FROM pg.planets WHERE gravity > 3.7", 6, 6),
("SELECT * FROM pg.planets WHERE gravity >= 3.7", 8, 8),
("SELECT * FROM pg.planets WHERE name LIKE '%a%'", 4, 4),
("SELECT * FROM pg.planets WHERE id > gravity", 2, 2),
("SELECT * FROM pg.planets WHERE surface_pressure IS NULL", 4, 4),
]


# skip to reduce contention
@skip_if(is_arm() or is_windows() or is_mac() or not is_version("3.10"))
def test_predicate_pushdown_postgres_other():
res = opteryx.query("SELECT * FROM pg.planets WHERE gravity <= 3.7")
assert res.rowcount == 3, res.rowcount
assert res.stats.get("rows_read", 0) == 3, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE name != 'Earth'")
assert res.rowcount == 8, res.rowcount
assert res.stats.get("rows_read", 0) == 8, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE name != 'E\"arth'")
assert res.rowcount == 9, res.rowcount
assert res.stats.get("rows_read", 0) == 9, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE gravity != 3.7")
assert res.rowcount == 7, res.rowcount
assert res.stats.get("rows_read", 0) == 7, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE gravity < 3.7")
assert res.rowcount == 1, res.rowcount
assert res.stats.get("rows_read", 0) == 1, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE gravity > 3.7")
assert res.rowcount == 6, res.rowcount
assert res.stats.get("rows_read", 0) == 6, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE gravity >= 3.7")
assert res.rowcount == 8, res.rowcount
assert res.stats.get("rows_read", 0) == 8, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE name LIKE '%a%'")
assert res.rowcount == 4, res.rowcount
assert res.stats.get("rows_read", 0) == 4, res.stats

res = opteryx.query("SELECT * FROM pg.planets WHERE id > gravity")
assert res.rowcount == 2, res.rowcount
assert res.stats.get("rows_read", 0) == 2, res.stats
@pytest.mark.parametrize("statement,expected_rowcount,expected_rows_read", test_cases)
def test_predicate_pushdown_postgres_parameterized(
statement, expected_rowcount, expected_rows_read
):
res = opteryx.query(statement)
assert res.rowcount == expected_rowcount, f"Expected {expected_rowcount}, got {res.rowcount}"
assert (
res.stats.get("rows_read", 0) == expected_rows_read
), f"Expected {expected_rows_read}, got {res.stats.get('rows_read', 0)}"


if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
import shutil

from tests.tools import trunc_printable

start_suite = time.monotonic_ns()
passed = 0
failed = 0

width = shutil.get_terminal_size((80, 20))[0] - 15

print(f"RUNNING BATTERY OF {len(test_cases)} TESTS")
for index, (statement, returned_rows, read_rows) in enumerate(test_cases):
print(
f"\033[38;2;255;184;108m{(index + 1):04}\033[0m"
f" {trunc_printable(format_sql(statement), width - 1)}",
end="",
flush=True,
)
try:
start = time.monotonic_ns()
test_predicate_pushdown_postgres_parameterized(statement, returned_rows, read_rows)
print(
f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅",
end="",
)
passed += 1
if failed > 0:
print(" \033[0;31m*\033[0m")
else:
print()
except Exception as err:
print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms ❌ *\033[0m")
print(">", err)
failed += 1

print("--- ✅ \033[0;32mdone\033[0m")

if failed > 0:
print("\n\033[38;2;139;233;253m\033[3mFAILURES\033[0m")

print(
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m ({((time.monotonic_ns() - start_suite) / 1e9):.2f} seconds)\n"
f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n"
f" \033[38;2;255;121;198m{failed} failed\033[0m"
)
14 changes: 14 additions & 0 deletions tests/plan_optimization/test_predicate_pushdown_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def test_predicate_pushdowns_sqlite_eq():
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 2, cur.stats

cur = conn.cursor()
cur.execute("SELECT * FROM sqlite.planets WHERE surfacePressure IS NULL;")
# We push unary ops to SQL
assert cur.rowcount == 4, cur.rowcount
assert cur.stats.get("rows_read", 0) == 4, cur.stats

cur = conn.cursor()
cur.execute(
"SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');"
)
# We push unary ops to SQL
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 1, cur.stats

conn.close()


Expand Down
3 changes: 2 additions & 1 deletion tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def test_sql_battery(statement, rows, columns, exception):

print(f"RUNNING BATTERY OF {len(STATEMENTS)} SHAPE TESTS")
for index, (statement, rows, cols, err) in enumerate(STATEMENTS):
start = time.monotonic_ns()

printable = statement
if hasattr(printable, "decode"):
printable = printable.decode()
Expand All @@ -1567,6 +1567,7 @@ def test_sql_battery(statement, rows, columns, exception):
flush=True,
)
try:
start = time.monotonic_ns()
test_sql_battery(statement, rows, cols, err)
print(
f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅",
Expand Down
Loading