Skip to content

Commit

Permalink
feat(sql): allow any SQL dialect accepted by sqlgllot in Table.sql
Browse files Browse the repository at this point in the history
…and `Backend.sql`
  • Loading branch information
cpcloud authored and kszucs committed Jun 6, 2023
1 parent 134743a commit f38c447
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 26 deletions.
4 changes: 2 additions & 2 deletions ibis/__init__.py
Expand Up @@ -109,8 +109,8 @@ def connect(*args, **kwargs):
proxy.name = name
proxy._from_url = backend._from_url
proxy._to_sql = backend._to_sql
if hasattr(backend, "_sqlglot_dialect"):
proxy._sqlglot_dialect = backend._sqlglot_dialect
if (dialect := getattr(backend, "_sqlglot_dialect", None)) is not None:
proxy._sqlglot_dialect = dialect
# Add any additional methods that should be exposed at the top level
for name in getattr(backend, "_top_level_methods", ()):
setattr(proxy, name, getattr(backend, name))
Expand Down
23 changes: 23 additions & 0 deletions ibis/backends/base/__init__.py
Expand Up @@ -35,6 +35,9 @@
__all__ = ('BaseBackend', 'Database', 'connect')


_IBIS_TO_SQLGLOT_DIALECT = {"mssql": "tsql", "impala": "hive", "pyspark": "spark"}


class Database:
"""Generic Database class."""

Expand Down Expand Up @@ -976,6 +979,26 @@ def _load_into_cache(self, name, expr):
def _clean_up_cached_table(self, op):
raise NotImplementedError(self.name)

def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str:
# only transpile if dialect was passed
if dialect is None:
return query

import sqlglot as sg

# only transpile if the backend dialect doesn't match the input dialect
name = self.name
if (output_dialect := getattr(self, "_sqlglot_dialect", name)) is None:
raise NotImplementedError(f"No known sqlglot dialect for backend {name}")

if dialect != output_dialect:
(query,) = sg.transpile(
query,
read=_IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect),
write=output_dialect,
)
return query


@functools.lru_cache(maxsize=None)
def _get_backend_names() -> frozenset[str]:
Expand Down
12 changes: 11 additions & 1 deletion ibis/backends/base/sql/__init__.py
Expand Up @@ -31,6 +31,10 @@ class BaseSQLBackend(BaseBackend):

compiler = Compiler

@property
def _sqlglot_dialect(self) -> str:
return self.name

def _from_url(self, url: str, **kwargs: Any) -> BaseBackend:
"""Connect to a backend using a URL `url`.
Expand Down Expand Up @@ -94,7 +98,9 @@ def _fully_qualified_name(self, name, database):
# XXX
return name

def sql(self, query: str, schema: sch.Schema | None = None) -> ir.Table:
def sql(
self, query: str, schema: sch.Schema | None = None, dialect: str | None = None
) -> ir.Table:
"""Convert a SQL query to an Ibis table expression.
Parameters
Expand All @@ -104,12 +110,16 @@ def sql(self, query: str, schema: sch.Schema | None = None) -> ir.Table:
schema
The expected schema for this query. If not provided, will be
inferred automatically if possible.
dialect
Optional string indicating the dialect of `query`. The default
value of `None` will use the backend's native dialect.
Returns
-------
Table
Table expression
"""
query = self._transpile_sql(query, dialect=dialect)
if schema is None:
schema = self._get_schema_using_query(query)
else:
Expand Down
18 changes: 13 additions & 5 deletions ibis/backends/clickhouse/__init__.py
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
import pandas as pd

from ibis.common.typing import SupportsSchema


def _to_memtable(v):
return ibis.memtable(v).op() if not isinstance(v, ops.InMemoryTable) else v
Expand Down Expand Up @@ -56,7 +58,7 @@ def insert(self, obj, settings: Mapping[str, Any] | None = None, **kwargs):


class Backend(BaseBackend):
name = 'clickhouse'
name = "clickhouse"

# ClickHouse itself does, but the client driver does not
supports_temporary_tables = False
Expand All @@ -73,14 +75,20 @@ class Options(ibis.config.Config):
bool_type: Literal["Bool", "UInt8", "Int8"] = "Bool"

def _log(self, sql: str) -> None:
"""Log the SQL, usually to the standard output.
"""Log `sql`.
This method can be implemented by subclasses. The logging
happens when `ibis.options.verbose` is `True`.
This method can be implemented by subclasses. Logging occurs when
`ibis.options.verbose` is `True`.
"""
util.log(sql)

def sql(self, query: str, schema=None) -> ir.Table:
def sql(
self,
query: str,
schema: SupportsSchema | None = None,
dialect: str | None = None,
) -> ir.Table:
query = self._transpile_sql(query, dialect=dialect)
if schema is None:
schema = self._get_schema_using_query(query)
return ops.SQLQueryResult(query, ibis.schema(schema), self).to_expr()
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/impala/__init__.py
Expand Up @@ -182,11 +182,11 @@ def _column_batches_to_dataframe(names, batches):

class Backend(BaseSQLBackend):
name = 'impala'
# not 100% accurate, but very close
_sqlglot_dialect = "hive"
_top_level_methods = ("hdfs_connect",)
compiler = ImpalaCompiler

_sqlglot_dialect = "hive" # not 100% accurate, but very close
_top_level_methods = ("hdfs_connect",)

class Options(ibis.config.Config):
"""Impala specific options.
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/mssql/__init__.py
Expand Up @@ -15,9 +15,10 @@
class Backend(BaseAlchemyBackend):
name = "mssql"
compiler = MsSqlCompiler
_sqlglot_dialect = "tsql"
supports_create_or_replace = False

_sqlglot_dialect = "tsql"

def do_connect(
self,
host: str = "localhost",
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/pyspark/__init__.py
Expand Up @@ -101,7 +101,8 @@ class PySparkCompiler(Compiler):

class Backend(BaseSQLBackend):
compiler = PySparkCompiler
name = 'pyspark'
name = "pyspark"
_sqlglot_dialect = "spark"

class Options(ibis.config.Config):
"""PySpark options.
Expand Down
63 changes: 63 additions & 0 deletions ibis/backends/tests/test_dot_sql.py
@@ -1,9 +1,11 @@
import pandas as pd
import pytest
import sqlglot as sg
from pytest import param

import ibis
from ibis import _
from ibis.backends.base import _IBIS_TO_SQLGLOT_DIALECT, _get_backend_names

try:
from polars.exceptions import ComputeError as PolarsComputeError
Expand All @@ -28,6 +30,11 @@
"bigquery": "ibis_gbq_testing.functional_alltypes",
}

try:
from clickhouse_connect.driver.exceptions import DatabaseError
except ImportError:
DatabaseError = None


@dot_sql_notimpl
@dot_sql_notyet
Expand Down Expand Up @@ -229,3 +236,59 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df):
expected2 = df.bigint_col.rename("x")
backend.assert_series_equal(foo1.x.execute(), expected1)
backend.assert_series_equal(foo2.x.execute(), expected2)


_NO_SQLGLOT_DIALECT = {"pandas", "dask", "datafusion", "polars", "druid"}
no_sqlglot_dialect = sorted(
param(backend, marks=pytest.mark.xfail) for backend in _NO_SQLGLOT_DIALECT
)


@pytest.mark.parametrize(
"dialect",
[
*sorted(_get_backend_names() - _NO_SQLGLOT_DIALECT),
*no_sqlglot_dialect,
],
)
@pytest.mark.broken(["clickhouse"], raises=DatabaseError)
@pytest.mark.notyet(["trino"], raises=NotImplementedError)
@table_dot_sql_notimpl
@dot_sql_notimpl
@dot_sql_notyet
@dot_sql_never
def test_table_dot_sql_transpile(backend, alltypes, dialect, df):
name = "foo2"
foo = alltypes.select(x=_.int_col + 1).alias(name)
expr = sg.select("x").from_(sg.table(name, quoted=True))
dialect = _IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect)
sqlstr = expr.sql(dialect=dialect, pretty=True)
dot_sql_expr = foo.sql(sqlstr, dialect=dialect)
result = dot_sql_expr.execute()
expected = df.int_col.add(1).rename("x")
backend.assert_series_equal(result.x, expected)


@pytest.mark.parametrize(
"dialect",
[
*sorted(_get_backend_names() - {"pyspark", *_NO_SQLGLOT_DIALECT}),
*no_sqlglot_dialect,
],
)
@pytest.mark.notyet(["druid"], raises=ValueError)
@pytest.mark.notyet(["snowflake", "bigquery"])
@pytest.mark.notyet(
["oracle"], strict=False, reason="only works with backends that quote everything"
)
@dot_sql_notimpl
@dot_sql_never
def test_con_dot_sql_transpile(backend, con, dialect, df):
t = sg.table("functional_alltypes")
foo = sg.select(sg.alias(sg.column("int_col") + 1, "x")).from_(t)
dialect = _IBIS_TO_SQLGLOT_DIALECT.get(dialect, dialect)
sqlstr = foo.sql(dialect=dialect, pretty=True)
expr = con.sql(sqlstr, dialect=dialect)
result = expr.execute()
expected = df.int_col.add(1).rename("x")
backend.assert_series_equal(result.x, expected)
23 changes: 10 additions & 13 deletions ibis/expr/types/relations.py
Expand Up @@ -2658,23 +2658,19 @@ def alias(self, alias: str) -> ir.Table:
expr.compile()
return expr

def sql(self, query: str) -> ir.Table:
def sql(self, query: str, dialect: str | None = None) -> ir.Table:
"""Run a SQL query against a table expression.
!!! note "The SQL string is backend specific"
`query` must be valid SQL for the execution backend the expression
will run against.
This restriction may be lifted in a future version of ibis.
See [`Table.alias`][ibis.expr.types.relations.Table.alias] for
details on using named table expressions in a SQL string.
Parameters
----------
query
Query string
dialect
Optional string indicating the dialect of `query`. The default
value of `None` will use the backend's native dialect.
Returns
-------
Expand All @@ -2698,11 +2694,12 @@ def sql(self, query: str) -> ir.Table:
│ Torgersen │ 38.950980 │
└───────────┴──────────────────────┘
"""
op = ops.SQLStringView(
child=self,
name=next(_ALIASES),
query=query,
)

# only transpile if dialect was passed
if dialect is not None:
backend = self._find_backend()
query = backend._transpile_sql(query, dialect=dialect)
op = ops.SQLStringView(child=self, name=next(_ALIASES), query=query)
return op.to_expr()

def to_pandas(self, **kwargs) -> pd.DataFrame:
Expand Down

0 comments on commit f38c447

Please sign in to comment.