Skip to content

Commit

Permalink
refactor(datatypes): use sqlglot for parsing backend specific types
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The minimum version of `sqlglot` is now 17.2.0, to support much faster and more robust backend type parsing.
  • Loading branch information
cpcloud committed Aug 13, 2023
1 parent 6be6c2b commit fe7ba24
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 492 deletions.
239 changes: 74 additions & 165 deletions ibis/backends/clickhouse/datatypes.py
@@ -1,192 +1,101 @@
from __future__ import annotations

import functools
from functools import partial
from typing import Literal
from typing import TYPE_CHECKING, Literal, Mapping

import parsy
import sqlglot as sg
from sqlglot.expressions import ColumnDef, DataType

import ibis
import ibis.expr.datatypes as dt
from ibis.common.parsing import (
COMMA,
FIELD,
LPAREN,
NUMBER,
PRECISION,
RAW_NUMBER,
RAW_STRING,
RPAREN,
SCALE,
SPACES,
spaceless_string,
)
from ibis.common.collections import FrozenDict
from ibis.formats.parser import TypeParser

if TYPE_CHECKING:
from sqlglot.expressions import DataTypeSize, Expression


def _bool_type() -> Literal["Bool", "UInt8", "Int8"]:
return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool")


def parse(text: str) -> dt.DataType:
datetime64_args = LPAREN.then(
parsy.seq(
scale=parsy.decimal_digit.map(int).optional(),
timezone=COMMA.then(RAW_STRING).optional(),
)
).skip(RPAREN)

datetime64 = spaceless_string("datetime64").then(
datetime64_args.optional(default={}).combine_dict(
partial(dt.Timestamp, nullable=False)
)
)

datetime = spaceless_string("datetime").then(
parsy.seq(
timezone=LPAREN.then(RAW_STRING).skip(RPAREN).optional()
).combine_dict(partial(dt.Timestamp, nullable=False))
)

primitive = (
datetime64
| datetime
| spaceless_string("null", "nothing").result(dt.null)
| spaceless_string("bigint", "int64").result(dt.Int64(nullable=False))
| spaceless_string("double", "float64").result(dt.Float64(nullable=False))
| spaceless_string("float32", "float").result(dt.Float32(nullable=False))
| spaceless_string("smallint", "int16", "int2").result(dt.Int16(nullable=False))
| spaceless_string("date32", "date").result(dt.Date(nullable=False))
| spaceless_string("time").result(dt.Time(nullable=False))
| spaceless_string("tinyint", "int8", "int1").result(dt.Int8(nullable=False))
| spaceless_string("boolean", "bool").result(dt.Boolean(nullable=False))
| spaceless_string("integer", "int32", "int4", "int").result(
dt.Int32(nullable=False)
)
| spaceless_string("uint64").result(dt.UInt64(nullable=False))
| spaceless_string("uint32").result(dt.UInt32(nullable=False))
| spaceless_string("uint16").result(dt.UInt16(nullable=False))
| spaceless_string("uint8").result(dt.UInt8(nullable=False))
| spaceless_string("uuid").result(dt.UUID(nullable=False))
| spaceless_string(
"longtext",
"mediumtext",
"tinytext",
"text",
"longblob",
"mediumblob",
"tinyblob",
"blob",
"varchar",
"char",
"string",
).result(dt.String(nullable=False))
)

ty = parsy.forward_declaration()

nullable = (
spaceless_string("nullable")
.then(LPAREN)
.then(ty.map(lambda ty: ty.copy(nullable=True)))
.skip(RPAREN)
)

fixed_string = (
spaceless_string("fixedstring")
.then(LPAREN)
.then(NUMBER)
.then(RPAREN)
.result(dt.String(nullable=False))
)

decimal = (
spaceless_string("decimal", "numeric")
.then(LPAREN)
.then(
parsy.seq(precision=PRECISION.skip(COMMA), scale=SCALE).combine_dict(
partial(dt.Decimal(nullable=False))
)
)
.skip(RPAREN)
)
class ClickHouseTypeParser(TypeParser):
__slots__ = ()

array = spaceless_string("array").then(
LPAREN.then(ty.map(partial(dt.Array, nullable=False))).skip(RPAREN)
)
dialect = "clickhouse"
default_decimal_precision = None
default_decimal_scale = None
default_nullable = False

map = (
spaceless_string("map")
.then(LPAREN)
.then(parsy.seq(ty, COMMA.then(ty)).combine(partial(dt.Map, nullable=False)))
.skip(RPAREN)
short_circuit: Mapping[str, dt.DataType] = FrozenDict(
{
"IPv4": dt.INET(nullable=default_nullable),
"IPv6": dt.INET(nullable=default_nullable),
"Object('json')": dt.JSON(nullable=default_nullable),
"Array(Null)": dt.Array(dt.null, nullable=default_nullable),
"Array(Nothing)": dt.Array(dt.null, nullable=default_nullable),
}
)

at_least_one_space = parsy.regex(r"\s+")

nested = (
spaceless_string("nested")
.then(LPAREN)
.then(
parsy.seq(SPACES.then(FIELD.skip(at_least_one_space)), ty)
.combine(lambda field, ty: (field, dt.Array(ty, nullable=False)))
.sep_by(COMMA)
.map(partial(dt.Struct.from_tuples, nullable=False))
)
.skip(RPAREN)
)

struct = (
spaceless_string("tuple")
.then(LPAREN)
.then(
parsy.seq(
SPACES.then(FIELD.skip(at_least_one_space).optional()),
ty,
@classmethod
def _get_DATETIME(
cls, first: DataTypeSize | None = None, second: DataTypeSize | None = None
) -> dt.Timestamp:
if first is not None and second is not None:
scale = first
timezone = second
elif first is not None and second is None:
timezone, scale = (
(first, second) if first.this.is_string else (second, first)
)
.sep_by(COMMA)
.map(
lambda field_names_types: dt.Struct.from_tuples(
[
(field_name if field_name is not None else f"f{i:d}", typ)
for i, (field_name, typ) in enumerate(field_names_types)
],
nullable=False,
else:
scale = first
timezone = second
return cls._get_TIMESTAMP(scale=scale, timezone=timezone)

@classmethod
def _get_DATETIME64(
cls, scale: DataTypeSize | None = None, timezone: DataTypeSize | None = None
) -> dt.Timestamp:
return cls._get_TIMESTAMP(scale=scale, timezone=timezone)

@classmethod
def _get_NULLABLE(cls, inner_type: DataType) -> dt.DataType:
return cls._get_type(inner_type).copy(nullable=True)

@classmethod
def _get_LOWCARDINALITY(cls, inner_type: DataType) -> dt.DataType:
return cls._get_type(inner_type)

@classmethod
def _get_NESTED(cls, *fields: DataType) -> dt.Struct:
return dt.Struct(
{
field.name: dt.Array(
cls._get_type(field.args["kind"]), nullable=cls.default_nullable
)
)
for field in fields
},
nullable=cls.default_nullable,
)
.skip(RPAREN)
)

enum_value = SPACES.then(RAW_STRING).skip(spaceless_string("=")).then(RAW_NUMBER)
@classmethod
def _get_STRUCT(cls, *fields: Expression) -> dt.Struct:
types = {}

lowcardinality = (
spaceless_string("lowcardinality").then(LPAREN).then(ty).skip(RPAREN)
)
for i, field in enumerate(fields):
if isinstance(field, ColumnDef):
inner_type = field.args["kind"]
name = field.name
else:
inner_type = sg.parse_one(str(field), into=DataType, read="clickhouse")
name = f"f{i:d}"

enum = (
spaceless_string("enum")
.then(RAW_NUMBER)
.then(LPAREN)
.then(enum_value.sep_by(COMMA))
.skip(RPAREN)
.result(dt.String(nullable=False))
)
types[name] = cls._get_type(inner_type)
return dt.Struct(types, nullable=cls.default_nullable)

ty.become(
nullable
| nested
| primitive
| fixed_string
| decimal
| array
| map
| struct
| enum
| lowcardinality
| spaceless_string("IPv4", "IPv6").result(dt.INET(nullable=False))
| spaceless_string("Object('json')", "JSON").result(dt.JSON(nullable=False))
)
return ty.parse(text)

parse = ClickHouseTypeParser.parse


@functools.singledispatch
Expand Down
30 changes: 10 additions & 20 deletions ibis/backends/druid/datatypes.py
@@ -1,18 +1,16 @@
from __future__ import annotations

import parsy
from typing import Mapping

import sqlalchemy as sa
import sqlalchemy.types as sat
from dateutil.parser import parse as timestamp_parse
from sqlalchemy.ext.compiler import compiles

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
LANGLE,
RANGLE,
spaceless_string,
)
from ibis.common.collections import FrozenDict
from ibis.formats.parser import TypeParser


class DruidDateTime(sat.TypeDecorator):
Expand Down Expand Up @@ -59,23 +57,15 @@ def _smallint(element, compiler, **kw):
return "SMALLINT"


def parse(text: str) -> dt.DataType:
"""Parse a Druid type into an ibis data type."""
primitive = (
spaceless_string("string").result(dt.string)
| spaceless_string("double").result(dt.float64)
| spaceless_string("float").result(dt.float32)
| spaceless_string("long").result(dt.int64)
| spaceless_string("json").result(dt.json)
)
class DruidTypeParser(TypeParser):
__slots__ = ()

ty = parsy.forward_declaration()
# druid doesn't have a sophisticated type system and hive is close enough
dialect = "hive"
short_circuit: Mapping[str, dt.DataType] = FrozenDict({"complex<json>": dt.json})

json = spaceless_string("complex").then(LANGLE).then(ty).skip(RANGLE)
array = spaceless_string("array").then(LANGLE).then(ty.map(dt.Array)).skip(RANGLE)

ty.become(primitive | array | json)
return ty.parse(text)
parse = DruidTypeParser.parse


class DruidType(AlchemyType):
Expand Down

0 comments on commit fe7ba24

Please sign in to comment.