Skip to content

Commit

Permalink
fix(flink): implement TypeMapper and SchemaMapper for Flink backend
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenzhongxu authored and gforsyth committed Dec 6, 2023
1 parent 1413de9 commit f983bfa
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 42 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/flink/__init__.py
Expand Up @@ -12,14 +12,14 @@
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified
from ibis.backends.flink.compiler.core import FlinkCompiler
from ibis.backends.flink.datatypes import FlinkRowSchema
from ibis.backends.flink.ddl import (
CreateDatabase,
CreateTableFromConnector,
DropDatabase,
DropTable,
InsertSelect,
)
from ibis.backends.flink.utils import ibis_schema_to_flink_schema

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -354,7 +354,7 @@ def create_table(
obj = obj.to_pandas()
if isinstance(obj, pd.DataFrame):
table = self._table_env.from_pandas(
obj, ibis_schema_to_flink_schema(schema)
obj, FlinkRowSchema.from_ibis(schema)
)
if isinstance(obj, ir.Table):
table = obj
Expand Down
93 changes: 93 additions & 0 deletions ibis/backends/flink/datatypes.py
@@ -0,0 +1,93 @@
from __future__ import annotations

import pyflink.table.types as fl

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper


class FlinkRowSchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema: sch.Schema | None) -> list[fl.RowType]:
if schema is None:
return None

return fl.DataTypes.ROW(
[
fl.DataTypes.FIELD(k, FlinkType.from_ibis(v))
for k, v in schema.fields.items()
]
)


class FlinkType(TypeMapper):
@classmethod
def to_ibis(cls, typ: fl.DataType, nullable=True) -> dt.DataType:
"""Convert a flink type to an ibis type."""
if typ == fl.DataTypes.STRING():
return dt.String(nullable=nullable)
elif typ == fl.DataTypes.BOOLEAN():
return dt.Boolean(nullable=nullable)
elif typ == fl.DataTypes.BYTES():
return dt.Binary(nullable=nullable)
elif typ == fl.DataTypes.TINYINT():
return dt.Int8(nullable=nullable)
elif typ == fl.DataTypes.SMALLINT():
return dt.Int16(nullable=nullable)
elif typ == fl.DataTypes.INT():
return dt.Int32(nullable=nullable)
elif typ == fl.DataTypes.BIGINT():
return dt.Int64(nullable=nullable)
elif typ == fl.DataTypes.FLOAT():
return dt.Float32(nullable=nullable)
elif typ == fl.DataTypes.DOUBLE():
return dt.Float64(nullable=nullable)
elif typ == fl.DataTypes.DATE():
return dt.Date(nullable=nullable)
elif typ == fl.DataTypes.TIME():
return dt.Time(nullable=nullable)
elif typ == fl.DataTypes.TIMESTAMP():
return dt.Timestamp(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> fl.DataType:
"""Convert an ibis type to a flink type."""
if dtype.is_string():
return fl.DataTypes.STRING()
elif dtype.is_boolean():
return fl.DataTypes.BOOLEAN()
elif dtype.is_binary():
return fl.DataTypes.BYTES()
elif dtype.is_int8():
return fl.DataTypes.TINYINT()
elif dtype.is_int16():
return fl.DataTypes.SMALLINT()
elif dtype.is_int32():
return fl.DataTypes.INT()
elif dtype.is_int64():
return fl.DataTypes.BIGINT()
elif dtype.is_uint8():
return fl.DataTypes.TINYINT()
elif dtype.is_uint16():
return fl.DataTypes.SMALLINT()
elif dtype.is_uint32():
return fl.DataTypes.INT()
elif dtype.is_uint64():
return fl.DataTypes.BIGINT()
elif dtype.is_float16():
return fl.DataTypes.FLOAT()
elif dtype.is_float32():
return fl.DataTypes.FLOAT()
elif dtype.is_float64():
return fl.DataTypes.DOUBLE()
elif dtype.is_date():
return fl.DataTypes.DATE()
elif dtype.is_time():
return fl.DataTypes.TIME()
elif dtype.is_timestamp():
return fl.DataTypes.TIMESTAMP()
else:
return super().from_ibis(dtype)
5 changes: 2 additions & 3 deletions ibis/backends/flink/registry.py
Expand Up @@ -9,6 +9,7 @@
operation_registry as base_operation_registry,
)
from ibis.backends.base.sql.registry.main import varargs
from ibis.backends.flink.datatypes import FlinkType
from ibis.common.temporal import TimestampUnit

if TYPE_CHECKING:
Expand Down Expand Up @@ -221,8 +222,6 @@ def _window(translator: ExprTranslator, op: ops.Node) -> str:


def _clip(translator: ExprTranslator, op: ops.Node) -> str:
from ibis.backends.flink.utils import _to_pyflink_types

arg = translator.translate(op.arg)

if op.upper is not None:
Expand All @@ -233,7 +232,7 @@ def _clip(translator: ExprTranslator, op: ops.Node) -> str:
lower = translator.translate(op.lower)
arg = f"IF({arg} < {lower} AND {arg} IS NOT NULL, {lower}, {arg})"

return f"CAST({arg} AS {_to_pyflink_types[type(op.dtype)]!s})"
return f"CAST({arg} AS {FlinkType.from_ibis(op.dtype)!s})"


def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str:
Expand Down
40 changes: 3 additions & 37 deletions ibis/backends/flink/utils.py
Expand Up @@ -5,11 +5,9 @@
from abc import ABC, abstractmethod
from collections import defaultdict

from pyflink.table.types import DataTypes, RowType

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.flink.datatypes import FlinkType
from ibis.common.temporal import IntervalUnit
from ibis.util import convert_unit

Expand Down Expand Up @@ -247,35 +245,14 @@ def _translate_interval(value, dtype):
return interval.format_as_string()


_to_pyflink_types = {
dt.String: DataTypes.STRING(),
dt.Boolean: DataTypes.BOOLEAN(),
dt.Binary: DataTypes.BYTES(),
dt.Int8: DataTypes.TINYINT(),
dt.Int16: DataTypes.SMALLINT(),
dt.Int32: DataTypes.INT(),
dt.Int64: DataTypes.BIGINT(),
dt.UInt8: DataTypes.TINYINT(),
dt.UInt16: DataTypes.SMALLINT(),
dt.UInt32: DataTypes.INT(),
dt.UInt64: DataTypes.BIGINT(),
dt.Float16: DataTypes.FLOAT(),
dt.Float32: DataTypes.FLOAT(),
dt.Float64: DataTypes.DOUBLE(),
dt.Date: DataTypes.DATE(),
dt.Time: DataTypes.TIME(),
dt.Timestamp: DataTypes.TIMESTAMP(),
}


def translate_literal(op: ops.Literal) -> str:
value = op.value
dtype = op.dtype

if value is None:
if dtype.is_null():
return "NULL"
return f"CAST(NULL AS {_to_pyflink_types[type(dtype)]!s})"
return f"CAST(NULL AS {FlinkType.from_ibis(dtype)!s})"

if dtype.is_boolean():
# TODO(chloeh13q): Flink supports a third boolean called "UNKNOWN"
Expand Down Expand Up @@ -305,7 +282,7 @@ def translate_literal(op: ops.Literal) -> str:
raise ValueError("The precision can be up to 38 in Flink")

return f"CAST({value} AS DECIMAL({precision}, {scale}))"
return f"CAST({value} AS {_to_pyflink_types[type(dtype)]!s})"
return f"CAST({value} AS {FlinkType.from_ibis(dtype)!s})"
elif dtype.is_timestamp():
# TODO(chloeh13q): support timestamp with local timezone
if isinstance(value, datetime.datetime):
Expand All @@ -327,14 +304,3 @@ def translate_literal(op: ops.Literal) -> str:
return f"ARRAY{list(value)}"

raise NotImplementedError(f"No translation rule for {dtype}")


def ibis_schema_to_flink_schema(schema: sch.Schema) -> RowType:
if schema is None:
return None
return DataTypes.ROW(
[
DataTypes.FIELD(key, _to_pyflink_types[type(value)])
for key, value in schema.fields.items()
]
)

0 comments on commit f983bfa

Please sign in to comment.