Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions duckdb/polars_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations # noqa: D100

import contextlib
import datetime
import json
import typing
Expand Down Expand Up @@ -176,9 +177,12 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
decimal_value = value["Decimal"]
assert isinstance(decimal_value, list), (
f"A {dtype} should be a two member list but got {type(decimal_value)}"
f"A {dtype} should be a two or three member list but got {type(decimal_value)}"
)
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]))
assert 2 <= len(decimal_value) <= 3, (
f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list"
)
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1]))

# Datetime with microseconds since epoch
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
Expand Down Expand Up @@ -260,7 +264,8 @@ def source_generator(
relation_final = relation_final.limit(n_rows)
if predicate is not None:
# We have a predicate, if possible, we push it down to DuckDB
duck_predicate = _predicate_to_expression(predicate)
with contextlib.suppress(AssertionError, KeyError):
duck_predicate = _predicate_to_expression(predicate)
# Try to pushdown filter, if one exists
if duck_predicate is not None:
relation_final = relation_final.filter(duck_predicate)
Expand Down
52 changes: 50 additions & 2 deletions tests/fast/arrow/test_polars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import json

import pytest

Expand All @@ -8,7 +9,7 @@
arrow = pytest.importorskip("pyarrow")
pl_testing = pytest.importorskip("polars.testing")

from duckdb.polars_io import _predicate_to_expression # noqa: E402
from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402


def valid_filter(filter):
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor):
"UBIGINT",
"FLOAT",
"DOUBLE",
# "HUGEINT",
"HUGEINT",
"DECIMAL(4,1)",
"DECIMAL(9,1)",
"DECIMAL(18,4)",
Expand Down Expand Up @@ -605,3 +606,50 @@ def test_polars_lazy_many_batches(self, duckdb_cursor):
correct = duckdb_cursor.execute("FROM t").fetchall()

assert res == correct

def test_invalid_expr_json(self):
bad_key_expr = """
{
"BinaryExpr": {
"left": { "Column": "foo" },
"middle": "Gt",
"right": { "Literal": { "Int": 5 } }
}
}
"""
with pytest.raises(KeyError, match="'op'"):
_pl_tree_to_sql(json.loads(bad_key_expr))

bad_type_expr = """
{
"BinaryExpr": {
"left": { "Column": [ "foo" ] },
"op": "Gt",
"right": { "Literal": { "Int": 5 } }
}
}
"""
with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"):
_pl_tree_to_sql(json.loads(bad_type_expr))

def test_decimal_scale(self):
scalar_decimal_no_scale = """
{ "Scalar": {
"Decimal": [
1,
0
]
} }
"""
assert _pl_tree_to_sql(json.loads(scalar_decimal_no_scale)) == "1"

scalar_decimal_scale = """
{ "Scalar": {
"Decimal": [
1,
38,
0
]
} }
"""
assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1"