20 changes: 20 additions & 0 deletions ci/schema/mssql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ CREATE TABLE diamonds (
z FLOAT
);


-- /data is a volume mount to the ibis testing data
-- used for snappy test data loading
-- DataFrame.to_sql is unusably slow for loading CSVs
BULK INSERT diamonds
FROM '/data/diamonds.csv'
WITH (FORMAT = 'CSV', FIELDTERMINATOR = ',', ROWTERMINATOR = '\n', FIRSTROW = 2)

DROP TABLE IF EXISTS batting;

CREATE TABLE batting (
Expand Down Expand Up @@ -40,6 +48,10 @@ CREATE TABLE batting (
"GIDP" BIGINT
);

BULK INSERT batting
FROM '/data/batting.csv'
WITH (FORMAT = 'CSV', FIELDTERMINATOR = ',', ROWTERMINATOR = '\n', FIRSTROW = 2)

DROP TABLE IF EXISTS awards_players;

CREATE TABLE awards_players (
Expand All @@ -51,6 +63,10 @@ CREATE TABLE awards_players (
notes VARCHAR(MAX)
);

BULK INSERT awards_players
FROM '/data/awards_players.csv'
WITH (FORMAT = 'CSV', FIELDTERMINATOR = ',', ROWTERMINATOR = '\n', FIRSTROW = 2)

DROP TABLE IF EXISTS functional_alltypes;

CREATE TABLE functional_alltypes (
Expand All @@ -71,6 +87,10 @@ CREATE TABLE functional_alltypes (
month INTEGER
);

BULK INSERT functional_alltypes
FROM '/data/functional_alltypes.csv'
WITH (FORMAT = 'CSV', FIELDTERMINATOR = ',', ROWTERMINATOR = '\n', FIRSTROW = 2)

CREATE INDEX "ix_functional_alltypes_index" ON functional_alltypes ("index");

DROP TABLE IF EXISTS win;
Expand Down
2 changes: 1 addition & 1 deletion ci/schema/mysql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ CREATE TABLE functional_alltypes (
double_col DOUBLE,
date_string_col TEXT,
string_col TEXT,
timestamp_col TIMESTAMP,
timestamp_col DATETIME,
year INTEGER,
month INTEGER
) DEFAULT CHARACTER SET = utf8;
Expand Down
10 changes: 7 additions & 3 deletions ci/schema/postgresql.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
DROP SEQUENCE IF EXISTS test_sequence;
CREATE SEQUENCE IF NOT EXISTS test_sequence;

CREATE EXTENSION IF NOT EXISTS hstore;
CREATE EXTENSION IF NOT EXISTS postgis;
CREATE EXTENSION IF NOT EXISTS plpython3u;

Expand Down Expand Up @@ -204,3 +202,9 @@ INSERT INTO win VALUES
('a', 2, 0),
('a', 3, 1),
('a', 4, 1);

DROP TABLE IF EXISTS map CASCADE;
CREATE TABLE map (kv HSTORE);
INSERT INTO map VALUES
('a=>1,b=>2,c=>3'),
('d=>4,e=>5,c=>6');
31 changes: 23 additions & 8 deletions ci/schema/snowflake.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
CREATE OR REPLACE TABLE diamonds (
CREATE OR REPLACE FILE FORMAT ibis_testing
type = 'CSV'
field_delimiter = ','
skip_header = 1
field_optionally_enclosed_by = '"';

CREATE OR REPLACE STAGE ibis_testing file_format = ibis_testing;

CREATE TEMP TABLE diamonds (
"carat" FLOAT,
"cut" TEXT,
"color" TEXT,
Expand All @@ -11,7 +19,7 @@ CREATE OR REPLACE TABLE diamonds (
"z" FLOAT
);

CREATE OR REPLACE TABLE batting (
CREATE TEMP TABLE batting (
"playerID" TEXT,
"yearID" BIGINT,
"stint" BIGINT,
Expand All @@ -36,7 +44,7 @@ CREATE OR REPLACE TABLE batting (
"GIDP" BIGINT
);

CREATE OR REPLACE TABLE awards_players (
CREATE TEMP TABLE awards_players (
"playerID" TEXT,
"awardID" TEXT,
"yearID" BIGINT,
Expand All @@ -45,7 +53,7 @@ CREATE OR REPLACE TABLE awards_players (
"notes" TEXT
);

CREATE OR REPLACE TABLE functional_alltypes (
CREATE TEMP TABLE functional_alltypes (
"index" BIGINT,
"Unnamed: 0" BIGINT,
"id" INTEGER,
Expand All @@ -63,7 +71,7 @@ CREATE OR REPLACE TABLE functional_alltypes (
"month" INTEGER
);

CREATE OR REPLACE TABLE array_types (
CREATE TEMP TABLE array_types (
"x" ARRAY,
"y" ARRAY,
"z" ARRAY,
Expand All @@ -80,7 +88,14 @@ INSERT INTO array_types ("x", "y", "z", "grouper", "scalar_column", "multi_dim")
SELECT [2, NULL, 3], ['b', NULL, 'c'], NULL, 'b', 5.0, NULL UNION
SELECT [4, NULL, NULL, 5], ['d', NULL, NULL, 'e'], [4.0, NULL, NULL, 5.0], 'c', 6.0, [[1, 2, 3]];

CREATE OR REPLACE TABLE struct ("abc" OBJECT);
CREATE TEMP TABLE map ("kv" OBJECT);

INSERT INTO map ("kv")
SELECT object_construct('a', 1, 'b', 2, 'c', 3) UNION
SELECT object_construct('d', 4, 'e', 5, 'c', 6);


CREATE TEMP TABLE struct ("abc" OBJECT);

INSERT INTO struct ("abc")
SELECT {'a': 1.0, 'b': 'banana', 'c': 2} UNION
Expand All @@ -91,7 +106,7 @@ INSERT INTO struct ("abc")
SELECT NULL UNION
SELECT {'a': 3.0, 'b': 'orange', 'c': NULL};

CREATE OR REPLACE TABLE json_t ("js" VARIANT);
CREATE TEMP TABLE json_t ("js" VARIANT);

INSERT INTO json_t ("js")
SELECT parse_json('{"a": [1,2,3,4], "b": 1}') UNION
Expand All @@ -101,7 +116,7 @@ INSERT INTO json_t ("js")
SELECT parse_json('[42,47,55]') UNION
SELECT parse_json('[]');

CREATE OR REPLACE TABLE win ("g" TEXT, "x" BIGINT, "y" BIGINT);
CREATE TEMP TABLE win ("g" TEXT, "x" BIGINT, "y" BIGINT);
INSERT INTO win VALUES
('a', 0, 3),
('a', 1, 2),
Expand Down
12 changes: 12 additions & 0 deletions ci/schema/trino.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
DROP TABLE IF EXISTS map;
CREATE TABLE map (kv MAP<VARCHAR, BIGINT>);
INSERT INTO map VALUES
(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3])),
(MAP(ARRAY['d', 'e', 'f'], ARRAY[4, 5, 6]));

DROP TABLE IF EXISTS ts;
CREATE TABLE ts (x TIMESTAMP(3), y TIMESTAMP(6), z TIMESTAMP(9));
INSERT INTO ts VALUES
(TIMESTAMP '2023-01-07 13:20:05.561',
TIMESTAMP '2023-01-07 13:20:05.561021',
TIMESTAMP '2023-01-07 13:20:05.561000231');
261 changes: 160 additions & 101 deletions conda-lock/linux-64-3.10.lock

Large diffs are not rendered by default.

260 changes: 159 additions & 101 deletions conda-lock/linux-64-3.8.lock

Large diffs are not rendered by default.

261 changes: 160 additions & 101 deletions conda-lock/linux-64-3.9.lock

Large diffs are not rendered by default.

276 changes: 158 additions & 118 deletions conda-lock/osx-64-3.10.lock

Large diffs are not rendered by default.

263 changes: 160 additions & 103 deletions conda-lock/osx-64-3.8.lock

Large diffs are not rendered by default.

273 changes: 156 additions & 117 deletions conda-lock/osx-64-3.9.lock

Large diffs are not rendered by default.

255 changes: 157 additions & 98 deletions conda-lock/osx-arm64-3.10.lock

Large diffs are not rendered by default.

277 changes: 159 additions & 118 deletions conda-lock/osx-arm64-3.8.lock

Large diffs are not rendered by default.

255 changes: 157 additions & 98 deletions conda-lock/osx-arm64-3.9.lock

Large diffs are not rendered by default.

255 changes: 157 additions & 98 deletions conda-lock/win-64-3.10.lock

Large diffs are not rendered by default.

255 changes: 157 additions & 98 deletions conda-lock/win-64-3.8.lock

Large diffs are not rendered by default.

256 changes: 158 additions & 98 deletions conda-lock/win-64-3.9.lock

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
version: "3.4"
services:
clickhouse:
image: clickhouse/clickhouse-server:22.12.2.25-alpine
build:
context: .
dockerfile: ./docker/clickhouse/Dockerfile
image: ibis-clickhouse
ports:
- 8123:8123
- 9000:9000
Expand Down Expand Up @@ -122,7 +125,7 @@ services:
retries: 3
test:
- CMD-SHELL
- /opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P "$$MSSQL_SA_PASSWORD" -Q "SELECT 1 AS one"
- /opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P "$$MSSQL_SA_PASSWORD" -Q "IF DB_ID('ibis_testing') IS NULL BEGIN CREATE DATABASE [ibis_testing] END"
timeout: 10s
build:
context: .
Expand All @@ -136,6 +139,8 @@ services:
user: postgres
environment:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: ibis_testing
POSTGRES_USER: postgres
healthcheck:
interval: 10s
retries: 3
Expand Down
2 changes: 2 additions & 0 deletions docker/clickhouse/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM clickhouse/clickhouse-server:22.12.3.5-alpine
COPY ./ci/ibis-testing-data /var/lib/clickhouse/user_files
9 changes: 4 additions & 5 deletions docs/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ We love new contributors!

To get started:

1. [Set up a development environment](https://ibis-project.org/docs/latest/contribute/01_environment/)
1. [Learn about the commit workflow](https://ibis-project.org/docs/latest/contribute/02_workflow/)
1. [Review the code style guidelines](https://ibis-project.org/docs/latest/contribute/03_style/)
1. [Learn how to run the backend test suite](https://ibis-project.org/docs/latest/contribute/04_backend_tests/)
1. [Dig into the nitty gritty of being a maintainer](https://ibis-project.org/docs/latest/contribute/05_maintainers_guide/)
1. [Set up a development environment](https://ibis-project.org/docs/latest/community/contribute/01_environment/)
1. [Learn about the commit workflow](https://ibis-project.org/docs/latest/community/contribute/02_workflow/)
1. [Review the code style guidelines](https://ibis-project.org/docs/latest/community/contribute/03_style/)
1. [Dig into the nitty gritty of being a maintainer](https://ibis-project.org/docs/latest/community/contribute/05_maintainers_guide/)
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* [Timestamps + Dates + Times](api/expressions/timestamps.md)
* [Collections](api/expressions/collections.md)
* [Geospatial](api/expressions/geospatial.md)
* [Column Selectors](api/selectors.md)
* [Data Types](api/datatypes.md)
* [Schemas](api/schemas.md)
* [Backend Interfaces](api/backends/)
Expand Down
2 changes: 2 additions & 0 deletions docs/api/expressions/top_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ These methods and objects are available directly in the `ibis` module.
::: ibis.date
::: ibis.desc
::: ibis.difference
::: ibis.get_backend
::: ibis.greatest
::: ibis.ifelse
::: ibis.intersect
Expand All @@ -35,6 +36,7 @@ These methods and objects are available directly in the `ibis` module.
::: ibis.read_parquet
::: ibis.row_number
::: ibis.schema
::: ibis.set_backend
::: ibis.struct
::: ibis.table
::: ibis.time
Expand Down
5 changes: 5 additions & 0 deletions docs/api/selectors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Column Selectors

<!-- prettier-ignore-start -->
::: ibis.expr.selectors
<!-- prettier-ignore-end -->
8 changes: 8 additions & 0 deletions docs/backends/BigQuery.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
backend_name: Google BigQuery
backend_url: https://cloud.google.com/bigquery
backend_module: bigquery
backend_param_style: connection parameters
---

{% include 'backends/template.md' %}
203 changes: 203 additions & 0 deletions docs/backends/app/backend_info_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import datetime
import tempfile
from pathlib import Path
from typing import List, Optional

import pandas as pd
import requests
import sqlglot
import streamlit as st

import ibis
from ibis import _

ONE_HOUR_IN_SECONDS = datetime.timedelta(hours=1).total_seconds()

st.set_page_config(layout='wide')

# Track all queries. We display them at the bottom of the page.
ibis.options.verbose = True
sql_queries = []
ibis.options.verbose_log = lambda sql: sql_queries.append(sql)


@st.experimental_memo(ttl=ONE_HOUR_IN_SECONDS)
def support_matrix_df():
resp = requests.get(
"https://ibis-project.org/docs/dev/backends/raw_support_matrix.csv"
)
resp.raise_for_status()

with tempfile.NamedTemporaryFile() as f:
f.write(resp.content)
return (
ibis.read_csv(f.name)
.relabel({'FullOperation': 'full_operation'})
.mutate(
short_operation=_.full_operation.split(".")[-1],
operation_category=_.full_operation.split(".")[-2],
)
.execute()
)


@st.experimental_memo(ttl=ONE_HOUR_IN_SECONDS)
def backends_info_df():
return pd.DataFrame(
{
"bigquery": ["string", "sql"],
"clickhouse": ["string", "sql"],
'dask': ["dataframe"],
"datafusion": ["dataframe"],
"duckdb": ["sqlalchemy", "sql"],
"impala": ["string", "sql"],
"mssql": ["sqlalchemy", "sql"],
"mysql": ["sqlalchemy", "sql"],
"pandas": ["dataframe"],
"polars": ["dataframe"],
"postgres": ["sqlalchemy", "sql"],
"pyspark": ["dataframe"],
"snowflake": ["sqlalchemy", "sql"],
"sqlite": ["sqlalchemy", "sql"],
"trino": ["sqlalchemy", "sql"],
}.items(),
columns=['backend_name', 'categories'],
)


backend_info_table = ibis.memtable(backends_info_df())
support_matrix_table = ibis.memtable(support_matrix_df())


@st.experimental_memo(ttl=ONE_HOUR_IN_SECONDS)
def get_all_backend_categories():
return (
backend_info_table.select(category=_.categories.unnest())
.distinct()
.order_by('category')['category']
.execute()
.tolist()
)


@st.experimental_memo(ttl=ONE_HOUR_IN_SECONDS)
def get_all_operation_categories():
return (
support_matrix_table.select(_.operation_category)
.distinct()['operation_category']
.execute()
.tolist()
)


@st.experimental_memo(ttl=ONE_HOUR_IN_SECONDS)
def get_backend_names(categories: Optional[List[str]] = None):
backend_expr = backend_info_table.mutate(category=_.categories.unnest())
if categories:
backend_expr = backend_expr.filter(_.category.isin(categories))
return (
backend_expr.select(_.backend_name).distinct().backend_name.execute().tolist()
)


def get_selected_backend_name():
backend_categories = get_all_backend_categories()
selected_categories_names = st.sidebar.multiselect(
'Backend category',
options=backend_categories,
default=None,
)
if not selected_categories_names:
return get_backend_names()
return get_backend_names(selected_categories_names)


def get_selected_operation_categories():
all_ops_categories = get_all_operation_categories()

selected_ops_categories = st.sidebar.multiselect(
'Operation category',
options=sorted(all_ops_categories),
default=None,
)
if not selected_ops_categories:
selected_ops_categories = all_ops_categories
show_geospatial = st.sidebar.checkbox('Include Geospatial ops', value=True)
if not show_geospatial and 'geospatial' in selected_ops_categories:
selected_ops_categories.remove("geospatial")
return selected_ops_categories


current_backend_names = get_selected_backend_name()
current_ops_categories = get_selected_operation_categories()

hide_supported_by_all_backends = st.sidebar.selectbox(
'Operation compatibility',
['Show all', 'Show supported by all backends', 'Hide supported by all backends'],
0,
)
show_full_ops_name = st.sidebar.checkbox('Show full operation name', False)

# Start ibis expression
table_expr = support_matrix_table

# Add index to result
if show_full_ops_name:
table_expr = table_expr.mutate(index=_.full_operation)
else:
table_expr = table_expr.mutate(index=_.short_operation)
table_expr = table_expr.order_by(_.index)

# Filter operations by selected categories
table_expr = table_expr.filter(_.operation_category.isin(current_ops_categories))

# Filter operation by compatibility
supported_backend_count = sum(
getattr(table_expr, backend_name).ifelse(1, 0)
for backend_name in current_backend_names
)
if hide_supported_by_all_backends == 'Show supported by all backends':
table_expr = table_expr.filter(
supported_backend_count == len(current_backend_names)
)
elif hide_supported_by_all_backends == 'Hide supported by all backends':
table_expr = table_expr.filter(
supported_backend_count != len(current_backend_names)
)

# Show only selected backend
table_expr = table_expr[current_backend_names + ["index"]]

# Execute query
df = table_expr.execute()
df = df.set_index('index')

# Display result
all_visible_ops_count = len(df.index)
if all_visible_ops_count:
# Compute coverage
coverage = (
df.sum()
.sort_values(ascending=False)
.map(lambda n: f"{n} ({round(100 * n / all_visible_ops_count)}%)")
.to_frame(name="API Coverage")
.T
)

table = pd.concat([coverage, df.replace({True: "✔", False: "🚫"})])
st.dataframe(table)
else:
st.write("No data")

with st.expander("SQL queries"):
for sql_query in sql_queries:
pretty_sql_query = sqlglot.transpile(
sql_query, read='duckdb', write='duckdb', pretty=True
)[0]
st.code(
pretty_sql_query,
language='sql',
)

with st.expander("Source code"):
st.code(Path(__file__).read_text())
5 changes: 5 additions & 0 deletions docs/backends/app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ibis-framework[duckdb]>=4.0
pandas
requests
streamlit
sqlglot
2 changes: 1 addition & 1 deletion docs/backends/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ database through a driver API for execution.

- [Apache Impala](Impala.md)
- [ClickHouse](ClickHouse.md)
- [Google BigQuery](https://github.com/ibis-project/ibis-bigquery/)
- [Google BigQuery](BigQuery.md)
- [HeavyAI](https://github.com/heavyai/ibis-heavyai)

## Expression Generating Backends
Expand Down
29 changes: 21 additions & 8 deletions docs/backends/support_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@ hide:

Backends are shown in descending order of the number of supported operations.

!!! tip "The Snowflake backend coverage is an overestimate"
!!! tip "Backends with low coverage are good places to start contributing!"

The Snowflake backend translation functions are reused from the PostgreSQL backend
and some operations that claim coverage may not work.
Each backend implements operations differently, but this is usually very
similar to other backends. If you want to start contributing to ibis, it's
a good idea to start by adding missing operations to backends that have low
operation coverage.

The Snowflake backend is a good place to start contributing!
<div class="streamlit-app">
<iframe id="streamlit-app" src="https://ibis-project.streamlit.app/?embedded=true"></iframe>
</div>

## Core Operations
!!! note "This app is built using [`streamlit`](https://streamlit.io/)"

{{ read_csv("docs/backends/core_support_matrix.csv") }}
You can develop the app locally by editing `docs/backends/app/backend_info_app.py` and
opening a PR with your changes.

## Geospatial Operations
Test your changes locally by running

{{ read_csv("docs/backends/geospatial_support_matrix.csv") }}
```sh
$ streamlit run docs/backends/app/backend_info_app.py
```

The changes will show up in the dev docs when your PR is merged!

## Raw Data

You can also download data from the above tables in [CSV format](./raw_support_matrix.csv).
6 changes: 3 additions & 3 deletions docs/blog/ibis-version-4.0.0-release.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ Let’s talk about some of the new changes 4.0 brings for Ibis users.

## Backends

Ibis 4.0 brings [Polars](https://ibis-project.org/docs/latest/backends/Polars/), [Snowflake](https://ibis-project.org/docs/dev/backends/Snowflake/), and [Trino](https://ibis-project.org/docs/dev/backends/Trino/) into an already-impressive stock of supported backends.
Ibis 4.0 brings [Polars](https://ibis-project.org/docs/4.0.0/backends/Polars/), [Snowflake](https://ibis-project.org/docs/4.0.0/backends/Snowflake/), and [Trino](https://ibis-project.org/docs/4.0.0/backends/Trino/) into an already-impressive stock of supported backends.
The [Polars](https://www.pola.rs/) backend adds another way for users to work locally with DataFrames.
The [Snowflake](https://www.snowflake.com/en/) and [Trino](https://trino.io/) backends add a free and familiar python API to popular data warehouses.

Alongside these new backends, Google BigQuery and Microsoft SQL have been moved to the main repo and have been updated.
Alongside these new backends, Google BigQuery and Microsoft SQL have been moved to the main repo, so their release cycle will follow the Ibis core.

## Functionality

There are a lot of improvements incoming, but some notable changes include:

- [read API](https://github.com/ibis-project/ibis/pull/5005)): allows users to read various file formats directly into their [configured `default_backend`](https://ibis-project.org/docs/dev/api/config/?h=default#ibis.config.Options) (default DuckDB) through `read_*` functions, which makes working with local files easier than ever.
- [read API](https://github.com/ibis-project/ibis/pull/5005): allows users to read various file formats directly into their [configured `default_backend`](https://ibis-project.org/docs/dev/api/config/?h=default#ibis.config.Options) (default DuckDB) through `read_*` functions, which makes working with local files easier than ever.
- [to_pyarrow and to_pyarrow_batches](https://github.com/ibis-project/ibis/pull/4454#issuecomment-1262640204): users can now return PyArrow objects (Tables, Arrays, Scalars, RecordBatchReader) and therefore grants all of the functionality that PyArrow provides
- [JSON getitem](https://github.com/ibis-project/ibis/pull/4525): users can now run getitem on a JSON field using Ibis expressions with some backends
- [Plotting support through `__array__`](https://github.com/ibis-project/ibis/pull/4547): allows users to plot Ibis expressions out of the box
Expand Down
14 changes: 6 additions & 8 deletions docs/community/contribute/02_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ pytest -m core
or any specific backends (`ibis/backends`) this material isn't necessary to
follow to make a pull request.

First, we need to download example data to run the tests successfully:

```sh
just download-data
```

To run the tests for a specific backend (e.g. sqlite):

```sh
Expand All @@ -50,14 +56,6 @@ export PGPASSWORD=postgres
psql -t -A -h localhost -U postgres -d ibis_testing -c "select 'success'"
```

## Download Test Data

Backends need to be populated with test data to run the tests successfully:

```sh
just download-data
```

## Writing the commit

Ibis follows the [Conventional Commits](https://www.conventionalcommits.org/) structure.
Expand Down
238 changes: 186 additions & 52 deletions docs/release_notes.md

Large diffs are not rendered by default.

20 changes: 9 additions & 11 deletions docs/sqlalchemy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@

a = (
sa.select(
[
sa.case(
[(i.c.investor_name.is_(None), "NO INVESTOR")],
else_=i.c.investor_name,
).label("investor_name"),
sa.func.count(c.c.permalink.distinct()).label("num_investments"),
sa.func.count(
sa.case([(c.status.in_(("ipo", "acquired")), c.c.permalink)]).distinct()
).label("acq_ipos"),
]
sa.case(
[(i.c.investor_name.is_(None), "NO INVESTOR")],
else_=i.c.investor_name,
).label("investor_name"),
sa.func.count(c.c.permalink.distinct()).label("num_investments"),
sa.func.count(
sa.case([(c.status.in_(("ipo", "acquired")), c.c.permalink)]).distinct()
).label("acq_ipos"),
)
.select_from(
c.join(i, onclause=c.c.permalink == i.c.company_permalink, isouter=True)
)
.group_by(1)
.order_by(sa.desc(2))
)
expr = sa.select([(a.c.acq_ipos / a.c.num_investments).label("acq_rate")])
expr = sa.select((a.c.acq_ipos / a.c.num_investments).label("acq_rate"))
83 changes: 8 additions & 75 deletions docs/stylesheets/extra.css
Original file line number Diff line number Diff line change
Expand Up @@ -39,84 +39,17 @@
text-align: center;
}

.support-matrix .md-typeset__table {
display: table;
min-width: 100%;
}

.support-matrix .md-typeset table:not([class]) {
display: table;
.streamlit-app {
text-align: center;
height: 1000px;
min-width: 100%;
}

body
> div.md-container
> main
> div
> div.md-content
> article
> div.md-typeset__scrollwrap
> div
> table
> thead
> tr
> th:nth-child(1) {
min-width: 9.8rem;
}
body
> div.md-container
> main
> div
> div.md-content
> article
> div.md-typeset__scrollwrap {
overflow-y: auto;
height: 750px;
}

body
> div.md-container
> main
> div
> div.md-content
> article
> div.md-typeset__scrollwrap
> div
> table {
display: table;
}

body
> div.md-container
> main
> div
> div.md-content
> article
> div.md-typeset__scrollwrap
> div
> table
> thead {
position: sticky;
top: 0;
z-index: 2;
background-color: black;
}

body
> div.md-container
> main
> div
> div.md-content
> article
> div.md-typeset__scrollwrap
> div
> table
> tbody
> tr
> td:nth-of-type(1) {
position: sticky;
left: 0;
z-index: 1;
background-color: black;
text-align: right;
#streamlit-app {
height: 1000px;
width: 100%;
border: none;
overflow: hidden;
}
18 changes: 9 additions & 9 deletions flake.lock
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

default = pkgs.ibis310;

inherit (pkgs) update-lock-files;
inherit (pkgs) update-lock-files gen-all-extras;
};

devShells = rec {
Expand Down
71 changes: 14 additions & 57 deletions gen_matrix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import io
from pathlib import Path

import mkdocs_gen_files
import pandas as pd
import tomli

Expand All @@ -24,79 +26,34 @@ def get_leaf_classes(op):
yield from get_leaf_classes(child_class)


EXCLUDED_OPS = {
INTERNAL_OPS = {
# Never translates into anything
ops.UnresolvedExistsSubquery,
ops.UnresolvedNotExistsSubquery,
ops.ScalarParameter,
} | {
op
for op in frozenset(get_leaf_classes(ops.Value))
if issubclass(op, (ops.GeoSpatialUnOp, ops.GeoSpatialBinOp))
}

INCLUDED_OPS = {
# Parent class of MultiQuantile so it's ignored by `get_backends()`
ops.Quantile,
}


ICONS = {
True: ":material-check-decagram:{ .verified }",
False: ":material-cancel:{ .cancel }",
}
PUBLIC_OPS = (frozenset(get_leaf_classes(ops.Value))) - INTERNAL_OPS


def gen_matrix(basename, possible_ops=None):
if possible_ops is None:
possible_ops = (
frozenset(get_leaf_classes(ops.Value)) | INCLUDED_OPS
) - EXCLUDED_OPS

support = {"operation": [f"`{op.__name__}`" for op in possible_ops]}
def main():
support = {"operation": [f"{op.__module__}.{op.__name__}" for op in PUBLIC_OPS]}
support.update(
(name, list(map(backend.has_operation, possible_ops)))
(name, list(map(backend.has_operation, PUBLIC_OPS)))
for name, backend in get_backends()
)

df = pd.DataFrame(support).set_index("operation").sort_index()

counts = df.sum().sort_values(ascending=False)
counts = counts[counts > 0]
num_ops = len(possible_ops)
coverage = (
counts.map(lambda n: f"_{n} ({round(100 * n / num_ops)}%)_")
.to_frame(name="**API Coverage**")
.T
)
file_path = Path("backends", "raw_support_matrix.csv")
local_path = Path(__file__).parent / "docs" / file_path

ops_table = df.loc[:, counts.index].replace(ICONS)
table = pd.concat([coverage, ops_table])
dst = Path(__file__).parent.joinpath(
"docs",
"backends",
f"{basename}_support_matrix.csv",
)

if dst.exists():
old = pd.read_csv(dst, index_col="Backends")
should_write = not old.equals(table)
else:
should_write = True

if should_write:
table.to_csv(dst, index_label="Backends")
buf = io.StringIO()
df.to_csv(buf, index_label="FullOperation")


def main():
gen_matrix(basename="core")
gen_matrix(
basename="geospatial",
possible_ops=(
frozenset(get_leaf_classes(ops.GeoSpatialUnOp))
| frozenset(get_leaf_classes(ops.GeoSpatialBinOp))
),
)
local_path.write_text(buf.getvalue())
with mkdocs_gen_files.open(file_path, "w") as f:
f.write(buf.getvalue())


main()
4 changes: 2 additions & 2 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
__all__ = ['api', 'ir', 'util', 'BaseBackend', 'IbisError', 'options']
__all__ += api.__all__

__version__ = "4.0.0"
__version__ = "4.1.0"

_KNOWN_BACKENDS = ['bigquery', 'heavyai']
_KNOWN_BACKENDS = ['heavyai']


def __dir__() -> list[str]:
Expand Down
47 changes: 29 additions & 18 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@
MutableMapping,
)

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

import ibis.expr.schema as sch

import ibis
import ibis.common.exceptions as exc
import ibis.config
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

__all__ = ('BaseBackend', 'Database', 'connect')


Expand Down Expand Up @@ -206,6 +204,12 @@ def __dir__(self) -> list[str]:
)
return list(o)

def __repr__(self) -> str:
tables = self._backend.list_tables()
rows = ["Tables", "------"]
rows.extend(f"- {name}" for name in sorted(tables))
return "\n".join(rows)

def _ipython_key_completions_(self) -> list[str]:
return self._backend.list_tables()

Expand All @@ -223,16 +227,6 @@ def _import_pyarrow():
else:
return pyarrow

@staticmethod
def _table_or_column_schema(expr: ir.Expr) -> sch.Schema:
from ibis.backends.pyarrow.datatypes import sch

if isinstance(expr, ir.Table):
return expr.schema()
else:
# ColumnExpr has no schema method, define single-column schema
return sch.schema([(expr.get_name(), expr.type())])

@util.experimental
def to_pyarrow(
self,
Expand Down Expand Up @@ -275,7 +269,7 @@ def to_pyarrow(
except ValueError:
# The pyarrow batches iterator is empty so pass in an empty
# iterator and a pyarrow schema
schema = self._table_or_column_schema(expr)
schema = expr.as_table().schema()
table = pa.Table.from_batches([], schema=schema.to_pyarrow())

if isinstance(expr, ir.Table):
Expand Down Expand Up @@ -317,7 +311,7 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
Maximum number of rows in each returned record batch.
kwargs
Keyword arguments
Expand Down Expand Up @@ -522,6 +516,23 @@ def list_tables(
The list of the table names that match the pattern `like`.
"""

@abc.abstractmethod
def table(self, name: str, database: str | None = None) -> ir.Table:
"""Construct a table expression.
Parameters
----------
name
Table name
database
Database name
Returns
-------
Table
Table expression
"""

@functools.cached_property
def tables(self):
"""An accessor for tables in the database.
Expand Down
47 changes: 16 additions & 31 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,26 +112,17 @@ def sql(self, query: str, schema: sch.Schema | None = None) -> ir.Table:
def _get_schema_using_query(self, query):
raise NotImplementedError(f"Backend {self.name} does not support .sql()")

def raw_sql(self, query: str) -> Any:
def raw_sql(self, query: str):
"""Execute a query string.
Could have unexpected results if the query modifies the behavior of
the session in a way unknown to Ibis; be careful.
!!! warning "The returned cursor object must be **manually** released if results are returned."
Parameters
----------
query
DML or DDL statement
Returns
-------
Any
Backend cursor
DDL or DML statement
"""
# TODO `self.con` is assumed to be defined in subclasses, but there
# is nothing that enforces it. We should find a way to make sure
# `self.con` is always a DBAPI2 connection, or raise an error
cursor = self.con.execute(query) # type: ignore
cursor = self.con.execute(query)
if cursor:
return cursor
cursor.release()
Expand Down Expand Up @@ -163,7 +154,7 @@ def to_pyarrow_batches(
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1_000_000,
**kwargs: Any,
**_: Any,
) -> pa.ipc.RecordBatchReader:
"""Execute expression and return an iterator of pyarrow record batches.
Expand All @@ -180,32 +171,26 @@ def to_pyarrow_batches(
params
Mapping of scalar parameter expressions to value.
chunk_size
Number of rows in each returned record batch.
kwargs
Keyword arguments
Maximum number of rows in each returned record batch.
Returns
-------
results
RecordBatchReader
RecordBatchReader
Collection of pyarrow `RecordBatch`s.
"""
pa = self._import_pyarrow()

from ibis.backends.pyarrow.datatypes import ibis_to_pyarrow_struct

schema = self._table_or_column_schema(expr)

def _batches():
schema = expr.as_table().schema()
array_type = schema.as_struct().to_pyarrow()
arrays = (
pa.array(map(tuple, batch), type=array_type)
for batch in self._cursor_batches(
expr, params=params, limit=limit, chunk_size=chunk_size
):
struct_array = pa.array(
map(tuple, batch),
type=ibis_to_pyarrow_struct(schema),
)
yield pa.RecordBatch.from_struct_array(struct_array)
)
)
batches = map(pa.RecordBatch.from_struct_array, arrays)

return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), _batches())
return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches)

def execute(
self,
Expand Down
129 changes: 105 additions & 24 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import abc
import atexit
import contextlib
import getpass
import warnings
from operator import methodcaller
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping

import sqlalchemy as sa

Expand Down Expand Up @@ -39,6 +42,8 @@
if TYPE_CHECKING:
import pandas as pd

import ibis.expr.datatypes as dt


__all__ = (
'BaseAlchemyBackend',
Expand Down Expand Up @@ -91,7 +96,7 @@ def _current_schema(self) -> str | None:
def do_connect(self, con: sa.engine.Engine) -> None:
self.con = con
self._inspector = sa.inspect(self.con)
self.meta = sa.MetaData(bind=self.con)
self.meta = sa.MetaData()
self._schemas: dict[str, sch.Schema] = {}
self._temp_views: set[str] = set()

Expand Down Expand Up @@ -236,8 +241,9 @@ def _get_insert_method(self, expr):
return methodcaller("from_select", list(expr.columns), compiled)

def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column]:
dialect = self.con.dialect
return [
sa.Column(colname, to_sqla_type(dtype), nullable=dtype.nullable)
sa.Column(colname, to_sqla_type(dialect, dtype), nullable=dtype.nullable)
for colname, dtype in zip(schema.names, schema.types)
]

Expand Down Expand Up @@ -274,7 +280,8 @@ def drop_table(
)

t = self._get_sqla_table(table_name, schema=database, autoload=False)
t.drop(checkfirst=force)
with self.begin() as bind:
t.drop(bind=bind, checkfirst=force)

assert not self.inspector.has_table(
table_name
Expand Down Expand Up @@ -338,7 +345,8 @@ def truncate_table(
database: str | None = None,
) -> None:
t = self._get_sqla_table(table_name, schema=database)
t.delete().execute()
with self.begin() as con:
con.execute(t.delete())

def schema(self, name: str) -> sch.Schema:
"""Get an ibis schema from the current database for the table `name`.
Expand Down Expand Up @@ -369,13 +377,46 @@ def _log(self, sql):
util.log(query_str)

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
autoload: bool = True,
**kwargs: Any,
self, name: str, schema: str | None = None, autoload: bool = True, **kwargs: Any
) -> sa.Table:
return sa.Table(name, self.meta, schema=schema, autoload=autoload)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Did not recognize type", category=sa.exc.SAWarning
)
table = sa.Table(
name,
self.meta,
schema=schema,
autoload_with=self.con if autoload else None,
)
nulltype_cols = frozenset(
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
)

if not nulltype_cols:
return table
return self._handle_failed_column_type_inference(table, nulltype_cols)

def _handle_failed_column_type_inference(
self, table: sa.Table, nulltype_cols: Iterable[str]
) -> sa.Table:
"""Handle cases where SQLAlchemy cannot infer the column types of `table`."""

self.inspector.reflect_table(table, table.columns)
dialect = self.con.dialect
quoted_name = dialect.identifier_preparer.quote(table.name)

for colname, type in self._metadata(quoted_name):
if colname in nulltype_cols:
# replace null types discovered by sqlalchemy with non null
# types
table.append_column(
sa.Column(
colname, to_sqla_type(dialect, type), nullable=type.nullable
),
replace_existing=True,
)
return table

def _sqla_table_to_expr(self, table: sa.Table) -> ir.Table:
schema = self._schemas.get(table.name)
Expand All @@ -387,6 +428,20 @@ def _sqla_table_to_expr(self, table: sa.Table) -> ir.Table:
)
return self.table_expr_class(node)

def raw_sql(self, query) -> None:
"""Execute a query string.
!!! warning "The returned cursor object must be **manually** released."
Parameters
----------
query
DDL or DML statement
"""
return self.con.connect().execute(
sa.text(query) if isinstance(query, str) else query
)

def table(
self,
name: str,
Expand Down Expand Up @@ -520,31 +575,57 @@ def insert(
f"The given obj is of type {type(obj).__name__} ."
)

def _quote(self, name: str) -> str:
"""Quote an identifier."""
return self.con.dialect.identifier_preparer.quote(name)

def _get_temp_view_definition(
self,
name: str,
definition: sa.sql.compiler.Compiled,
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
raise NotImplementedError(
f"The {self.name} backend does not implement temporary view creation"
)

def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
pass
query = f"DROP VIEW IF EXISTS {name}"

def drop(self, raw_name: str, query: str):
with self.begin() as con:
con.exec_driver_sql(query)
self._temp_views.discard(raw_name)

atexit.register(drop, self, raw_name, query)

def _create_temp_view(
def _get_compiled_statement(
self,
view: sa.Table,
definition: sa.sql.Selectable,
) -> None:
name: str,
compile_kwargs: Mapping[str, Any] | None = None,
):
if compile_kwargs is None:
compile_kwargs = {}
compiled = definition.compile(
dialect=self.con.dialect, compile_kwargs=compile_kwargs
)
lines = self._get_temp_view_definition(name, definition=compiled)
return lines, compiled.params

def _create_temp_view(self, view: sa.Table, definition: sa.sql.Selectable) -> None:
raw_name = view.name
if raw_name not in self._temp_views and raw_name in self.list_tables():
raise ValueError(f"{raw_name} already exists as a table or view")

name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
compiled = definition.compile()
defn = self._get_temp_view_definition(name, definition=compiled)
query = sa.text(defn).bindparams(**compiled.params)
self.con.execute(query)
name = self._quote(raw_name)
lines, params = self._get_compiled_statement(definition, name)
with self.begin() as con:
for line in lines:
con.exec_driver_sql(line, parameters=params)
self._temp_views.add(raw_name)
self._register_temp_view_cleanup(name, raw_name)

@abc.abstractmethod
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
...

def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a backend-specific SQL string."""
return sch.Schema.from_tuples(self._metadata(query))
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, source, sqla_table, name, schema):
if name is None:
name = sqla_table.name
if schema is None:
schema = sch.infer(sqla_table, schema=schema)
schema = sch.infer(sqla_table, schema=schema, dialect=source.con.dialect)
super().__init__(name=name, schema=schema, sqla_table=sqla_table, source=source)

# TODO(kszucs): remove this
Expand Down
229 changes: 167 additions & 62 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import functools
from typing import Iterable

import sqlalchemy as sa
from sqlalchemy.dialects import mysql, postgresql, sqlite
from multipledispatch import Dispatcher
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.dialects.mssql.base import MSDialect
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import UserDefinedType
Expand All @@ -20,25 +22,43 @@
import geoalchemy2 as ga


class ArrayType(UserDefinedType):
def __init__(self, value_type: sa.types.TypeEngine):
self.value_type = sa.types.to_instance(value_type)


@compiles(ArrayType, "default")
def compiles_array(element, compiler, **kw):
return f"ARRAY({compiler.process(element.value_type, **kw)})"


class StructType(UserDefinedType):
def __init__(
self,
pairs: Iterable[tuple[str, sa.types.TypeEngine]],
):
self.pairs = [(name, sa.types.to_instance(type)) for name, type in pairs]

def get_col_spec(self, **_):
pairs = ", ".join(f"{k} {v}" for k, v in self.pairs)
return f"STRUCT({pairs})"

@compiles(StructType, "default")
def compiles_struct(element, compiler, **kw):
content = ", ".join(
f"{field} {compiler.process(typ, **kw)}" for field, typ in element.pairs
)
return f"STRUCT({content})"


class MapType(UserDefinedType):
def __init__(self, key_type: sa.types.TypeEngine, value_type: sa.types.TypeEngine):
self.key_type = sa.types.to_instance(key_type)
self.value_type = sa.types.to_instance(value_type)

def get_col_spec(self, **_):
return f"MAP({self.key_type}, {self.value_type})"

@compiles(MapType, "default")
def compiles_map(element, compiler, **kw):
key_type = compiler.process(element.key_type, **kw)
value_type = compiler.process(element.value_type, **kw)
return f"MAP({key_type}, {value_type})"


class UInt64(sa.types.Integer):
Expand Down Expand Up @@ -80,8 +100,9 @@ def table_from_schema(name, meta, schema, database: str | None = None):
# Convert Ibis schema to SQLA table
columns = []

dialect = getattr(meta.bind, "dialect", _DEFAULT_DIALECT)
for colname, dtype in zip(schema.names, schema.types):
satype = to_sqla_type(dtype)
satype = to_sqla_type(dialect, dtype)
column = sa.Column(colname, satype, nullable=dtype.nullable)
columns.append(column)

Expand Down Expand Up @@ -113,65 +134,63 @@ def table_from_schema(name, meta, schema, database: str | None = None):
dt.UInt32: UInt32,
dt.UInt64: UInt64,
dt.JSON: sa.JSON,
dt.Interval: sa.Interval,
}


@functools.singledispatch
def to_sqla_type(itype, type_map=None):
if type_map is None:
type_map = ibis_type_to_sqla
return type_map[type(itype)]
_DEFAULT_DIALECT = DefaultDialect()

to_sqla_type = Dispatcher("to_sqla_type")

@to_sqla_type.register(dt.Decimal)
def _(itype, **kwargs):
return sa.types.NUMERIC(itype.precision, itype.scale)

@to_sqla_type.register(Dialect, dt.DataType)
def _default(_, itype):
return ibis_type_to_sqla[type(itype)]

@to_sqla_type.register(dt.Interval)
def _(itype, **kwargs):
return sa.types.Interval()

@to_sqla_type.register(Dialect, dt.Decimal)
def _decimal(_, itype):
return sa.types.NUMERIC(itype.precision, itype.scale)


@to_sqla_type.register(dt.Date)
def _(itype, **kwargs):
return sa.Date()
@to_sqla_type.register(Dialect, dt.Timestamp)
def _timestamp(_, itype):
return sa.TIMESTAMP(timezone=bool(itype.timezone))


@to_sqla_type.register(dt.Timestamp)
def _(itype, **kwargs):
return sa.TIMESTAMP(bool(itype.timezone))
@to_sqla_type.register(Dialect, dt.Array)
def _array(dialect, itype):
return ArrayType(to_sqla_type(dialect, itype.value_type))


@to_sqla_type.register(dt.Array)
def _(itype, **kwargs):
@to_sqla_type.register(PGDialect, dt.Array)
def _pg_array(dialect, itype):
# Unwrap the array element type because sqlalchemy doesn't allow arrays of
# arrays. This doesn't affect the underlying data.
while itype.is_array():
itype = itype.value_type
return sa.ARRAY(to_sqla_type(itype, **kwargs))
return sa.ARRAY(to_sqla_type(dialect, itype))


@to_sqla_type.register(dt.Struct)
def _(itype, **_):
return StructType(
[(name, to_sqla_type(type)) for name, type in itype.pairs.items()]
)
@to_sqla_type.register(PGDialect, dt.Map)
def _pg_map(dialect, itype):
if not (itype.key_type.is_string() and itype.value_type.is_string()):
raise TypeError(f"PostgreSQL only supports map<string, string>, got: {itype}")
return postgresql.HSTORE


@to_sqla_type.register(dt.Map)
def _(itype, **_):
return MapType(to_sqla_type(itype.key_type), to_sqla_type(itype.value_type))
@to_sqla_type.register(Dialect, dt.Struct)
def _struct(dialect, itype):
return StructType(
[(name, to_sqla_type(dialect, type)) for name, type in itype.fields.items()]
)


@to_sqla_type.register(dt.GeoSpatial)
def _(itype, **kwargs):
if itype.geotype == 'geometry':
return ga.Geometry
elif itype.geotype == 'geography':
return ga.Geography
else:
return ga.types._GISType
@to_sqla_type.register(Dialect, dt.Map)
def _map(dialect, itype):
return MapType(
to_sqla_type(dialect, itype.key_type), to_sqla_type(dialect, itype.value_type)
)


@dt.dtype.register(Dialect, sa.types.NullType)
Expand All @@ -184,21 +203,12 @@ def sa_boolean(_, satype, nullable=True):
return dt.Boolean(nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.NUMERIC)
@dt.dtype.register(MySQLDialect, sa.NUMERIC)
@dt.dtype.register(MySQLDialect, (sa.NUMERIC, mysql.NUMERIC))
def sa_mysql_numeric(_, satype, nullable=True):
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
return dt.Decimal(satype.precision or 10, satype.scale or 0, nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.TINYBLOB)
@dt.dtype.register(MySQLDialect, mysql.MEDIUMBLOB)
@dt.dtype.register(MySQLDialect, mysql.BLOB)
@dt.dtype.register(MySQLDialect, mysql.LONGBLOB)
def sa_mysql_blob(_, satype, nullable=True):
return dt.Binary(nullable=nullable)


_FLOAT_PREC_TO_TYPE = {
11: dt.Float16,
24: dt.Float32,
Expand Down Expand Up @@ -231,16 +241,44 @@ def sa_integer(_, satype, nullable=True):


@dt.dtype.register(Dialect, mysql.TINYINT)
@dt.dtype.register(MSDialect, mssql.TINYINT)
@dt.dtype.register(MySQLDialect, mysql.YEAR)
def sa_mysql_tinyint(_, satype, nullable=True):
return dt.Int8(nullable=nullable)


@dt.dtype.register(MSDialect, mssql.BIT)
def sa_mssql_bit(_, satype, nullable=True):
return dt.Boolean(nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.BIT)
def sa_mysql_bit(_, satype, nullable=True):
if 1 <= (length := satype.length) <= 8:
return dt.Int8(nullable=nullable)
elif 9 <= length <= 16:
return dt.Int16(nullable=nullable)
elif 17 <= length <= 32:
return dt.Int32(nullable=nullable)
elif 33 <= length <= 64:
return dt.Int64(nullable=nullable)
else:
raise ValueError(f"Invalid MySQL BIT length: {length:d}")


@dt.dtype.register(Dialect, sa.types.BigInteger)
@dt.dtype.register(MSDialect, mssql.MONEY)
def sa_bigint(_, satype, nullable=True):
return dt.Int64(nullable=nullable)


@dt.dtype.register(MSDialect, mssql.SMALLMONEY)
def sa_mssql_smallmoney(_, satype, nullable=True):
return dt.Int32(nullable=nullable)


@dt.dtype.register(Dialect, sa.REAL)
@dt.dtype.register(MySQLDialect, mysql.FLOAT)
def sa_real(_, satype, nullable=True):
return dt.Float32(nullable=nullable)

Expand All @@ -253,6 +291,7 @@ def sa_double(_, satype, nullable=True):


@dt.dtype.register(PGDialect, postgresql.UUID)
@dt.dtype.register(MSDialect, mssql.UNIQUEIDENTIFIER)
def sa_uuid(_, satype, nullable=True):
return dt.UUID(nullable=nullable)

Expand All @@ -262,6 +301,11 @@ def sa_macaddr(_, satype, nullable=True):
return dt.MACADDR(nullable=nullable)


@dt.dtype.register(PGDialect, postgresql.HSTORE)
def sa_hstore(_, satype, nullable=True):
return dt.Map(dt.string, dt.string, nullable=nullable)


@dt.dtype.register(PGDialect, postgresql.INET)
def sa_inet(_, satype, nullable=True):
return dt.INET(nullable=nullable)
Expand All @@ -273,11 +317,26 @@ def sa_json(_, satype, nullable=True):
return dt.JSON(nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.TIMESTAMP)
def sa_mysql_timestamp(_, satype, nullable=True):
return dt.Timestamp(timezone="UTC", nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.DATETIME)
def sa_mysql_datetime(_, satype, nullable=True):
return dt.Timestamp(nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.SET)
def sa_mysql_set(_, satype, nullable=True):
return dt.Set(dt.string, nullable=nullable)


if geospatial_supported:

@dt.dtype.register(Dialect, (ga.Geometry, ga.types._GISType))
def ga_geometry(_, gatype, nullable=True):
t = gatype.geometry_type
t = gatype.geometry_type.upper()
if t == 'POINT':
return dt.Point(nullable=nullable)
if t == 'LINESTRING':
Expand All @@ -290,11 +349,20 @@ def ga_geometry(_, gatype, nullable=True):
return dt.MultiPoint(nullable=nullable)
if t == 'MULTIPOLYGON':
return dt.MultiPolygon(nullable=nullable)
if t == 'GEOMETRY':
return dt.Geometry(nullable=nullable)
if t in ('GEOMETRY', 'GEOGRAPHY'):
return getattr(dt, gatype.name.lower())(nullable=nullable)
else:
raise ValueError(f"Unrecognized geometry type: {t}")

@to_sqla_type.register(Dialect, dt.GeoSpatial)
def _(_, itype, **kwargs):
if itype.geotype == 'geometry':
return ga.Geometry
elif itype.geotype == 'geography':
return ga.Geography
else:
return ga.types._GISType


POSTGRES_FIELD_TO_IBIS_UNIT = {
"YEAR": "Y",
Expand Down Expand Up @@ -338,6 +406,18 @@ def sa_string(_, satype, nullable=True):


@dt.dtype.register(Dialect, sa.LargeBinary)
@dt.dtype.register(MSDialect, (mssql.BINARY, mssql.TIMESTAMP))
@dt.dtype.register(
MySQLDialect,
(
mysql.TINYBLOB,
mysql.MEDIUMBLOB,
mysql.BLOB,
mysql.LONGBLOB,
mysql.BINARY,
mysql.VARBINARY,
),
)
def sa_binary(_, satype, nullable=True):
return dt.Binary(nullable=nullable)

Expand All @@ -358,8 +438,22 @@ def sa_datetime(_, satype, nullable=True, default_timezone='UTC'):
return dt.Timestamp(timezone=timezone, nullable=nullable)


@dt.dtype.register(Dialect, sa.ARRAY)
def sa_array(dialect, satype, nullable=True):
@dt.dtype.register(MSDialect, mssql.DATETIMEOFFSET)
def _datetimeoffset(_, sa_type, nullable=True):
if (prec := sa_type.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, timezone="UTC", nullable=nullable)


@dt.dtype.register(MSDialect, mssql.DATETIME2)
def _datetime2(_, sa_type, nullable=True):
if (prec := sa_type.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, nullable=nullable)


@dt.dtype.register(PGDialect, sa.ARRAY)
def sa_pg_array(dialect, satype, nullable=True):
dimensions = satype.dimensions
if dimensions is not None and dimensions != 1:
raise NotImplementedError(
Expand All @@ -376,16 +470,27 @@ def sa_struct(dialect, satype, nullable=True):
return dt.Struct.from_tuples(pairs, nullable=nullable)


@dt.dtype.register(Dialect, ArrayType)
def sa_array(dialect, satype, nullable=True):
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)


@sch.infer.register((sa.Table, sa.sql.TableClause))
def schema_from_table(table: sa.Table, schema: sch.Schema | None = None) -> sch.Schema:
"""Retrieve an ibis schema from a SQLAlchemy ``Table``.
def schema_from_table(
table: sa.Table,
schema: sch.Schema | None = None,
dialect: sa.engine.interfaces.Dialect | None = None,
) -> sch.Schema:
"""Retrieve an ibis schema from a SQLAlchemy `Table`.
Parameters
----------
table
Table whose schema to infer
schema
Schema to pull types from
dialect
Optional sqlalchemy dialect
Returns
-------
Expand All @@ -399,7 +504,7 @@ def schema_from_table(table: sa.Table, schema: sch.Schema | None = None) -> sch.
dtype = dt.dtype(schema[name])
else:
dtype = dt.dtype(
getattr(table.bind, 'dialect', Dialect()),
dialect or getattr(table.bind, "dialect", DefaultDialect()),
column.type,
nullable=column.nullable,
)
Expand Down
73 changes: 27 additions & 46 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from sqlalchemy import sql

import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.base.sql.alchemy.database import AlchemyTable
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
from ibis.backends.base.sql.alchemy.translator import (
AlchemyContext,
AlchemyExprTranslator,
Expand All @@ -22,10 +20,6 @@
from ibis.backends.base.sql.compiler.base import SetOp


def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
return [sa.column(n, to_sqla_type(t)) for n, t in schema.items()]


class _AlchemyTableSetFormatter(TableSetFormatter):
def get_result(self):
# Got to unravel the join stack; the nesting order could be
Expand All @@ -42,7 +36,7 @@ def get_result(self):
for jtype, table, preds in zip(
self.join_types, self.join_tables[1:], self.join_predicates
):
if len(preds):
if preds:
sqla_preds = [self._translate(pred) for pred in preds]
onclause = functools.reduce(sql.and_, sqla_preds)
else:
Expand All @@ -59,9 +53,23 @@ def get_result(self):
elif jtype is ops.OuterJoin:
result = result.outerjoin(table, onclause, full=True)
elif jtype is ops.LeftSemiJoin:
result = result.select().where(sa.exists(sa.select(1).where(onclause)))
# subquery is required for semi and anti joins done using
# sqlalchemy, otherwise multiple references to the original
# select are treated as distinct tables
#
# with a subquery, the result is a distinct table and so there's only one
# thing for subsequent expressions to reference
result = (
result.select()
.where(sa.exists(sa.select(1).where(onclause)))
.subquery()
)
elif jtype is ops.LeftAntiJoin:
result = result.select().where(~sa.exists(sa.select(1).where(onclause)))
result = (
result.select()
.where(~sa.exists(sa.select(1).where(onclause)))
.subquery()
)
else:
raise NotImplementedError(jtype)

Expand All @@ -80,32 +88,32 @@ def _format_table(self, op):

alias = ctx.get_ref(op)

translator = ctx.compiler.translator_class(ref_op, ctx)

if isinstance(ref_op, AlchemyTable):
result = ref_op.sqla_table
elif isinstance(ref_op, ops.UnboundTable):
# use SQLAlchemy's TableClause for unbound tables
result = sa.table(
ref_op.name,
*_schema_to_sqlalchemy_columns(ref_op.schema),
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
)
elif isinstance(ref_op, ops.SQLQueryResult):
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.text(ref_op.query).columns(*columns)
elif isinstance(ref_op, ops.SQLStringView):
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name)
elif isinstance(ref_op, ops.View):
# TODO(kszucs): avoid converting to expression
child_expr = ref_op.child.to_expr()
definition = child_expr.compile()
result = sa.table(
ref_op.name,
*_schema_to_sqlalchemy_columns(ref_op.schema),
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
)
backend = child_expr._find_backend()
backend._create_temp_view(view=result, definition=definition)
elif isinstance(ref_op, ops.InMemoryTable):
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)

if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
Expand Down Expand Up @@ -214,10 +222,8 @@ def _add_select(self, table_set):
else:
clauses = to_select

if self.exists:
result = sa.exists(clauses)
else:
result = sa.select(clauses)
result_func = sa.exists if self.exists else sa.select
result = result_func(*clauses)

if self.distinct:
result = result.distinct()
Expand All @@ -227,32 +233,7 @@ def _add_select(self, table_set):
if has_select_star or table_set is None:
return result

# if we're selecting from something that isn't a subquery e.g., Select,
# Alias, Table
if not isinstance(table_set, sa.sql.Subquery):
return result.select_from(table_set)

final_froms = result.get_final_froms()
num_froms = len(final_froms)

# if the result subquery has no FROMs then we can select from the
# table_set since there's only a single possibility for FROM
if not num_froms:
return result.select_from(table_set)

# we need to replace every occurrence of `result`'s `FROM`
# with `table_set` to handle correlated EXISTs coming from
# semi/anti-join
#
# previously this was `replace_selectable`, but that's deprecated so we
# inline its implementation here
#
# sqlalchemy suggests using the functionality in sa.sql.visitors, but
# that would effectively require reimplementing ClauseAdapter
replaced = sa.sql.util.ClauseAdapter(table_set).traverse(result)
num_froms = len(replaced.get_final_froms())
assert num_froms == 1, f"num_froms == {num_froms:d}"
return replaced
return result.select_from(table_set)

def _add_group_by(self, fragment):
# GROUP BY and HAVING
Expand Down
78 changes: 44 additions & 34 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,27 @@ def get_sqla_table(ctx, table):
return sa_table


def get_col_or_deferred_col(sa_table, colname):
"""Get a `Column`, or create a "deferred" column.
This is to handle the case when selecting a column from a join, which
happens when a join expression is cached during join traversal
We'd like to avoid generating a subquery just for selection but in
sqlalchemy the Join object is not selectable. However, at this point
know that the column can be referred to unambiguously
Later the expression is assembled into
`sa.select([sa.column(colname)]).select_from(table_set)` (roughly)
where `table_set` is `sa_table` above.
"""
try:
return sa_table.exported_columns[colname]
except KeyError:
return sa.column(colname)
def get_col(sa_table, op: ops.TableColumn) -> sa.sql.ColumnClause:
"""Extract a column from a table."""
cols = sa_table.exported_columns
colname = op.name

if (col := cols.get(colname)) is not None:
return col

# `cols` is a SQLAlchemy column collection that contains columns
# with names that are secretly prefixed by table that contains them
#
# for example, in `t0.join(t1).select(t0.a, t1.b)` t0.a will be named `t0_a`
# and t1.b will be named `t1_b`
#
# unfortunately SQLAlchemy doesn't let you select by the *un*prefixed
# column name despite the uniqueness of `colname`
#
# however, in ibis we have already deduplicated column names so we can
# refer to the name by position
colindex = op.table.schema._name_locs[colname]
return cols[colindex]


def _table_column(t, op):
Expand All @@ -107,7 +110,7 @@ def _table_column(t, op):

sa_table = get_sqla_table(ctx, table)

out_expr = get_col_or_deferred_col(sa_table, op.name)
out_expr = get_col(sa_table, op)
out_expr.quote = t._always_quote_columns

# If the column does not originate from the table set in the current SELECT
Expand All @@ -123,26 +126,33 @@ def _table_column(t, op):


def _table_array_view(t, op):
# the table that the TableArrayView op contains (op.table) has
# one or more input relations that we need to "pin" for sqlalchemy's
# auto correlation functionality -- this is what `.correlate_except` does
#
# every relation that is NOT passed to `correlate_except` is considered an
# outer-query table
ctx = t.context
table = ctx.get_compiled_expr(op.table)
return table
# TODO: handle the case of `op.table` being a join
first, *_ = an.find_immediate_parent_tables(op.table, keep_input=False)
ref = ctx.get_ref(first)
return table.correlate_except(ref)


def _exists_subquery(t, op):
from ibis.backends.base.sql.alchemy.query_builder import AlchemyCompiler

ctx = t.context

# TODO(kszucs): avoid converting the predicates to expressions
# this should be done by the rewrite step before compilation
filtered = (
op.foreign_table.to_expr()
.filter([pred.to_expr() for pred in op.predicates])
.projection([ir.literal(1).name(ir.core.unnamed)])
.projection([ir.literal(1).name("")])
)

sub_ctx = ctx.subcontext()
clause = AlchemyCompiler.to_sql(filtered, sub_ctx, exists=True)
clause = ctx.compiler.to_sql(filtered, sub_ctx, exists=True)

if isinstance(op, ops.NotExistsSubquery):
clause = sa.not_(clause)
Expand All @@ -153,18 +163,18 @@ def _exists_subquery(t, op):
def _cast(t, op):
arg = op.arg
typ = op.to
arg_dtype = arg.output_dtype

sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(typ)

if isinstance(arg, ir.CategoryValue) and typ == dt.int32:
if arg_dtype.is_category() and typ.is_int32():
return sa_arg

# specialize going from an integer type to a timestamp
if arg.output_dtype.is_integer() and isinstance(sa_type, sa.DateTime):
if arg_dtype.is_integer() and typ.is_timestamp():
return t.integer_to_timestamp(sa_arg)

if arg.output_dtype.is_binary() and typ.is_string():
if arg_dtype.is_binary() and typ.is_string():
return sa.func.encode(sa_arg, 'escape')

if typ.is_binary():
Expand All @@ -174,7 +184,8 @@ def _cast(t, op):

if typ.is_json() and not t.native_json_type:
return sa_arg
return sa.cast(sa_arg, sa_type)

return sa.cast(sa_arg, t.get_sqla_type(typ))


def _contains(func):
Expand Down Expand Up @@ -220,7 +231,7 @@ def _is_null(t, op):

def _not_null(t, op):
arg = t.translate(op.arg)
return arg.isnot(sa.null())
return arg.is_not(sa.null())


def _round(t, op):
Expand Down Expand Up @@ -257,7 +268,7 @@ def _translate_case(t, cases, results, default):
whens = zip(case_args, result_args)
default = t.translate(default)

return sa.case(list(whens), else_=default)
return sa.case(*whens, else_=default)


def _negate(t, op):
Expand Down Expand Up @@ -400,7 +411,7 @@ def compile_expr(t, expr):
def _zero_if_null(t, op):
sa_arg = t.translate(op.arg)
return sa.case(
[(sa_arg.is_(None), sa.cast(0, t.get_sqla_type(op.output_dtype)))],
(sa_arg.is_(None), sa.cast(0, t.get_sqla_type(op.output_dtype))),
else_=sa_arg,
)

Expand Down Expand Up @@ -511,7 +522,6 @@ def translator(t, op: ops.Node):
ops.NotNull: _not_null,
ops.Negate: _negate,
ops.Round: _round,
ops.TypeOf: unary(sa.func.typeof),
ops.Literal: _literal,
ops.NullLiteral: lambda *_: sa.null(),
ops.SimpleCase: _simple_case,
Expand Down Expand Up @@ -597,7 +607,7 @@ def translator(t, op: ops.Node):
ops.Clip: _clip(min_func=sa.func.least, max_func=sa.func.greatest),
ops.Where: fixed_arity(
lambda predicate, value_if_true, value_if_false: sa.case(
[(predicate, value_if_true)],
(predicate, value_if_true),
else_=value_if_false,
),
3,
Expand Down
22 changes: 19 additions & 3 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import functools

import sqlalchemy as sa

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.datatypes import ibis_type_to_sqla, to_sqla_type
from ibis.backends.base.sql.alchemy import to_sqla_type
from ibis.backends.base.sql.alchemy.datatypes import _DEFAULT_DIALECT
from ibis.backends.base.sql.alchemy.registry import (
fixed_arity,
sqlalchemy_operation_registry,
Expand Down Expand Up @@ -35,7 +38,6 @@ def subcontext(self):
class AlchemyExprTranslator(ExprTranslator):
_registry = sqlalchemy_operation_registry
_rewrites = ExprTranslator._rewrites.copy()
_type_map = ibis_type_to_sqla

context_class = AlchemyContext

Expand All @@ -54,11 +56,25 @@ class AlchemyExprTranslator(ExprTranslator):
ops.CumeDist,
)

_dialect_name = "default"

@functools.cached_property
def dialect(self) -> sa.engine.interfaces.Dialect:
if (name := self._dialect_name) == "default":
return _DEFAULT_DIALECT
dialect_cls = sa.dialects.registry.load(name)
return dialect_cls()

def _schema_to_sqlalchemy_columns(self, schema):
return [
sa.column(name, self.get_sqla_type(dtype)) for name, dtype in schema.items()
]

def name(self, translated, name, force=True):
return translated.label(name)

def get_sqla_type(self, data_type):
return to_sqla_type(data_type, type_map=self._type_map)
return to_sqla_type(self.dialect, data_type)

def _maybe_cast_bool(self, op, arg):
if (
Expand Down
14 changes: 3 additions & 11 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _format_table(self, op):

if isinstance(ref_op, ops.InMemoryTable):
result = self._format_in_memory_table(ref_op)
is_subquery = True
elif isinstance(ref_op, ops.PhysicalTable):
name = ref_op.name
# TODO(kszucs): add a mandatory `name` field to the base
Expand All @@ -102,7 +101,6 @@ def _format_table(self, op):
if name is None:
raise com.RelationError(f'Table did not have a name: {op!r}')
result = self._quote_identifier(name)
is_subquery = False
else:
# A subquery
if ctx.is_extracted(ref_op):
Expand All @@ -118,10 +116,8 @@ def _format_table(self, op):

subquery = ctx.get_compiled_expr(op)
result = f'(\n{util.indent(subquery, self.indent)}\n)'
is_subquery = True

if is_subquery or ctx.need_aliases(op):
result += f' {ctx.get_ref(op)}'
result += f' {ctx.get_ref(op)}'

return result

Expand Down Expand Up @@ -302,12 +298,8 @@ def format_select_set(self):
if isinstance(node, ops.Value):
expr_str = self._translate(node, named=True)
elif isinstance(node, ops.TableNode):
# A * selection, possibly prefixed
if context.need_aliases(node):
alias = context.get_ref(node)
expr_str = f'{alias}.*' if alias else '*'
else:
expr_str = '*'
alias = context.get_ref(node)
expr_str = f'{alias}.*' if alias else '*'
else:
raise TypeError(node)
formatted.append(expr_str)
Expand Down
96 changes: 1 addition & 95 deletions ibis/backends/base/sql/compiler/select_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,6 @@ class _LimitSpec(NamedTuple):
offset: int


class _CorrelatedRefCheck:
def __init__(self, query, node):
self.query = query
self.ctx = query.context
self.node = node
self.query_roots = frozenset(
an.find_immediate_parent_tables(self.query.table_set)
)
self.has_foreign_root = False
self.has_query_root = False
self.seen = set()

def get_result(self):
self.visit(self.node, in_subquery=False)
return self.has_query_root and self.has_foreign_root

def visit(self, node, in_subquery):
if node in self.seen:
return

in_subquery |= self.is_subquery(node)

for arg in node.args:
if isinstance(arg, ops.TableNode):
self.visit_table(arg, in_subquery=in_subquery)
elif isinstance(arg, ops.Node):
self.visit(arg, in_subquery=in_subquery)
elif isinstance(arg, tuple):
for item in arg:
self.visit(item, in_subquery=in_subquery)

self.seen.add(node)

def is_subquery(self, node):
return isinstance(
node,
(
ops.TableArrayView,
ops.ExistsSubquery,
ops.NotExistsSubquery,
),
) or (isinstance(node, ops.TableColumn) and not self.is_root(node.table))

def visit_table(self, node, in_subquery):
if isinstance(node, (ops.PhysicalTable, ops.SelfReference)):
self.ref_check(node, in_subquery=in_subquery)

for arg in node.args:
if isinstance(arg, tuple):
for item in arg:
self.visit(item, in_subquery=in_subquery)
elif isinstance(arg, ops.Node):
self.visit(arg, in_subquery=in_subquery)

def ref_check(self, node, in_subquery) -> None:
ctx = self.ctx

is_root = self.is_root(node)

self.has_query_root |= is_root and in_subquery
self.has_foreign_root |= not is_root and in_subquery

if (
not is_root
and not ctx.has_ref(node)
and (not in_subquery or ctx.has_ref(node, parent_contexts=True))
):
ctx.make_alias(node)

def is_root(self, what: ops.TableNode) -> bool:
return what in self.query_roots


def _get_scalar(field):
def scalar_handler(results):
return results[field][0]
Expand Down Expand Up @@ -149,11 +76,6 @@ def to_select(

return select_query

@staticmethod
def _foreign_ref_check(query, expr):
checker = _CorrelatedRefCheck(query, expr)
return checker.get_result()

@staticmethod
def _adapt_operation(node):
# Non-table expressions need to be adapted to some well-formed table
Expand All @@ -177,7 +99,7 @@ def _adapt_operation(node):
table_expr = node.table.to_expr()[[node.name]]
result_handler = _get_column(node.name)
else:
table_expr = node.to_expr().to_projection()
table_expr = node.to_expr().as_table()
result_handler = _get_column(node.name)

return table_expr.op(), result_handler
Expand Down Expand Up @@ -215,22 +137,6 @@ def _populate_context(self):
if self.table_set is not None:
self._make_table_aliases(self.table_set)

# XXX: This is a temporary solution to the table-aliasing / correlated
# subquery problem. Will need to revisit and come up with a cleaner
# design (also as one way to avoid pathological naming conflicts; for
# example, we could define a table alias before we know that it
# conflicts with the name of a table used in a subquery, join, or
# another part of the query structure)

# There may be correlated subqueries inside the filters, requiring that
# we use an explicit alias when outputting as SQL. For now, we're just
# going to see if any table nodes appearing in the where stack have
# been marked previously by the above code.
for expr in self.filters:
needs_alias = self._foreign_ref_check(self, expr)
if needs_alias:
self.context.set_always_alias()

# TODO(kszucs): should be rewritten using lin.traverse()
def _make_table_aliases(self, node):
ctx = self.context
Expand Down
22 changes: 12 additions & 10 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.registry import operation_registry, quote_identifier
from ibis.expr.types.core import unnamed


class QueryContext:
Expand All @@ -25,7 +24,7 @@ def __init__(self, compiler, indent=2, parent=None, params=None):
self.subquery_memo = {}
self.indent = indent
self.parent = parent
self.always_alias = False
self.always_alias = True
self.query = None
self.params = params if params is not None else {}

Expand Down Expand Up @@ -93,9 +92,6 @@ def make_alias(self, node):
alias = f't{i:d}'
self.set_ref(node, alias)

def need_aliases(self, expr=None):
return self.always_alias or len(self.table_refs) > 1

def _contexts(
self,
*,
Expand All @@ -115,14 +111,20 @@ def has_ref(self, node, parent_contexts=False):
def set_ref(self, node, alias):
self.table_refs[node] = alias

def get_ref(self, node):
def get_ref(self, node, search_parents=False):
"""Return the alias used to refer to an expression."""
assert isinstance(node, ops.Node)
assert isinstance(node, ops.Node), type(node)

if self.is_extracted(node):
return self.top_context.table_refs.get(node)

return self.table_refs.get(node)
if (ref := self.table_refs.get(node)) is not None:
return ref

if search_parents and (parent := self.parent) is not None:
return parent.get_ref(node, search_parents=search_parents)

return None

def is_extracted(self, node):
return node in self.top_context.extracted_subexprs
Expand Down Expand Up @@ -199,7 +201,7 @@ def _needs_name(self, op):
# This column has been given an explicitly different name
return False

return op.name is not unnamed
return bool(op.name)

def name(self, translated, name, force=True):
return f'{translated} AS {quote_identifier(name, force=force)}'
Expand Down Expand Up @@ -248,7 +250,7 @@ def _trans_param(self, op):
if dtype.is_struct():
literal = ibis.struct(raw_value, type=dtype)
elif dtype.is_map():
literal = ibis.map(raw_value, type=dtype)
literal = ibis.map(list(raw_value.keys()), list(raw_value.values()))
else:
literal = ibis.literal(raw_value, type=dtype)
return self.translate(literal.op())
Expand Down
21 changes: 10 additions & 11 deletions ibis/backends/base/sql/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,16 @@ def _pieces(self):
main_schema = self.schema
part_schema = self.partition
if not isinstance(part_schema, sch.Schema):
part_schema = sch.Schema(
part_schema, [self.schema[name] for name in part_schema]
)

to_delete = []
for name in self.partition:
if name in self.schema:
to_delete.append(name)

if len(to_delete):
main_schema = main_schema.delete(to_delete)
part_fields = {name: self.schema[name] for name in part_schema}
part_schema = sch.Schema(part_fields)

to_delete = {name for name in self.partition if name in self.schema}
fields = {
name: dtype
for name, dtype in main_schema.items()
if name not in to_delete
}
main_schema = sch.Schema(fields)

yield format_schema(main_schema)
yield f'PARTITIONED BY {format_schema(part_schema)}'
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/binary_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def translate(translator, op):
ctx.is_foreign_expr(leaf)
for leaf in an.find_immediate_parent_tables(op.options)
):
array = op.options.to_expr().to_projection().to_array().op()
array = op.options.to_expr().as_table().to_array().op()
right = table_array_view(translator, array)
else:
right = translator.translate(op.options)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def literal(translator, op):
elif dtype.is_set():
typeclass = 'set'
else:
raise NotImplementedError
raise NotImplementedError(f'Unsupported type: {dtype!r}')

return literal_formatters[typeclass](translator, op)

Expand Down
20 changes: 7 additions & 13 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,17 @@ def table_column(translator, op):
proj_expr = op.table.to_expr().projection([op.name]).to_array().op()
return table_array_view(translator, proj_expr)

if ctx.need_aliases():
alias = ctx.get_ref(op.table)
if alias is not None:
quoted_name = f'{alias}.{quoted_name}'
alias = ctx.get_ref(op.table, search_parents=True)
if alias is not None:
quoted_name = f"{alias}.{quoted_name}"

return quoted_name


def exists_subquery(translator, op):
ctx = translator.context

dummy = ir.literal(1).name(ir.core.unnamed)
dummy = ir.literal(1).name("")

filtered = op.foreign_table.to_expr().filter(
[pred.to_expr() for pred in op.predicates]
Expand All @@ -175,14 +174,8 @@ def exists_subquery(translator, op):

subquery = ctx.get_compiled_expr(node)

if isinstance(op, ops.ExistsSubquery):
key = 'EXISTS'
elif isinstance(op, ops.NotExistsSubquery):
key = 'NOT EXISTS'
else:
raise NotImplementedError

return f'{key} (\n{util.indent(subquery, ctx.indent)}\n)'
prefix = "NOT " * isinstance(op, ops.NotExistsSubquery)
return f'{prefix}EXISTS (\n{util.indent(subquery, ctx.indent)}\n)'


# XXX this is not added to operation_registry, but looks like impala is
Expand Down Expand Up @@ -403,5 +396,6 @@ def count_star(translator, op):
ops.DayOfWeekName: timestamp.day_of_week_name,
ops.Strftime: timestamp.strftime,
ops.SortKey: sort_key,
ops.TypeOf: unary('typeof'),
**binary_infix_ops,
}
87 changes: 61 additions & 26 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import warnings
from typing import TYPE_CHECKING, Any, Mapping
from urllib.parse import parse_qs, urlparse

import google.auth.credentials
Expand All @@ -17,7 +18,6 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.bigquery import version as ibis_bigquery_version
from ibis.backends.bigquery.client import (
BigQueryCursor,
BigQueryDatabase,
Expand All @@ -32,7 +32,8 @@
with contextlib.suppress(ImportError):
from ibis.backends.bigquery.udf import udf # noqa: F401

__version__: str = ibis_bigquery_version.__version__
if TYPE_CHECKING:
import pyarrow as pa

SCOPES = ["https://www.googleapis.com/auth/bigquery"]
EXTERNAL_DATA_SCOPES = [
Expand Down Expand Up @@ -72,7 +73,7 @@ def _from_url(self, url):
dataset_id=result.path[1:] or params.get("dataset_id", [""])[0],
)

def connect(
def do_connect(
self,
project_id: str | None = None,
dataset_id: str = "",
Expand All @@ -82,7 +83,7 @@ def connect(
auth_external_data: bool = False,
auth_cache: str = "default",
partition_column: str | None = "PARTITIONTIME",
) -> "Backend":
):
"""Create a :class:`Backend` for use with Ibis.
Parameters
Expand Down Expand Up @@ -164,22 +165,18 @@ def connect(

project_id = project_id or default_project_id

new_backend = self.__class__()

(
new_backend.data_project,
new_backend.billing_project,
new_backend.dataset,
self.data_project,
self.billing_project,
self.dataset,
) = parse_project_and_dataset(project_id, dataset_id)

new_backend.client = bq.Client(
project=new_backend.billing_project,
self.client = bq.Client(
project=self.billing_project,
credentials=credentials,
client_info=_create_client_info(application_name),
)
new_backend.partition_column = partition_column

return new_backend
self.partition_column = partition_column

def _parse_project_and_dataset(self, dataset) -> tuple[str, str]:
if not dataset and not self.dataset:
Expand Down Expand Up @@ -219,22 +216,16 @@ def _fully_qualified_name(self, name, database):
def _get_schema_using_query(self, query):
job_config = bq.QueryJobConfig(dry_run=True, use_query_cache=False)
job = self.client.query(query, job_config=job_config)
names, ibis_types = self._adapt_types(job.schema)
return sch.Schema(names, ibis_types)
fields = self._adapt_types(job.schema)
return sch.Schema(fields)

def _get_table_schema(self, qualified_name):
dataset, table = qualified_name.rsplit(".", 1)
assert dataset is not None, "dataset is None"
return self.get_schema(table, database=dataset)

def _adapt_types(self, descr):
names = []
adapted_types = []
for col in descr:
names.append(col.name)
typename = bigquery_field_to_ibis_dtype(col)
adapted_types.append(typename)
return names, adapted_types
return {col.name: bigquery_field_to_ibis_dtype(col) for col in descr}

def _execute(self, stmt, results=True, query_parameters=None):
job_config = bq.job.QueryJobConfig()
Expand Down Expand Up @@ -357,22 +348,66 @@ def exists_table(self, name: str, database: str | None = None) -> bool:
return True

def fetch_from_cursor(self, cursor, schema):
arrow_t = self._cursor_to_arrow(cursor)
df = arrow_t.to_pandas(timestamp_as_object=True)
return schema.apply_to(df)

def _cursor_to_arrow(self, cursor):
query = cursor.query
query_result = query.result()
# workaround potentially not having the ability to create read sessions
# in the dataset project
orig_project = query_result._project
query_result._project = self.billing_project
try:
arrow_t = query_result.to_arrow(
arrow_table = query_result.to_arrow(
progress_bar_type=None,
bqstorage_client=None,
create_bqstorage_client=True,
)
finally:
query_result._project = orig_project
df = arrow_t.to_pandas(timestamp_as_object=True)
return schema.apply_to(df)
return arrow_table

def to_pyarrow(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
) -> pa.Table:
self._import_pyarrow()
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
cursor = self.raw_sql(sql, params=params, **kwargs)
table = self._cursor_to_arrow(cursor)
if isinstance(expr, ir.Scalar):
assert len(table.columns) == 1, "len(table.columns) != 1"
return table[0][0]
elif isinstance(expr, ir.Column):
assert len(table.columns) == 1, "len(table.columns) != 1"
return table[0]
else:
return table

def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1_000_000,
**kwargs: Any,
):
self._import_pyarrow()

# kind of pointless, but it'll work if there's enough memory
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
cursor = self.raw_sql(sql, params=params, **kwargs)
table = self._cursor_to_arrow(cursor)
return table.to_reader(chunk_size)

def get_schema(self, name, database=None):
table_id = self._fully_qualified_name(name, database)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def bigquery_field_to_ibis_dtype(field):
assert fields, "RECORD fields are empty"
names = [el.name for el in fields]
ibis_types = list(map(dt.dtype, fields))
ibis_type = dt.Struct(names, ibis_types)
ibis_type = dt.Struct(dict(zip(names, ibis_types)))
else:
ibis_type = _LEGACY_TO_STANDARD.get(typ, typ)
ibis_type = _DTYPE_TO_IBIS_TYPE.get(ibis_type, ibis_type)
Expand Down Expand Up @@ -119,7 +119,7 @@ def bigquery_param(dtype, value, name):

@bigquery_param.register
def bq_param_struct(dtype: dt.Struct, value, name):
fields = dtype.pairs
fields = dtype.fields
field_params = [bigquery_param(fields[k], v, k) for k, v in value.items()]
result = bq.StructQueryParameter(name, *field_params)
return result
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, expr, context):

def compile(self):
"""Generate UDF string from definition."""
return self.expr.op().js
return self.expr.op().sql


class BigQueryUnion(sql_compiler.Union):
Expand Down Expand Up @@ -52,6 +52,8 @@ def keyword(cls, distinct):

def find_bigquery_udf(op):
"""Filter which includes only UDFs from expression tree."""
if type(op) in BigQueryExprTranslator._rewrites:
op = BigQueryExprTranslator._rewrites[type(op)](op)
if isinstance(op, operations.BigQueryUDFNode):
result = op
else:
Expand Down Expand Up @@ -118,3 +120,7 @@ def _generate_setup_queries(expr, context):
# UDFs are uniquely identified by the name of the Node subclass we
# generate.
return list(toolz.unique(queries, key=lambda x: type(x.expr.op()).__name__))


# Register custom UDFs
import ibis.backends.bigquery.custom_udfs # noqa: F401, E402
39 changes: 39 additions & 0 deletions ibis/backends/bigquery/custom_udfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.bigquery.compiler import BigQueryExprTranslator
from ibis.backends.bigquery.udf import udf

# Based on:
# https://github.com/GoogleCloudPlatform/bigquery-utils/blob/45e1ac51367ab6209f68e04b1660d5b00258c131/udfs/community/typeof.sqlx#L1
typeof_ = udf.sql(
name="typeof",
params={"input": 'ANY TYPE'},
output_type=dt.str,
sql_expression=r"""
(
SELECT
CASE
-- Process NUMERIC, DATE, DATETIME, TIME, TIMESTAMP,
WHEN REGEXP_CONTAINS(literal, r'^[A-Z]+ "') THEN REGEXP_EXTRACT(literal, r'^([A-Z]+) "')
WHEN REGEXP_CONTAINS(literal, r'^-?[0-9]*$') THEN 'INT64'
WHEN
REGEXP_CONTAINS(literal, r'^(-?[0-9]+[.e].*|CAST\("([^"]*)" AS FLOAT64\))$')
THEN
'FLOAT64'
WHEN literal IN ('true', 'false') THEN 'BOOL'
WHEN literal LIKE '"%' OR literal LIKE "'%" THEN 'STRING'
WHEN literal LIKE 'b"%' THEN 'BYTES'
WHEN literal LIKE '[%' THEN 'ARRAY'
WHEN REGEXP_CONTAINS(literal, r'^(STRUCT)?\(') THEN 'STRUCT'
WHEN literal LIKE 'ST_%' THEN 'GEOGRAPHY'
WHEN literal = 'NULL' THEN 'NULL'
ELSE
'UNKNOWN'
END
FROM
UNNEST([FORMAT('%T', input)]) AS literal
)
""",
)

BigQueryExprTranslator.rewrites(ops.TypeOf)(lambda op: typeof_(op.arg).op())
106 changes: 40 additions & 66 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,6 @@

import ibis.expr.datatypes as dt


class TypeTranslationContext:
"""A tag class to alter the way a type is translated.
This is used to raise an exception when INT64 types are encountered to
avoid suprising results due to BigQuery's handling of INT64 types in
JavaScript UDFs.
"""

__slots__ = ()


class UDFContext(TypeTranslationContext):
__slots__ = ()


ibis_type_to_bigquery_type = Dispatcher("ibis_type_to_bigquery_type")


Expand All @@ -28,81 +12,60 @@ def trans_string_default(datatype):
return ibis_type_to_bigquery_type(dt.dtype(datatype))


@ibis_type_to_bigquery_type.register(dt.DataType)
def trans_default(t):
return ibis_type_to_bigquery_type(t, TypeTranslationContext())


@ibis_type_to_bigquery_type.register(str, TypeTranslationContext)
def trans_string_context(datatype, context):
return ibis_type_to_bigquery_type(dt.dtype(datatype), context)


@ibis_type_to_bigquery_type.register(dt.Floating, TypeTranslationContext)
def trans_float64(t, context):
@ibis_type_to_bigquery_type.register(dt.Floating)
def trans_float64(t):
return "FLOAT64"


@ibis_type_to_bigquery_type.register(dt.Integer, TypeTranslationContext)
def trans_integer(t, context):
@ibis_type_to_bigquery_type.register(dt.Integer)
def trans_integer(t):
return "INT64"


@ibis_type_to_bigquery_type.register(dt.Binary, TypeTranslationContext)
def trans_binary(t, context):
@ibis_type_to_bigquery_type.register(dt.Binary)
def trans_binary(t):
return "BYTES"


@ibis_type_to_bigquery_type.register(dt.UInt64, (TypeTranslationContext, UDFContext))
def trans_lossy_integer(t, context):
@ibis_type_to_bigquery_type.register(dt.UInt64)
def trans_lossy_integer(t):
raise TypeError("Conversion from uint64 to BigQuery integer type (int64) is lossy")


@ibis_type_to_bigquery_type.register(dt.Array, TypeTranslationContext)
def trans_array(t, context):
return f"ARRAY<{ibis_type_to_bigquery_type(t.value_type, context)}>"
@ibis_type_to_bigquery_type.register(dt.Array)
def trans_array(t):
return f"ARRAY<{ibis_type_to_bigquery_type(t.value_type)}>"


@ibis_type_to_bigquery_type.register(dt.Struct, TypeTranslationContext)
def trans_struct(t, context):
@ibis_type_to_bigquery_type.register(dt.Struct)
def trans_struct(t):
return "STRUCT<{}>".format(
", ".join(
f"{name} {ibis_type_to_bigquery_type(dt.dtype(type), context)}"
for name, type in zip(t.names, t.types)
f"{name} {ibis_type_to_bigquery_type(dt.dtype(type_))}"
for name, type_ in t.fields.items()
)
)


@ibis_type_to_bigquery_type.register(dt.Date, TypeTranslationContext)
def trans_date(t, context):
@ibis_type_to_bigquery_type.register(dt.Date)
def trans_date(t):
return "DATE"


@ibis_type_to_bigquery_type.register(dt.Timestamp, TypeTranslationContext)
def trans_timestamp(t, context):
@ibis_type_to_bigquery_type.register(dt.Timestamp)
def trans_timestamp(t):
if t.timezone is not None:
raise TypeError("BigQuery does not support timestamps with timezones")
return "TIMESTAMP"


@ibis_type_to_bigquery_type.register(dt.DataType, TypeTranslationContext)
def trans_type(t, context):
@ibis_type_to_bigquery_type.register(dt.DataType)
def trans_type(t):
return str(t).upper()


@ibis_type_to_bigquery_type.register(dt.Integer, UDFContext)
def trans_integer_udf(t, context):
# JavaScript does not have integers, only a Number class. BigQuery doesn't
# behave as expected with INT64 inputs or outputs
raise TypeError(
"BigQuery does not support INT64 as an argument type or a return type "
"for UDFs. Replace INT64 with FLOAT64 in your UDF signature and "
"cast all INT64 inputs to FLOAT64."
)


@ibis_type_to_bigquery_type.register(dt.Decimal, TypeTranslationContext)
def trans_numeric(t, context):
@ibis_type_to_bigquery_type.register(dt.Decimal)
def trans_numeric(t):
if (t.precision, t.scale) != (38, 9):
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
Expand All @@ -111,11 +74,22 @@ def trans_numeric(t, context):
return "NUMERIC"


@ibis_type_to_bigquery_type.register(dt.Decimal, UDFContext)
def trans_numeric_udf(t, context):
raise TypeError("Decimal types are not supported in BigQuery UDFs")
@ibis_type_to_bigquery_type.register(dt.JSON)
def trans_json(t):
return "JSON"


@ibis_type_to_bigquery_type.register(dt.JSON, TypeTranslationContext)
def trans_json(t, context):
return "JSON"
def spread_type(dt: dt.DataType):
"""Returns a generator that contains all the types in the given type.
For complex types like set and array, it returns the types of the elements.
"""
if dt.is_array():
yield from spread_type(dt.value_type)
elif dt.is_struct():
for type_ in dt.types:
yield from spread_type(type_)
elif dt.is_map():
yield from spread_type(dt.key_type)
yield from spread_type(dt.value_type)
yield dt
26 changes: 26 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def _literal(translator, op):
return "FROM_BASE64('{}')".format(
base64.b64encode(value).decode(encoding="utf-8")
)
elif dtype.is_struct():
cols = (
f'{translator.translate(ops.Literal(op.value[name], dtype=type_))} AS {name}'
for name, type_ in zip(dtype.names, dtype.types)
)
return "STRUCT({})".format(", ".join(cols))

try:
return literal(translator, op)
Expand Down Expand Up @@ -478,6 +484,11 @@ def _nullifzero(t, op):
return f"NULLIF({t.translate(op.arg)}, {casted})"


def _zeroifnull(t, op):
casted = bigquery_cast('0', op.output_dtype)
return f"COALESCE({t.translate(op.arg)}, {casted})"


def _array_agg(t, op):
arg = op.arg
if (where := op.where) is not None:
Expand Down Expand Up @@ -561,6 +572,7 @@ def _nth_value(t, op):
ops.IfNull: fixed_arity("IFNULL", 2),
ops.NullIf: fixed_arity("NULLIF", 2),
ops.NullIfZero: _nullifzero,
ops.ZeroIfNull: _zeroifnull,
ops.NotAny: bigquery_compile_notany,
ops.NotAll: bigquery_compile_notall,
# Reductions
Expand All @@ -575,6 +587,12 @@ def _nth_value(t, op):
ops.Clip: _clip,
ops.Degrees: lambda t, op: f"(180 * {t.translate(op.arg)} / ACOS(-1))",
ops.Radians: lambda t, op: f"(ACOS(-1) * {t.translate(op.arg)} / 180)",
ops.BitwiseNot: lambda t, op: f"~ {t.translate(op.arg)}",
ops.BitwiseXor: lambda t, op: f"{t.translate(op.left)} ^ {t.translate(op.right)}",
ops.BitwiseOr: lambda t, op: f"{t.translate(op.left)} | {t.translate(op.right)}",
ops.BitwiseAnd: lambda t, op: f"{t.translate(op.left)} & {t.translate(op.right)}",
ops.BitwiseLeftShift: lambda t, op: f"{t.translate(op.left)} << {t.translate(op.right)}",
ops.BitwiseRightShift: lambda t, op: f"{t.translate(op.left)} >> {t.translate(op.right)}",
# Temporal functions
ops.Date: unary("DATE"),
ops.DateFromYMD: fixed_arity("DATE", 3),
Expand Down Expand Up @@ -694,6 +712,14 @@ def _nth_value(t, op):
ops.FindInSet,
ops.DateDiff,
ops.TimestampDiff,
ops.ExtractAuthority,
ops.ExtractFile,
ops.ExtractFragment,
ops.ExtractHost,
ops.ExtractPath,
ops.ExtractProtocol,
ops.ExtractQuery,
ops.ExtractUserInfo,
}

OPERATION_REGISTRY = {
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _(typ: dt.Struct) -> Mapping[str, Any]:
return {
"field_type": "RECORD",
"mode": "NULLABLE" if typ.nullable else "REQUIRED",
"fields": ibis_schema_to_bq_schema(ibis.schema(typ.pairs)),
"fields": ibis_schema_to_bq_schema(ibis.schema(typ.fields)),
}


Expand Down Expand Up @@ -134,7 +134,8 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
make_job = lambda func, *a, **kw: func(*a, **kw).result()

futures = []
with concurrent.futures.ThreadPoolExecutor() as e:
# 10 is because of urllib3 connection pool size
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as e:
futures.append(
e.submit(
make_job,
Expand Down
18 changes: 9 additions & 9 deletions ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def test_different_partition_col_name(monkeypatch, client):

def test_subquery_scalar_params(alltypes, project_id, dataset_id):
expected = f"""\
SELECT count\\(`foo`\\) AS `count`
SELECT count\\(t0\\.`foo`\\) AS `count`
FROM \\(
SELECT `string_col`, sum\\(`float_col`\\) AS `foo`
SELECT t1\\.`string_col`, sum\\(t1\\.`float_col`\\) AS `foo`
FROM \\(
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`
WHERE `timestamp_col` < @param_\\d+
SELECT t2\\.`float_col`, t2\\.`timestamp_col`, t2\\.`int_col`, t2\\.`string_col`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` t2
WHERE t2\\.`timestamp_col` < @param_\\d+
\\) t1
GROUP BY 1
\\) t0"""
Expand Down Expand Up @@ -225,11 +225,11 @@ def test_cross_project_query(public):
expr = table[table.tags.contains("ibis")][["title", "tags"]]
result = expr.compile()
expected = """\
SELECT `title`, `tags`
SELECT t0.`title`, t0.`tags`
FROM (
SELECT *
FROM `bigquery-public-data.stackoverflow.posts_questions`
WHERE STRPOS(`tags`, 'ibis') - 1 >= 0
SELECT t1.*
FROM `bigquery-public-data.stackoverflow.posts_questions` t1
WHERE STRPOS(t1.`tags`, 'ibis') - 1 >= 0
) t0"""
assert result == expected
n = 5
Expand Down
38 changes: 35 additions & 3 deletions ibis/backends/bigquery/tests/system/udf/test_udf_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import pandas.testing as tm
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
Expand All @@ -25,7 +26,11 @@ def df(alltypes):


def test_udf(alltypes, df):
@udf(input_type=[dt.double, dt.double], output_type=dt.double)
@udf(
input_type=[dt.double, dt.double],
output_type=dt.double,
determinism=True,
)
def my_add(a, b):
return a + b

Expand Down Expand Up @@ -56,7 +61,7 @@ def __init__(self, width, height):

return Rectangle(a, b)

result = my_struct_thing.js
result = my_struct_thing.sql
snapshot.assert_match(result, "out.sql")

expr = my_struct_thing(alltypes.double_col, alltypes.double_col)
Expand Down Expand Up @@ -104,7 +109,7 @@ def my_str_len(s):
add = expr.op()

# generated javascript is identical
assert add.left.op().js == add.right.op().js
assert add.left.op().sql == add.right.op().sql
assert client.execute(expr) == 8.0


Expand Down Expand Up @@ -138,3 +143,30 @@ def my_array_len(x):

assert client.execute(my_str_len("aaa")) == 3
assert client.execute(my_array_len(["aaa", "bb"])) == 2


@pytest.mark.parametrize(
("argument_type",),
[
param(
dt.string,
id="string",
),
param(
"ANY TYPE",
id="string",
),
],
)
def test_udf_sql(client, argument_type):
format_t = udf.sql(
"format_t",
params={'input': argument_type},
output_type=dt.string,
sql_expression="FORMAT('%T', input)",
)

s = ibis.literal("abcd")
expr = format_t(s)

client.execute(expr)
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT APPROX_QUANTILES(if(`month` > 0, `double_col`, NULL), 2)[OFFSET(1)] AS `tmp`
FROM functional_alltypes
SELECT APPROX_QUANTILES(if(t0.`month` > 0, t0.`double_col`, NULL), 2)[OFFSET(1)] AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT APPROX_COUNT_DISTINCT(if(`month` > 0, `double_col`, NULL)) AS `tmp`
FROM functional_alltypes
SELECT APPROX_COUNT_DISTINCT(if(t0.`month` > 0, t0.`double_col`, NULL)) AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT APPROX_QUANTILES(`double_col`, 2)[OFFSET(1)] AS `tmp`
FROM functional_alltypes
SELECT APPROX_QUANTILES(t0.`double_col`, 2)[OFFSET(1)] AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT APPROX_COUNT_DISTINCT(`double_col`) AS `tmp`
FROM functional_alltypes
SELECT APPROX_COUNT_DISTINCT(t0.`double_col`) AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT CAST(`value` AS BYTES) AS `tmp`
FROM t
SELECT CAST(t0.`value` AS BYTES) AS `tmp`
FROM t t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT BIT_AND(if(`bigint_col` > 0, `int_col`, NULL)) AS `tmp`
FROM functional_alltypes
SELECT BIT_AND(if(t0.`bigint_col` > 0, t0.`int_col`, NULL)) AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT BIT_OR(if(`bigint_col` > 0, `int_col`, NULL)) AS `tmp`
FROM functional_alltypes
SELECT BIT_OR(if(t0.`bigint_col` > 0, t0.`int_col`, NULL)) AS `tmp`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT BIT_XOR(if(`bigint_col` > 0, `int_col`, NULL)) AS `tmp`
FROM functional_alltypes
SELECT BIT_XOR(if(t0.`bigint_col` > 0, t0.`int_col`, NULL)) AS `tmp`
FROM functional_alltypes t0
Loading