diff --git a/setup.cfg b/setup.cfg index 14a4d56..04e8849 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ project_urls = [options] packages = find: install_requires = - firebolt-sdk + firebolt-sdk>=0.8.0 sqlalchemy>=1.0.0 python_requires = >=3.6 package_dir = diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index b3dafb1..e48f24d 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -5,6 +5,7 @@ import firebolt.db as dbapi import sqlalchemy.types as sqltypes +from firebolt.client.auth import UsernamePassword from firebolt.db import Cursor from sqlalchemy.engine import Connection as AlchemyConnection from sqlalchemy.engine import ExecutionContext, default @@ -101,24 +102,22 @@ def __init__( def dbapi(cls) -> ModuleType: return dbapi - # Build firebolt-sdk compatible connection arguments. - # URL format : firebolt://username:password@host:port/db_name def create_connect_args(self, url: URL) -> Tuple[List, Dict]: + """ + Build firebolt-sdk compatible connection arguments. + URL format : firebolt://username:password@host:port/db_name + """ + parameters = dict(url.query) + # parameters are all passed as a string, we need to convert + # bool flag to boolean for SDK compatibility + token_cache_flag = bool(strtobool(parameters.pop("use_token_cache", "True"))) kwargs = { "database": url.host or None, - "username": url.username or None, - "password": url.password or None, + "auth": UsernamePassword(url.username, url.password, token_cache_flag), "engine_name": url.database, } - parameters = dict(url.query) if "account_name" in parameters: kwargs["account_name"] = parameters.pop("account_name") - if "use_token_cache" in parameters: - # parameters are all passed as a string, we need to convert it - # to boolean for SDK compatibility - kwargs["use_token_cache"] = bool( - strtobool(parameters.pop("use_token_cache")) - ) self._set_parameters = parameters # If URL override is not provided leave it to the sdk to determine the endpoint if "FIREBOLT_BASE_URL" in os.environ: diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index cd0fcaf..b0a6e43 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -38,10 +38,13 @@ def test_create_connect_args(self, dialect: FireboltDialect): with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}): result_list, result_dict = dialect.create_connect_args(u) assert result_dict["engine_name"] == "test_engine_name" - assert result_dict["username"] == "test_user@email" - assert result_dict["password"] == "test_password" + assert result_dict["auth"].username == "test_user@email" + assert result_dict["auth"].password == "test_password" + assert result_dict["auth"]._use_token_cache is True assert result_dict["database"] == "test_db_name" assert result_dict["api_endpoint"] == "test_url" + assert "username" not in result_dict + assert "password" not in result_dict assert result_list == [] # No endpoint override with mock.patch.dict(os.environ, {}, clear=True): @@ -74,10 +77,9 @@ def test_create_connect_args_token_cache( ) u = url.make_url(connection_url) result_list, result_dict = dialect.create_connect_args(u) - assert ( - "use_token_cache" in result_dict - ), "use_token_cache was not parsed correctly from connection string" - assert result_dict["use_token_cache"] == expected + assert result_dict["auth"].username == "test_user@email" + assert result_dict["auth"].password == "test_password" + assert result_dict["auth"]._use_token_cache == expected assert dialect._set_parameters == {"param1": "1", "param2": "2"} def test_do_execute(