Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ project_urls =
[options]
packages = find:
install_requires =
firebolt-sdk
firebolt-sdk>=0.8.0
sqlalchemy>=1.0.0
python_requires = >=3.6
package_dir =
Expand Down
21 changes: 10 additions & 11 deletions src/firebolt_db/firebolt_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import firebolt.db as dbapi
import sqlalchemy.types as sqltypes
from firebolt.client.auth import UsernamePassword
from firebolt.db import Cursor
from sqlalchemy.engine import Connection as AlchemyConnection
from sqlalchemy.engine import ExecutionContext, default
Expand Down Expand Up @@ -101,24 +102,22 @@ def __init__(
def dbapi(cls) -> ModuleType:
return dbapi

# Build firebolt-sdk compatible connection arguments.
# URL format : firebolt://username:password@host:port/db_name
def create_connect_args(self, url: URL) -> Tuple[List, Dict]:
"""
Build firebolt-sdk compatible connection arguments.
URL format : firebolt://username:password@host:port/db_name
"""
parameters = dict(url.query)
# parameters are all passed as a string, we need to convert
# bool flag to boolean for SDK compatibility
token_cache_flag = bool(strtobool(parameters.pop("use_token_cache", "True")))
kwargs = {
"database": url.host or None,
"username": url.username or None,
"password": url.password or None,
"auth": UsernamePassword(url.username, url.password, token_cache_flag),
"engine_name": url.database,
}
parameters = dict(url.query)
if "account_name" in parameters:
kwargs["account_name"] = parameters.pop("account_name")
if "use_token_cache" in parameters:
# parameters are all passed as a string, we need to convert it
# to boolean for SDK compatibility
kwargs["use_token_cache"] = bool(
strtobool(parameters.pop("use_token_cache"))
)
self._set_parameters = parameters
# If URL override is not provided leave it to the sdk to determine the endpoint
if "FIREBOLT_BASE_URL" in os.environ:
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/test_firebolt_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ def test_create_connect_args(self, dialect: FireboltDialect):
with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}):
result_list, result_dict = dialect.create_connect_args(u)
assert result_dict["engine_name"] == "test_engine_name"
assert result_dict["username"] == "test_user@email"
assert result_dict["password"] == "test_password"
assert result_dict["auth"].username == "test_user@email"
assert result_dict["auth"].password == "test_password"
assert result_dict["auth"]._use_token_cache is True
assert result_dict["database"] == "test_db_name"
assert result_dict["api_endpoint"] == "test_url"
assert "username" not in result_dict
assert "password" not in result_dict
assert result_list == []
# No endpoint override
with mock.patch.dict(os.environ, {}, clear=True):
Expand Down Expand Up @@ -74,10 +77,9 @@ def test_create_connect_args_token_cache(
)
u = url.make_url(connection_url)
result_list, result_dict = dialect.create_connect_args(u)
assert (
"use_token_cache" in result_dict
), "use_token_cache was not parsed correctly from connection string"
assert result_dict["use_token_cache"] == expected
assert result_dict["auth"].username == "test_user@email"
assert result_dict["auth"].password == "test_password"
assert result_dict["auth"]._use_token_cache == expected
assert dialect._set_parameters == {"param1": "1", "param2": "2"}

def test_do_execute(
Expand Down