diff --git a/deepnote_toolkit/sql/duckdb_sql.py b/deepnote_toolkit/sql/duckdb_sql.py index bf90164..60a8493 100644 --- a/deepnote_toolkit/sql/duckdb_sql.py +++ b/deepnote_toolkit/sql/duckdb_sql.py @@ -1,8 +1,11 @@ import sys import duckdb +from duckdb_extensions import import_extension from packaging.version import Version +from deepnote_toolkit.logging import LoggerManager + _DEEPNOTE_DUCKDB_CONNECTION = None _DEFAULT_DUCKDB_SAMPLE_SIZE = 20_000 @@ -40,16 +43,32 @@ def _get_duckdb_connection(): duckdb.Connection: A connection to the DuckDB database. """ global _DEEPNOTE_DUCKDB_CONNECTION + logger = LoggerManager().get_logger() if not _DEEPNOTE_DUCKDB_CONNECTION: _DEEPNOTE_DUCKDB_CONNECTION = duckdb.connect( database=":memory:", read_only=False ) + # DuckDB extensions are loaded from included wheels to prevent loading them + # from the internet on every notebook start + # # Install and load the spatial extension. Primary use case: reading xlsx files # e.g. SELECT * FROM st_read('excel.xlsx') - _DEEPNOTE_DUCKDB_CONNECTION.execute("install spatial;") - _DEEPNOTE_DUCKDB_CONNECTION.execute("load spatial;") + # there is also official excel extension, which mentions that Excel support from spatial extension + # may be removed in the future (see: https://duckdb.org/docs/stable/core_extensions/excel) + for extension_name in ["spatial", "excel"]: + try: + import_extension( + name=extension_name, + force_install=True, + con=_DEEPNOTE_DUCKDB_CONNECTION, + ) + _DEEPNOTE_DUCKDB_CONNECTION.load_extension(extension_name) + except Exception as e: + # Extensions are optional and connection still works, users are able to load + # them manually if needed (pulling them from internet in this case as fallback) + logger.warning(f"Failed to load DuckDB {extension_name} extension: {e}") _set_sample_size(_DEEPNOTE_DUCKDB_CONNECTION, _DEFAULT_DUCKDB_SAMPLE_SIZE) diff --git a/poetry.lock b/poetry.lock index c7b6739..bdf8865 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1655,6 +1655,56 @@ files = [ [package.extras] all = ["adbc-driver-manager", "fsspec", "ipython", "numpy", "pandas", "pyarrow"] +[[package]] +name = "duckdb-extension-excel" +version = "1.4.1" +description = "Duckdb excel extension" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "duckdb_extension_excel-1.4.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:16521c719b0a4d8637c84c965a02148e339f2f2aed616e32aad7e4b303ac57b4"}, + {file = "duckdb_extension_excel-1.4.1-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:06881c3eecadfb48adf6a5cbb039c0ea384d9a9532dc3bc1b8391fef04f44905"}, + {file = "duckdb_extension_excel-1.4.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:e6a969dea89815bd93af5fc92534091b694f020d39aa930e91de903cbc04dd05"}, + {file = "duckdb_extension_excel-1.4.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9e8e90b087b07b1aa2fc001b15ba4094104232fb3ff04774105e2a2b79db1695"}, + {file = "duckdb_extension_excel-1.4.1-py3-none-win_amd64.whl", hash = "sha256:9e3f14ac386839a4f6dfe8cf33b660ca6e506cc0f4f3ac6bc42a3ef6ff4436f1"}, +] + +[package.dependencies] +duckdb = "1.4.1" + +[[package]] +name = "duckdb-extension-spatial" +version = "1.4.1" +description = "Duckdb spatial extension" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "duckdb_extension_spatial-1.4.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:911e8c0e2816584beb34ff433853745cba03266936c13a587a7c471fb5e82e01"}, + {file = "duckdb_extension_spatial-1.4.1-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:2a3710130cbd545168a1397e005457ba5e0e24b3fe89c576db2acfb88ac6aa6a"}, + {file = "duckdb_extension_spatial-1.4.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d32e5745bfeb839e288e9ee741d6f16c91f99ec0514fcc42715a0365934baf3c"}, + {file = "duckdb_extension_spatial-1.4.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:997757a23f3bd8b15a26d0093a6bf83c291cce16d78d3a26edf9443df3fb462e"}, + {file = "duckdb_extension_spatial-1.4.1-py3-none-win_amd64.whl", hash = "sha256:ac19e649d78a31f74c19817023e8332930b9d28c137455e0aeaa6f044a340dc0"}, +] + +[package.dependencies] +duckdb = "1.4.1" + +[[package]] +name = "duckdb-extensions" +version = "1.4.1" +description = "DuckDB extensions as python package" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "duckdb_extensions-1.4.1-py3-none-any.whl", hash = "sha256:153227ced6e3e7ae3a963611bc90299050bd3437c6eed36edbe25a9173add12a"}, +] + +[package.dependencies] +duckdb = "1.4.1" + [[package]] name = "dunamai" version = "1.25.0" @@ -7355,4 +7405,4 @@ server = ["deepnote-python-lsp-server", "jupyter-resource-usage", "jupyter-serve [metadata] lock-version = "2.1" python-versions = ">=3.9.0,<3.14,!=3.9.7" -content-hash = "1f130b5dd4d909faaca0ae2874ea688283574fd8a658bdb8ab75b32a89cb45ac" +content-hash = "d30f63bb45bb945b636bf86f2a91f659db8f0fcd942b760aaa77cb73da685309" diff --git a/pyproject.toml b/pyproject.toml index ce6e235..70e62fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ dependencies = [ "duckdb>=1.1.0,<2.0.0; python_version < '3.12'", "duckdb>=1.1.0,<2.0.0; python_version >= '3.12'", "duckdb>=1.4.1,<2.0.0; python_version >= '3.13'", + "duckdb-extensions>=1.1.0,<2.0.0", # bake in as dependency to not pull extensions from the internet on every notebook start + "duckdb-extension-spatial>=1.1.0,<2.0.0", + "duckdb-extension-excel>=1.1.0,<2.0.0", "google-cloud-bigquery-storage==2.16.2; python_version < '3.13'", "google-cloud-bigquery-storage>=2.33.1,<3; python_version>='3.13'", diff --git a/tests/unit/test_duckdb_sql.py b/tests/unit/test_duckdb_sql.py new file mode 100644 index 0000000..556a273 --- /dev/null +++ b/tests/unit/test_duckdb_sql.py @@ -0,0 +1,137 @@ +from contextlib import contextmanager +from unittest import mock + +import pandas as pd +import pytest + +from deepnote_toolkit.sql.duckdb_sql import ( + _get_duckdb_connection, + _set_sample_size, + _set_scan_all_frames, +) + + +@contextmanager +def fresh_duckdb_connection(): + import deepnote_toolkit.sql.duckdb_sql as duckdb_sql_module + + duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None + conn = _get_duckdb_connection() + + try: + yield conn + finally: + conn.close() + duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None + + +@pytest.fixture(scope="function") +def duckdb_connection(): + with fresh_duckdb_connection() as conn: + yield conn + + +@pytest.mark.parametrize("extension_name", ["spatial", "excel"]) +def test_extension_installed_and_loadable(duckdb_connection, extension_name): + result = duckdb_connection.execute( + f"SELECT installed FROM duckdb_extensions() WHERE extension_name = '{extension_name}'" + ).fetchone() + + assert ( + result is not None + ), f"{extension_name} extension should be found in duckdb_extensions()" + assert result[0] is True, f"{extension_name} extension should be installed" + + loaded_result = duckdb_connection.execute( + f"SELECT loaded FROM duckdb_extensions() WHERE extension_name = '{extension_name}'" + ).fetchone() + assert loaded_result[0] is True, f"{extension_name} extension should be loaded" + + +def test_connection_singleton_pattern(): + conn1 = _get_duckdb_connection() + conn2 = _get_duckdb_connection() + + assert conn1 is conn2, "Connection should be a singleton" + + +def test_set_sample_size(duckdb_connection): + _set_sample_size(duckdb_connection, 50000) + result = duckdb_connection.execute( + "SELECT value FROM duckdb_settings() WHERE name = 'pandas_analyze_sample'" + ).fetchone() + assert int(result[0]) == 50000 + + +def test_set_scan_all_frames(duckdb_connection): + _set_scan_all_frames(duckdb_connection, False) + result = duckdb_connection.execute( + "SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'" + ).fetchone() + assert result[0] == "false" + + _set_scan_all_frames(duckdb_connection, True) + result = duckdb_connection.execute( + "SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'" + ).fetchone() + assert result[0] == "true" + + +@mock.patch("deepnote_toolkit.sql.duckdb_sql.import_extension") +def test_connection_returns_successfully_when_import_extension_fails( + mock_import_extension, +): + mock_import_extension.side_effect = Exception("Failed to import extension") + + with fresh_duckdb_connection() as conn: + assert conn is not None + result = conn.execute( + "SELECT extension_name, loaded FROM duckdb_extensions()" + ).df() + assert result is not None + # check that spatial and excel extensions are not loaded as import extension failed + result = result[result["extension_name"].isin(["spatial", "excel"])] + assert all(result["loaded"]) is False + + +@mock.patch("duckdb.DuckDBPyConnection.load_extension") +def test_connection_returns_successfully_when_load_extension_fails(mock_load_extension): + mock_load_extension.side_effect = Exception("Failed to load extension") + + with fresh_duckdb_connection() as conn: + assert conn is not None + result = conn.execute( + "SELECT extension_name, loaded FROM duckdb_extensions()" + ).df() + assert result is not None + # check that spatial and excel extensions are not loaded as import extension failed + result = result[result["extension_name"].isin(["spatial", "excel"])] + assert all(result["loaded"]) is False + + +def test_excel_extension_roundtrip(duckdb_connection, tmp_path): + test_data = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "score": [95.5, 87.3, 91.2], + } + ) + duckdb_connection.register("test_table", test_data) + excel_path = tmp_path / "test_data.xlsx" + duckdb_connection.execute( + f"COPY test_table TO '{excel_path}' WITH (FORMAT xlsx, HEADER true)" + ) + duckdb_connection.unregister("test_table") + + assert excel_path.exists(), "Excel file should be created" + + # read with spatial extension + result = duckdb_connection.execute(f"SELECT * FROM st_read('{excel_path}')").df() + diff = test_data.compare(result) + assert diff.empty, "Data should be the same" + + # read with excel extension + result = duckdb_connection.execute(f"SELECT * FROM read_xlsx('{excel_path}')").df() + diff = test_data.compare(result) + assert diff.empty, "Data should be the same"