Skip to content

Commit

Permalink
feat(polars): implement .sql methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 5, 2023
1 parent 35fc5f7 commit 86f2a34
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 252 deletions.
58 changes: 40 additions & 18 deletions ibis/backends/polars/__init__.py
Expand Up @@ -14,7 +14,7 @@
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, Database
from ibis.backends.polars.compiler import translate
from ibis.backends.polars.datatypes import schema_from_polars
from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars
from ibis.util import gen_name, normalize_filename

if TYPE_CHECKING:
Expand All @@ -28,6 +28,7 @@ class Backend(BaseBackend):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._tables = dict()
self._context = pl.SQLContext()

def do_connect(
self, tables: MutableMapping[str, pl.LazyFrame] | None = None
Expand All @@ -39,9 +40,8 @@ def do_connect(
tables
An optional mapping of string table names to polars LazyFrames.
"""
if not tables:
tables = {}
self._tables.update(tables)
for name, table in (tables or {}).items():
self._add_table(name, table)

@property
def version(self) -> str:
Expand All @@ -57,8 +57,7 @@ def list_tables(self, like=None, database=None):
return self._filter_with_like(list(self._tables.keys()), like)

def table(self, name: str, _schema: sch.Schema = None) -> ir.Table:
table = self._tables[name]
schema = schema_from_polars(table.schema)
schema = schema_from_polars(self._tables[name].schema)
return ops.DatabaseTable(name, schema, self).to_expr()

def register(
Expand Down Expand Up @@ -126,6 +125,21 @@ def _register_failure(self):
f"please call one of {msg} directly"
)

def _add_table(self, name: str, obj: pl.LazyFrame | pl.DataFrame) -> None:
self._tables[name] = obj
self._context.register(name, obj)

def _remove_table(self, name: str) -> None:
del self._tables[name]
self._context.unregister(name)

def sql(self, query: str, schema: sch.Schema | None = None) -> ir.Table:
if schema is None:
schema = self._get_schema_using_query(query)
name = ibis.util.gen_name("polars_dot_sql_table")
self._add_table(name, self._context.execute(query))
return self.table(name)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
Expand All @@ -151,10 +165,10 @@ def read_csv(
path = normalize_filename(path)
table_name = table_name or gen_name("read_csv")
try:
self._tables[table_name] = pl.scan_csv(path, **kwargs)
self._add_table(table_name, pl.scan_csv(path, **kwargs))
except pl.exceptions.ComputeError:
# handles compressed csvs
self._tables[table_name] = pl.read_csv(path, **kwargs).lazy()
self._add_table(table_name, pl.read_csv(path, **kwargs).lazy())
return self.table(table_name)

def read_pandas(
Expand All @@ -180,7 +194,7 @@ def read_pandas(
The just-registered table
"""
table_name = table_name or gen_name("read_in_memory")
self._tables[table_name] = pl.from_pandas(source, **kwargs).lazy()
self._add_table(table_name, pl.from_pandas(source, **kwargs).lazy())
return self.table(table_name)

def read_parquet(
Expand All @@ -207,7 +221,7 @@ def read_parquet(
"""
path = normalize_filename(path)
table_name = table_name or gen_name("read_parquet")
self._tables[table_name] = pl.scan_parquet(path, **kwargs)
self._add_table(table_name, pl.scan_parquet(path, **kwargs))
return self.table(table_name)

def database(self, name=None):
Expand All @@ -224,8 +238,9 @@ def create_table(
overwrite: bool = False,
) -> ir.Table:
if schema is not None and obj is None:
raise NotImplementedError(
"Empty table creation is not yet supported in the Polars backend"
obj = pl.LazyFrame(
[],
schema={name: dtype_to_polars(dtype) for name, dtype in schema.items()},
)

if database is not None:
Expand All @@ -251,7 +266,8 @@ def create_table(
if not isinstance(obj, (pl.DataFrame, pl.LazyFrame)):
obj = pl.LazyFrame(obj)

self._tables[name] = obj
self._add_table(name, obj)
return self.table(name)

def get_schema(self, table_name, database=None):
return self._tables[table_name].schema
Expand All @@ -275,6 +291,7 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:

def compile(self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **_: Any):
node = expr.op()
ctx = self._context
if params:
replacements = {}
for p, v in params.items():
Expand All @@ -284,22 +301,27 @@ def compile(self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **_: A
expr = node.to_expr()

if isinstance(expr, ir.Table):
return translate(node)
return translate(node, ctx=ctx)
elif isinstance(expr, ir.Column):
# expression must be named for the projection
node = expr.as_table().op()
return translate(node)
return translate(node, ctx=ctx)
elif isinstance(expr, ir.Scalar):
if an.is_scalar_reduction(node):
node = an.reduction_to_aggregation(node).op()
return translate(node)
return translate(node, ctx=ctx)
else:
# doesn't have any _tables associated so create projection
# based off of an empty table
return pl.DataFrame().lazy().select(translate(node))
return pl.DataFrame().lazy().select(translate(node, ctx=ctx))
else:
raise com.IbisError(f"Cannot compile expression of type: {type(expr)}")

def _get_schema_using_query(self, query: str) -> sch.Schema:
return schema_from_polars(
self._context.execute(f"SELECT * FROM ({query}) LIMIT 0").schema
)

def execute(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -382,7 +404,7 @@ def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).cache())

def _clean_up_cached_table(self, op):
del self._tables[op.name]
self._remove_table(op.name)

def create_view(self, *_, **__) -> ir.Table:
raise NotImplementedError(self.name)
Expand Down

0 comments on commit 86f2a34

Please sign in to comment.