Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ celerybeat.pid

# Environments
.env
.env.*
.venv
env/
venv/
Expand Down
3 changes: 3 additions & 0 deletions deepnote_toolkit/sql/jinjasql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def render_jinja_sql_template(template, param_style=None):
Args:
template (str): The Jinja SQL template to render.
param_style (str, optional): The parameter style to use. Defaults to "pyformat".
Common styles: "qmark" (?), "format" (%s), "pyformat" (%(name)s)

Returns:
str: The rendered SQL query.
"""

escaped_template = _escape_jinja_template(template)

# Default to pyformat for backwards compatibility
# Note: Some databases like Trino require "qmark" or "format" style
jinja_sql = JinjaSql(
param_style=param_style if param_style is not None else "pyformat"
)
Expand Down
17 changes: 16 additions & 1 deletion deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class ExecuteSqlError(Exception):
del sql_alchemy_dict["params"]["snowflake_private_key_passphrase"]

param_style = sql_alchemy_dict.get("param_style")

# Auto-detect param_style for databases that don't support pyformat default
if param_style is None:
url_obj = make_url(sql_alchemy_dict["url"])
# Mapping of SQLAlchemy dialect names to their required param_style
dialect_param_styles = {
"trino": "qmark", # Trino requires ? placeholders with list/tuple params
}
param_style = dialect_param_styles.get(url_obj.drivername)

skip_template_render = re.search(
"^snowflake.*host=.*.proxy.cloud.getdbt.com", sql_alchemy_dict["url"]
)
Expand Down Expand Up @@ -425,10 +435,15 @@ def _execute_sql_on_engine(engine, query, bind_params):
connection.connection if needs_raw_connection else connection
)

# pandas.read_sql_query expects params as tuple (not list) for qmark/format style
params_for_pandas = (
tuple(bind_params) if isinstance(bind_params, list) else bind_params
)

return pd.read_sql_query(
query,
con=connection_for_pandas,
params=bind_params,
params=params_for_pandas,
coerce_float=coerce_float,
)
except ResourceClosedError:
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ dev = [
"poetry-dynamic-versioning>=1.4.0,<2.0.0",
"twine>=6.1.0,<7.0.0",
"codespell>=2.3.0,<3.0.0",
"pytest-subtests>=0.15.0,<0.16.0"
"pytest-subtests>=0.15.0,<0.16.0",
"python-dotenv>=1.2.1,<2.0.0"
]
license-check = [
# Dependencies needed for license checking that aren't in main production dependencies
Expand Down
227 changes: 227 additions & 0 deletions tests/integration/test_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import json
import os
from contextlib import contextmanager
from pathlib import Path
from unittest import mock
from urllib.parse import quote

import pandas as pd
import pytest
from dotenv import load_dotenv
from trino import dbapi
from trino.auth import BasicAuthentication

from deepnote_toolkit import env as dnenv
from deepnote_toolkit.sql.sql_execution import execute_sql


@contextmanager
def use_trino_sql_connection(connection_json, env_var_name="TEST_TRINO_CONNECTION"):
dnenv.set_env(env_var_name, connection_json)
try:
yield env_var_name
finally:
dnenv.unset_env(env_var_name)


@pytest.fixture(scope="module")
def trino_credentials():
env_path = Path(__file__).parent.parent.parent / ".env"

if env_path.exists():
load_dotenv(env_path)

host = os.getenv("TRINO_HOST")
port = os.getenv("TRINO_PORT", "8080")
user = os.getenv("TRINO_USER")
password = os.getenv("TRINO_PASSWORD")
catalog = os.getenv("TRINO_CATALOG", "system")
schema = os.getenv("TRINO_SCHEMA", "runtime")
http_scheme = os.getenv("TRINO_HTTP_SCHEME", "https")

if not host or not user:
pytest.skip(
"Trino credentials not found. "
"Please set TRINO_HOST and TRINO_USER in .env file"
)

return {
"host": host,
"port": int(port),
"user": user,
"password": password,
"catalog": catalog,
"schema": schema,
"http_scheme": http_scheme,
}


@pytest.fixture(scope="module")
def trino_connection(trino_credentials):
auth = None

if trino_credentials["password"]:
auth = BasicAuthentication(
trino_credentials["user"], trino_credentials["password"]
)

conn = dbapi.connect(
host=trino_credentials["host"],
port=trino_credentials["port"],
user=trino_credentials["user"],
auth=auth,
http_scheme=trino_credentials["http_scheme"],
catalog=trino_credentials["catalog"],
schema=trino_credentials["schema"],
)

try:
yield conn
finally:
conn.close()


class TestTrinoConnection:
"""Test Trino database connection."""

def test_connection_established(self, trino_connection):
"""Test that connection to Trino is established."""
cursor = trino_connection.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()

assert result is not None
assert result[0] == 1

cursor.close()

def test_show_catalogs(self, trino_connection):
"""Test listing available catalogs."""
cursor = trino_connection.cursor()
cursor.execute("SHOW CATALOGS")
catalogs = cursor.fetchall()

assert len(catalogs) > 0
assert any("system" in str(catalog) for catalog in catalogs)

cursor.close()


@pytest.fixture
def trino_toolkit_connection(trino_credentials):
"""Create a Trino connection JSON for deepnote toolkit."""
username = quote(trino_credentials["user"], safe="")
password_part = (
f":{quote(trino_credentials['password'], safe='')}"
if trino_credentials["password"]
else ""
)
connection_url = (
f"trino://{username}{password_part}"
f"@{trino_credentials['host']}:{trino_credentials['port']}"
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
)

# Trino uses `qmark` paramstyle (`?` placeholders with list/tuple params), not pyformat, which is the default
connection_json = json.dumps(
{
"url": connection_url,
"params": {},
"param_style": "qmark",
}
)

with use_trino_sql_connection(connection_json) as env_var_name:
yield env_var_name


class TestTrinoWithDeepnoteToolkit:
"""Test Trino connection using Toolkit's SQL execution."""

def test_execute_sql_simple_query(self, trino_toolkit_connection):
result = execute_sql(
template="SELECT 1 as test_value",
sql_alchemy_json_env_var=trino_toolkit_connection,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "test_value" in result.columns
assert result["test_value"].iloc[0] == 1

def test_execute_sql_with_jinja_template(self, trino_toolkit_connection):
test_string = "test string"
test_number = 123

def mock_get_variable_value(variable_name):
variables = {
"test_string_var": test_string,
"test_number_var": test_number,
}
return variables[variable_name]

with mock.patch(
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
side_effect=mock_get_variable_value,
):
result = execute_sql(
template="SELECT {{test_string_var}} as message, {{test_number_var}} as number",
sql_alchemy_json_env_var=trino_toolkit_connection,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "message" in result.columns
assert "number" in result.columns
assert result["message"].iloc[0] == test_string
assert result["number"].iloc[0] == test_number

def test_execute_sql_with_autodetection(self, trino_credentials):
"""
Test execute_sql with auto-detection of param_style
(regression reported in BLU-5135)

This simulates the real-world scenario where the backend provides a connection
JSON without explicit param_style, and Toolkit must auto-detect it.
"""

username = quote(trino_credentials["user"], safe="")
password_part = (
f":{quote(trino_credentials['password'], safe='')}"
if trino_credentials["password"]
else ""
)
connection_url = (
f"trino://{username}{password_part}"
f"@{trino_credentials['host']}:{trino_credentials['port']}"
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
)

connection_json = json.dumps(
{
"url": connection_url,
"params": {},
# NO param_style - should auto-detect to `qmark` for Trino
}
)

test_value = "test value"

with (
use_trino_sql_connection(
connection_json, "TEST_TRINO_AUTODETECT"
) as env_var_name,
mock.patch(
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
return_value=test_value,
),
):
result = execute_sql(
template="SELECT {{test_var}} as detected",
sql_alchemy_json_env_var=env_var_name,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "detected" in result.columns
assert result["detected"].iloc[0] == test_value
Loading