83 changes: 32 additions & 51 deletions ibis/backends/dask/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import Any

import dask
import pandas as pd
import pandas.testing as tm
import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.pandas.tests.conftest import TestConf as PandasTest
from ibis.backends.tests.data import array_types, win

if TYPE_CHECKING:
from pathlib import Path

dd = pytest.importorskip("dask.dataframe")
from ibis.backends.tests.data import array_types, json_types, win

# FIXME Dask issue with non deterministic groupby results, relates to the
# shuffle method on a local cluster. Manually setting the shuffle method
Expand All @@ -32,52 +28,35 @@ def npartitions():

class TestConf(PandasTest):
supports_structs = False
deps = ("dask.dataframe",)

@staticmethod
def connect(data_directory: Path):
# Note - we use `dd.from_pandas(pd.read_csv(...))` instead of
# `dd.read_csv` due to https://github.com/dask/dask/issues/6970

return ibis.dask.connect(
{
"functional_alltypes": dd.from_pandas(
pd.read_parquet(
data_directory / "parquet" / "functional_alltypes.parquet"
),
npartitions=NPARTITIONS,
),
"batting": dd.from_pandas(
pd.read_parquet(data_directory / "parquet" / "batting.parquet"),
npartitions=NPARTITIONS,
),
"awards_players": dd.from_pandas(
pd.read_parquet(
data_directory / "parquet" / "awards_players.parquet"
),
npartitions=NPARTITIONS,
),
'diamonds': dd.from_pandas(
pd.read_parquet(data_directory / "parquet" / "diamonds.parquet"),
npartitions=NPARTITIONS,
),
'json_t': dd.from_pandas(
pd.DataFrame(
{
"js": [
'{"a": [1,2,3,4], "b": 1}',
'{"a":null,"b":2}',
'{"a":"foo", "c":null}',
"null",
"[42,47,55]",
"[]",
]
}
),
npartitions=NPARTITIONS,
),
"win": dd.from_pandas(win, npartitions=NPARTITIONS),
"array_types": dd.from_pandas(array_types, npartitions=NPARTITIONS),
}
def connect(*, tmpdir, worker_id, **kw):
return ibis.dask.connect(**kw)

def _load_data(self, **_: Any) -> None:
import dask.dataframe as dd

con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.create_table(
table_name,
dd.from_pandas(pd.read_parquet(path), npartitions=NPARTITIONS),
)

con.create_table(
"array_types",
dd.from_pandas(array_types, npartitions=NPARTITIONS),
overwrite=True,
)
con.create_table(
"win", dd.from_pandas(win, npartitions=NPARTITIONS), overwrite=True
)
con.create_table(
"json_t",
dd.from_pandas(json_types, npartitions=NPARTITIONS),
overwrite=True,
)

@classmethod
Expand All @@ -93,6 +72,8 @@ def assert_series_equal(

@pytest.fixture
def dataframe(npartitions):
dd = pytest.importorskip("dask.dataframe")

return dd.from_pandas(
pd.DataFrame(
{
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/dask/tests/execution/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def df(npartitions):


@pytest.fixture(scope='module')
def batting_df(data_directory):
df = dd.read_parquet(data_directory / 'parquet' / 'batting.parquet')
def batting_df(data_dir):
df = dd.read_parquet(data_dir / 'parquet' / 'batting.parquet')
# Dask dataframe thinks the columns are of type int64,
# but when computed they are all float64.
non_float_cols = ['playerID', 'yearID', 'stint', 'teamID', 'lgID', 'G']
Expand All @@ -73,8 +73,8 @@ def batting_df(data_directory):


@pytest.fixture(scope='module')
def awards_players_df(data_directory):
return dd.read_parquet(data_directory / 'parquet' / 'awards_players.parquet')
def awards_players_df(data_dir):
return dd.read_parquet(data_dir / 'parquet' / 'awards_players.parquet')


@pytest.fixture(scope='module')
Expand Down
43 changes: 16 additions & 27 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero

if TYPE_CHECKING:
from pathlib import Path

pa = pytest.importorskip("pyarrow")


class TestConf(BackendTest, RoundAwayFromZero):
# check_names = False
Expand All @@ -20,35 +16,28 @@ class TestConf(BackendTest, RoundAwayFromZero):
supports_structs = False
supports_json = False
supports_arrays = False
stateful = False
deps = ("datafusion",)

def _load_data(self, **_: Any) -> None:
con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)

@staticmethod
def connect(data_directory: Path):
client = ibis.datafusion.connect({})
client.register(
data_directory / "parquet" / 'functional_alltypes.parquet',
table_name='functional_alltypes',
)
client.register(
data_directory / "parquet" / 'batting.parquet', table_name='batting'
)
client.register(
data_directory / "parquet" / 'awards_players.parquet',
table_name='awards_players',
)
client.register(
data_directory / "parquet" / 'diamonds.parquet', table_name='diamonds'
)
return client
def connect(*, tmpdir, worker_id, **kw):
return ibis.datafusion.connect(**kw)


@pytest.fixture(scope='session')
def client(data_directory):
return TestConf.connect(data_directory)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='session')
def alltypes(client):
return client.table("functional_alltypes")
def alltypes(con):
return con.table("functional_alltypes")


@pytest.fixture(scope='session')
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/datafusion/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ def conn():
return ibis.datafusion.connect()


def test_read_csv(conn, data_directory):
t = conn.read_csv(data_directory / "csv" / "functional_alltypes.csv")
def test_read_csv(conn, data_dir):
t = conn.read_csv(data_dir / "csv" / "functional_alltypes.csv")
assert t.count().execute()


def test_read_parquet(conn, data_directory):
t = conn.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_parquet(conn, data_dir):
t = conn.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")
assert t.count().execute()


Expand Down
42 changes: 13 additions & 29 deletions ibis/backends/druid/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import chain, repeat
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import pytest
from requests import Session

import ibis
from ibis.backends.tests.base import (
RoundHalfToEven,
ServiceBackendTest,
ServiceSpec,
)
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -102,17 +98,14 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):
native_bool = True
supports_structs = False
supports_json = False # it does, but we haven't implemented it
service_name = "druid-middlemanager"
deps = ("pydruid.db.sqlalchemy",)

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name="druid-middlemanager",
data_volume="/data",
files=data_dir.joinpath("parquet").glob("*.parquet"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("parquet").glob("*.parquet")

@staticmethod
def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
def _load_data(self, **_: Any) -> None:
"""Load test data into a druid backend instance.
Parameters
Expand All @@ -122,28 +115,19 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
script_dir
Location of scripts defining schemas
"""
# copy data into the volume mount
queries = filter(
None,
map(
str.strip,
(script_dir / "schema" / "druid.sql").read_text().split(";"),
),
)

# run queries concurrently using threads; lots of time is spent on IO
# making requests to check whether data loading is complete
with Session() as session, ThreadPoolExecutor() as executor:
for fut in as_completed(
executor.submit(run_query, session, query) for query in queries
executor.submit(run_query, session, query) for query in self.ddl_script
):
fut.result()

@staticmethod
def connect(_: Path):
return ibis.connect(DRUID_URL)
def connect(*, tmpdir, worker_id, **kw):
return ibis.connect(DRUID_URL, **kw)


@pytest.fixture(scope='session')
def con():
return ibis.connect(DRUID_URL)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
72 changes: 26 additions & 46 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,48 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Iterator

import pytest

import ibis
from ibis import util
from ibis.backends.conftest import SANDBOXED, TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero

if TYPE_CHECKING:
from pathlib import Path

from ibis.backends.base import BaseBackend


class TestConf(BackendTest, RoundAwayFromZero):
supports_map = True
deps = "duckdb", "duckdb_engine"
stateful = False

def __init__(self, data_directory: Path, **kwargs: Any) -> None:
self.connection = con = self.connect(data_directory, **kwargs)

def preload(self):
if not SANDBOXED:
con._load_extensions(["httpfs", "postgres_scanner", "sqlite_scanner"])

script_dir = data_directory.parent

parquet_dir = data_directory / "parquet"
stmts = [
f"""
CREATE OR REPLACE TABLE {table} AS
SELECT * FROM read_parquet('{parquet_dir / f'{table}.parquet'}')
"""
for table in TEST_TABLES
]
stmts.extend(
stripped_stmt
for stmt in script_dir.joinpath("schema", "duckdb.sql")
.read_text()
.split(";")
if (stripped_stmt := stmt.strip())
)
with con.begin() as c:
util.consume(map(c.exec_driver_sql, stmts))
self.connection._load_extensions(
["httpfs", "postgres_scanner", "sqlite_scanner"]
)

@property
def ddl_script(self) -> Iterator[str]:
parquet_dir = self.data_dir / "parquet"
for table in TEST_TABLES:
yield (
f"""
CREATE OR REPLACE TABLE {table} AS
SELECT * FROM read_parquet('{parquet_dir / f'{table}.parquet'}')
"""
)
yield from super().ddl_script

@staticmethod
def _load_data(data_dir, script_dir, **_: Any) -> None:
"""Load test data into a DuckDB backend instance.
Parameters
----------
data_dir
Location of test data
"""
return TestConf(data_directory=data_dir)

@staticmethod
def connect(data_directory: Path, **kwargs: Any) -> BaseBackend:
pytest.importorskip("duckdb")
return ibis.duckdb.connect(**kwargs) # type: ignore
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
# extension directory per test worker to prevent simultaneous downloads
return ibis.duckdb.connect(
extension_directory=str(tmpdir.mktemp(f"{worker_id}_exts")), **kw
)


@pytest.fixture(scope="session")
def con(data_directory, tmp_path_factory, worker_id):
return TestConf(
data_directory, extension_directory=str(tmp_path_factory.mktemp(worker_id))
).connection
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
20 changes: 10 additions & 10 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from ibis.backends.conftest import LINUX, SANDBOXED


def test_read_csv(data_directory):
t = ibis.read_csv(data_directory / "csv" / "functional_alltypes.csv")
def test_read_csv(data_dir):
t = ibis.read_csv(data_dir / "csv" / "functional_alltypes.csv")
assert t.count().execute()


def test_read_csv_with_columns(data_directory):
def test_read_csv_with_columns(data_dir):
t = ibis.read_csv(
data_directory / "csv" / "awards_players.csv",
data_dir / "csv" / "awards_players.csv",
header=True,
columns={
'playerID': 'VARCHAR',
Expand All @@ -40,16 +40,16 @@ def test_read_csv_with_columns(data_directory):
assert t.count().execute()


def test_read_parquet(data_directory):
t = ibis.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_parquet(data_dir):
t = ibis.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")
assert t.count().execute()


@pytest.mark.xfail_version(
duckdb=["duckdb<0.7.0"], reason="read_json_auto doesn't exist", raises=exc.IbisError
)
def test_read_json(data_directory, tmp_path):
pqt = ibis.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_json(data_dir, tmp_path):
pqt = ibis.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")

path = tmp_path.joinpath("ft.json")
path.write_text(pqt.execute().to_json(orient="records", lines=True))
Expand Down Expand Up @@ -161,13 +161,13 @@ def test_register_sqlite(con, tmp_path):
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
raises=duckdb.IOException,
)
def test_attach_sqlite(data_directory, tmp_path):
def test_attach_sqlite(data_dir, tmp_path):
import sqlite3

test_db_path = tmp_path / "test.db"
with sqlite3.connect(test_db_path) as scon:
for line in (
Path(data_directory.parent / "schema" / "sqlite.sql").read_text().split(";")
Path(data_dir.parent / "schema" / "sqlite.sql").read_text().split(";")
):
scon.execute(line)

Expand Down
87 changes: 26 additions & 61 deletions ibis/backends/impala/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class TestConf(UnorderedComparator, BackendTest, RoundAwayFromZero):
returned_timestamp_unit = 's'
supports_structs = False
supports_json = False
deps = "fsspec", "requests", "impala"

@staticmethod
def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
def _load_data(self, **_: Any) -> None:
"""Load test data into an Impala backend instance.
Parameters
Expand All @@ -43,10 +43,11 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
script_dir
Location of scripts defining schemas
"""
fsspec = pytest.importorskip("fsspec")
import fsspec

fs = fsspec.filesystem("file")

data_files = fs.find(data_dir / "impala")
data_files = fs.find(self.data_dir / "impala")

# without setting the pool size
# connections are dropped from the urllib3
Expand All @@ -56,19 +57,9 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:

env = IbisTestEnv()
futures = []
with contextlib.closing(
ibis.impala.connect(
host=env.impala_host,
port=env.impala_port,
hdfs_client=fsspec.filesystem(
env.hdfs_protocol,
host=env.nn_host,
port=env.hdfs_port,
user=env.hdfs_user,
),
pool_size=URLLIB_DEFAULT_POOL_SIZE,
)
) as con, concurrent.futures.ThreadPoolExecutor(
con = self.connection
con.create_database(env.test_data_db, force=True)
with concurrent.futures.ThreadPoolExecutor(
max_workers=int(
os.environ.get("IBIS_DATA_MAX_WORKERS", URLLIB_DEFAULT_POOL_SIZE)
)
Expand All @@ -90,7 +81,7 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
data_file,
os.path.join(
env.test_data_dir,
os.path.relpath(data_file, data_dir),
os.path.relpath(data_file, self.data_dir),
),
)
for data_file in data_files
Expand Down Expand Up @@ -137,17 +128,16 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
for fut in concurrent.futures.as_completed(futures):
fut.result()

def postload(self, **kw):
env = IbisTestEnv()
self.connection = self.connect(database=env.test_data_db, **kw)

@staticmethod
def connect(
data_directory: Path,
database: str
| None = os.environ.get("IBIS_TEST_DATA_DB", "ibis_testing"), # noqa: B008
with_hdfs: bool = True,
):
fsspec = pytest.importorskip("fsspec")
def connect(*, tmpdir, worker_id, **kw):
import fsspec

env = IbisTestEnv()
return ibis.impala.connect(
con = ibis.impala.connect(
host=env.impala_host,
port=env.impala_port,
auth_mechanism=env.auth_mechanism,
Expand All @@ -156,11 +146,10 @@ def connect(
host=env.nn_host,
port=env.hdfs_port,
user=env.hdfs_user,
)
if with_hdfs
else None,
database=database,
),
**kw,
)
return con

def _get_original_column_names(self, tablename: str) -> list[str]:
return list(TEST_TABLES[tablename].names)
Expand Down Expand Up @@ -282,36 +271,22 @@ def hdfs(env, tmp_dir):


@pytest.fixture(scope="session")
def backend(tmp_path_factory, data_directory, script_directory, worker_id):
return TestConf.load_data(
data_directory,
script_directory,
tmp_path_factory,
worker_id,
)


@pytest.fixture(scope="module")
def con_no_hdfs(env, data_directory, backend):
con = backend.connect(data_directory, with_hdfs=False)
con.disable_codegen(disabled=not env.use_codegen)
assert con.get_options()['DISABLE_CODEGEN'] == str(int(not env.use_codegen))
yield con
con.close()
def backend(tmp_path_factory, data_dir, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id)


@pytest.fixture(scope="module")
def con(env, data_directory, backend):
con = backend.connect(data_directory)
def con(env, backend):
con = backend.connection
con.disable_codegen(disabled=not env.use_codegen)
assert con.get_options()['DISABLE_CODEGEN'] == str(int(not env.use_codegen))
yield con
con.close()


@pytest.fixture
def tmp_db(env, con, test_data_db):
impala = pytest.importorskip("impala")
def tmp_db(env, con):
import impala

tmp_db = env.tmp_db

Expand All @@ -328,16 +303,6 @@ def tmp_db(env, con, test_data_db):
con.drop_database(tmp_db, force=True)


@pytest.fixture(scope="module")
def con_no_db(env, data_directory, backend):
con = backend.connect(data_directory, database=None)
if not env.use_codegen:
con.disable_codegen()
assert con.get_options()['DISABLE_CODEGEN'] == '1'
yield con
con.close()


@pytest.fixture(scope="module")
def alltypes(con):
return con.table("functional_alltypes")
Expand All @@ -349,7 +314,7 @@ def alltypes_df(alltypes):


@pytest.fixture
def temp_database(con, test_data_db):
def temp_database(con):
name = util.gen_name('database')
con.create_database(name)
yield name
Expand Down
22 changes: 6 additions & 16 deletions ibis/backends/impala/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,18 @@
thrift = pytest.importorskip("thrift")


@pytest.fixture
def db(con, test_data_db):
return con.database(test_data_db)


def test_hdfs_connect_function_is_public():
assert hasattr(ibis.impala, "hdfs_connect")


def test_raise_ibis_error_no_hdfs(con_no_hdfs):
def test_raise_ibis_error_no_hdfs(env):
con = ibis.impala.connect(
host=env.impala_host, port=env.impala_port, auth_mechanism=env.auth_mechanism
)

# GH299
with pytest.raises(com.IbisError):
con_no_hdfs.hdfs # noqa: B018


def test_get_table_ref(db):
assert isinstance(db.functional_alltypes, ir.Table)
assert isinstance(db['functional_alltypes'], ir.Table)
con.hdfs # noqa: B018


def test_run_sql(con, test_data_db):
Expand Down Expand Up @@ -192,10 +186,6 @@ def test_sql_query_limits(con, test_data_db):
assert table.count().execute(limit=10) == 25


def test_database_repr(db, test_data_db):
assert test_data_db in repr(db)


def test_database_default_current_database(con):
db = con.database()
assert db.name == con.current_database
Expand Down
53 changes: 22 additions & 31 deletions ibis/backends/mssql/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import pytest
import sqlalchemy as sa

import ibis
from ibis.backends.conftest import init_database
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest, ServiceSpec
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -29,19 +29,16 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):
supports_arrays_outside_of_select = supports_arrays
supports_structs = False
supports_json = False
service_name = "mssql"
deps = "pymssql", "sqlalchemy"

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name=cls.name(),
data_volume="/data",
files=data_dir.joinpath("csv").glob("*.csv"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("csv").glob("*.csv")

@staticmethod
def _load_data(
data_dir: Path,
script_dir: Path,
self,
*,
user: str = MSSQL_USER,
password: str = MSSQL_PASS,
host: str = MSSQL_HOST,
Expand All @@ -58,34 +55,28 @@ def _load_data(
script_dir
Location of scripts defining schemas
"""
with open(script_dir / 'schema' / 'mssql.sql') as schema:
init_database(
url=sa.engine.make_url(
f"mssql+pymssql://{user}:{password}@{host}:{port:d}/{database}"
),
database=database,
schema=schema,
isolation_level="AUTOCOMMIT",
recreate=False,
)
init_database(
url=sa.engine.make_url(
f"mssql+pymssql://{user}:{password}@{host}:{port:d}/{database}"
),
database=database,
schema=self.ddl_script,
isolation_level="AUTOCOMMIT",
recreate=False,
)

@staticmethod
def connect(_: Path):
def connect(*, tmpdir, worker_id, **kw):
return ibis.mssql.connect(
host=MSSQL_HOST,
user=MSSQL_USER,
password=MSSQL_PASS,
database=IBIS_TEST_MSSQL_DB,
port=MSSQL_PORT,
**kw,
)


@pytest.fixture(scope='session')
def con():
return ibis.mssql.connect(
host=MSSQL_HOST,
user=MSSQL_USER,
password=MSSQL_PASS,
database=IBIS_TEST_MSSQL_DB,
port=MSSQL_PORT,
)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
88 changes: 38 additions & 50 deletions ibis/backends/mysql/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import pytest
import sqlalchemy as sa
from packaging.version import parse as parse_version

import ibis
from ibis.backends.conftest import TEST_TABLES, init_database
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest, ServiceSpec
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -24,38 +24,32 @@
class TestConf(ServiceBackendTest, RoundHalfToEven):
# mysql has the same rounding behavior as postgres
check_dtype = False
supports_window_operations = False
returned_timestamp_unit = 's'
supports_arrays = False
supports_arrays_outside_of_select = supports_arrays
native_bool = False
supports_structs = False
service_name = "mysql"
deps = "pymysql", "sqlalchemy"

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name=cls.name(),
data_volume="/data",
files=data_dir.joinpath("csv").glob("*.csv"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("csv").glob("*.csv")

def __init__(self, data_directory: Path) -> None:
super().__init__(data_directory)
@property
def supports_window_operations(self) -> bool:
con = self.connection
with con.begin() as c:
version = c.exec_driver_sql("SELECT VERSION()").scalar()
version = c.execute(sa.select(sa.func.version())).scalar()

# mariadb supports window operations after version 10.2
# mysql supports window operations after version 8
min_version = "10.2" if "MariaDB" in version else "8.0"
self.__class__.supports_window_operations = parse_version(
con.version
) >= parse_version(min_version)
return parse_version(con.version) >= parse_version(min_version)

@staticmethod
def _load_data(
data_dir: Path,
script_dir: Path,
self,
*,
user: str = MYSQL_USER,
password: str = MYSQL_PASS,
host: str = MYSQL_HOST,
Expand All @@ -72,37 +66,37 @@ def _load_data(
script_dir
Location of scripts defining schemas
"""
with open(script_dir / 'schema' / 'mysql.sql') as schema:
engine = init_database(
url=sa.engine.make_url(
f"mysql+pymysql://{user}:{password}@{host}:{port:d}?local_infile=1",
),
database=database,
schema=schema,
isolation_level="AUTOCOMMIT",
recreate=False,
)
with engine.begin() as con:
for table in TEST_TABLES:
csv_path = data_dir / "csv" / f"{table}.csv"
lines = [
f"LOAD DATA LOCAL INFILE {str(csv_path)!r}",
f"INTO TABLE {table}",
"COLUMNS TERMINATED BY ','",
"""OPTIONALLY ENCLOSED BY '"'""",
"LINES TERMINATED BY '\\n'",
"IGNORE 1 LINES",
]
con.exec_driver_sql("\n".join(lines))
engine = init_database(
url=sa.engine.make_url(
f"mysql+pymysql://{user}:{password}@{host}:{port:d}?local_infile=1",
),
database=database,
schema=self.ddl_script,
isolation_level="AUTOCOMMIT",
recreate=False,
)
with engine.begin() as con:
for table in TEST_TABLES:
csv_path = self.data_dir / "csv" / f"{table}.csv"
lines = [
f"LOAD DATA LOCAL INFILE {str(csv_path)!r}",
f"INTO TABLE {table}",
"COLUMNS TERMINATED BY ','",
"""OPTIONALLY ENCLOSED BY '"'""",
"LINES TERMINATED BY '\\n'",
"IGNORE 1 LINES",
]
con.exec_driver_sql("\n".join(lines))

@staticmethod
def connect(_: Path):
def connect(*, tmpdir, worker_id, **kw):
return ibis.mysql.connect(
host=MYSQL_HOST,
user=MYSQL_USER,
password=MYSQL_PASS,
database=IBIS_TEST_MYSQL_DB,
port=MYSQL_PORT,
**kw,
)


Expand All @@ -121,14 +115,8 @@ def setup_privs():


@pytest.fixture(scope='session')
def con():
return ibis.mysql.connect(
host=MYSQL_HOST,
user=MYSQL_USER,
password=MYSQL_PASS,
database=IBIS_TEST_MYSQL_DB,
port=MYSQL_PORT,
)
def con(tmp_path_factory, data_dir, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='session')
Expand Down
67 changes: 28 additions & 39 deletions ibis/backends/oracle/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import itertools
import os
import subprocess
from typing import TYPE_CHECKING, Any, TextIO
from typing import TYPE_CHECKING, Any, Iterable

import pytest
import sqlalchemy as sa

import ibis
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest, ServiceSpec
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -33,22 +33,20 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):
native_bool = False
supports_structs = False
supports_json = False

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name=cls.name(),
data_volume="/opt/oracle/data",
files=itertools.chain(
data_dir.joinpath("csv").glob("*.csv"),
data_dir.parent.joinpath("schema", "oracle").glob("*.ctl"),
),
data_volume = "/opt/oracle/data"
service_name = "oracle"
deps = "oracledb", "sqlalchemy"

@property
def test_files(self) -> Iterable[Path]:
return itertools.chain(
self.data_dir.joinpath("csv").glob("*.csv"),
self.script_dir.joinpath("oracle").glob("*.ctl"),
)

@staticmethod
def _load_data(
data_dir: Path,
script_dir: Path,
self,
*,
user: str = ORACLE_USER,
password: str = ORACLE_PASS,
host: str = ORACLE_HOST,
Expand Down Expand Up @@ -81,15 +79,14 @@ def _load_data(
]
)

with open(script_dir / 'schema' / 'oracle.sql') as schema:
init_oracle_database(
url=sa.engine.make_url(
f"oracle://{user}:{password}@{host}:{port:d}/{database}",
),
database=database,
schema=schema,
connect_args=dict(service_name=database),
)
init_oracle_database(
url=sa.engine.make_url(
f"oracle://{user}:{password}@{host}:{port:d}/{database}",
),
database=database,
schema=self.ddl_script,
connect_args=dict(service_name=database),
)

# then call sqlldr to ingest
with concurrent.futures.ThreadPoolExecutor() as executor:
Expand All @@ -107,18 +104,19 @@ def _load_data(
],
stdout=subprocess.DEVNULL,
)
for ctl_file in script_dir.joinpath("schema", "oracle").glob("*.ctl")
for ctl_file in self.script_dir.joinpath("oracle").glob("*.ctl")
):
fut.result()

@staticmethod
def connect(_: Path):
def connect(*, tmpdir, worker_id, **kw):
return ibis.oracle.connect(
host=ORACLE_HOST,
user=ORACLE_USER,
password=ORACLE_PASS,
database="IBIS_TESTING",
port=ORACLE_PORT,
**kw,
)

@staticmethod
Expand All @@ -127,20 +125,14 @@ def format_table(name: str) -> str:


@pytest.fixture(scope='session')
def con():
return ibis.oracle.connect(
host=ORACLE_HOST,
user=ORACLE_USER,
password=ORACLE_PASS,
database="IBIS_TESTING",
port=ORACLE_PORT,
)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


def init_oracle_database(
url: sa.engine.url.URL,
database: str,
schema: TextIO | None = None,
schema: str | None = None,
**kwargs: Any,
) -> sa.engine.Engine:
"""Initialise `database` at `url` with `schema`.
Expand Down Expand Up @@ -170,10 +162,7 @@ def init_oracle_database(

if schema:
with engine.begin() as conn:
for stmt in filter(
None,
map(str.strip, schema.read().split(';')),
):
for stmt in schema:
# XXX: maybe should just remove the comments in the sql file
# so we don't end up writing an entire parser here.
if not stmt.startswith("--"):
Expand Down
38 changes: 17 additions & 21 deletions ibis/backends/pandas/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pandas as pd
from typing import Any

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundHalfToEven
from ibis.backends.tests.data import array_types, json_types, struct_types, win

if TYPE_CHECKING:
from pathlib import Path


class TestConf(BackendTest, RoundHalfToEven):
check_names = False
supported_to_timestamp_units = BackendTest.supported_to_timestamp_units | {'ns'}
supports_divide_by_zero = True
returned_timestamp_unit = 'ns'
stateful = False
deps = ("pandas",)

def _load_data(self, **_: Any) -> None:
import pandas as pd

con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.create_table(table_name, pd.read_parquet(path))
con.create_table("array_types", array_types, overwrite=True)
con.create_table("struct", struct_types, overwrite=True)
con.create_table("win", win, overwrite=True)
con.create_table("json_t", json_types, overwrite=True)

@staticmethod
def connect(data_directory: Path):
return ibis.pandas.connect(
dictionary={
**{
table: pd.read_parquet(
data_directory / "parquet" / f"{table}.parquet"
)
for table in TEST_TABLES.keys()
},
'struct': struct_types,
'json_t': json_types,
'array_types': array_types,
'win': win,
}
)
def connect(*, tmpdir, worker_id, **kw):
return ibis.pandas.connect(**kw)
8 changes: 4 additions & 4 deletions ibis/backends/pandas/tests/execution/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ def df():


@pytest.fixture(scope='module')
def batting_df(data_directory):
def batting_df(data_dir):
num_rows = 1000
start_index = 30
df = pd.read_parquet(data_directory / 'parquet' / 'batting.parquet').iloc[
df = pd.read_parquet(data_dir / 'parquet' / 'batting.parquet').iloc[
start_index : start_index + num_rows
]
return df.reset_index(drop=True)


@pytest.fixture(scope='module')
def awards_players_df(data_directory):
return pd.read_parquet(data_directory / 'parquet' / 'awards_players.parquet')
def awards_players_df(data_dir):
return pd.read_parquet(data_dir / 'parquet' / 'awards_players.parquet')


@pytest.fixture(scope='module')
Expand Down
50 changes: 19 additions & 31 deletions ibis/backends/polars/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,44 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import array_types, struct_types, win

if TYPE_CHECKING:
from pathlib import Path

pl = pytest.importorskip("polars")


class TestConf(BackendTest, RoundAwayFromZero):
supports_structs = True
supports_json = False
reduction_tolerance = 1e-3
stateful = False
deps = ("polars",)

def _load_data(self, **_: Any) -> None:
con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / 'parquet' / f'{table_name}.parquet'
con.register(path, table_name=table_name)
con.register(array_types, table_name='array_types')
con.register(struct_types, table_name='struct')
con.register(win, table_name="win")

@staticmethod
def connect(data_directory: Path):
client = ibis.polars.connect({})
client.register(
data_directory / 'parquet' / 'functional_alltypes.parquet',
table_name='functional_alltypes',
)
client.register(
data_directory / "parquet" / 'batting.parquet', table_name='batting'
)
client.register(
data_directory / "parquet" / 'awards_players.parquet',
table_name='awards_players',
)
client.register(
data_directory / "parquet" / 'diamonds.parquet', table_name='diamonds'
)
client.register(array_types, table_name='array_types')
client.register(struct_types, table_name='struct')
client.register(win, table_name="win")

return client
def connect(*, tmpdir, worker_id, **kw):
return ibis.polars.connect(**kw)


@pytest.fixture(scope='session')
def client(data_directory):
return TestConf.connect(data_directory)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='session')
def alltypes(client):
return client.table("functional_alltypes")
def alltypes(con):
return con.table("functional_alltypes")


@pytest.fixture(scope='session')
Expand Down
54 changes: 23 additions & 31 deletions ibis/backends/postgres/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import pytest
import sqlalchemy as sa

import ibis
from ibis.backends.conftest import init_database
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest, ServiceSpec
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -47,19 +47,16 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):

returned_timestamp_unit = 's'
supports_structs = False
service_name = "postgres"
deps = "psycopg2", "sqlalchemy"

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name=cls.name(),
data_volume="/data",
files=data_dir.joinpath("csv").glob("*.csv"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("csv").glob("*.csv")

@staticmethod
def _load_data(
data_dir: Path,
script_dir: Path,
self,
*,
user: str = PG_USER,
password: str = PG_PASS,
host: str = PG_HOST,
Expand All @@ -76,36 +73,31 @@ def _load_data(
script_dir
Location of scripts defining schemas
"""
with open(script_dir / 'schema' / 'postgresql.sql') as schema:
init_database(
url=sa.engine.make_url(
f"postgresql://{user}:{password}@{host}:{port:d}/{database}"
),
database=database,
schema=schema,
isolation_level="AUTOCOMMIT",
recreate=False,
)
init_database(
url=sa.engine.make_url(
f"postgresql://{user}:{password}@{host}:{port:d}/{database}"
),
database=database,
schema=self.ddl_script,
isolation_level="AUTOCOMMIT",
recreate=False,
)

@staticmethod
def connect(data_directory: Path):
def connect(*, tmpdir, worker_id, port: int | None = None, **kw):
return ibis.postgres.connect(
host=PG_HOST,
port=PG_PORT,
port=port or PG_PORT,
user=PG_USER,
password=PG_PASS,
database=IBIS_TEST_POSTGRES_DB,
**kw,
)


@pytest.fixture(scope='session')
def con(tmp_path_factory, data_directory, script_directory, worker_id):
return TestConf.load_data(
data_directory,
script_directory,
tmp_path_factory,
worker_id,
).connect(data_directory)
def con(tmp_path_factory, data_dir, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='module')
Expand Down
294 changes: 138 additions & 156 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from datetime import datetime, timedelta, timezone
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -11,178 +12,159 @@
from ibis import util
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import win

pytest.importorskip("pyspark")

import pyspark.sql.functions as F # noqa: E402
import pyspark.sql.types as pt # noqa: E402
from pyspark.sql import Row, SparkSession # noqa: E402


def get_common_spark_testing_client(data_directory, connect):
spark = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.showConsoleProgress", False)
.config('spark.default.parallelism', 1)
.config('spark.dynamicAllocation.enabled', False)
.config('spark.rdd.compress', False)
.config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer')
.config('spark.shuffle.compress', False)
.config('spark.shuffle.spill.compress', False)
.config('spark.sql.shuffle.partitions', 1)
.config('spark.ui.enabled', False)
.getOrCreate()
)
_spark_testing_client = connect(spark)
s: SparkSession = _spark_testing_client._session
num_partitions = 4

sort_cols = {"functional_alltypes": "id"}
from ibis.backends.tests.data import json_types, win

for name in TEST_TABLES.keys():
path = str(data_directory / "parquet" / f"{name}.parquet")
t = s.read.parquet(path).repartition(num_partitions)
if (sort_col := sort_cols.get(name)) is not None:
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

s.createDataFrame([(1, 'a')], ['foo', 'bar']).createOrReplaceTempView('simple')
def set_pyspark_database(con, database):
con._session.catalog.setCurrentDatabase(database)

s.createDataFrame(
[
Row(abc=Row(a=1.0, b='banana', c=2)),
Row(abc=Row(a=2.0, b='apple', c=3)),
Row(abc=Row(a=3.0, b='orange', c=4)),
Row(abc=Row(a=None, b='banana', c=2)),
Row(abc=Row(a=2.0, b=None, c=3)),
Row(abc=None),
Row(abc=Row(a=3.0, b='orange', c=None)),
],
).createOrReplaceTempView('struct')

s.createDataFrame(
[([1, 2], [[3, 4], [5, 6]], {'a': [[2, 4], [3, 5]]})],
[
'list_of_ints',
'list_of_list_of_ints',
'map_string_list_of_list_of_ints',
],
).createOrReplaceTempView('nested_types')
s.createDataFrame(
[
(
[1, 2, 3],
['a', 'b', 'c'],
[1.0, 2.0, 3.0],
'a',
1.0,
[[], [1, 2, 3], None],
),
([4, 5], ['d', 'e'], [4.0, 5.0], 'a', 2.0, []),
([6, None], ['f', None], [6.0, None], 'a', 3.0, [None, [], None]),
(
[None, 1, None],
[None, 'a', None],
[],
'b',
4.0,
[[1], [2], [], [3, 4, 5]],
),
([2, None, 3], ['b', None, 'c'], None, 'b', 5.0, None),
(
[4, None, None, 5],
['d', None, None, 'e'],
[4.0, None, None, 5.0],
'c',
6.0,
[[1, 2, 3]],
),
],
["x", "y", "z", "grouper", "scalar_column", "multi_dim"],
).createOrReplaceTempView("array_types")

s.createDataFrame(
[({(1, 3): [[2, 4], [3, 5]]},)], ['map_tuple_list_of_list_of_ints']
).createOrReplaceTempView('complicated')

s.createDataFrame(
[('a', 1, 4.0, 'a'), ('b', 2, 5.0, 'a'), ('c', 3, 6.0, 'b')],
['a', 'b', 'c', 'key'],
).createOrReplaceTempView('udf')

s.createDataFrame(
pd.DataFrame(
{
'a': np.arange(10, dtype=float),
'b': [3.0, np.NaN] * 5,
'key': list('ddeefffggh'),
}
)
).createOrReplaceTempView('udf_nan')

s.createDataFrame(
[(float(i), None if i % 2 else 3.0, 'ddeefffggh'[i]) for i in range(10)],
['a', 'b', 'key'],
).createOrReplaceTempView('udf_null')

s.createDataFrame(
pd.DataFrame(
{
'a': np.arange(4.0).tolist() + np.random.rand(3).tolist(),
'b': np.arange(4.0).tolist() + np.random.rand(3).tolist(),
'key': list('ddeefff'),
}
)
).createOrReplaceTempView('udf_random')

s.createDataFrame(
pd.DataFrame(
{
"js": [
'{"a": [1,2,3,4], "b": 1}',
'{"a":null,"b":2}',
'{"a":"foo", "c":null}',
"null",
"[42,47,55]",
"[]",
]
}
)
).createOrReplaceTempView("json_t")

s.createDataFrame(win).createOrReplaceTempView("win")
class TestConf(BackendTest, RoundAwayFromZero):
supported_to_timestamp_units = {'s'}
deps = ("pyspark",)

return _spark_testing_client
def _load_data(self, **_: Any) -> None:
from pyspark.sql import Row

s = self.connection._session
num_partitions = 4

def get_pyspark_testing_client(data_directory):
return get_common_spark_testing_client(data_directory, ibis.pyspark.connect)
sort_cols = {"functional_alltypes": "id"}

for name in TEST_TABLES.keys():
path = str(self.data_dir / "parquet" / f"{name}.parquet")
t = s.read.parquet(path).repartition(num_partitions)
if (sort_col := sort_cols.get(name)) is not None:
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

def set_pyspark_database(con, database):
con._session.catalog.setCurrentDatabase(database)
s.createDataFrame([(1, 'a')], ['foo', 'bar']).createOrReplaceTempView('simple')

s.createDataFrame(
[
Row(abc=Row(a=1.0, b='banana', c=2)),
Row(abc=Row(a=2.0, b='apple', c=3)),
Row(abc=Row(a=3.0, b='orange', c=4)),
Row(abc=Row(a=None, b='banana', c=2)),
Row(abc=Row(a=2.0, b=None, c=3)),
Row(abc=None),
Row(abc=Row(a=3.0, b='orange', c=None)),
],
).createOrReplaceTempView('struct')

s.createDataFrame(
[([1, 2], [[3, 4], [5, 6]], {'a': [[2, 4], [3, 5]]})],
[
'list_of_ints',
'list_of_list_of_ints',
'map_string_list_of_list_of_ints',
],
).createOrReplaceTempView('nested_types')
s.createDataFrame(
[
(
[1, 2, 3],
['a', 'b', 'c'],
[1.0, 2.0, 3.0],
'a',
1.0,
[[], [1, 2, 3], None],
),
([4, 5], ['d', 'e'], [4.0, 5.0], 'a', 2.0, []),
([6, None], ['f', None], [6.0, None], 'a', 3.0, [None, [], None]),
(
[None, 1, None],
[None, 'a', None],
[],
'b',
4.0,
[[1], [2], [], [3, 4, 5]],
),
([2, None, 3], ['b', None, 'c'], None, 'b', 5.0, None),
(
[4, None, None, 5],
['d', None, None, 'e'],
[4.0, None, None, 5.0],
'c',
6.0,
[[1, 2, 3]],
),
],
["x", "y", "z", "grouper", "scalar_column", "multi_dim"],
).createOrReplaceTempView("array_types")

s.createDataFrame(
[({(1, 3): [[2, 4], [3, 5]]},)], ['map_tuple_list_of_list_of_ints']
).createOrReplaceTempView('complicated')

s.createDataFrame(
[('a', 1, 4.0, 'a'), ('b', 2, 5.0, 'a'), ('c', 3, 6.0, 'b')],
['a', 'b', 'c', 'key'],
).createOrReplaceTempView('udf')

s.createDataFrame(
pd.DataFrame(
{
'a': np.arange(10, dtype=float),
'b': [3.0, np.NaN] * 5,
'key': list('ddeefffggh'),
}
)
).createOrReplaceTempView('udf_nan')

s.createDataFrame(
[(float(i), None if i % 2 else 3.0, 'ddeefffggh'[i]) for i in range(10)],
['a', 'b', 'key'],
).createOrReplaceTempView('udf_null')

s.createDataFrame(
pd.DataFrame(
{
'a': np.arange(4.0).tolist() + np.random.rand(3).tolist(),
'b': np.arange(4.0).tolist() + np.random.rand(3).tolist(),
'key': list('ddeefff'),
}
)
).createOrReplaceTempView('udf_random')

class TestConf(BackendTest, RoundAwayFromZero):
supported_to_timestamp_units = {'s'}
s.createDataFrame(json_types).createOrReplaceTempView("json_t")
s.createDataFrame(win).createOrReplaceTempView("win")

@staticmethod
def connect(data_directory):
return get_pyspark_testing_client(data_directory)
def connect(*, tmpdir, worker_id, **kw):
from pyspark.sql import SparkSession

spark = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.showConsoleProgress", False)
.config('spark.default.parallelism', 1)
.config('spark.dynamicAllocation.enabled', False)
.config('spark.rdd.compress', False)
.config('spark.serializer', 'org.apache.spark.serializer.KryoSerializer')
.config('spark.shuffle.compress', False)
.config('spark.shuffle.spill.compress', False)
.config('spark.sql.shuffle.partitions', 1)
.config('spark.ui.enabled', False)
.getOrCreate()
)
return ibis.pyspark.connect(spark, **kw)


@pytest.fixture(scope='session')
def con(data_directory):
con = TestConf.connect(data_directory)
def con(data_dir, tmp_path_factory, worker_id):
import pyspark.sql.functions as F
import pyspark.sql.types as pt

backend_test = TestConf.load_data(data_dir, tmp_path_factory, worker_id)
con = backend_test.connection

df = con._session.range(0, 10)
df = df.withColumn("str_col", F.lit('value'))
Expand Down
34 changes: 9 additions & 25 deletions ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,10 @@ def copy_into(con, data_dir: Path, table: str) -> None:
class TestConf(BackendTest, RoundAwayFromZero):
supports_map = True
default_identifier_case_fn = staticmethod(str.upper)
deps = ("snowflake.connector", "snowflake.sqlalchemy")

def __init__(self, data_directory: Path) -> None:
self.connection = self.connect(data_directory)

@staticmethod
def _load_data(
data_dir, script_dir, database: str = "ibis_testing", **_: Any
) -> None:
"""Load test data into a Snowflake backend instance.
Parameters
----------
data_dir
Location of test data
script_dir
Location of scripts defining schemas
"""
pytest.importorskip("snowflake.connector")
pytest.importorskip("snowflake.sqlalchemy")

def _load_data(self, **_: Any) -> None:
"""Load test data into a Snowflake backend instance."""
snowflake_url = _get_url()

raw_url = sa.engine.make_url(snowflake_url)
Expand All @@ -93,25 +77,25 @@ def _load_data(
USE DATABASE ibis_testing;
CREATE SCHEMA IF NOT EXISTS {dbschema};
USE SCHEMA {dbschema};
{script_dir.joinpath("schema", "snowflake.sql").read_text()}"""
{self.script_dir.joinpath("snowflake.sql").read_text()}"""
)

with con.begin() as c:
# not much we can do to make this faster, but running these in
# multiple threads seems to save about 2x
with concurrent.futures.ThreadPoolExecutor() as exe:
for future in concurrent.futures.as_completed(
exe.submit(copy_into, c, data_dir, table)
exe.submit(copy_into, c, self.data_dir, table)
for table in TEST_TABLES.keys()
):
future.result()

@staticmethod
@functools.lru_cache(maxsize=None)
def connect(data_directory: Path) -> BaseBackend:
return ibis.connect(_get_url())
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
return ibis.connect(_get_url(), **kw)


@pytest.fixture(scope="session")
def con(data_directory):
return TestConf.connect(data_directory)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
46 changes: 12 additions & 34 deletions ibis/backends/sqlite/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import csv
import sqlite3
from typing import TYPE_CHECKING, Any
from typing import Any

import pytest

Expand All @@ -12,11 +12,6 @@
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero

if TYPE_CHECKING:
from pathlib import Path

from ibis.backends.base import BaseBackend


class TestConf(BackendTest, RoundAwayFromZero):
supports_arrays = False
Expand All @@ -25,45 +20,28 @@ class TestConf(BackendTest, RoundAwayFromZero):
check_dtype = False
returned_timestamp_unit = 's'
supports_structs = False
stateful = False
deps = ("sqlalchemy",)

def __init__(self, data_directory: Path) -> None:
self.connection = self.connect(data_directory)
@staticmethod
def connect(*, tmpdir, worker_id, **kw):
return ibis.sqlite.connect(**kw)

schema = data_directory.parent.joinpath('schema', 'sqlite.sql').read_text()
def _load_data(self, **kw: Any) -> None:
"""Load test data into a SQLite backend instance."""
super()._load_data(**kw)

with self.connection.begin() as con:
for stmt in filter(None, map(str.strip, schema.split(';'))):
con.exec_driver_sql(stmt)

for table in TEST_TABLES:
basename = f"{table}.csv"
with data_directory.joinpath("csv", basename).open("r") as f:
with self.data_dir.joinpath("csv", basename).open("r") as f:
reader = csv.reader(f)
header = next(reader)
assert header, f"empty header for table: `{table}`"
spec = ", ".join("?" * len(header))
with contextlib.closing(con.connection.cursor()) as cur:
cur.executemany(f"INSERT INTO {table} VALUES ({spec})", reader)

@staticmethod
def _load_data(
data_dir: Path, script_dir: Path, database: str | None = None, **_: Any
) -> None:
"""Load test data into a SQLite backend instance.
Parameters
----------
data_dir
Location of test data
script_dir
Location of scripts defining schemas
"""
return TestConf(data_dir)

@staticmethod
def connect(data_directory: Path) -> BaseBackend:
return ibis.sqlite.connect() # type: ignore

@property
def functional_alltypes(self) -> ir.Table:
t = super().functional_alltypes
Expand All @@ -80,8 +58,8 @@ def dbpath(tmp_path):


@pytest.fixture(scope="session")
def con(data_directory):
return TestConf(data_directory).connection
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope="session")
Expand Down
107 changes: 68 additions & 39 deletions ibis/backends/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -81,14 +81,32 @@ class BackendTest(abc.ABC):
supports_map = False # basically nothing does except trino and snowflake
reduction_tolerance = 1e-7
default_identifier_case_fn = staticmethod(toolz.identity)
stateful = True
service_name = None

@property
@abc.abstractmethod
def deps(self) -> Iterable[str]:
"""A list of dependencies that must be present to run tests."""

@property
def ddl_script(self) -> Iterator[str]:
return filter(
None,
map(
str.strip,
self.script_dir.joinpath(f"{self.name()}.sql").read_text().split(";"),
),
)

@staticmethod
def format_table(name: str) -> str:
return name

def __init__(self, data_directory: Path) -> None:
self.connection = self.connect(data_directory)
self.data_directory = data_directory
def __init__(self, *, data_dir: Path, tmpdir, worker_id, **kw) -> None:
self.connection = self.connect(tmpdir=tmpdir, worker_id=worker_id, **kw)
self.data_dir = data_dir
self.script_dir = data_dir.parent / "schema"

def __str__(self):
return f'<BackendTest {self.name()}>'
Expand All @@ -100,41 +118,58 @@ def name(cls) -> str:

@staticmethod
@abc.abstractmethod
def connect(data_directory: Path):
"""Return a connection with data loaded from `data_directory`."""
def connect(*, tmpdir, worker_id, **kw: Any):
"""Return a connection with data loaded from `data_dir`."""

def _load_data(self, **_: Any) -> None:
"""Load test data into a backend."""
with self.connection.begin() as con:
for stmt in self.ddl_script:
con.exec_driver_sql(stmt)

@staticmethod # noqa: B027
def _load_data(data_directory: Path, script_directory: Path, **kwargs: Any) -> None:
"""Load test data into a backend.
def stateless_load(self, **kw):
self.preload()
self._load_data(**kw)

Default implementation is a no-op.
"""
def stateful_load(self, fn, **kw):
if not fn.exists():
self.stateless_load(**kw)
fn.touch()

@classmethod
def load_data(
cls, data_dir: Path, script_dir: Path, tmpdir: Path, worker_id: str, **kw: Any
) -> None:
"""Load testdata from `data_directory` into the backend using scripts
in `script_directory`."""
def load_data(cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any) -> None:
"""Load testdata from `data_dir`."""
# handling for multi-processes pytest

# get the temp directory shared by all workers
root_tmp_dir = tmpdir.getbasetemp()
if worker_id != "master":
root_tmp_dir = root_tmp_dir.parent

fn = root_tmp_dir / f"lockfile_{cls.name()}"
fn = root_tmp_dir / (getattr(cls, "service_name", None) or cls.name())
with FileLock(f"{fn}.lock"):
if not fn.exists():
cls.preload(data_dir)
cls._load_data(data_dir, script_dir, **kw)
fn.touch()
return cls(data_dir)

@classmethod # noqa: B027
def preload(cls, data_dir: Path):
cls.skip_if_missing_deps()

inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw)

if inst.stateful:
inst.stateful_load(fn, **kw)
else:
inst.stateless_load(**kw)
inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw)
return inst

@classmethod
def skip_if_missing_deps(cls) -> None:
for dep in cls.deps:
pytest.importorskip(dep)

def preload(self): # noqa: B027
"""Code to execute before loading data."""

def postload(self, **_): # noqa: B027
"""Code to execute after loading data."""

@classmethod
def assert_series_equal(
cls, left: pd.Series, right: pd.Series, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -228,23 +263,17 @@ def make_context(self, params: Mapping[ir.Value, Any] | None = None):
return self.api.compiler.make_context(params=params)


class ServiceSpec(NamedTuple):
name: str
data_volume: str
files: Iterable[Path]


class ServiceBackendTest(BackendTest):
@classmethod
data_volume = "/data"

@property
@abc.abstractmethod
def service_spec(data_dir: Path) -> ServiceSpec:
def test_files(self) -> Iterable[Path]:
...

@classmethod
def preload(cls, data_dir: Path):
spec = cls.service_spec(data_dir)
service = spec.name
data_volume = spec.data_volume
def preload(self):
service = self.service_name
data_volume = self.data_volume
with concurrent.futures.ThreadPoolExecutor() as e:
for fut in concurrent.futures.as_completed(
e.submit(
Expand All @@ -257,6 +286,6 @@ def preload(cls, data_dir: Path):
f"{service}:{data_volume}/{path.name}",
],
)
for path in spec.files
for path in self.test_files
):
fut.result()
36 changes: 17 additions & 19 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def pushd(new_dir):


@pytest.fixture
def gzip_csv(data_directory, tmp_path):
def gzip_csv(data_dir, tmp_path):
basename = "diamonds.csv"
f = tmp_path.joinpath(f"{basename}.gz")
data = data_directory.joinpath("csv", basename).read_bytes()
data = data_dir.joinpath("csv", basename).read_bytes()
f.write_bytes(gzip.compress(data))
return str(f.absolute())

Expand Down Expand Up @@ -92,8 +92,8 @@ def gzip_csv(data_directory, tmp_path):
"trino",
]
)
def test_register_csv(con, data_directory, fname, in_table_name, out_table_name):
with pushd(data_directory / "csv"):
def test_register_csv(con, data_dir, fname, in_table_name, out_table_name):
with pushd(data_dir / "csv"):
table = con.register(fname, table_name=in_table_name)

assert any(out_table_name in t for t in con.list_tables())
Expand All @@ -117,8 +117,8 @@ def test_register_csv(con, data_directory, fname, in_table_name, out_table_name)
"trino",
]
)
def test_register_csv_gz(con, data_directory, gzip_csv):
with pushd(data_directory):
def test_register_csv_gz(con, data_dir, gzip_csv):
with pushd(data_dir):
table = con.register(gzip_csv)

assert table.count().execute()
Expand All @@ -139,11 +139,11 @@ def test_register_csv_gz(con, data_directory, gzip_csv):
"trino",
]
)
def test_register_with_dotted_name(con, data_directory, tmp_path):
def test_register_with_dotted_name(con, data_dir, tmp_path):
basename = "foo.bar.baz/diamonds.csv"
f = tmp_path.joinpath(basename)
f.parent.mkdir()
data = data_directory.joinpath("csv", "diamonds.csv").read_bytes()
data = data_dir.joinpath("csv", "diamonds.csv").read_bytes()
f.write_bytes(data)
table = con.register(str(f.absolute()))

Expand Down Expand Up @@ -195,12 +195,12 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]:
]
)
def test_register_parquet(
con, tmp_path, data_directory, fname, in_table_name, out_table_name
con, tmp_path, data_dir, fname, in_table_name, out_table_name
):
pq = pytest.importorskip("pyarrow.parquet")

fname = Path(fname)
table = read_table(data_directory / "csv" / fname.name)
table = read_table(data_dir / "csv" / fname.name)

pq.write_table(table, tmp_path / fname.name)

Expand Down Expand Up @@ -233,11 +233,11 @@ def test_register_parquet(
def test_register_iterator_parquet(
con,
tmp_path,
data_directory,
data_dir,
):
pq = pytest.importorskip("pyarrow.parquet")

table = read_table(data_directory / "csv" / "functional_alltypes.csv")
table = read_table(data_dir / "csv" / "functional_alltypes.csv")

pq.write_table(table, tmp_path / "functional_alltypes.parquet")

Expand Down Expand Up @@ -416,18 +416,16 @@ def test_register_garbage(con, monkeypatch):
"trino",
]
)
def test_read_parquet(
con, tmp_path, data_directory, fname, in_table_name, out_table_name
):
def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name, out_table_name):
pq = pytest.importorskip("pyarrow.parquet")

fname = Path(fname)
fname = Path(data_directory) / "parquet" / fname.name
fname = Path(data_dir) / "parquet" / fname.name
table = pq.read_table(fname)

pq.write_table(table, tmp_path / fname.name)

with pushd(data_directory):
with pushd(data_dir):
if con.name == "pyspark":
# pyspark doesn't respect CWD
fname = str(Path(fname).absolute())
Expand Down Expand Up @@ -466,8 +464,8 @@ def test_read_parquet(
"trino",
]
)
def test_read_csv(con, data_directory, fname, in_table_name, out_table_name):
with pushd(data_directory / "csv"):
def test_read_csv(con, data_dir, fname, in_table_name, out_table_name):
with pushd(data_dir / "csv"):
if con.name == "pyspark":
# pyspark doesn't respect CWD
fname = str(Path(fname).absolute())
Expand Down
118 changes: 47 additions & 71 deletions ibis/backends/trino/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

import itertools
import os
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any, Iterable, Iterator

import pandas as pd
import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import RoundAwayFromZero, ServiceBackendTest, ServiceSpec
from ibis.backends.postgres.tests.conftest import TestConf as PostgresTestConf
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import struct_types
from ibis.util import consume

if TYPE_CHECKING:
from pathlib import Path

TRINO_USER = os.environ.get(
'IBIS_TEST_TRINO_USER', os.environ.get('TRINO_USER', 'user')
Expand All @@ -29,56 +30,37 @@
os.environ.get('TRINO_DATABASE', 'memory'),
)

sa = pytest.importorskip("sqlalchemy")

class TrinoPostgresTestConf(PostgresTestConf):
service_name = "trino-postgres"
deps = "sqlalchemy", "psycopg2"

@classmethod
def name(cls) -> str:
return "postgres"

class TestConf(ServiceBackendTest, RoundAwayFromZero):
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("csv").glob("*.csv")


class TestConf(BackendTest, RoundAwayFromZero):
# trino rounds half to even for double precision and half away from zero
# for numeric and decimal

returned_timestamp_unit = 's'
supports_structs = True
supports_map = True
service_name = "trino"
deps = ("sqlalchemy", "trino.sqlalchemy")

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name="trino-postgres",
data_volume="/data",
files=data_dir.joinpath("csv").glob("*.csv"),
)

@staticmethod
def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
"""Load test data into a Trino backend instance.
Parameters
----------
data_dir
Location of test data
script_dir
Location of scripts defining schemas
"""
from ibis.backends.postgres.tests.conftest import (
IBIS_TEST_POSTGRES_DB,
PG_HOST,
PG_PASS,
PG_USER,
)
from ibis.backends.postgres.tests.conftest import TestConf as PostgresTestConf

PostgresTestConf._load_data(data_dir, script_dir, port=5433)
pgcon = ibis.postgres.connect(
host=PG_HOST,
port=5433,
user=PG_USER,
password=PG_PASS,
database=IBIS_TEST_POSTGRES_DB,
schema="public",
)

con = TestConf.connect(data_dir)
def load_data(cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any) -> None:
TrinoPostgresTestConf.load_data(data_dir, tmpdir, worker_id, port=5433)
return super().load_data(data_dir, tmpdir, worker_id, **kw)

@property
def ddl_script(self) -> Iterator[str]:
selects = []
for row in struct_types.abc:
if pd.isna(row):
Expand All @@ -93,37 +75,36 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
# mirror the existing tables except for intervals which are not supported
# and maps which we do natively in trino, because trino has more extensive
# map support
unsupported_memory_tables = {"intervals", "not_supported_intervals", "map"}
lines = []
for table in frozenset(pgcon.list_tables()) - unsupported_memory_tables:
unsupported_memory_tables = ("intervals", "not_supported_intervals", "map")
with self.connection.begin() as c:
pg_tables = c.exec_driver_sql(
f"""
SELECT table_name
FROM postgresql.information_schema.tables
WHERE table_schema = 'public'
AND table_name NOT IN {unsupported_memory_tables!r}
"""
).scalars()

for table in pg_tables:
dest = f"memory.default.{table}"
lines.append(f"DROP VIEW IF EXISTS {dest}")
lines.append(
f"CREATE VIEW {dest} AS SELECT * FROM postgresql.public.{table}"
)

lines.extend(
itertools.chain(
[
"DROP VIEW IF EXISTS struct",
f"CREATE VIEW struct AS {' UNION ALL '.join(selects)}",
],
Path(script_dir, "schema", "trino.sql").read_text().split(";"),
)
)
yield f"DROP VIEW IF EXISTS {dest}"
yield f"CREATE VIEW {dest} AS SELECT * FROM postgresql.public.{table}"

with con.begin() as c:
consume(map(c.exec_driver_sql, filter(None, map(str.strip, lines))))
yield "DROP VIEW IF EXISTS struct"
yield f"CREATE VIEW struct AS {' UNION ALL '.join(selects)}"
yield from super().ddl_script

@staticmethod
def connect(data_directory: Path):
def connect(*, tmpdir, worker_id, **kw):
return ibis.trino.connect(
host=TRINO_HOST,
port=TRINO_PORT,
user=TRINO_USER,
password=TRINO_PASS,
database=IBIS_TEST_TRINO_DB,
schema="default",
**kw,
)

def _remap_column_names(self, table_name: str) -> dict[str, str]:
Expand All @@ -142,13 +123,8 @@ def awards_players(self):


@pytest.fixture(scope='session')
def con(tmp_path_factory, data_directory, script_directory, worker_id):
return TestConf.load_data(
data_directory,
script_directory,
tmp_path_factory,
worker_id,
).connect(data_directory)
def con(tmp_path_factory, data_dir, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='module')
Expand Down