| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,21 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| base_identifiers = [ | ||
| "add", | ||
| "aggregate", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| SELECT | ||
| `bqutil`.`fn`.from_hex('face') AS `from_hex_'face'` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| SELECT | ||
| farm_fingerprint(b'Hello, World!') AS `farm_fingerprint_b'Hello_ World_'` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import ibis | ||
|
|
||
| to_sql = ibis.bigquery.compile | ||
|
|
||
|
|
||
| @ibis.udf.scalar.builtin | ||
| def farm_fingerprint(value: bytes) -> int: | ||
| ... | ||
|
|
||
|
|
||
| @ibis.udf.scalar.builtin(schema="bqutil.fn") | ||
| def from_hex(value: str) -> int: | ||
| """Community function to convert from hex string to integer. | ||
| See: | ||
| https://github.com/GoogleCloudPlatform/bigquery-utils/tree/master/udfs/community#from_hexvalue-string | ||
| """ | ||
|
|
||
|
|
||
| def test_bqutil_fn_from_hex(snapshot): | ||
| # Project ID should be enclosed in backticks. | ||
| expr = from_hex("face") | ||
| snapshot.assert_match(to_sql(expr), "out.sql") | ||
|
|
||
|
|
||
| def test_farm_fingerprint(snapshot): | ||
| # No backticks needed if there is no schema defined. | ||
| expr = farm_fingerprint(b"Hello, World!") | ||
| snapshot.assert_match(to_sql(expr), "out.sql") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| now() AS "TimestampNow()" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toDate(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toStartOfHour(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toStartOfMinute(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toStartOfMinute(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toMonday(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| toStartOfYear(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime64(2015, 1, 1, 12, 34, 56, 789321, 6) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime64(2015, 1, 1, 12, 34, 56, 789321, 6, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321, tzinfo=tzutc())" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime64(2015, 1, 1, 12, 34, 56, 789, 3) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime64(2015, 1, 1, 12, 34, 56, 789, 3, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000, tzinfo=tzutc())" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| FALSE AS False |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 1.5 AS "1.5" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 5 AS "5" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 'I can''t' AS """I can't""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 'An "escape"' AS "'An ""escape""'" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 'simple' AS "'simple'" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| TRUE AS True |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| 1 + 2 AS "Add(1, 2)" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| SELECT | ||
| now() AS "TimestampNow()" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| SELECT | ||
| ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS REAL(53))) AS tmp | ||
| FROM t AS t0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| SELECT | ||
| ST_ASTEXT(t0.geom) AS tmp | ||
| FROM t AS t0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| SELECT | ||
| ST_NPOINTS(t0.geom) AS tmp | ||
| FROM t AS t0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import numpy.testing as npt | ||
| import pandas.testing as tm | ||
| import pyarrow as pa | ||
| import pytest | ||
| from pytest import param | ||
|
|
||
| import ibis | ||
|
|
||
| gpd = pytest.importorskip("geopandas") | ||
| gtm = pytest.importorskip("geopandas.testing") | ||
| shapely = pytest.importorskip("shapely") | ||
|
|
||
|
|
||
| def test_geospatial_point(zones, zones_gdf): | ||
| coord = zones.x_cent.point(zones.y_cent).name("coord") | ||
| # this returns GeometryArray | ||
| gp_coord = gpd.points_from_xy(zones_gdf.x_cent, zones_gdf.y_cent) | ||
|
|
||
| npt.assert_array_equal(coord.to_pandas().values, gp_coord) | ||
|
|
||
|
|
||
| # this functions are not implemented in geopandas | ||
| @pytest.mark.parametrize( | ||
| ("operation", "keywords"), | ||
| [ | ||
| param("as_text", {}, id="as_text"), | ||
| param("n_points", {}, id="n_points"), | ||
| ], | ||
| ) | ||
| def test_geospatial_unary_snapshot(operation, keywords, snapshot): | ||
| t = ibis.table([("geom", "geometry")], name="t") | ||
| expr = getattr(t.geom, operation)(**keywords).name("tmp") | ||
| snapshot.assert_match(ibis.to_sql(expr), "out.sql") | ||
|
|
||
|
|
||
| def test_geospatial_dwithin(snapshot): | ||
| t = ibis.table([("geom", "geometry")], name="t") | ||
| expr = t.geom.d_within(t.geom, 3.0).name("tmp") | ||
|
|
||
| snapshot.assert_match(ibis.to_sql(expr), "out.sql") | ||
|
|
||
|
|
||
| # geospatial unary functions that return a non-geometry series | ||
| # we can test using pd.testing (tm) | ||
| @pytest.mark.parametrize( | ||
| ("op", "keywords", "gp_op"), | ||
| [ | ||
| param("area", {}, "area", id="area"), | ||
| param("is_valid", {}, "is_valid", id="is_valid"), | ||
| param( | ||
| "geometry_type", | ||
| {}, | ||
| "geom_type", | ||
| id="geometry_type", | ||
| marks=pytest.mark.xfail(raises=pa.lib.ArrowTypeError), | ||
| ), | ||
| ], | ||
| ) | ||
| def test_geospatial_unary_tm(op, keywords, gp_op, zones, zones_gdf): | ||
| expr = getattr(zones.geom, op)(**keywords).name("tmp") | ||
| gp_expr = getattr(zones_gdf.geometry, gp_op) | ||
|
|
||
| tm.assert_series_equal(expr.to_pandas(), gp_expr, check_names=False) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("op", "keywords", "gp_op"), | ||
| [ | ||
| param("x", {}, "x", id="x_coord"), | ||
| param("y", {}, "y", id="y_coord"), | ||
| ], | ||
| ) | ||
| def test_geospatial_xy(op, keywords, gp_op, zones, zones_gdf): | ||
| cen = zones.geom.centroid().name("centroid") | ||
| gp_cen = zones_gdf.geometry.centroid | ||
|
|
||
| expr = getattr(cen, op)(**keywords).name("tmp") | ||
| gp_expr = getattr(gp_cen, gp_op) | ||
|
|
||
| tm.assert_series_equal(expr.to_pandas(), gp_expr, check_names=False) | ||
|
|
||
|
|
||
| def test_geospatial_length(lines, lines_gdf): | ||
| # note: ST_LENGTH returns 0 for the case of polygon | ||
| # or multi polygon while pandas geopandas returns the perimeter. | ||
| length = lines.geom.length().name("length") | ||
| gp_length = lines_gdf.geometry.length | ||
|
|
||
| tm.assert_series_equal(length.to_pandas(), gp_length, check_names=False) | ||
|
|
||
|
|
||
| # geospatial binary functions that return a non-geometry series | ||
| # we can test using pd.testing (tm) | ||
| @pytest.mark.parametrize( | ||
| ("op", "gp_op"), | ||
| [ | ||
| param("contains", "contains", id="contains"), | ||
| param("geo_equals", "geom_equals", id="geo_eqs"), | ||
| param("covers", "covers", id="covers"), | ||
| param("covered_by", "covered_by", id="covered_by"), | ||
| param("crosses", "crosses", id="crosses"), | ||
| param("disjoint", "disjoint", id="disjoint"), | ||
| param("distance", "distance", id="distance"), | ||
| param("intersects", "intersects", id="intersects"), | ||
| param("overlaps", "overlaps", id="overlaps"), | ||
| param("touches", "touches", id="touches"), | ||
| param("within", "within", id="within"), | ||
| ], | ||
| ) | ||
| def test_geospatial_binary_tm(op, gp_op, zones, zones_gdf): | ||
| expr = getattr(zones.geom, op)(zones.geom).name("tmp") | ||
| gp_func = getattr(zones_gdf.geometry, gp_op)(zones_gdf.geometry) | ||
|
|
||
| tm.assert_series_equal(expr.to_pandas(), gp_func, check_names=False) | ||
|
|
||
|
|
||
| # geospatial unary functions that return a geometry series | ||
| # we can test using gpd.testing (gtm) | ||
| @pytest.mark.parametrize( | ||
| ("op", "gp_op"), | ||
| [ | ||
| param("centroid", "centroid", id="centroid"), | ||
| param("envelope", "envelope", id="envelope"), | ||
| ], | ||
| ) | ||
| def test_geospatial_unary_gtm(op, gp_op, zones, zones_gdf): | ||
| expr = getattr(zones.geom, op)().name("tmp") | ||
| gp_expr = getattr(zones_gdf.geometry, gp_op) | ||
|
|
||
| gtm.assert_geoseries_equal(expr.to_pandas(), gp_expr, check_crs=False) | ||
|
|
||
|
|
||
| # geospatial binary functions that return a geometry series | ||
| # we can test using gpd.testing (gtm) | ||
| @pytest.mark.parametrize( | ||
| ("op", "gp_op"), | ||
| [ | ||
| param("difference", "difference", id="difference"), | ||
| param("intersection", "intersection", id="intersection"), | ||
| param("union", "union", id=""), | ||
| ], | ||
| ) | ||
| def test_geospatial_binary_gtm(op, gp_op, zones, zones_gdf): | ||
| expr = getattr(zones.geom, op)(zones.geom).name("tmp") | ||
| gp_func = getattr(zones_gdf.geometry, gp_op)(zones_gdf.geometry) | ||
|
|
||
| gtm.assert_geoseries_equal(expr.to_pandas(), gp_func, check_crs=False) | ||
|
|
||
|
|
||
| def test_geospatial_end_point(lines, lines_gdf): | ||
| epoint = lines.geom.end_point().name("end_point") | ||
| # geopandas does not have end_point this is a work around to get it | ||
| gp_epoint = lines_gdf.geometry.boundary.explode(index_parts=True).xs(1, level=1) | ||
|
|
||
| gtm.assert_geoseries_equal(epoint.to_pandas(), gp_epoint, check_crs=False) | ||
|
|
||
|
|
||
| def test_geospatial_start_point(lines, lines_gdf): | ||
| spoint = lines.geom.start_point().name("start_point") | ||
| # geopandas does not have start_point this is a work around to get it | ||
| gp_spoint = lines_gdf.geometry.boundary.explode(index_parts=True).xs(0, level=1) | ||
|
|
||
| gtm.assert_geoseries_equal(spoint.to_pandas(), gp_spoint, check_crs=False) | ||
|
|
||
|
|
||
| # this one takes a bit longer than the rest. | ||
| def test_geospatial_unary_union(zones, zones_gdf): | ||
| unary_union = zones.geom.unary_union().name("unary_union") | ||
| # this returns a shapely geometry object | ||
| gp_unary_union = zones_gdf.geometry.unary_union | ||
|
|
||
| # using set_precision because https://github.com/duckdb/duckdb_spatial/issues/189 | ||
| assert shapely.equals( | ||
| shapely.set_precision(unary_union.to_pandas(), grid_size=1e-7), | ||
| shapely.set_precision(gp_unary_union, grid_size=1e-7), | ||
| ) | ||
|
|
||
|
|
||
| def test_geospatial_buffer_point(zones, zones_gdf): | ||
| cen = zones.geom.centroid().name("centroid") | ||
| gp_cen = zones_gdf.geometry.centroid | ||
|
|
||
| buffer = cen.buffer(100.0) | ||
| # geopandas resolution default is 16, while duckdb is 8. | ||
| gp_buffer = gp_cen.buffer(100.0, resolution=8) | ||
|
|
||
| gtm.assert_geoseries_equal(buffer.to_pandas(), gp_buffer, check_crs=False) | ||
|
|
||
|
|
||
| def test_geospatial_buffer(zones, zones_gdf): | ||
| buffer = zones.geom.buffer(100.0) | ||
| # geopandas resolution default is 16, while duckdb is 8. | ||
| gp_buffer = zones_gdf.geometry.buffer(100.0, resolution=8) | ||
|
|
||
| gtm.assert_geoseries_equal(buffer.to_pandas(), gp_buffer, check_crs=False) | ||
|
|
||
|
|
||
| # using a smaller dataset for time purposes | ||
| def test_geospatial_convert(geotable, gdf): | ||
| # geotable is fabricated but let's say the | ||
| # data is in CRS: EPSG:2263 | ||
| # let's transform to EPSG:4326 (latitude-longitude projection) | ||
| geo_ll = geotable.geom.convert("EPSG:2263", "EPSG:4326") | ||
|
|
||
| gdf.crs = "EPSG:2263" | ||
| gdf_ll = gdf.geometry.to_crs(crs=4326) | ||
|
|
||
| gtm.assert_geoseries_equal( | ||
| geo_ll.to_pandas(), gdf_ll, check_less_precise=True, check_crs=False | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,234 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import re | ||
| import warnings | ||
| from collections import ChainMap | ||
| from contextlib import contextmanager | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import sqlalchemy as sa | ||
| import sqlglot as sg | ||
|
|
||
| from ibis import util | ||
| from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend | ||
| from ibis.backends.base.sqlglot.datatypes import PostgresType | ||
| from ibis.backends.exasol.compiler import ExasolCompiler | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterable, MutableMapping | ||
|
|
||
| from ibis.backends.base import BaseBackend | ||
| from ibis.expr import datatypes as dt | ||
|
|
||
|
|
||
| class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema): | ||
| name = "exasol" | ||
| compiler = ExasolCompiler | ||
| supports_temporary_tables = False | ||
| supports_create_or_replace = False | ||
| supports_in_memory_tables = False | ||
| supports_python_udfs = False | ||
|
|
||
| def do_connect( | ||
| self, | ||
| user: str, | ||
| password: str, | ||
| host: str = "localhost", | ||
| port: int = 8563, | ||
| schema: str | None = None, | ||
| encryption: bool = True, | ||
| certificate_validation: bool = True, | ||
| encoding: str = "en_US.UTF-8", | ||
| ) -> None: | ||
| """Create an Ibis client connected to an Exasol database. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| user | ||
| Username used for authentication. | ||
| password | ||
| Password used for authentication. | ||
| host | ||
| Hostname to connect to (default: "localhost"). | ||
| port | ||
| Port number to connect to (default: 8563) | ||
| schema | ||
| Database schema to open, if `None`, no schema will be opened. | ||
| encryption | ||
| Enables/disables transport layer encryption (default: True). | ||
| certificate_validation | ||
| Enables/disables certificate validation (default: True). | ||
| encoding | ||
| The encoding format (default: "en_US.UTF-8"). | ||
| """ | ||
| options = [ | ||
| "SSLCertificate=SSL_VERIFY_NONE" if not certificate_validation else "", | ||
| f"ENCRYPTION={'yes' if encryption else 'no'}", | ||
| f"CONNECTIONCALL={encoding}", | ||
| ] | ||
| url_template = ( | ||
| "exa+websocket://{user}:{password}@{host}:{port}/{schema}?{options}" | ||
| ) | ||
| url = sa.engine.url.make_url( | ||
| url_template.format( | ||
| user=user, | ||
| password=password, | ||
| host=host, | ||
| port=port, | ||
| schema=schema, | ||
| options="&".join(options), | ||
| ) | ||
| ) | ||
| engine = sa.create_engine(url, poolclass=sa.pool.StaticPool) | ||
| super().do_connect(engine) | ||
|
|
||
| def _convert_kwargs(self, kwargs: MutableMapping) -> None: | ||
| def convert_sqla_to_ibis(keyword_arguments): | ||
| sqla_to_ibis = {"tls": "encryption", "username": "user"} | ||
| for sqla_kwarg, ibis_kwarg in sqla_to_ibis.items(): | ||
| if sqla_kwarg in keyword_arguments: | ||
| keyword_arguments[ibis_kwarg] = keyword_arguments.pop(sqla_kwarg) | ||
|
|
||
| def filter_kwargs(keyword_arguments): | ||
| allowed_parameters = [ | ||
| "user", | ||
| "password", | ||
| "host", | ||
| "port", | ||
| "schema", | ||
| "encryption", | ||
| "certificate", | ||
| "encoding", | ||
| ] | ||
| to_be_removed = [ | ||
| key for key in keyword_arguments if key not in allowed_parameters | ||
| ] | ||
| for parameter_name in to_be_removed: | ||
| del keyword_arguments[parameter_name] | ||
|
|
||
| convert_sqla_to_ibis(kwargs) | ||
| filter_kwargs(kwargs) | ||
|
|
||
| def _from_url(self, url: str, **kwargs) -> BaseBackend: | ||
| """Construct an ibis backend from a SQLAlchemy-conforming URL.""" | ||
| kwargs = ChainMap(kwargs) | ||
| _, new_kwargs = self.inspector.dialect.create_connect_args(url) | ||
| kwargs = kwargs.new_child(new_kwargs) | ||
| kwargs = dict(kwargs) | ||
| self._convert_kwargs(kwargs) | ||
|
|
||
| return self.connect(**kwargs) | ||
|
|
||
| @property | ||
| def inspector(self): | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning) | ||
| return super().inspector | ||
|
|
||
| @contextmanager | ||
| def begin(self): | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning) | ||
| with super().begin() as con: | ||
| yield con | ||
|
|
||
| def list_tables(self, like=None, database=None): | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning) | ||
| return super().list_tables(like=like, database=database) | ||
|
|
||
| def _get_sqla_table( | ||
| self, | ||
| name: str, | ||
| autoload: bool = True, | ||
| **kwargs: Any, | ||
| ) -> sa.Table: | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning) | ||
| return super()._get_sqla_table(name=name, autoload=autoload, **kwargs) | ||
|
|
||
| def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: | ||
| table = sg.table(util.gen_name("exasol_metadata")) | ||
| create_view = sg.exp.Create( | ||
| kind="VIEW", this=table, expression=sg.parse_one(query, dialect="postgres") | ||
| ) | ||
| drop_view = sg.exp.Drop(kind="VIEW", this=table) | ||
| describe = sg.exp.Describe(this=table).sql(dialect="postgres") | ||
| # strip trailing encodings e.g., UTF8 | ||
| varchar_regex = re.compile(r"^(VARCHAR(?:\(\d+\)))?(?:\s+.+)?$") | ||
| with self.begin() as con: | ||
| con.exec_driver_sql(create_view.sql(dialect="postgres")) | ||
| try: | ||
| yield from ( | ||
| ( | ||
| name, | ||
| PostgresType.from_string(varchar_regex.sub(r"\1", typ)), | ||
| ) | ||
| for name, typ, *_ in con.exec_driver_sql(describe) | ||
| ) | ||
| finally: | ||
| con.exec_driver_sql(drop_view.sql(dialect="postgres")) | ||
|
|
||
| @property | ||
| def current_schema(self) -> str: | ||
| return self._scalar_query(sa.select(sa.text("CURRENT_SCHEMA"))) | ||
|
|
||
| @property | ||
| def current_database(self) -> str: | ||
| return None | ||
|
|
||
| def drop_schema( | ||
| self, name: str, database: str | None = None, force: bool = False | ||
| ) -> None: | ||
| if database is not None: | ||
| raise NotImplementedError( | ||
| "`database` argument is not supported for the Exasol backend" | ||
| ) | ||
| drop_schema = sg.exp.Drop( | ||
| kind="SCHEMA", this=sg.to_identifier(name), exists=force | ||
| ) | ||
| with self.begin() as con: | ||
| con.exec_driver_sql(drop_schema.sql(dialect="postgres")) | ||
|
|
||
| def create_schema( | ||
| self, name: str, database: str | None = None, force: bool = False | ||
| ) -> None: | ||
| if database is not None: | ||
| raise NotImplementedError( | ||
| "`database` argument is not supported for the Exasol backend" | ||
| ) | ||
| create_schema = sg.exp.Create( | ||
| kind="SCHEMA", this=sg.to_identifier(name), exists=force | ||
| ) | ||
| with self.begin() as con: | ||
| open_schema = self.current_schema | ||
| con.exec_driver_sql(create_schema.sql(dialect="postgres")) | ||
| # Exasol implicitly opens the created schema, therefore we need to restore | ||
| # the previous context. | ||
| action = ( | ||
| sa.text(f"OPEN SCHEMA {open_schema}") | ||
| if open_schema | ||
| else sa.text(f"CLOSE SCHEMA {name}") | ||
| ) | ||
| con.exec_driver_sql(action) | ||
|
|
||
| def list_schemas( | ||
| self, like: str | None = None, database: str | None = None | ||
| ) -> list[str]: | ||
| if database is not None: | ||
| raise NotImplementedError( | ||
| "`database` argument is not supported for the Exasol backend" | ||
| ) | ||
|
|
||
| schema, table = "SYS", "EXA_SCHEMAS" | ||
| sch = sa.table( | ||
| table, | ||
| sa.column("schema_name", sa.TEXT()), | ||
| schema=schema, | ||
| ) | ||
|
|
||
| query = sa.select(sch.c.schema_name) | ||
|
|
||
| with self.begin() as con: | ||
| schemas = list(con.execute(query).scalars()) | ||
| return self._filter_with_like(schemas, like=like) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import sqlalchemy as sa | ||
|
|
||
| from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator | ||
| from ibis.backends.exasol import registry | ||
| from ibis.backends.exasol.datatypes import ExasolSQLType | ||
|
|
||
|
|
||
| class ExasolExprTranslator(AlchemyExprTranslator): | ||
| _registry = registry.create() | ||
| _rewrites = AlchemyExprTranslator._rewrites.copy() | ||
| _integer_to_timestamp = sa.func.from_unixtime | ||
| _dialect_name = "exa.websocket" | ||
| native_json_type = False | ||
| type_mapper = ExasolSQLType | ||
|
|
||
|
|
||
| rewrites = ExasolExprTranslator.rewrites | ||
|
|
||
|
|
||
| class ExasolCompiler(AlchemyCompiler): | ||
| translator_class = ExasolExprTranslator | ||
| support_values_syntax_in_select = False |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import sqlalchemy.types as sa_types | ||
|
|
||
| from ibis.backends.base.sql.alchemy.datatypes import AlchemyType | ||
|
|
||
| if TYPE_CHECKING: | ||
| import ibis.expr.datatypes as dt | ||
|
|
||
|
|
||
| class ExasolSQLType(AlchemyType): | ||
| dialect = "exa.websocket" | ||
|
|
||
| @classmethod | ||
| def from_ibis(cls, dtype: dt.DataType) -> sa_types.TypeEngine: | ||
| if dtype.is_string(): | ||
| # see also: https://docs.exasol.com/db/latest/sql_references/data_types/datatypesoverview.htm | ||
| MAX_VARCHAR_SIZE = 2_000_000 | ||
| return sa_types.VARCHAR(MAX_VARCHAR_SIZE) | ||
| return super().from_ibis(dtype) | ||
|
|
||
| @classmethod | ||
| def to_ibis(cls, typ: sa_types.TypeEngine, nullable: bool = True) -> dt.DataType: | ||
| return super().to_ibis(typ, nullable=nullable) |