Skip to content

Commit

Permalink
feat(polars): support pyarrow decimal types
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Apr 13, 2023
1 parent 7472dd5 commit 7e6c365
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
10 changes: 3 additions & 7 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,9 @@ def _to_pyarrow_table(
df = lf.collect()

table = df.to_arrow()
if isinstance(expr, ir.Table):
schema = expr.schema().to_pyarrow()
return table.cast(schema)
elif isinstance(expr, ir.Value):
schema = sch.schema({expr.get_name(): expr.type().to_pyarrow()})
schema = schema.to_pyarrow()
return table.cast(schema)
if isinstance(expr, (ir.Table, ir.Value)):
schema = expr.as_table().schema().to_pyarrow()
return table.rename_columns(schema.names).cast(schema)
else:
raise com.IbisError(f"Cannot execute expression of type: {type(expr)}")

Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def to_polars_type(dtype):
raise NotImplementedError(f"Unsupported type: {dtype!r}")


@to_polars_type.register(dt.Decimal)
def from_ibis_decimal(dtype):
return pl.Decimal(dtype.precision, dtype.scale)


@to_polars_type.register(dt.Timestamp)
def from_ibis_timestamp(dtype):
return pl.Datetime("ns", dtype.timezone)
Expand Down Expand Up @@ -103,6 +108,11 @@ def from_polars_struct(typ):
)


@to_ibis_dtype.register(pl.Decimal)
def from_polars_decimal(typ: pl.Decimal):
return dt.Decimal(precision=typ.precision, scale=typ.scale)


@sch.infer.register(pl.LazyFrame)
def from_polars_schema(df: pl.LazyFrame) -> sch.Schema:
fields = [(name, to_ibis_dtype(typ)) for name, typ in df.schema.items()]
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,9 @@ def test_numeric_literal(con, backend, expr, expected_types):
],
)
@pytest.mark.notimpl(
['polars', 'datafusion'], "Unsupported type", raises=NotImplementedError
['polars', 'datafusion'],
"Unsupported type",
raises=(NotImplementedError, ValueError),
)
def test_decimal_literal(con, backend, expr, expected_types, expected_result):
backend_name = backend.name()
Expand Down

0 comments on commit 7e6c365

Please sign in to comment.