Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions deepnote_toolkit/sql/duckdb_sql.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import sys

import duckdb
from duckdb_extensions import import_extension
from packaging.version import Version

from deepnote_toolkit.logging import LoggerManager

_DEEPNOTE_DUCKDB_CONNECTION = None
_DEFAULT_DUCKDB_SAMPLE_SIZE = 20_000

Expand Down Expand Up @@ -40,16 +43,32 @@ def _get_duckdb_connection():
duckdb.Connection: A connection to the DuckDB database.
"""
global _DEEPNOTE_DUCKDB_CONNECTION
logger = LoggerManager().get_logger()

if not _DEEPNOTE_DUCKDB_CONNECTION:
_DEEPNOTE_DUCKDB_CONNECTION = duckdb.connect(
database=":memory:", read_only=False
)

# DuckDB extensions are loaded from included wheels to prevent loading them
# from the internet on every notebook start
#
# Install and load the spatial extension. Primary use case: reading xlsx files
# e.g. SELECT * FROM st_read('excel.xlsx')
_DEEPNOTE_DUCKDB_CONNECTION.execute("install spatial;")
_DEEPNOTE_DUCKDB_CONNECTION.execute("load spatial;")
# there is also official excel extension, which mentions that Excel support from spatial extension
# may be removed in the future (see: https://duckdb.org/docs/stable/core_extensions/excel)
for extension_name in ["spatial", "excel"]:
try:
import_extension(
name=extension_name,
force_install=True,
con=_DEEPNOTE_DUCKDB_CONNECTION,
)
_DEEPNOTE_DUCKDB_CONNECTION.load_extension(extension_name)
except Exception as e:
# Extensions are optional and connection still works, users are able to load
# them manually if needed (pulling them from internet in this case as fallback)
logger.warning(f"Failed to load DuckDB {extension_name} extension: {e}")

_set_sample_size(_DEEPNOTE_DUCKDB_CONNECTION, _DEFAULT_DUCKDB_SAMPLE_SIZE)

Expand Down
52 changes: 51 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ dependencies = [
"duckdb>=1.1.0,<2.0.0; python_version < '3.12'",
"duckdb>=1.1.0,<2.0.0; python_version >= '3.12'",
"duckdb>=1.4.1,<2.0.0; python_version >= '3.13'",
"duckdb-extensions>=1.1.0,<2.0.0", # bake in as dependency to not pull extensions from the internet on every notebook start
"duckdb-extension-spatial>=1.1.0,<2.0.0",
"duckdb-extension-excel>=1.1.0,<2.0.0",
"google-cloud-bigquery-storage==2.16.2; python_version < '3.13'",
"google-cloud-bigquery-storage>=2.33.1,<3; python_version>='3.13'",

Expand Down
137 changes: 137 additions & 0 deletions tests/unit/test_duckdb_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from contextlib import contextmanager
from unittest import mock

import pandas as pd
import pytest

from deepnote_toolkit.sql.duckdb_sql import (
_get_duckdb_connection,
_set_sample_size,
_set_scan_all_frames,
)


@contextmanager
def fresh_duckdb_connection():
import deepnote_toolkit.sql.duckdb_sql as duckdb_sql_module

duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None
conn = _get_duckdb_connection()

try:
yield conn
finally:
conn.close()
duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None


@pytest.fixture(scope="function")
def duckdb_connection():
with fresh_duckdb_connection() as conn:
yield conn


@pytest.mark.parametrize("extension_name", ["spatial", "excel"])
def test_extension_installed_and_loadable(duckdb_connection, extension_name):
result = duckdb_connection.execute(
f"SELECT installed FROM duckdb_extensions() WHERE extension_name = '{extension_name}'"
).fetchone()

assert (
result is not None
), f"{extension_name} extension should be found in duckdb_extensions()"
assert result[0] is True, f"{extension_name} extension should be installed"

loaded_result = duckdb_connection.execute(
f"SELECT loaded FROM duckdb_extensions() WHERE extension_name = '{extension_name}'"
).fetchone()
assert loaded_result[0] is True, f"{extension_name} extension should be loaded"


def test_connection_singleton_pattern():
conn1 = _get_duckdb_connection()
conn2 = _get_duckdb_connection()

assert conn1 is conn2, "Connection should be a singleton"


def test_set_sample_size(duckdb_connection):
_set_sample_size(duckdb_connection, 50000)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'pandas_analyze_sample'"
).fetchone()
assert int(result[0]) == 50000


def test_set_scan_all_frames(duckdb_connection):
_set_scan_all_frames(duckdb_connection, False)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'"
).fetchone()
assert result[0] == "false"

_set_scan_all_frames(duckdb_connection, True)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'"
).fetchone()
assert result[0] == "true"


@mock.patch("deepnote_toolkit.sql.duckdb_sql.import_extension")
def test_connection_returns_successfully_when_import_extension_fails(
mock_import_extension,
):
mock_import_extension.side_effect = Exception("Failed to import extension")

with fresh_duckdb_connection() as conn:
assert conn is not None
result = conn.execute(
"SELECT extension_name, loaded FROM duckdb_extensions()"
).df()
assert result is not None
# check that spatial and excel extensions are not loaded as import extension failed
result = result[result["extension_name"].isin(["spatial", "excel"])]
assert all(result["loaded"]) is False


@mock.patch("duckdb.DuckDBPyConnection.load_extension")
def test_connection_returns_successfully_when_load_extension_fails(mock_load_extension):
mock_load_extension.side_effect = Exception("Failed to load extension")

with fresh_duckdb_connection() as conn:
assert conn is not None
result = conn.execute(
"SELECT extension_name, loaded FROM duckdb_extensions()"
).df()
assert result is not None
# check that spatial and excel extensions are not loaded as import extension failed
result = result[result["extension_name"].isin(["spatial", "excel"])]
assert all(result["loaded"]) is False


def test_excel_extension_roundtrip(duckdb_connection, tmp_path):
test_data = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"score": [95.5, 87.3, 91.2],
}
)
duckdb_connection.register("test_table", test_data)
excel_path = tmp_path / "test_data.xlsx"
duckdb_connection.execute(
f"COPY test_table TO '{excel_path}' WITH (FORMAT xlsx, HEADER true)"
)
duckdb_connection.unregister("test_table")

assert excel_path.exists(), "Excel file should be created"

# read with spatial extension
result = duckdb_connection.execute(f"SELECT * FROM st_read('{excel_path}')").df()
diff = test_data.compare(result)
assert diff.empty, "Data should be the same"

# read with excel extension
result = duckdb_connection.execute(f"SELECT * FROM read_xlsx('{excel_path}')").df()
diff = test_data.compare(result)
assert diff.empty, "Data should be the same"