201 changes: 103 additions & 98 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

_from_bigquery_types = {
"INT64": dt.Int64,
Expand All @@ -20,109 +21,113 @@
}


def dtype_from_bigquery(typ: str, nullable=True) -> dt.DataType:
if typ == "DATETIME":
return dt.Timestamp(timezone=None, nullable=nullable)
elif typ == "TIMESTAMP":
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif typ == "NUMERIC":
return dt.Decimal(38, 9, nullable=nullable)
elif typ == "BIGNUMERIC":
return dt.Decimal(76, 38, nullable=nullable)
elif typ == "GEOGRAPHY":
return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable)
else:
try:
return _from_bigquery_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"Unable to convert BigQuery type to ibis: {typ}")


def dtype_from_bigquery_field(field: bq.SchemaField) -> dt.DataType:
typ = field.field_type
if typ == "RECORD":
assert field.fields, "RECORD fields are empty"
fields = {f.name: dtype_from_bigquery_field(f) for f in field.fields}
dtype = dt.Struct(fields)
else:
dtype = dtype_from_bigquery(typ)

mode = field.mode
if mode == "NULLABLE":
return dtype.copy(nullable=True)
elif mode == "REQUIRED":
return dtype.copy(nullable=False)
elif mode == "REPEATED":
return dt.Array(dtype)
else:
raise TypeError(f"Unknown BigQuery field.mode: {mode}")


def dtype_to_bigquery(dtype: dt.DataType) -> str:
if dtype.is_floating():
return "FLOAT64"
elif dtype.is_uint64():
raise TypeError(
"Conversion from uint64 to BigQuery integer type (int64) is lossy"
)
elif dtype.is_integer():
return "INT64"
elif dtype.is_binary():
return "BYTES"
elif dtype.is_date():
return "DATE"
elif dtype.is_timestamp():
if dtype.timezone is None:
return "DATETIME"
elif dtype.timezone == 'UTC':
return "TIMESTAMP"
class BigQueryType(TypeMapper):
@classmethod
def to_ibis(cls, typ: str, nullable: bool = True) -> dt.DataType:
if typ == "DATETIME":
return dt.Timestamp(timezone=None, nullable=nullable)
elif typ == "TIMESTAMP":
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif typ == "NUMERIC":
return dt.Decimal(38, 9, nullable=nullable)
elif typ == "BIGNUMERIC":
return dt.Decimal(76, 38, nullable=nullable)
elif typ == "GEOGRAPHY":
return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable)
else:
try:
return _from_bigquery_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"Unable to convert BigQuery type to ibis: {typ}")

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> str:
if dtype.is_floating():
return "FLOAT64"
elif dtype.is_uint64():
raise TypeError(
"BigQuery does not support timestamps with timezones other than 'UTC'"
"Conversion from uint64 to BigQuery integer type (int64) is lossy"
)
elif dtype.is_integer():
return "INT64"
elif dtype.is_binary():
return "BYTES"
elif dtype.is_date():
return "DATE"
elif dtype.is_timestamp():
if dtype.timezone is None:
return "DATETIME"
elif dtype.timezone == 'UTC':
return "TIMESTAMP"
else:
raise TypeError(
"BigQuery does not support timestamps with timezones other than 'UTC'"
)
elif dtype.is_decimal():
if (dtype.precision, dtype.scale) == (76, 38):
return 'BIGNUMERIC'
if (dtype.precision, dtype.scale) in [(38, 9), (None, None)]:
return "NUMERIC"
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {dtype.precision}. Current scale: {dtype.scale}"
)
elif dtype.is_array():
return f"ARRAY<{cls.from_ibis(dtype.value_type)}>"
elif dtype.is_struct():
fields = (f"{k} {cls.from_ibis(v)}" for k, v in dtype.fields.items())
return "STRUCT<{}>".format(", ".join(fields))
elif dtype.is_json():
return "JSON"
elif dtype.is_geospatial():
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return "GEOGRAPHY"
raise TypeError(
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
elif dtype.is_decimal():
if (dtype.precision, dtype.scale) == (76, 38):
return 'BIGNUMERIC'
if (dtype.precision, dtype.scale) in [(38, 9), (None, None)]:
return "NUMERIC"
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {dtype.precision}. Current scale: {dtype.scale}"
)
elif dtype.is_array():
return f"ARRAY<{dtype_to_bigquery(dtype.value_type)}>"
elif dtype.is_struct():
fields = (f"{k} {dtype_to_bigquery(v)}" for k, v in dtype.fields.items())
return "STRUCT<{}>".format(", ".join(fields))
elif dtype.is_json():
return "JSON"
elif dtype.is_geospatial():
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return "GEOGRAPHY"
raise TypeError(
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
else:
return str(dtype).upper()


def schema_to_bigquery(schema: sch.Schema) -> list[bq.SchemaField]:
result = []
for name, dtype in schema.items():
if isinstance(dtype, dt.Array):
mode = "REPEATED"
dtype = dtype.value_type
else:
mode = "REQUIRED" if not dtype.nullable else "NULLABLE"
field = bq.SchemaField(name, dtype_to_bigquery(dtype), mode=mode)
result.append(field)
return result

return str(dtype).upper()


class BigQuerySchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
result = []
for name, dtype in schema.items():
if isinstance(dtype, dt.Array):
mode = "REPEATED"
dtype = dtype.value_type
else:
mode = "REQUIRED" if not dtype.nullable else "NULLABLE"
field = bq.SchemaField(name, BigQueryType.from_ibis(dtype), mode=mode)
result.append(field)
return result

@classmethod
def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
typ = field.field_type
if typ == "RECORD":
assert field.fields, "RECORD fields are empty"
fields = {f.name: cls._dtype_from_bigquery_field(f) for f in field.fields}
dtype = dt.Struct(fields)
else:
dtype = BigQueryType.to_ibis(typ)

mode = field.mode
if mode == "NULLABLE":
return dtype.copy(nullable=True)
elif mode == "REQUIRED":
return dtype.copy(nullable=False)
elif mode == "REPEATED":
return dt.Array(dtype)
else:
raise TypeError(f"Unknown BigQuery field.mode: {mode}")

def schema_from_bigquery(fields: list[bq.SchemaField]) -> sch.Schema:
return sch.Schema({f.name: dtype_from_bigquery_field(f) for f in fields})
@classmethod
def to_ibis(cls, fields: list[bq.SchemaField]) -> sch.Schema:
return sch.Schema({f.name: cls._dtype_from_bigquery_field(f) for f in fields})


# TODO(kszucs): we can eliminate this function by making dt.DataType traversible
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
reduction,
unary,
)
from ibis.backends.bigquery.datatypes import dtype_to_bigquery
from ibis.backends.bigquery.datatypes import BigQueryType
from ibis.common.temporal import DateUnit, IntervalUnit, TimeUnit

if TYPE_CHECKING:
Expand Down Expand Up @@ -67,7 +67,7 @@ def bigquery_cast_floating_to_integer(compiled_arg, from_, to):
@bigquery_cast.register(str, dt.DataType, dt.DataType)
def bigquery_cast_generate(compiled_arg, from_, to):
"""Cast to desired type."""
sql_type = dtype_to_bigquery(to)
sql_type = BigQueryType.from_ibis(to)
return f"CAST({compiled_arg} AS {sql_type})"


Expand Down Expand Up @@ -253,7 +253,7 @@ def _literal(translator, op):
prefix = "-" * value.is_signed()
return f"CAST('{prefix}inf' AS FLOAT64)"
else:
return f"{dtype_to_bigquery(dtype)} '{value}'"
return f"{BigQueryType.from_ibis(dtype)} '{value}'"
elif dtype.is_uuid():
return translator.translate(ops.Literal(str(value), dtype=dt.str))

Expand Down Expand Up @@ -446,7 +446,7 @@ def compiles_string_to_timestamp(translator, op):


def compiles_floor(t, op):
bigquery_type = dtype_to_bigquery(op.output_dtype)
bigquery_type = BigQueryType.from_ibis(op.output_dtype)
arg = op.arg
return f"CAST(FLOOR({t.translate(arg)}) AS {bigquery_type})"

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import ibis.expr.datatypes as dt
import ibis.selectors as s
from ibis.backends.bigquery import EXTERNAL_DATA_SCOPES, Backend
from ibis.backends.bigquery.datatypes import dtype_to_bigquery
from ibis.backends.bigquery.datatypes import BigQueryType
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero, UnorderedComparator
from ibis.backends.tests.data import json_types, non_null_array_types, struct_types, win
Expand All @@ -41,13 +41,13 @@ def ibis_type_to_bq_field(typ: dt.DataType) -> Mapping[str, Any]:

@ibis_type_to_bq_field.register(dt.DataType)
def _(typ: dt.DataType) -> Mapping[str, Any]:
return {"field_type": dtype_to_bigquery(typ)}
return {"field_type": BigQueryType.from_ibis(typ)}


@ibis_type_to_bq_field.register(dt.Array)
def _(typ: dt.Array) -> Mapping[str, Any]:
return {
"field_type": dtype_to_bigquery(typ.value_type),
"field_type": BigQueryType.from_ibis(typ.value_type),
"mode": "REPEATED",
}

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/bigquery/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import ibis.expr.datatypes as dt
from ibis.backends.bigquery.datatypes import (
dtype_to_bigquery,
BigQueryType,
spread_type,
)

Expand Down Expand Up @@ -67,13 +67,13 @@
],
)
def test_simple(datatype, expected):
assert dtype_to_bigquery(datatype) == expected
assert BigQueryType.from_ibis(datatype) == expected


@pytest.mark.parametrize("datatype", [dt.uint64, dt.Decimal(8, 3)])
def test_simple_failure_mode(datatype):
with pytest.raises(TypeError):
dtype_to_bigquery(datatype)
BigQueryType.from_ibis(datatype)


@pytest.mark.parametrize(
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/bigquery/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.backends.bigquery.datatypes import dtype_to_bigquery, spread_type
from ibis.backends.bigquery.datatypes import BigQueryType, spread_type
from ibis.backends.bigquery.operations import BigQueryUDFNode
from ibis.backends.bigquery.udf.core import PythonToJavaScriptTranslator
from ibis.udf.validate import validate_output_type
Expand Down Expand Up @@ -286,11 +286,11 @@ def compiles_udf_node(t, op):
bigquery_signature = ", ".join(
"{name} {type}".format(
name=name,
type=dtype_to_bigquery(dt.dtype(type_)),
type=BigQueryType.from_ibis(dt.dtype(type_)),
)
for name, type_ in params.items()
)
return_type = dtype_to_bigquery(dt.dtype(output_type))
return_type = BigQueryType.from_ibis(dt.dtype(output_type))
libraries_opts = (
f"\nOPTIONS (\n library={list(libraries)!r}\n)" if libraries else ""
)
Expand Down Expand Up @@ -369,7 +369,7 @@ def sql(
for name, type_ in params.items()
}

return_type = dtype_to_bigquery(dt.dtype(output_type))
return_type = BigQueryType.from_ibis(dt.dtype(output_type))

udf_node_fields["output_dtype"] = output_type
udf_node_fields["output_shape"] = rlz.shape_like("args")
Expand All @@ -389,7 +389,7 @@ def compiles_udf_node(t, op):
name=name,
type="ANY TYPE"
if type_ == "ANY TYPE"
else dtype_to_bigquery(dt.dtype(type_)),
else BigQueryType.from_ibis(dt.dtype(type_)),
)
for name, type_ in params.items()
)
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ibis.backends.base import BaseBackend
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import parse, serialize
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -391,7 +392,7 @@ def execute(
if df.empty:
df = pd.DataFrame(columns=schema.names)

result = self._pandas_converter.convert_frame(df, schema)
result = PandasData.convert_table(df, schema)
if isinstance(expr, ir.Scalar):
return result.iat[0, 0]
elif isinstance(expr, ir.Column):
Expand Down Expand Up @@ -473,7 +474,7 @@ def fetch_from_cursor(self, cursor, schema):
import pandas as pd

df = pd.DataFrame.from_records(iter(cursor), columns=schema.names)
return self._pandas_converter.convert_frame(df, schema)
return PandasData.convert_table(df, schema)

def close(self) -> None:
"""Close ClickHouse connection."""
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ibis.backends.dask.core import execute_and_reset
from ibis.backends.pandas import BasePandasBackend
from ibis.backends.pandas.core import _apply_schema
from ibis.formats.pandas import schema_from_dask_dataframe
from ibis.formats.pandas import DaskData

# Make sure that the pandas backend options have been loaded
ibis.pandas # noqa: B018
Expand Down Expand Up @@ -113,7 +113,7 @@ def compile(
def table(self, name: str, schema: sch.Schema = None):
df = self.dictionary[name]
schema = schema or self.schemas.get(name, None)
schema = schema_from_dask_dataframe(df, schema=schema)
schema = DaskData.infer_table(df, schema=schema)
return ops.DatabaseTable(name, schema, self).to_expr()

@classmethod
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.formats.pyarrow import dtype_to_pyarrow
from ibis.formats.pyarrow import PyArrowType


@functools.singledispatch
Expand Down Expand Up @@ -43,7 +43,7 @@ def literal(op):
else:
value = op.value

arrow_type = dtype_to_pyarrow(op.dtype)
arrow_type = PyArrowType.from_ibis(op.dtype)
arrow_scalar = pa.scalar(value, type=arrow_type)

return df.literal(arrow_scalar)
Expand All @@ -52,7 +52,7 @@ def literal(op):
@translate.register(ops.Cast)
def cast(op):
arg = translate(op.arg)
typ = dtype_to_pyarrow(op.to)
typ = PyArrowType.from_ibis(op.to)
return arg.cast(to=typ)


Expand Down Expand Up @@ -458,8 +458,8 @@ def e(_):
def elementwise_udf(op):
udf = df.udf(
op.func,
input_types=list(map(dtype_to_pyarrow, op.input_type)),
return_type=dtype_to_pyarrow(op.return_type),
input_types=list(map(PyArrowType.from_ibis, op.input_type)),
return_type=PyArrowType.from_ibis(op.return_type),
volatility="volatile",
)
args = map(translate, op.func_args)
Expand Down Expand Up @@ -504,8 +504,8 @@ def regex_extract(op):
)
string_array_get = df.udf(
lambda arr, index=index: pc.list_element(arr, index),
input_types=[dtype_to_pyarrow(dt.Array(dt.string))],
return_type=dtype_to_pyarrow(dt.string),
input_types=[PyArrowType.from_ibis(dt.Array(dt.string))],
return_type=PyArrowType.from_ibis(dt.string),
volatility="immutable",
name="string_array_get",
)
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from ibis import util
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
from ibis.backends.duckdb.datatypes import dtype_to_duckdb, parse
from ibis.backends.duckdb.datatypes import DuckDBType, parse
from ibis.expr.operations.relations import PandasDataFrameProxy
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -114,7 +115,7 @@ def column_reflect(inspector, table, column_info):
_, typ, *_ = con.connection.fetchone()
complex_type_info_cache[colname] = coltype = parse(typ)

column_info["type"] = dtype_to_duckdb(coltype)
column_info["type"] = DuckDBType.from_ibis(coltype)

return meta

Expand Down Expand Up @@ -844,7 +845,7 @@ def fetch_from_cursor(
for name, col in zip(table.column_names, table.columns)
}
)
return self._pandas_converter.convert_frame(df, schema)
return PandasData.convert_table(df, schema)

def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
with self.begin() as con:
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.duckdb.datatypes import dtype_from_duckdb, dtype_to_duckdb
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.backends.duckdb.registry import operation_registry


Expand All @@ -14,9 +14,7 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
_rewrites = AlchemyExprTranslator._rewrites.copy()
_has_reduction_filter_syntax = True
_dialect_name = "duckdb"

get_sqla_type = staticmethod(dtype_to_duckdb)
get_ibis_type = staticmethod(dtype_from_duckdb)
type_mapper = DuckDBType


@compiles(sat.UInt8, "duckdb")
Expand Down
33 changes: 15 additions & 18 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import toolz

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
COMMA,
FIELD,
Expand Down Expand Up @@ -123,17 +120,17 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
}


def dtype_from_duckdb(typ, nullable=True):
if dtype := _from_duckdb_types.get(type(typ)):
return dtype(nullable=nullable)
else:
return dtype_from_sqlalchemy(
typ, nullable=nullable, converter=dtype_from_duckdb
)


def dtype_to_duckdb(dtype):
if typ := _to_duckdb_types.get(type(dtype)):
return typ
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_duckdb)
class DuckDBType(AlchemyType):
@classmethod
def to_ibis(cls, typ, nullable=True):
if dtype := _from_duckdb_types.get(type(typ)):
return dtype(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype):
if typ := _to_duckdb_types.get(type(dtype)):
return typ
else:
return super().from_ibis(dtype)
3 changes: 2 additions & 1 deletion ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
wrap_udf,
)
from ibis.config import options
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -347,7 +348,7 @@ def fetch_from_cursor(self, cursor, schema):
names = [name for name, *_ in cursor.description]
df = _column_batches_to_dataframe(names, batches)
if schema:
return self._pandas_converter.convert_frame(df, schema)
return PandasData.convert_table(df, schema)
return df

@contextlib.contextmanager
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mssql.datatypes import dtype_from_mssql, dtype_to_mssql
from ibis.backends.mssql.datatypes import MSSQLType
from ibis.backends.mssql.registry import _timestamp_from_unix, operation_registry


Expand All @@ -24,9 +24,7 @@ class MsSqlExprTranslator(AlchemyExprTranslator):
)
_require_order_by = AlchemyExprTranslator._require_order_by + (ops.Reduction,)
_dialect_name = "mssql"

get_sqla_type = staticmethod(dtype_to_mssql)
get_ibis_type = staticmethod(dtype_from_mssql)
type_mapper = MSSQLType


rewrites = MsSqlExprTranslator.rewrites
Expand Down
62 changes: 30 additions & 32 deletions ibis/backends/mssql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from sqlalchemy.dialects import mssql

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType


class _FieldDescription(TypedDict):
Expand Down Expand Up @@ -101,21 +98,6 @@ def _type_from_result_set_info(col: _FieldDescription) -> dt.DataType:
dt.String: mssql.NVARCHAR,
}


def dtype_to_mssql(dtype):
if typ := _to_mssql_types.get(type(dtype)):
return typ
elif dtype.is_timestamp():
if (precision := dtype.scale) is None:
precision = 7
if dtype.timezone is not None:
return mssql.DATETIMEOFFSET(precision=precision)
else:
return mssql.DATETIME2(precision=precision)
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_mssql)


_from_mssql_types = {
mssql.TINYINT: dt.Int8,
mssql.BIT: dt.Boolean,
Expand All @@ -133,16 +115,32 @@ def dtype_to_mssql(dtype):
}


def dtype_from_mssql(typ, nullable=True):
if dtype := _from_mssql_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, mssql.DATETIMEOFFSET):
if (prec := typ.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, timezone="UTC", nullable=nullable)
elif isinstance(typ, mssql.DATETIME2):
if (prec := typ.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, nullable=nullable)
else:
return dtype_from_sqlalchemy(typ, nullable=nullable, converter=dtype_from_mssql)
class MSSQLType(AlchemyType):
@classmethod
def to_ibis(cls, typ, nullable=True):
if dtype := _from_mssql_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, mssql.DATETIMEOFFSET):
if (prec := typ.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, timezone="UTC", nullable=nullable)
elif isinstance(typ, mssql.DATETIME2):
if (prec := typ.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype):
if typ := _to_mssql_types.get(type(dtype)):
return typ
elif dtype.is_timestamp():
if (precision := dtype.scale) is None:
precision = 7
if dtype.timezone is not None:
return mssql.DATETIMEOFFSET(precision=precision)
else:
return mssql.DATETIME2(precision=precision)
else:
return super().from_ibis(dtype)
6 changes: 2 additions & 4 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sqlalchemy as sa

from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mysql.datatypes import dtype_from_mysql, dtype_to_mysql
from ibis.backends.mysql.datatypes import MySQLType
from ibis.backends.mysql.registry import operation_registry


Expand All @@ -14,9 +14,7 @@ class MySQLExprTranslator(AlchemyExprTranslator):
_integer_to_timestamp = sa.func.from_unixtime
native_json_type = False
_dialect_name = "mysql"

get_sqla_type = staticmethod(dtype_to_mysql)
get_ibis_type = staticmethod(dtype_from_mysql)
type_mapper = MySQLType


rewrites = MySQLExprTranslator.rewrites
Expand Down
65 changes: 32 additions & 33 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from sqlalchemy.dialects import mysql

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType

# binary character set
# used to distinguish blob binary vs blob text
Expand Down Expand Up @@ -220,33 +217,35 @@ def result_processor(self, *_):
}


def dtype_to_mysql(dtype):
try:
return _to_mysql_types[type(dtype)]
except KeyError:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_mysql)


def dtype_from_mysql(typ, nullable=True):
if isinstance(typ, (sat.NUMERIC, mysql.NUMERIC, mysql.DECIMAL)):
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
return dt.Decimal(typ.precision or 10, typ.scale or 0, nullable=nullable)
elif isinstance(typ, mysql.BIT):
if 1 <= (length := typ.length) <= 8:
return dt.Int8(nullable=nullable)
elif 9 <= length <= 16:
return dt.Int16(nullable=nullable)
elif 17 <= length <= 32:
return dt.Int32(nullable=nullable)
elif 33 <= length <= 64:
return dt.Int64(nullable=nullable)
class MySQLType(AlchemyType):
@classmethod
def from_ibis(cls, dtype):
try:
return _to_mysql_types[type(dtype)]
except KeyError:
return super().from_ibis(dtype)

@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, (sat.NUMERIC, mysql.NUMERIC, mysql.DECIMAL)):
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
return dt.Decimal(typ.precision or 10, typ.scale or 0, nullable=nullable)
elif isinstance(typ, mysql.BIT):
if 1 <= (length := typ.length) <= 8:
return dt.Int8(nullable=nullable)
elif 9 <= length <= 16:
return dt.Int16(nullable=nullable)
elif 17 <= length <= 32:
return dt.Int32(nullable=nullable)
elif 33 <= length <= 64:
return dt.Int64(nullable=nullable)
else:
raise ValueError(f"Invalid MySQL BIT length: {length:d}")
elif isinstance(typ, mysql.TIMESTAMP):
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif isinstance(typ, mysql.SET):
return dt.Set(dt.string, nullable=nullable)
elif dtype := _from_mysql_types[type(typ)]:
return dtype(nullable=nullable)
else:
raise ValueError(f"Invalid MySQL BIT length: {length:d}")
elif isinstance(typ, mysql.TIMESTAMP):
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif isinstance(typ, mysql.SET):
return dt.Set(dt.string, nullable=nullable)
elif dtype := _from_mysql_types[type(typ)]:
return dtype(nullable=nullable)
else:
return dtype_from_sqlalchemy(dtype, converter=dtype_from_mysql)
return super().to_ibis(typ, nullable=nullable)
6 changes: 2 additions & 4 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
BaseAlchemyBackend,
)
from ibis.backends.oracle.datatypes import ( # noqa: E402
dtype_from_oracle,
dtype_to_oracle,
OracleType,
parse,
)
from ibis.backends.oracle.registry import operation_registry # noqa: E402
Expand Down Expand Up @@ -64,8 +63,7 @@ class OracleExprTranslator(AlchemyExprTranslator):
_quote_column_names = True
_quote_table_names = True

get_sqla_type = staticmethod(dtype_to_oracle)
get_ibis_type = staticmethod(dtype_from_oracle)
type_mapper = OracleType


class OracleCompiler(AlchemyCompiler):
Expand Down
55 changes: 27 additions & 28 deletions ibis/backends/oracle/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,40 @@
from sqlalchemy.dialects import oracle

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType

if TYPE_CHECKING:
from oracle.base_impl import DbType


def dtype_from_oracle(typ, nullable=True):
if isinstance(typ, oracle.ROWID):
return dt.String(nullable=nullable)
elif isinstance(typ, sat.Float):
return dt.Float64(nullable=nullable)
elif isinstance(typ, sat.Numeric):
if typ.scale == 0:
# kind of a lie, should be int128 because 38 digits
return dt.Int64(nullable=nullable)
class OracleType(AlchemyType):
@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, oracle.ROWID):
return dt.String(nullable=nullable)
elif isinstance(typ, sat.Float):
return dt.Float64(nullable=nullable)
elif isinstance(typ, sat.Numeric):
if typ.scale == 0:
# kind of a lie, should be int128 because 38 digits
return dt.Int64(nullable=nullable)
else:
return dt.Decimal(
precision=typ.precision or 38,
scale=typ.scale or 0,
nullable=nullable,
)
else:
return dt.Decimal(
precision=typ.precision or 38,
scale=typ.scale or 0,
nullable=nullable,
)
else:
return dtype_from_sqlalchemy(typ, converter=dtype_from_oracle)
return super().to_ibis(typ, nullable=nullable)


def dtype_to_oracle(dtype):
if isinstance(dtype, dt.Float64):
return sat.Float(precision=53).with_variant(oracle.FLOAT(14), 'oracle')
elif isinstance(dtype, dt.Float32):
return sat.Float(precision=23).with_variant(oracle.FLOAT(7), 'oracle')
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_oracle)
@classmethod
def from_ibis(cls, dtype):
if isinstance(dtype, dt.Float64):
return sat.Float(precision=53).with_variant(oracle.FLOAT(14), 'oracle')
elif isinstance(dtype, dt.Float32):
return sat.Float(precision=23).with_variant(oracle.FLOAT(7), 'oracle')
else:
return super().from_ibis(dtype)


_ORACLE_TYPES = {
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend
from ibis.formats.pandas import schema_from_pandas_dataframe, schema_to_pandas
from ibis.formats.pandas import PandasData, PandasSchema

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -95,17 +95,17 @@ def list_tables(self, like=None, database=None):
def table(self, name: str, schema: sch.Schema = None):
df = self.dictionary[name]
schema = schema or self.schemas.get(name, None)
schema = schema_from_pandas_dataframe(df, schema=schema)
schema = PandasData.infer_table(df, schema=schema)
return ops.DatabaseTable(name, schema, self).to_expr()

def get_schema(self, table_name, database=None):
schemas = self.schemas
try:
schema = schemas[table_name]
except KeyError:
schemas[table_name] = schema = schema_from_pandas_dataframe(
self.dictionary[table_name]
)
df = self.dictionary[table_name]
schemas[table_name] = schema = PandasData.infer_table(df)

return schema

def compile(self, expr, *args, **kwargs):
Expand Down Expand Up @@ -143,7 +143,7 @@ def create_table(
)
df = self._convert_object(obj)
else:
dtypes = dict(schema_to_pandas(schema))
dtypes = dict(PandasSchema.from_ibis(schema))
df = self._from_pandas(pd.DataFrame(columns=dtypes.keys()).astype(dtypes))

if name in self.dictionary and not overwrite:
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/pandas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,15 @@ def execute_and_reset(


def _apply_schema(op: ops.Node, result: pd.DataFrame | pd.Series):
from ibis.formats.pandas import PandasConverter
from ibis.formats.pandas import PandasData

assert isinstance(op, ops.Node), type(op)
if isinstance(result, pd.DataFrame):
df = result.reset_index().loc[:, list(op.schema.names)]
return PandasConverter.convert_frame(df, op.schema)
return PandasData.convert_table(df, op.schema)
elif isinstance(result, pd.Series):
schema = op.to_expr().as_table().schema()
df = PandasConverter.convert_frame(result.to_frame(), schema)
df = PandasData.convert_table(result.to_frame(), schema)
return df.iloc[:, 0].reset_index(drop=True)
else:
return result
Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.postgres.datatypes import dtype_from_postgres, dtype_to_postgres
from ibis.backends.postgres.datatypes import PostgresType
from ibis.backends.postgres.registry import operation_registry


Expand All @@ -20,8 +20,7 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):
# it does support it, but we can't use it because of support for pivot
supports_unnest_in_select = False

get_sqla_type = staticmethod(dtype_to_postgres)
get_ibis_type = staticmethod(dtype_from_postgres)
type_mapper = PostgresType


rewrites = PostgreSQLExprTranslator.rewrites
Expand Down
86 changes: 43 additions & 43 deletions ibis/backends/postgres/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import parsy
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as psql
import sqlalchemy.types as sat
import toolz

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
COMMA,
LBRACKET,
Expand Down Expand Up @@ -121,43 +119,45 @@ def _get_type(typestr: str) -> dt.DataType:
}


def dtype_to_postgres(dtype):
if dtype.is_floating():
if isinstance(dtype, dt.Float64):
return psql.DOUBLE_PRECISION
class PostgresType(AlchemyType):
@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
if dtype.is_floating():
if isinstance(dtype, dt.Float64):
return psql.DOUBLE_PRECISION
else:
return psql.REAL
elif dtype.is_array():
# Unwrap the array element type because sqlalchemy doesn't allow arrays of
# arrays. This doesn't affect the underlying data.
while dtype.is_array():
dtype = dtype.value_type
return sa.ARRAY(cls.from_ibis(dtype))
elif dtype.is_map():
if not (dtype.key_type.is_string() and dtype.value_type.is_string()):
raise TypeError(
f"PostgreSQL only supports map<string, string>, got: {dtype}"
)
return psql.HSTORE()
elif dtype.is_uuid():
return psql.UUID()
else:
return super().from_ibis(dtype)

@classmethod
def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType:
if dtype := _from_postgres_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, psql.HSTORE):
return dt.Map(dt.string, dt.string, nullable=nullable)
elif isinstance(typ, psql.INTERVAL):
field = typ.fields.upper()
if (unit := _postgres_interval_fields.get(field, None)) is None:
raise ValueError(f"Unknown PostgreSQL interval field {field!r}")
elif unit in {"Y", "M"}:
raise ValueError(
"Variable length intervals are not yet supported with PostgreSQL"
)
return dt.Interval(unit=unit, nullable=nullable)
else:
return psql.REAL
elif dtype.is_array():
# Unwrap the array element type because sqlalchemy doesn't allow arrays of
# arrays. This doesn't affect the underlying data.
while dtype.is_array():
dtype = dtype.value_type
return sa.ARRAY(dtype_to_postgres(dtype))
elif dtype.is_map():
if not (dtype.key_type.is_string() and dtype.value_type.is_string()):
raise TypeError(
f"PostgreSQL only supports map<string, string>, got: {dtype}"
)
return psql.HSTORE()
elif dtype.is_uuid():
return psql.UUID()
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_postgres)


def dtype_from_postgres(typ, nullable=True):
if dtype := _from_postgres_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, psql.HSTORE):
return dt.Map(dt.string, dt.string, nullable=nullable)
elif isinstance(typ, psql.INTERVAL):
field = typ.fields.upper()
if (unit := _postgres_interval_fields.get(field, None)) is None:
raise ValueError(f"Unknown PostgreSQL interval field {field!r}")
elif unit in {"Y", "M"}:
raise ValueError(
"Variable length intervals are not yet supported with PostgreSQL"
)
return dt.Interval(unit=unit, nullable=nullable)
else:
return dtype_from_sqlalchemy(typ, converter=dtype_from_postgres)
return super().to_ibis(typ, nullable=nullable)
4 changes: 2 additions & 2 deletions ibis/backends/postgres/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ibis.udf.validate as v
from ibis import IbisError
from ibis.backends.postgres.compiler import PostgreSQLExprTranslator, PostgresUDFNode
from ibis.backends.postgres.datatypes import dtype_to_postgres
from ibis.backends.postgres.datatypes import PostgresType

_udf_name_cache: MutableMapping[str, Any] = collections.defaultdict(itertools.count)

Expand All @@ -28,7 +28,7 @@ class PostgresUDFError(IbisError):

def _ibis_to_postgres_str(ibis_type):
"""Map an ibis DataType to a Postgres-appropriate string."""
satype = dtype_to_postgres(ibis_type)
satype = PostgresType.from_ibis(ibis_type)
if callable(satype):
satype = satype()
return satype.compile(dialect=_postgres_dialect)
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ibis.backends.pyspark import ddl
from ibis.backends.pyspark.client import PySparkTable
from ibis.backends.pyspark.compiler import PySparkExprTranslator
from ibis.backends.pyspark.datatypes import dtype_from_pyspark
from ibis.backends.pyspark.datatypes import PySparkType
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -211,7 +212,7 @@ def execute(self, expr: ir.Expr, **kwargs: Any) -> Any:
table_expr = expr.as_table()
df = self.compile(table_expr, **kwargs).toPandas()

result = self._pandas_converter.convert_frame(df, table_expr.schema())
result = PandasData.convert_table(df, table_expr.schema())
if isinstance(expr, ir.Table):
return result
elif isinstance(expr, ir.Column):
Expand Down Expand Up @@ -241,7 +242,7 @@ def raw_sql(self, query: str) -> _PySparkCursor:

def _get_schema_using_query(self, query):
cursor = self.raw_sql(f"SELECT * FROM ({query}) t0 LIMIT 0")
struct = dtype_from_pyspark(cursor.query.schema)
struct = PySparkType.to_ibis(cursor.query.schema)
return sch.Schema(struct)

def _get_jtable(self, name, database=None):
Expand Down Expand Up @@ -339,7 +340,7 @@ def get_schema(
)

df = self._session.table(table_name)
struct = dtype_from_pyspark(df.schema)
struct = PySparkType.to_ibis(df.schema)
return sch.Schema(struct)

def create_table(
Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ibis import interval
from ibis.backends.base.df.timecontext import adjust_context
from ibis.backends.pandas.execution import execute
from ibis.backends.pyspark.datatypes import dtype_to_pyspark
from ibis.backends.pyspark.datatypes import PySparkType
from ibis.backends.pyspark.timecontext import (
combine_time_context,
filter_by_time_context,
Expand Down Expand Up @@ -266,7 +266,7 @@ def compile_cast(t, op, **kwargs):
'in the PySpark backend. {} not allowed.'.format(type(op.arg))
)

cast_type = dtype_to_pyspark(op.to)
cast_type = PySparkType.from_ibis(op.to)

src_column = t.translate(op.arg, **kwargs)
return src_column.cast(cast_type)
Expand Down Expand Up @@ -756,7 +756,7 @@ def column_max(value, limit):
def clip(column, lower_value, upper_value):
return column_max(column_min(column, F.lit(lower_value)), F.lit(upper_value))

return clip(col, lower, upper).cast(dtype_to_pyspark(op.output_dtype))
return clip(col, lower, upper).cast(PySparkType.from_ibis(op.output_dtype))


@compiles(ops.Round)
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def compile_array_length(t, op, **kwargs):
def compile_array_slice(t, op, **kwargs):
start = op.start.value if op.start is not None else op.start
stop = op.stop.value if op.stop is not None else op.stop
spark_type = dtype_to_pyspark(op.arg.output_dtype)
spark_type = PySparkType.from_ibis(op.arg.output_dtype)

@F.udf(spark_type)
def array_slice(array):
Expand Down Expand Up @@ -1751,7 +1751,7 @@ def compile_fillna_table(t, op, **kwargs):

@compiles(ops.ElementWiseVectorizedUDF)
def compile_elementwise_udf(t, op, **kwargs):
spark_output_type = dtype_to_pyspark(op.return_type)
spark_output_type = PySparkType.from_ibis(op.return_type)
func = op.func
spark_udf = pandas_udf(func, spark_output_type, PandasUDFType.SCALAR)
func_args = (t.translate(arg, **kwargs) for arg in op.func_args)
Expand All @@ -1760,7 +1760,7 @@ def compile_elementwise_udf(t, op, **kwargs):

@compiles(ops.ReductionVectorizedUDF)
def compile_reduction_udf(t, op, *, aggcontext=None, **kwargs):
spark_output_type = dtype_to_pyspark(op.return_type)
spark_output_type = PySparkType.from_ibis(op.return_type)
spark_udf = pandas_udf(op.func, spark_output_type, PandasUDFType.GROUPED_AGG)
func_args = (t.translate(arg, **kwargs) for arg in op.func_args)

Expand Down Expand Up @@ -1895,7 +1895,7 @@ def compile_random(*args, **kwargs):
@compiles(ops.InMemoryTable)
def compile_in_memory_table(t, op, session, **kwargs):
fields = [
pt.StructField(name, dtype_to_pyspark(dtype), dtype.nullable)
pt.StructField(name, PySparkType.from_ibis(dtype), dtype.nullable)
for name, dtype in op.schema.items()
]
schema = pt.StructType(fields)
Expand Down Expand Up @@ -1956,7 +1956,7 @@ def compile_dummy_table(t, op, session=None, **kwargs):
def compile_scalar_parameter(t, op, timecontext=None, scope=None, **kwargs):
assert scope is not None, "scope is None"
raw_value = scope.get_value(op, timecontext)
return F.lit(raw_value).cast(dtype_to_pyspark(op.output_dtype))
return F.lit(raw_value).cast(PySparkType.from_ibis(op.output_dtype))


@compiles(ops.E)
Expand Down
116 changes: 61 additions & 55 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
from ibis.backends.base.sql.registry import sql_type_names
from ibis.formats import TypeMapper

_sql_type_names = dict(sql_type_names, date='date')

Expand Down Expand Up @@ -45,61 +46,66 @@ def type_to_sql_string(tval):
}


def dtype_from_pyspark(typ, nullable=True):
"""Convert a pyspark type to an ibis type."""
if isinstance(typ, pt.DecimalType):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, pt.ArrayType):
return dt.Array(dtype_from_pyspark(typ.elementType), nullable=nullable)
elif isinstance(typ, pt.MapType):
return dt.Map(
dtype_from_pyspark(typ.keyType),
dtype_from_pyspark(typ.valueType),
nullable=nullable,
)
elif isinstance(typ, pt.StructType):
fields = {f.name: dtype_from_pyspark(f.dataType) for f in typ.fields}

return dt.Struct(fields, nullable=nullable)
elif isinstance(typ, pt.DayTimeIntervalType):
if typ.startField == typ.endField and typ.startField in _pyspark_interval_units:
unit = _pyspark_interval_units[typ.startField]
return dt.Interval(unit, nullable=nullable)
else:
raise com.IbisTypeError(f"{typ!r} couldn't be converted to Interval")
elif isinstance(typ, pt.UserDefinedType):
return dtype_from_pyspark(typ.sqlType(), nullable=nullable)
else:
try:
return _from_pyspark_dtypes[type(typ)](nullable=nullable)
except KeyError:
raise NotImplementedError(
f'Unable to convert type {typ} of type {type(typ)} to an ibis type.'
class PySparkType(TypeMapper):
@classmethod
def to_ibis(cls, typ, nullable=True):
"""Convert a pyspark type to an ibis type."""
if isinstance(typ, pt.DecimalType):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif isinstance(typ, pt.ArrayType):
return dt.Array(cls.to_ibis(typ.elementType), nullable=nullable)
elif isinstance(typ, pt.MapType):
return dt.Map(
cls.to_ibis(typ.keyType),
cls.to_ibis(typ.valueType),
nullable=nullable,
)
elif isinstance(typ, pt.StructType):
fields = {f.name: cls.to_ibis(f.dataType) for f in typ.fields}

return dt.Struct(fields, nullable=nullable)
elif isinstance(typ, pt.DayTimeIntervalType):
if (
typ.startField == typ.endField
and typ.startField in _pyspark_interval_units
):
unit = _pyspark_interval_units[typ.startField]
return dt.Interval(unit, nullable=nullable)
else:
raise com.IbisTypeError(f"{typ!r} couldn't be converted to Interval")
elif isinstance(typ, pt.UserDefinedType):
return cls.to_ibis(typ.sqlType(), nullable=nullable)
else:
try:
return _from_pyspark_dtypes[type(typ)](nullable=nullable)
except KeyError:
raise NotImplementedError(
f'Unable to convert type {typ} of type {type(typ)} to an ibis type.'
)

def dtype_to_pyspark(dtype):
if dtype.is_decimal():
return pt.DecimalType(dtype.precision, dtype.scale)
elif dtype.is_array():
element_type = dtype_to_pyspark(dtype.value_type)
contains_null = dtype.value_type.nullable
return pt.ArrayType(element_type, contains_null)
elif dtype.is_map():
key_type = dtype_to_pyspark(dtype.key_type)
value_type = dtype_to_pyspark(dtype.value_type)
value_contains_null = dtype.value_type.nullable
return pt.MapType(key_type, value_type, value_contains_null)
elif dtype.is_struct():
fields = [
pt.StructField(n, dtype_to_pyspark(t), t.nullable)
for n, t in dtype.fields.items()
]
return pt.StructType(fields)
else:
try:
return _to_pyspark_dtypes[type(dtype)]()
except KeyError:
raise com.IbisTypeError(
f"Unable to convert dtype {dtype!r} to pyspark type"
)
@classmethod
def from_ibis(cls, dtype):
if dtype.is_decimal():
return pt.DecimalType(dtype.precision, dtype.scale)
elif dtype.is_array():
element_type = cls.from_ibis(dtype.value_type)
contains_null = dtype.value_type.nullable
return pt.ArrayType(element_type, contains_null)
elif dtype.is_map():
key_type = cls.from_ibis(dtype.key_type)
value_type = cls.from_ibis(dtype.value_type)
value_contains_null = dtype.value_type.nullable
return pt.MapType(key_type, value_type, value_contains_null)
elif dtype.is_struct():
fields = [
pt.StructField(n, cls.from_ibis(t), t.nullable)
for n, t in dtype.fields.items()
]
return pt.StructType(fields)
else:
try:
return _to_pyspark_dtypes[type(dtype)]()
except KeyError:
raise com.IbisTypeError(
f"Unable to convert dtype {dtype!r} to pyspark type"
)
20 changes: 4 additions & 16 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import functools
import itertools
import os
import shutil
Expand All @@ -23,11 +22,8 @@
AlchemyExprTranslator,
BaseAlchemyBackend,
)
from ibis.backends.snowflake.datatypes import (
dtype_from_snowflake,
dtype_to_snowflake,
parse,
)
from ibis.backends.snowflake.converter import SnowflakePandasData
from ibis.backends.snowflake.datatypes import SnowflakeType, parse
from ibis.backends.snowflake.registry import operation_registry

if TYPE_CHECKING:
Expand All @@ -50,9 +46,7 @@ class SnowflakeExprTranslator(AlchemyExprTranslator):
_quote_column_names = True
_quote_table_names = True
supports_unnest_in_select = False

get_sqla_type = staticmethod(dtype_to_snowflake)
get_ibis_type = staticmethod(dtype_from_snowflake)
type_mapper = SnowflakeType


class SnowflakeCompiler(AlchemyCompiler):
Expand Down Expand Up @@ -94,12 +88,6 @@ class Backend(BaseAlchemyBackend):
compiler = SnowflakeCompiler
supports_create_or_replace = True

@functools.cached_property
def _pandas_converter(self):
from ibis.backends.snowflake.converter import SnowflakePandasConverter

return SnowflakePandasConverter

@property
def _current_schema(self) -> str:
with self.begin() as con:
Expand Down Expand Up @@ -272,7 +260,7 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
if (table := cursor.cursor.fetch_arrow_all()) is None:
table = pa.Table.from_pylist([], schema=schema.to_pyarrow())
df = table.to_pandas(timestamp_as_object=True)
return self._pandas_converter.convert_frame(df, schema)
return SnowflakePandasData.convert_table(df, schema)

def to_pyarrow_batches(
self,
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/snowflake/converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ibis.formats.pandas import PandasConverter
from ibis.formats.pandas import PandasData


class SnowflakePandasConverter(PandasConverter):
convert_Struct = staticmethod(PandasConverter.convert_JSON)
convert_Array = staticmethod(PandasConverter.convert_JSON)
convert_Map = staticmethod(PandasConverter.convert_JSON)
class SnowflakePandasData(PandasData):
convert_Struct = staticmethod(PandasData.convert_JSON)
convert_Array = staticmethod(PandasData.convert_JSON)
convert_Map = staticmethod(PandasData.convert_JSON)
95 changes: 47 additions & 48 deletions ibis/backends/snowflake/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from sqlalchemy.ext.compiler import compiles

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType


@compiles(sat.NullType, "snowflake")
Expand Down Expand Up @@ -46,50 +43,52 @@ def parse(text: str) -> dt.DataType:
return _SNOWFLAKE_TYPES[text]


def dtype_to_snowflake(dtype):
if dtype.is_array():
return ARRAY
elif dtype.is_map() or dtype.is_struct():
return OBJECT
elif dtype.is_json():
return VARIANT
elif dtype.is_timestamp():
if dtype.timezone is None:
return TIMESTAMP_NTZ
class SnowflakeType(AlchemyType):
@classmethod
def from_ibis(cls, dtype):
if dtype.is_array():
return ARRAY
elif dtype.is_map() or dtype.is_struct():
return OBJECT
elif dtype.is_json():
return VARIANT
elif dtype.is_timestamp():
if dtype.timezone is None:
return TIMESTAMP_NTZ
else:
return TIMESTAMP_TZ
elif dtype.is_string():
# 16MB
return sat.VARCHAR(2**24)
elif dtype.is_binary():
# 8MB
return sat.VARBINARY(2**23)
else:
return TIMESTAMP_TZ
elif dtype.is_string():
# 16MB
return sat.VARCHAR(2**24)
elif dtype.is_binary():
# 8MB
return sat.VARBINARY(2**23)
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_snowflake)

return super().from_ibis(dtype)

def dtype_from_snowflake(typ, nullable=True):
if isinstance(typ, (sat.REAL, sat.FLOAT, sat.Float)):
return dt.Float64(nullable=nullable)
elif isinstance(typ, TIMESTAMP_NTZ):
return dt.Timestamp(timezone=None, nullable=nullable)
elif isinstance(typ, (TIMESTAMP_LTZ, TIMESTAMP_TZ)):
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif isinstance(typ, ARRAY):
return dt.Array(dt.json, nullable=nullable)
elif isinstance(typ, OBJECT):
return dt.Map(dt.string, dt.json, nullable=nullable)
elif isinstance(typ, VARIANT):
return dt.JSON(nullable=nullable)
elif isinstance(typ, sat.Numeric):
if (scale := typ.scale) == 0:
# kind of a lie, should be int128 because 38 digits
return dt.Int64(nullable=nullable)
@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, (sat.REAL, sat.FLOAT, sat.Float)):
return dt.Float64(nullable=nullable)
elif isinstance(typ, TIMESTAMP_NTZ):
return dt.Timestamp(timezone=None, nullable=nullable)
elif isinstance(typ, (TIMESTAMP_LTZ, TIMESTAMP_TZ)):
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif isinstance(typ, ARRAY):
return dt.Array(dt.json, nullable=nullable)
elif isinstance(typ, OBJECT):
return dt.Map(dt.string, dt.json, nullable=nullable)
elif isinstance(typ, VARIANT):
return dt.JSON(nullable=nullable)
elif isinstance(typ, sat.Numeric):
if (scale := typ.scale) == 0:
# kind of a lie, should be int128 because 38 digits
return dt.Int64(nullable=nullable)
else:
return dt.Decimal(
precision=typ.precision or 38,
scale=scale or 0,
nullable=nullable,
)
else:
return dt.Decimal(
precision=typ.precision or 38,
scale=scale or 0,
nullable=nullable,
)
else:
return dtype_from_sqlalchemy(typ, nullable=nullable)
return super().to_ibis(typ, nullable=nullable)
4 changes: 2 additions & 2 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.sqlite import udf
from ibis.backends.sqlite.compiler import SQLiteCompiler
from ibis.backends.sqlite.datatypes import dtype_to_sqlite, parse
from ibis.backends.sqlite.datatypes import SqliteType, parse

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -125,7 +125,7 @@ def do_connect(
# easier than subclassing the builtin SQLite dialect, and achieves
# the same desired behavior.
def _to_ischema_val(t):
sa_type = dtype_to_sqlite(dt.dtype(t))
sa_type = SqliteType.from_ibis(dt.dtype(t))
if isinstance(sa_type, sa.types.TypeEngine):
# SQLAlchemy expects a callable here, rather than an
# instance. Use a lambda to work around this.
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.sqlite.datatypes import dtype_from_sqlite, dtype_to_sqlite
from ibis.backends.sqlite.datatypes import SqliteType
from ibis.backends.sqlite.registry import operation_registry


class SQLiteExprTranslator(AlchemyExprTranslator):
_registry = operation_registry
_rewrites = AlchemyExprTranslator._rewrites.copy()
_dialect_name = "sqlite"

get_sqla_type = staticmethod(dtype_to_sqlite)
get_ibis_type = staticmethod(dtype_from_sqlite)
type_mapper = SqliteType


rewrites = SQLiteExprTranslator.rewrites
Expand Down
36 changes: 17 additions & 19 deletions ibis/backends/sqlite/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

from __future__ import annotations

import sqlalchemy as sa
import sqlalchemy.types as sat
from sqlalchemy.dialects import sqlite

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType


def parse(text: str) -> dt.DataType:
Expand Down Expand Up @@ -46,17 +42,19 @@ def parse(text: str) -> dt.DataType:
return dt.decimal


def dtype_to_sqlite(dtype):
if dtype.is_floating():
return sa.REAL
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_sqlite)


def dtype_from_sqlite(typ, nullable=True):
if isinstance(typ, sat.REAL):
return dt.Float64(nullable=nullable)
elif isinstance(typ, sqlite.JSON):
return dt.JSON(nullable=nullable)
else:
return dtype_from_sqlalchemy(typ, converter=dtype_from_sqlite)
class SqliteType(AlchemyType):
@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
if dtype.is_floating():
return sat.REAL
else:
return super().from_ibis(dtype)

@classmethod
def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType:
if isinstance(typ, sat.REAL):
return dt.Float64(nullable=nullable)
elif isinstance(typ, sqlite.JSON):
return dt.JSON(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)
6 changes: 2 additions & 4 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.trino.datatypes import dtype_from_trino, dtype_to_trino
from ibis.backends.trino.datatypes import TrinoType
from ibis.backends.trino.registry import operation_registry


Expand All @@ -26,9 +26,7 @@ class TrinoSQLExprTranslator(AlchemyExprTranslator):
)
_dialect_name = "trino"
supports_unnest_in_select = False

get_sqla_type = staticmethod(dtype_to_trino)
get_ibis_type = staticmethod(dtype_from_trino)
type_mapper = TrinoType


rewrites = TrinoSQLExprTranslator.rewrites
Expand Down
91 changes: 45 additions & 46 deletions ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from trino.sqlalchemy.datatype import ROW as _ROW

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import (
dtype_from_sqlalchemy,
dtype_to_sqlalchemy,
)
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
COMMA,
FIELD,
Expand Down Expand Up @@ -145,45 +142,47 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
}


def dtype_from_trino(typ, nullable=True):
if dtype := _from_trino_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, sat.NUMERIC):
return dt.Decimal(typ.precision or 18, typ.scale or 3, nullable=nullable)
elif isinstance(typ, sat.ARRAY):
value_dtype = dtype_from_trino(typ.item_type)
return dt.Array(value_dtype, nullable=nullable)
elif isinstance(typ, ROW):
fields = ((k, dtype_from_trino(v)) for k, v in typ.attr_types)
return dt.Struct.from_tuples(fields, nullable=nullable)
elif isinstance(typ, MAP):
return dt.Map(
dtype_from_trino(typ.key_type),
dtype_from_trino(typ.value_type),
nullable=nullable,
)
elif isinstance(typ, TIMESTAMP):
return dt.Timestamp(
timezone="UTC" if typ.timezone else None,
scale=typ.precision,
nullable=nullable,
)
else:
return dtype_from_sqlalchemy(typ, converter=dtype_from_trino)


def dtype_to_trino(dtype):
if isinstance(dtype, dt.Float64):
return DOUBLE()
elif isinstance(dtype, dt.Float32):
return sat.REAL()
elif dtype.is_string():
return sat.VARCHAR()
elif dtype.is_struct():
return ROW((name, dtype_to_trino(typ)) for name, typ in dtype.fields.items())
elif dtype.is_map():
return MAP(dtype_to_trino(dtype.key_type), dtype_to_trino(dtype.value_type))
elif dtype.is_timestamp():
return TIMESTAMP(precision=dtype.scale, timezone=bool(dtype.timezone))
else:
return dtype_to_sqlalchemy(dtype, converter=dtype_to_trino)
class TrinoType(AlchemyType):
@classmethod
def to_ibis(cls, typ, nullable=True):
if dtype := _from_trino_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, sat.NUMERIC):
return dt.Decimal(typ.precision or 18, typ.scale or 3, nullable=nullable)
elif isinstance(typ, sat.ARRAY):
value_dtype = cls.to_ibis(typ.item_type)
return dt.Array(value_dtype, nullable=nullable)
elif isinstance(typ, ROW):
fields = ((k, cls.to_ibis(v)) for k, v in typ.attr_types)
return dt.Struct.from_tuples(fields, nullable=nullable)
elif isinstance(typ, MAP):
return dt.Map(
cls.to_ibis(typ.key_type),
cls.to_ibis(typ.value_type),
nullable=nullable,
)
elif isinstance(typ, TIMESTAMP):
return dt.Timestamp(
timezone="UTC" if typ.timezone else None,
scale=typ.precision,
nullable=nullable,
)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype):
if isinstance(dtype, dt.Float64):
return DOUBLE()
elif isinstance(dtype, dt.Float32):
return sat.REAL()
elif dtype.is_string():
return sat.VARCHAR()
elif dtype.is_struct():
return ROW((name, cls.from_ibis(typ)) for name, typ in dtype.fields.items())
elif dtype.is_map():
return MAP(cls.from_ibis(dtype.key_type), cls.from_ibis(dtype.value_type))
elif dtype.is_timestamp():
return TIMESTAMP(precision=dtype.scale, timezone=bool(dtype.timezone))
else:
return super().from_ibis(dtype)
24 changes: 12 additions & 12 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,23 @@ def from_typehint(cls, typ, nullable=True) -> Self:
@classmethod
def from_numpy(cls, numpy_type, nullable=True) -> Self:
"""Return the equivalent ibis datatype."""
from ibis.formats.numpy import dtype_from_numpy
from ibis.formats.numpy import NumpyType

return dtype_from_numpy(numpy_type, nullable=nullable)
return NumpyType.to_ibis(numpy_type, nullable=nullable)

@classmethod
def from_pandas(cls, pandas_type, nullable=True) -> Self:
"""Return the equivalent ibis datatype."""
from ibis.formats.pandas import dtype_from_pandas
from ibis.formats.pandas import PandasType

return dtype_from_pandas(pandas_type, nullable=nullable)
return PandasType.to_ibis(pandas_type, nullable=nullable)

@classmethod
def from_pyarrow(cls, arrow_type, nullable=True) -> Self:
"""Return the equivalent ibis datatype."""
from ibis.formats.pyarrow import dtype_to_pyarrow
from ibis.formats.pyarrow import PyArrowType

return dtype_to_pyarrow(arrow_type, nullable=nullable)
return PyArrowType.to_ibis(arrow_type, nullable=nullable)

@classmethod
def from_dask(cls, dask_type, nullable=True) -> Self:
Expand All @@ -191,21 +191,21 @@ def from_dask(cls, dask_type, nullable=True) -> Self:

def to_numpy(self):
"""Return the equivalent numpy datatype."""
from ibis.formats.numpy import dtype_to_numpy
from ibis.formats.numpy import NumpyFormat

return dtype_to_numpy(self)
return NumpyFormat.from_dtype(self)

def to_pandas(self):
"""Return the equivalent pandas datatype."""
from ibis.formats.pandas import dtype_to_pandas
from ibis.formats.pandas import PandasType

return dtype_to_pandas(self)
return PandasType.from_ibis(self)

def to_pyarrow(self):
"""Return the equivalent pyarrow datatype."""
from ibis.formats.pyarrow import dtype_to_pyarrow
from ibis.formats.pyarrow import PyArrowType

return dtype_to_pyarrow(self)
return PyArrowType.from_ibis(self)

def to_dask(self):
"""Return the equivalent dask datatype."""
Expand Down
12 changes: 6 additions & 6 deletions ibis/expr/datatypes/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def infer_ipaddr(

@infer.register("numpy.generic")
def infer_numpy_scalar(value):
from ibis.formats.numpy import dtype_from_numpy
from ibis.formats.numpy import NumpyType

return dtype_from_numpy(value.dtype)
return NumpyType.to_ibis(value.dtype)


@infer.register("pandas.Timestamp")
Expand All @@ -171,13 +171,13 @@ def infer_interval_pandas(value) -> dt.Interval:
@infer.register("numpy.ndarray")
@infer.register("pandas.Series")
def infer_numpy_array(value):
from ibis.formats.numpy import dtype_from_numpy
from ibis.formats.pyarrow import infer_sequence_dtype
from ibis.formats.numpy import NumpyType
from ibis.formats.pyarrow import PyArrowData

if value.dtype.kind == 'O':
value_dtype = infer_sequence_dtype(value)
value_dtype = PyArrowData.infer_column(value)
else:
value_dtype = dtype_from_numpy(value.dtype)
value_dtype = NumpyType.to_ibis(value.dtype)

return dt.Array(value_dtype)

Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def to_frame(self) -> pd.DataFrame:
def to_pyarrow(self, schema: Schema) -> pa.Table:
import pyarrow as pa

from ibis.formats.pyarrow import schema_to_pyarrow
from ibis.formats.pyarrow import PyArrowSchema

return pa.Table.from_pandas(self._data, schema=schema_to_pyarrow(schema))
return pa.Table.from_pandas(self._data, schema=PyArrowSchema.from_ibis(schema))


@public
Expand Down
44 changes: 22 additions & 22 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,23 @@ def from_tuples(
@classmethod
def from_numpy(cls, numpy_schema):
"""Return the equivalent ibis schema."""
from ibis.formats.numpy import schema_from_numpy
from ibis.formats.numpy import NumpySchema

return schema_from_numpy(numpy_schema)
return NumpySchema.to_ibis(numpy_schema)

@classmethod
def from_pandas(cls, pandas_schema):
"""Return the equivalent ibis schema."""
from ibis.formats.pandas import schema_from_pandas
from ibis.formats.pandas import PandasSchema

return schema_from_pandas(pandas_schema)
return PandasSchema.to_ibis(pandas_schema)

@classmethod
def from_pyarrow(cls, pyarrow_schema):
"""Return the equivalent ibis schema."""
from ibis.formats.pyarrow import schema_from_pyarrow
from ibis.formats.pyarrow import PyArrowSchema

return schema_from_pyarrow(pyarrow_schema)
return PyArrowSchema.to_ibis(pyarrow_schema)

@classmethod
def from_dask(cls, dask_schema):
Expand All @@ -157,21 +157,21 @@ def from_dask(cls, dask_schema):

def to_numpy(self):
"""Return the equivalent numpy dtypes."""
from ibis.formats.numpy import schema_to_numpy
from ibis.formats.numpy import NumpySchema

return schema_to_numpy(self)
return NumpySchema.from_ibis(self)

def to_pandas(self):
"""Return the equivalent pandas datatypes."""
from ibis.formats.pandas import schema_to_pandas
from ibis.formats.pandas import PandasSchema

return schema_to_pandas(self)
return PandasSchema.from_ibis(self)

def to_pyarrow(self):
"""Return the equivalent pyarrow schema."""
from ibis.formats.pyarrow import schema_to_pyarrow
from ibis.formats.pyarrow import PyArrowSchema

return schema_to_pyarrow(self)
return PyArrowSchema.from_ibis(self)

def to_dask(self):
"""Return the equivalent dask dtypes."""
Expand Down Expand Up @@ -209,9 +209,9 @@ def name_at_position(self, i: int) -> str:
instead="use ibis.formats.pandas.PandasConverter.convert_frame() instead",
)
def apply_to(self, df: pd.DataFrame) -> pd.DataFrame:
from ibis.formats.pandas import PandasConverter
from ibis.formats.pandas import PandasData

return PandasConverter.convert_frame(df, self)
return PandasData.convert_table(df, self)


@lazy_singledispatch
Expand Down Expand Up @@ -248,32 +248,32 @@ def from_class(cls):

@schema.register("pandas.Series")
def from_pandas_series(s):
from ibis.formats.pandas import schema_from_pandas
from ibis.formats.pandas import PandasSchema

return schema_from_pandas(s)
return PandasSchema.to_ibis(s)


@schema.register("pyarrow.Schema")
def from_pyarrow_schema(schema):
from ibis.formats.pyarrow import schema_from_pyarrow
from ibis.formats.pyarrow import PyArrowSchema

return schema_from_pyarrow(schema)
return PyArrowSchema.to_ibis(schema)


@infer.register("pandas.DataFrame")
def infer_pandas_dataframe(df, schema=None):
from ibis.formats.pandas import schema_from_pandas_dataframe
from ibis.formats.pandas import PandasData

return schema_from_pandas_dataframe(df, schema)
return PandasData.infer_table(df, schema)


# TODO(kszucs): do we really need the schema kwarg?
@infer.register("pyarrow.Table")
def infer_pyarrow_table(table, schema=None):
from ibis.formats.pyarrow import schema_from_pyarrow
from ibis.formats.pyarrow import PyArrowSchema

schema = schema if schema is not None else table.schema
return schema_from_pyarrow(schema)
return PyArrowSchema.to_ibis(schema)


# lock the dispatchers to avoid adding new implementations
Expand Down
182 changes: 182 additions & 0 deletions ibis/formats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from ibis.expr.datatypes import DataType
from ibis.expr.schema import Schema

C = TypeVar('C')
T = TypeVar('T')
S = TypeVar('S')


class TypeMapper(ABC, Generic[T]):
# `T` is the format-specific type object, e.g. pyarrow.DataType or
# sqlalchemy.types.TypeEngine

@classmethod
@abstractmethod
def from_ibis(cls, dtype: DataType) -> T:
"""Convert an Ibis DataType to a format-specific type object.
Parameters
----------
dtype
The Ibis DataType to convert.
Returns
-------
Format-specific type object.
"""

@classmethod
@abstractmethod
def to_ibis(cls, typ: T, nullable: bool = True) -> DataType:
"""Convert a format-specific type object to an Ibis DataType.
Parameters
----------
typ
The format-specific type object to convert.
nullable
Whether the Ibis DataType should be nullable.
Returns
-------
Ibis DataType.
"""


class SchemaMapper(ABC, Generic[S]):
# `S` is the format-specific schema object, e.g. pyarrow.Schema

@classmethod
@abstractmethod
def from_ibis(cls, schema: Schema) -> S:
"""Convert an Ibis Schema to a format-specific schema object.
Parameters
----------
schema
The Ibis Schema to convert.
Returns
-------
Format-specific schema object.
"""

@classmethod
@abstractmethod
def to_ibis(cls, obj: S) -> Schema:
"""Convert a format-specific schema object to an Ibis Schema.
Parameters
----------
obj
The format-specific schema object to convert.
Returns
-------
Ibis Schema.
"""


class DataMapper(Generic[S, C, T]):
# `S` is the format-specific scalar object, e.g. pyarrow.Scalar
# `C` is the format-specific column object, e.g. pyarrow.Array
# `T` is the format-specific table object, e.g. pyarrow.Table

@classmethod
def convert_scalar(cls, obj: S, dtype: DataType) -> S:
"""Convert a format-specific scalar to the given ibis datatype.
Parameters
----------
obj
The format-specific scalar value to convert.
dtype
The Ibis datatype to convert to.
Returns
-------
Format specific scalar corresponding to the given Ibis datatype.
"""
raise NotImplementedError

@classmethod
def convert_column(cls, obj: C, dtype: DataType) -> C:
"""Convert a format-specific column to the given ibis datatype.
Parameters
----------
obj
The format-specific column value to convert.
dtype
The Ibis datatype to convert to.
Returns
-------
Format specific column corresponding to the given Ibis datatype.
"""
raise NotImplementedError

@classmethod
def convert_table(cls, obj: T, schema: Schema) -> T:
"""Convert a format-specific table to the given ibis schema.
Parameters
----------
obj
The format-specific table-like object to convert.
schema
The Ibis schema to convert to.
Returns
-------
Format specific table-like object corresponding to the given Ibis schema.
"""
raise NotImplementedError

@classmethod
def infer_scalar(cls, obj: S) -> DataType:
"""Infer the Ibis datatype of a format-specific scalar.
Parameters
----------
obj
The format-specific scalar to infer the Ibis datatype of.
Returns
-------
Ibis datatype corresponding to the given format-specific scalar.
"""
raise NotImplementedError

@classmethod
def infer_column(cls, obj: C) -> DataType:
"""Infer the Ibis datatype of a format-specific column.
Parameters
----------
obj
The format-specific column to infer the Ibis datatype of.
Returns
-------
Ibis datatype corresponding to the given format-specific column.
"""
raise NotImplementedError

@classmethod
def infer_table(cls, obj: T) -> Schema:
"""Infer the Ibis schema of a format-specific table.
Parameters
----------
obj
The format-specific table to infer the Ibis schema of.
Returns
-------
Ibis schema corresponding to the given format-specific table.
"""
raise NotImplementedError
129 changes: 67 additions & 62 deletions ibis/formats/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

_from_numpy_types = toolz.keymap(
np.dtype,
Expand All @@ -26,69 +27,73 @@
_to_numpy_types = {v: k for k, v in _from_numpy_types.items()}


def dtype_from_numpy(typ, nullable=True):
if np.issubdtype(typ, np.datetime64):
# TODO(kszucs): the following code provedes proper timestamp roundtrips
# between ibis and numpy/pandas but breaks the test suite at several
# places, we should revisit this later
# unit, _ = np.datetime_data(typ)
# if unit in {'generic', 'Y', 'M', 'D', 'h', 'm'}:
# return dt.Timestamp(nullable=nullable)
# else:
# return dt.Timestamp.from_unit(unit, nullable=nullable)
return dt.Timestamp(nullable=nullable)
elif np.issubdtype(typ, np.timedelta64):
unit, _ = np.datetime_data(typ)
if unit == 'generic':
unit = 's'
return dt.Interval(unit, nullable=nullable)
elif np.issubdtype(typ, np.str_):
return dt.String(nullable=nullable)
elif np.issubdtype(typ, np.bytes_):
return dt.Binary(nullable=nullable)
else:
try:
return _from_numpy_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"numpy dtype {typ!r} is not supported")
class NumpyType(TypeMapper[np.dtype]):
@classmethod
def to_ibis(cls, typ: np.dtype, nullable: bool = True) -> dt.DataType:
if np.issubdtype(typ, np.datetime64):
# TODO(kszucs): the following code provedes proper timestamp roundtrips
# between ibis and numpy/pandas but breaks the test suite at several
# places, we should revisit this later
# unit, _ = np.datetime_data(typ)
# if unit in {'generic', 'Y', 'M', 'D', 'h', 'm'}:
# return dt.Timestamp(nullable=nullable)
# else:
# return dt.Timestamp.from_unit(unit, nullable=nullable)
return dt.Timestamp(nullable=nullable)
elif np.issubdtype(typ, np.timedelta64):
unit, _ = np.datetime_data(typ)
if unit == 'generic':
unit = 's'
return dt.Interval(unit, nullable=nullable)
elif np.issubdtype(typ, np.str_):
return dt.String(nullable=nullable)
elif np.issubdtype(typ, np.bytes_):
return dt.Binary(nullable=nullable)
else:
try:
return _from_numpy_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"numpy dtype {typ!r} is not supported")

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> np.dtype:
if dtype.is_interval():
return np.dtype(f"timedelta64[{dtype.unit.short}]")
elif dtype.is_timestamp():
# TODO(kszucs): the following code provedes proper timestamp roundtrips
# between ibis and numpy/pandas but breaks the test suite at several
# places, we should revisit this later
# return np.dtype(f"datetime64[{dtype.unit.short}]")
return np.dtype("datetime64[ns]")
elif dtype.is_date():
# return np.dtype("datetime64[D]")
return np.dtype("datetime64[ns]")
elif dtype.is_time():
return np.dtype("timedelta64[ns]")
elif (
dtype.is_null()
or dtype.is_decimal()
or dtype.is_struct()
or dtype.is_variadic()
or dtype.is_unknown()
or dtype.is_uuid()
or dtype.is_geospatial()
):
return np.dtype("object")
else:
try:
return _to_numpy_types[type(dtype)]
except KeyError:
raise TypeError(f"ibis dtype {dtype!r} is not supported")

def dtype_to_numpy(dtype):
if dtype.is_interval():
return np.dtype(f"timedelta64[{dtype.unit.short}]")
elif dtype.is_timestamp():
# TODO(kszucs): the following code provedes proper timestamp roundtrips
# between ibis and numpy/pandas but breaks the test suite at several
# places, we should revisit this later
# return np.dtype(f"datetime64[{dtype.unit.short}]")
return np.dtype("datetime64[ns]")
elif dtype.is_date():
# return np.dtype("datetime64[D]")
return np.dtype("datetime64[ns]")
elif dtype.is_time():
return np.dtype("timedelta64[ns]")
elif (
dtype.is_null()
or dtype.is_decimal()
or dtype.is_struct()
or dtype.is_variadic()
or dtype.is_unknown()
or dtype.is_uuid()
or dtype.is_geospatial()
):
return np.dtype("object")
else:
try:
return _to_numpy_types[type(dtype)]
except KeyError:
raise TypeError(f"ibis dtype {dtype!r} is not supported")

class NumpySchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema):
numpy_types = map(NumpyType.from_ibis, schema.types)
return list(zip(schema.names, numpy_types))

def schema_to_numpy(schema):
numpy_types = map(dtype_to_numpy, schema.types)
return list(zip(schema.names, numpy_types))


def schema_from_numpy(schema):
ibis_types = {name: dtype_from_numpy(typ) for name, typ in schema}
return sch.Schema(ibis_types)
@classmethod
def to_ibis(cls, schema):
ibis_types = {name: NumpyType.to_ibis(typ) for name, typ in schema}
return sch.Schema(ibis_types)
165 changes: 87 additions & 78 deletions ibis/formats/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats.numpy import dtype_from_numpy, dtype_to_numpy
from ibis.formats.pyarrow import dtype_from_pyarrow, infer_sequence_dtype
from ibis.formats import DataMapper, SchemaMapper
from ibis.formats.numpy import NumpyType
from ibis.formats.pyarrow import PyArrowData, PyArrowType

_has_arrow_dtype = hasattr(pd, "ArrowDtype")

Expand All @@ -20,108 +21,112 @@
)


def dtype_to_pandas(dtype: dt.DataType):
"""Convert ibis dtype to the pandas / numpy alternative."""
assert isinstance(dtype, dt.DataType)

if dtype.is_timestamp() and dtype.timezone:
return pdt.DatetimeTZDtype('ns', dtype.timezone)
elif dtype.is_interval():
return np.dtype(f'timedelta64[{dtype.unit.short}]')
else:
return dtype_to_numpy(dtype)


def dtype_from_pandas(typ, nullable=True):
if pdt.is_datetime64tz_dtype(typ):
return dt.Timestamp(timezone=str(typ.tz), nullable=nullable)
elif pdt.is_datetime64_dtype(typ):
return dt.Timestamp(nullable=nullable)
elif pdt.is_categorical_dtype(typ):
return dt.String(nullable=nullable)
elif pdt.is_extension_array_dtype(typ):
if _has_arrow_dtype and isinstance(typ, pd.ArrowDtype):
return dtype_from_pyarrow(typ.pyarrow_dtype, nullable=nullable)
class PandasType(NumpyType):
@classmethod
def to_ibis(cls, typ, nullable=True):
if pdt.is_datetime64tz_dtype(typ):
return dt.Timestamp(timezone=str(typ.tz), nullable=nullable)
elif pdt.is_datetime64_dtype(typ):
return dt.Timestamp(nullable=nullable)
elif pdt.is_categorical_dtype(typ):
return dt.String(nullable=nullable)
elif pdt.is_extension_array_dtype(typ):
if _has_arrow_dtype and isinstance(typ, pd.ArrowDtype):
return PyArrowType.to_ibis(typ.pyarrow_dtype, nullable=nullable)
else:
name = typ.__class__.__name__.replace("Dtype", "")
klass = getattr(dt, name)
return klass(nullable=nullable)
else:
name = typ.__class__.__name__.replace("Dtype", "")
klass = getattr(dt, name)
return klass(nullable=nullable)
else:
return dtype_from_numpy(typ, nullable=nullable)

return super().to_ibis(typ, nullable=nullable)

def schema_to_pandas(schema):
pandas_types = map(dtype_to_pandas, schema.types)
return list(zip(schema.names, pandas_types))
@classmethod
def from_ibis(cls, dtype):
if dtype.is_timestamp() and dtype.timezone:
return pdt.DatetimeTZDtype('ns', dtype.timezone)
elif dtype.is_interval():
return np.dtype(f'timedelta64[{dtype.unit.short}]')
else:
return super().from_ibis(dtype)


def schema_from_pandas(schema):
ibis_types = {name: dtype_from_pandas(typ) for name, typ in schema}
return sch.schema(ibis_types)
class PandasSchema(SchemaMapper):
@classmethod
def to_ibis(cls, pandas_schema):
if isinstance(pandas_schema, pd.Series):
pandas_schema = pandas_schema.to_list()

fields = {name: PandasType.to_ibis(t) for name, t in pandas_schema}

def schema_from_pandas_dataframe(
df: pd.DataFrame, schema=None, inference_function=infer_sequence_dtype
):
schema = schema if schema is not None else {}
return sch.Schema(fields)

pairs = []
for column_name in df.dtypes.keys():
if not isinstance(column_name, str):
raise TypeError('Column names must be strings to use the pandas backend')
@classmethod
def from_ibis(cls, schema):
names = schema.names
types = [PandasType.from_ibis(t) for t in schema.types]
return list(zip(names, types))

if column_name in schema:
ibis_dtype = schema[column_name]
else:
pandas_column = df[column_name]
pandas_dtype = pandas_column.dtype
if pandas_dtype == np.object_:
ibis_dtype = inference_function(pandas_column.values)
else:
ibis_dtype = dtype_from_pandas(pandas_dtype)

pairs.append((column_name, ibis_dtype))
class PandasData(DataMapper):
@classmethod
def infer_scalar(cls, s):
return PyArrowData.infer_scalar(s)

return sch.schema(pairs)
@classmethod
def infer_column(cls, s):
return PyArrowData.infer_column(s)

@classmethod
def infer_table(cls, df, schema=None):
schema = schema if schema is not None else {}

pairs = []
for column_name in df.dtypes.keys():
if not isinstance(column_name, str):
raise TypeError(
'Column names must be strings to use the pandas backend'
)

if column_name in schema:
ibis_dtype = schema[column_name]
else:
pandas_column = df[column_name]
pandas_dtype = pandas_column.dtype
if pandas_dtype == np.object_:
ibis_dtype = cls.infer_column(pandas_column)
else:
ibis_dtype = PandasType.to_ibis(pandas_dtype)

def schema_from_dask_dataframe(df, schema=None):
# TODO(kszucs): we should limit the computation to the first partition or
# even just the first row if we switch to `pa.infer_type()` in the inference
# function
return schema_from_pandas_dataframe(
df,
schema=schema,
inference_function=lambda s: infer_sequence_dtype(s.compute()),
)
pairs.append((column_name, ibis_dtype))

return sch.Schema.from_tuples(pairs)

class PandasConverter:
@classmethod
def convert_frame(cls, df, schema):
def convert_table(cls, df, schema):
if len(schema) != len(df.columns):
raise ValueError(
"schema column count does not match input data column count"
)

for (name, series), dtype in zip(df.items(), schema.types):
df[name] = cls.convert_series(series, dtype)
df[name] = cls.convert_column(series, dtype)

# return data with the schema's columns which may be different than the input columns
# return data with the schema's columns which may be different than the
# input columns
df.columns = schema.names
return df

@classmethod
def convert_series(cls, s, dtype):
pandas_type = dtype.to_pandas()
def convert_column(cls, obj, dtype):
pandas_type = PandasType.from_ibis(dtype)

if s.dtype == pandas_type and dtype.is_primitive():
return s
if obj.dtype == pandas_type and dtype.is_primitive():
return obj

converter = getattr(
cls, f"convert_{dtype.__class__.__name__}", cls.convert_default
)
return converter(s, dtype, pandas_type)
method_name = f"convert_{dtype.__class__.__name__}"
convert_method = getattr(cls, method_name, cls.convert_default)

return convert_method(obj, dtype, pandas_type)

@staticmethod
def convert_default(s, dtype, pandas_type):
Expand All @@ -132,8 +137,6 @@ def convert_default(s, dtype, pandas_type):

@staticmethod
def convert_Boolean(s, dtype, pandas_type):
import pandas.api.types as pdt

if s.empty:
return s.astype(pandas_type)
elif pdt.is_object_dtype(s.dtype):
Expand Down Expand Up @@ -224,3 +227,9 @@ def try_json(x):
return x

return s.map(try_json, na_action="ignore").astype("object")


class DaskData(PandasData):
@classmethod
def infer_column(cls, s):
return PyArrowData.infer_column(s.compute())
278 changes: 165 additions & 113 deletions ibis/formats/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pyarrow as pa

import ibis.expr.datatypes as dt
from ibis.expr.schema import Schema
from ibis.formats import DataMapper, SchemaMapper, TypeMapper

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -58,126 +59,177 @@
}


def dtype_from_pyarrow(typ: pa.DataType, nullable=True) -> dt.DataType:
"""Convert a pyarrow type to an ibis type."""

if pa.types.is_null(typ):
return dt.null
elif pa.types.is_decimal(typ):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif pa.types.is_timestamp(typ):
return dt.Timestamp.from_unit(typ.unit, timezone=typ.tz, nullable=nullable)
elif pa.types.is_time(typ):
return dt.Time(nullable=nullable)
elif pa.types.is_duration(typ):
return dt.Interval(typ.unit, nullable=nullable)
elif pa.types.is_interval(typ):
raise ValueError("Arrow interval type is not supported")
elif (
pa.types.is_list(typ)
or pa.types.is_large_list(typ)
or pa.types.is_fixed_size_list(typ)
):
value_dtype = dtype_from_pyarrow(typ.value_type, typ.value_field.nullable)
return dt.Array(value_dtype, nullable=nullable)
elif pa.types.is_struct(typ):
field_dtypes = {
field.name: dtype_from_pyarrow(field.type, field.nullable) for field in typ
}
return dt.Struct(field_dtypes, nullable=nullable)
elif pa.types.is_map(typ):
# TODO(kszucs): keys_sorted has just been exposed in pyarrow
key_dtype = dtype_from_pyarrow(typ.key_type, typ.key_field.nullable)
value_dtype = dtype_from_pyarrow(typ.item_type, typ.item_field.nullable)
return dt.Map(key_dtype, value_dtype, nullable=nullable)
else:
return _from_pyarrow_types[typ](nullable=nullable)


def dtype_to_pyarrow(dtype: dt.DataType) -> pa.DataType:
if dtype.is_decimal():
# set default precision and scale to something; unclear how to choose this
precision = 38 if dtype.precision is None else dtype.precision
scale = 9 if dtype.scale is None else dtype.scale

if precision > 76:
raise ValueError(
f"Unsupported precision {dtype.precision} for decimal type"
class PyArrowType(TypeMapper):
@classmethod
def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType:
"""Convert a pyarrow type to an ibis type."""

if pa.types.is_null(typ):
return dt.null
elif pa.types.is_decimal(typ):
return dt.Decimal(typ.precision, typ.scale, nullable=nullable)
elif pa.types.is_timestamp(typ):
return dt.Timestamp.from_unit(typ.unit, timezone=typ.tz, nullable=nullable)
elif pa.types.is_time(typ):
return dt.Time(nullable=nullable)
elif pa.types.is_duration(typ):
return dt.Interval(typ.unit, nullable=nullable)
elif pa.types.is_interval(typ):
raise ValueError("Arrow interval type is not supported")
elif (
pa.types.is_list(typ)
or pa.types.is_large_list(typ)
or pa.types.is_fixed_size_list(typ)
):
value_dtype = cls.to_ibis(typ.value_type, typ.value_field.nullable)
return dt.Array(value_dtype, nullable=nullable)
elif pa.types.is_struct(typ):
field_dtypes = {
field.name: cls.to_ibis(field.type, field.nullable) for field in typ
}
return dt.Struct(field_dtypes, nullable=nullable)
elif pa.types.is_map(typ):
# TODO(kszucs): keys_sorted has just been exposed in pyarrow
key_dtype = cls.to_ibis(typ.key_type, typ.key_field.nullable)
value_dtype = cls.to_ibis(typ.item_type, typ.item_field.nullable)
return dt.Map(key_dtype, value_dtype, nullable=nullable)
else:
return _from_pyarrow_types[typ](nullable=nullable)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> pa.DataType:
"""Convert an ibis type to a pyarrow type."""

if dtype.is_decimal():
# set default precision and scale to something; unclear how to choose this
precision = 38 if dtype.precision is None else dtype.precision
scale = 9 if dtype.scale is None else dtype.scale

if precision > 76:
raise ValueError(
f"Unsupported precision {dtype.precision} for decimal type"
)
elif precision > 38:
return pa.decimal256(precision, scale)
else:
return pa.decimal128(precision, scale)
elif dtype.is_timestamp():
return pa.timestamp(
dtype.unit.short if dtype.scale is not None else "us", tz=dtype.timezone
)
elif dtype.is_interval():
return pa.duration(dtype.unit.short)
elif dtype.is_time():
return pa.time64("ns")
elif dtype.is_date():
return pa.date64()
elif dtype.is_array():
value_field = pa.field(
'item',
cls.from_ibis(dtype.value_type),
nullable=dtype.value_type.nullable,
)
return pa.list_(value_field)
elif dtype.is_struct():
fields = [
pa.field(name, cls.from_ibis(dtype), nullable=dtype.nullable)
for name, dtype in dtype.items()
]
return pa.struct(fields)
elif dtype.is_map():
key_field = pa.field(
'key', cls.from_ibis(dtype.key_type), nullable=dtype.key_type.nullable
)
value_field = pa.field(
'value',
cls.from_ibis(dtype.value_type),
nullable=dtype.value_type.nullable,
)
elif precision > 38:
return pa.decimal256(precision, scale)
return pa.map_(key_field, value_field, keys_sorted=False)
else:
return pa.decimal128(precision, scale)
elif dtype.is_timestamp():
return pa.timestamp(
dtype.unit.short if dtype.scale is not None else "us", tz=dtype.timezone
)
elif dtype.is_interval():
return pa.duration(dtype.unit.short)
elif dtype.is_time():
return pa.time64("ns")
elif dtype.is_date():
return pa.date64()
elif dtype.is_array():
value_field = pa.field(
'item',
dtype_to_pyarrow(dtype.value_type),
nullable=dtype.value_type.nullable,
)
return pa.list_(value_field)
elif dtype.is_struct():
try:
return _to_pyarrow_types[type(dtype)]
except KeyError:
raise NotImplementedError(
f"Converting {dtype} to pyarrow is not supported yet"
)


class PyArrowSchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema: Schema) -> pa.Schema:
"""Convert a schema to a pyarrow schema."""
fields = [
pa.field(name, dtype_to_pyarrow(dtype), nullable=dtype.nullable)
for name, dtype in dtype.items()
pa.field(name, PyArrowType.from_ibis(dtype), nullable=dtype.nullable)
for name, dtype in schema.items()
]
return pa.struct(fields)
elif dtype.is_map():
key_field = pa.field(
'key', dtype_to_pyarrow(dtype.key_type), nullable=dtype.key_type.nullable
)
value_field = pa.field(
'value',
dtype_to_pyarrow(dtype.value_type),
nullable=dtype.value_type.nullable,
)
return pa.map_(key_field, value_field, keys_sorted=False)
else:
try:
return _to_pyarrow_types[type(dtype)]
except KeyError:
raise NotImplementedError(
f"Converting {dtype} to pyarrow is not supported yet"
)

return pa.schema(fields)

def schema_from_pyarrow(schema: pa.Schema) -> Schema:
fields = [(f.name, dtype_from_pyarrow(f.type, f.nullable)) for f in schema]
return Schema.from_tuples(fields)
@classmethod
def to_ibis(cls, schema: pa.Schema) -> Schema:
"""Convert a pyarrow schema to a schema."""
fields = [(f.name, PyArrowType.to_ibis(f.type, f.nullable)) for f in schema]
return Schema.from_tuples(fields)


def schema_to_pyarrow(schema: Schema) -> pa.Schema:
fields = [
pa.field(name, dtype_to_pyarrow(dtype), nullable=dtype.nullable)
for name, dtype in schema.items()
]
return pa.schema(fields)
class PyArrowData(DataMapper):
@classmethod
def infer_scalar(cls, scalar: Any) -> dt.DataType:
"""Infer the ibis type of a scalar."""
return PyArrowType.to_ibis(pa.scalar(scalar).type)

@classmethod
def infer_column(cls, column: Sequence) -> dt.DataType:
"""Infer the ibis type of a sequence."""
if isinstance(column, pa.Array):
return PyArrowType.to_ibis(column.type)

def infer_sequence_dtype(sequence: Sequence) -> dt.DataType:
try:
pyarrow_type = pa.array(sequence, from_pandas=True).type
# pyarrow_type = pa.infer_type(sequence, from_pandas=True)
except pa.ArrowInvalid:
try:
# handle embedded series objects
return dt.highest_precedence(map(dt.infer, sequence))
except TypeError:
# we can still have a type error, e.g., float64 and string in the
# same array
pyarrow_type = pa.array(column, from_pandas=True).type
# pyarrow_type = pa.infer_type(column, from_pandas=True)
except pa.ArrowInvalid:
try:
# handle embedded series objects
return dt.highest_precedence(map(dt.infer, column))
except TypeError:
# we can still have a type error, e.g., float64 and string in the
# same array
return dt.unknown
except pa.ArrowTypeError:
# arrow can't infer the type
return dt.unknown
except pa.ArrowTypeError:
# arrow can't infer the type
return dt.unknown
else:
# arrow inferred the type, now convert that type to an ibis type
return dtype_from_pyarrow(pyarrow_type)
else:
# arrow inferred the type, now convert that type to an ibis type
return PyArrowType.to_ibis(pyarrow_type)

@classmethod
def infer_table(cls, table) -> Schema:
"""Infer the schema of a table."""
if not isinstance(table, pa.Table):
table = pa.table(table)

return PyArrowSchema.to_ibis(table.schema)

@classmethod
def convert_scalar(cls, scalar: pa.Scalar, dtype: dt.DataType) -> pa.Scalar:
desired_type = PyArrowType.from_ibis(dtype)
if scalar.type != desired_type:
return scalar.cast(desired_type)
else:
return scalar

@classmethod
def convert_column(cls, column: pa.Array, dtype: dt.DataType) -> pa.Array:
desired_type = PyArrowType.from_ibis(dtype)
if column.type != desired_type:
return column.cast(desired_type)
else:
return column

@classmethod
def convert_table(cls, table: pa.Table, schema: Schema) -> pa.Table:
desired_schema = PyArrowSchema.from_ibis(schema)
if table.schema != desired_schema:
return table.cast(desired_schema)
else:
return table
Loading