Skip to content

Commit

Permalink
refactor(sqlalchemy): generalize handling of failed type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 17, 2023
1 parent 23c35e1 commit b0f4e4c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 75 deletions.
38 changes: 32 additions & 6 deletions ibis/backends/base/sql/alchemy/__init__.py
Expand Up @@ -3,6 +3,7 @@
import abc
import contextlib
import getpass
import warnings
from operator import methodcaller
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping

Expand Down Expand Up @@ -372,13 +373,38 @@ def _log(self, sql):
util.log(query_str)

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
autoload: bool = True,
**kwargs: Any,
self, name: str, schema: str | None = None, autoload: bool = True, **kwargs: Any
) -> sa.Table:
return sa.Table(name, self.meta, schema=schema, autoload=autoload)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Did not recognize type", category=sa.exc.SAWarning
)
table = sa.Table(name, self.meta, schema=schema, autoload=autoload)
nulltype_cols = frozenset(
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
)

if not nulltype_cols:
return table
return self._handle_failed_column_type_inference(table, nulltype_cols)

def _handle_failed_column_type_inference(
self, table: sa.Table, nulltype_cols: Iterable[str]
) -> sa.Table:
"""Handle cases where SQLAlchemy cannot infer the column types of `table`."""

self.inspector.reflect_table(table, table.columns)
quoted_name = self.con.dialect.identifier_preparer.quote(table.name)

for colname, type in self._metadata(quoted_name):
if colname in nulltype_cols:
# replace null types discovered by sqlalchemy with non null
# types
table.append_column(
sa.Column(colname, to_sqla_type(type), nullable=type.nullable),
replace_existing=True,
)
return table

def _sqla_table_to_expr(self, table: sa.Table) -> ir.Table:
schema = self._schemas.get(table.name)
Expand Down
53 changes: 2 additions & 51 deletions ibis/backends/duckdb/__init__.py
Expand Up @@ -14,7 +14,6 @@

import ibis.expr.datatypes as dt
from ibis import util
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type

if TYPE_CHECKING:
import duckdb
Expand Down Expand Up @@ -371,22 +370,7 @@ def read_postgres(self, uri, table_name=None):
return self._read(table_name)

def _read(self, table_name):

_table = self.table(table_name)
with warnings.catch_warnings():
# don't fail or warn if duckdb-engine fails to discover types
# mostly (tinyint)
warnings.filterwarnings(
"ignore",
message="Did not recognize type",
category=sa.exc.SAWarning,
)
# We don't rely on index reflection, ignore this warning
warnings.filterwarnings(
"ignore",
message="duckdb-engine doesn't yet support reflection on indices",
)
self.inspector.reflect_table(_table.op().sqla_table, _table.columns)
return self.table(table_name)

def to_pyarrow_batches(
Expand Down Expand Up @@ -476,48 +460,15 @@ def _register_in_memory_table(self, table_op):
self.con.execute("register", (table_op.name, df))

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
**kwargs: Any,
self, name: str, schema: str | None = None, **kwargs: Any
) -> sa.Table:
with warnings.catch_warnings():
# don't fail or warn if duckdb-engine fails to discover types
warnings.filterwarnings(
"ignore",
message="Did not recognize type",
category=sa.exc.SAWarning,
)
# We don't rely on index reflection, ignore this warning
warnings.filterwarnings(
"ignore",
message="duckdb-engine doesn't yet support reflection on indices",
)

table = super()._get_sqla_table(name, schema, **kwargs)

nulltype_cols = frozenset(
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
)

if not nulltype_cols:
return table

quoted_name = self.con.dialect.identifier_preparer.quote(name)

for colname, type in self._metadata(quoted_name):
if colname in nulltype_cols:
# replace null types discovered by sqlalchemy with non null
# types
table.append_column(
sa.Column(
colname,
to_sqla_type(type),
nullable=type.nullable,
),
replace_existing=True,
)
return table
return super()._get_sqla_table(name, schema, **kwargs)

def _get_temp_view_definition(
self,
Expand Down
34 changes: 16 additions & 18 deletions ibis/backends/mssql/__init__.py
Expand Up @@ -2,18 +2,17 @@

from __future__ import annotations

import atexit
import contextlib
from typing import TYPE_CHECKING, Iterable, Literal
from typing import TYPE_CHECKING, Literal

import sqlalchemy as sa

from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.mssql.compiler import MsSqlCompiler
from ibis.backends.mssql.datatypes import _type_from_result_set_info
from ibis.backends.mssql.datatypes import _FieldDescription, _type_from_result_set_info

if TYPE_CHECKING:
import ibis.expr.datatypes as dt
pass


class Backend(BaseAlchemyBackend):
Expand Down Expand Up @@ -54,25 +53,24 @@ def begin(self):
finally:
bind.execute(f"SET DATEFIRST {previous_datefirst}")

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
def _metadata(self, query):
if query in self.list_tables():
query = f"SELECT * FROM [{query}]"

with self.begin() as bind:
for column in bind.execute(
f"EXEC sp_describe_first_result_set @tsql = N'{query}';"
).mappings():
yield column["name"], _type_from_result_set_info(column)
result_set_info: list[_FieldDescription] = (
bind.execute(f"EXEC sp_describe_first_result_set @tsql = N'{query}';")
.mappings()
.fetchall()
)
return [
(column['name'], _type_from_result_set_info(column))
for column in result_set_info
]

def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
) -> str:
return f"CREATE OR ALTER VIEW {name} AS {definition}"

def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
query = f"DROP VIEW IF EXISTS {name}"

def drop(self, raw_name: str, query: str):
self.con.execute(query)
self._temp_views.discard(raw_name)

atexit.register(drop, self, raw_name, query)

0 comments on commit b0f4e4c

Please sign in to comment.