Skip to content

Commit

Permalink
fix(trino): support trino 0.323 special tuple type for struct results
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed May 9, 2023
1 parent 89aecf2 commit ea1529d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
14 changes: 13 additions & 1 deletion ibis/backends/trino/__init__.py
Expand Up @@ -7,12 +7,13 @@

import sqlalchemy as sa
import toolz
from trino.sqlalchemy.datatype import ROW as _ROW

import ibis.expr.datatypes as dt
from ibis import util
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.trino.compiler import TrinoSQLCompiler
from ibis.backends.trino.datatypes import parse
from ibis.backends.trino.datatypes import ROW, parse


class Backend(BaseAlchemyBackend):
Expand Down Expand Up @@ -76,6 +77,17 @@ def do_connect(
)
)

@staticmethod
def _new_sa_metadata():
meta = sa.MetaData()

@sa.event.listens_for(meta, "column_reflect")
def column_reflect(inspector, table, column_info):
if isinstance(typ := column_info["type"], _ROW):
column_info["type"] = ROW(typ.attr_types)

return meta

def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
tmpname = f"_ibis_trino_output_{util.guid()[:6]}"
with self.begin() as con:
Expand Down
27 changes: 23 additions & 4 deletions ibis/backends/trino/datatypes.py
@@ -1,11 +1,14 @@
from __future__ import annotations

from functools import partial
from typing import Any

import parsy
import sqlalchemy as sa
import trino.client
from sqlalchemy.ext.compiler import compiles
from trino.sqlalchemy.datatype import DOUBLE, JSON, MAP, ROW, TIMESTAMP
from trino.sqlalchemy.datatype import DOUBLE, JSON, MAP, TIMESTAMP
from trino.sqlalchemy.datatype import ROW as _ROW
from trino.sqlalchemy.dialect import TrinoDialect

import ibis.expr.datatypes as dt
Expand All @@ -23,6 +26,24 @@
)


class ROW(_ROW):
_result_is_tuple = hasattr(trino.client, "NamedRowTuple")

def result_processor(self, dialect, coltype: str) -> None:
if not coltype.lower().startswith("row"):
return None

def process(
value, result_is_tuple: bool = self._result_is_tuple
) -> dict[str, Any] | None:
if value is None or not result_is_tuple:
return value
else:
return dict(zip(value._names, value))

return process


def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
"""Parse a Trino type into an ibis data type."""

Expand Down Expand Up @@ -127,9 +148,7 @@ def _string(_, itype):

@to_sqla_type.register(TrinoDialect, dt.Struct)
def _struct(dialect, itype):
return ROW(
[(name, to_sqla_type(dialect, typ)) for name, typ in itype.fields.items()]
)
return ROW((name, to_sqla_type(dialect, typ)) for name, typ in itype.fields.items())


@to_sqla_type.register(TrinoDialect, dt.Timestamp)
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion requirements.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ea1529d

Please sign in to comment.