diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 982fdc5..8268684 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -56,6 +56,8 @@ jobs: env: USER_NAME: ${{ secrets.FIREBOLT_USERNAME }} PASSWORD: ${{ secrets.FIREBOLT_PASSWORD }} + SERVICE_ID: ${{ secrets.SERVICE_ID }} + SERVICE_SECRET: ${{ secrets.SERVICE_SECRET }} DATABASE_NAME: ${{ steps.setup.outputs.database_name }} ENGINE_NAME: ${{ steps.setup.outputs.engine_name }} ENGINE_URL: ${{ steps.setup.outputs.engine_url }} diff --git a/.github/workflows/python-integration-tests.yml b/.github/workflows/python-integration-tests.yml index 3080b31..aa40b8c 100644 --- a/.github/workflows/python-integration-tests.yml +++ b/.github/workflows/python-integration-tests.yml @@ -33,6 +33,8 @@ jobs: env: USER_NAME: ${{ secrets.FIREBOLT_USERNAME }} PASSWORD: ${{ secrets.FIREBOLT_PASSWORD }} + SERVICE_ID: ${{ secrets.SERVICE_ID }} + SERVICE_SECRET: ${{ secrets.SERVICE_SECRET }} DATABASE_NAME: ${{ steps.setup.outputs.database_name }} ENGINE_NAME: ${{ steps.setup.outputs.engine_name }} ENGINE_URL: ${{ steps.setup.outputs.engine_url }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 91d09c5..cf03584 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,6 +40,8 @@ jobs: env: USER_NAME: ${{ secrets.FIREBOLT_USERNAME }} PASSWORD: ${{ secrets.FIREBOLT_PASSWORD }} + SERVICE_ID: ${{ secrets.SERVICE_ID }} + SERVICE_SECRET: ${{ secrets.SERVICE_SECRET }} DATABASE_NAME: ${{ steps.setup.outputs.database_name }} ENGINE_NAME: ${{ steps.setup.outputs.engine_name }} ENGINE_URL: ${{ steps.setup.outputs.engine_url }} diff --git a/.legitignore b/.legitignore new file mode 100644 index 0000000..8b109ac --- /dev/null +++ b/.legitignore @@ -0,0 +1 @@ +tests/unit/** # Ignore everything in the unit test directory diff --git a/src/firebolt_db/firebolt_dialect.py b/src/firebolt_db/firebolt_dialect.py index c94270a..192259f 100644 --- a/src/firebolt_db/firebolt_dialect.py +++ b/src/firebolt_db/firebolt_dialect.py @@ -5,7 +5,7 @@ import firebolt.db as dbapi import sqlalchemy.types as sqltypes -from firebolt.client.auth import Auth, UsernamePassword +from firebolt.client.auth import Auth, ServiceAccount, UsernamePassword from firebolt.db import Cursor from sqlalchemy.engine import Connection as AlchemyConnection from sqlalchemy.engine import ExecutionContext, default @@ -111,9 +111,14 @@ def create_connect_args(self, url: URL) -> Tuple[List, Dict]: # parameters are all passed as a string, we need to convert # bool flag to boolean for SDK compatibility token_cache_flag = bool(strtobool(parameters.pop("use_token_cache", "True"))) + auth = ( + ServiceAccount(url.username, url.password, token_cache_flag) + if "@" not in url.username + else UsernamePassword(url.username, url.password, token_cache_flag) + ) kwargs: Dict[str, Union[str, Auth, Dict[str, Any], None]] = { "database": url.host or None, - "auth": UsernamePassword(url.username, url.password, token_cache_flag), + "auth": auth, "engine_name": url.database, "additional_parameters": {}, } diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index b321b36..ac1c031 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,6 +13,8 @@ DATABASE_NAME_ENV = "DATABASE_NAME" USERNAME_ENV = "USER_NAME" PASSWORD_ENV = "PASSWORD" +SERVICE_ID = "SERVICE_ID" +SERVICE_SECRET = "SERVICE_SECRET" def must_env(var_name: str) -> str: @@ -41,6 +43,16 @@ def password() -> str: return must_env(PASSWORD_ENV) +@fixture(scope="session") +def service_id() -> str: + return must_env(SERVICE_ID) + + +@fixture(scope="session") +def service_secret() -> str: + return must_env(SERVICE_SECRET) + + @fixture(scope="session") def engine( username: str, password: str, database_name: str, engine_name: str @@ -50,12 +62,27 @@ def engine( ) +@fixture(scope="session") +def engine_service_account( + service_id: str, service_secret: str, database_name: str, engine_name: str +) -> Engine: + return create_engine( + f"firebolt://{service_id}:{service_secret}@{database_name}/{engine_name}" + ) + + @fixture(scope="session") def connection(engine: Engine) -> Connection: with engine.connect() as c: yield c +@fixture(scope="session") +def connection_service_account(engine_service_account: Engine) -> Connection: + with engine_service_account.connect() as c: + yield c + + @fixture(scope="session") def event_loop(): loop = asyncio.get_event_loop() diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 5af0198..a43f67d 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -111,3 +111,7 @@ def test_get_columns(self, engine: Engine, fact_table_name: str): assert row_keys[1] == "type" assert row_keys[2] == "nullable" assert row_keys[3] == "default" + + def test_service_account_connect(self, connection_service_account: Connection): + result = connection_service_account.execute("SELECT 1") + assert result.fetchall() == [(1,)] diff --git a/tests/unit/test_firebolt_dialect.py b/tests/unit/test_firebolt_dialect.py index a53c71d..e08f89c 100644 --- a/tests/unit/test_firebolt_dialect.py +++ b/tests/unit/test_firebolt_dialect.py @@ -30,9 +30,25 @@ def test_create_dialect(self, dialect: FireboltDialect): assert isinstance(dialect.type_compiler, FireboltTypeCompiler) assert dialect.context == {} + def test_create_connect_args_service_account(self, dialect: FireboltDialect): + u = url.make_url( + "test_engine://test-sa-user-key:test_password@test_db_name/test_engine_name" + ) + with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}): + result_list, result_dict = dialect.create_connect_args(u) + assert result_dict["engine_name"] == "test_engine_name" + assert result_dict["auth"].client_id == "test-sa-user-key" + assert result_dict["auth"].client_secret == "test_password" + assert result_dict["auth"]._use_token_cache is True + assert result_dict["database"] == "test_db_name" + assert result_dict["api_endpoint"] == "test_url" + assert "username" not in result_dict + assert "password" not in result_dict + assert result_list == [] + def test_create_connect_args(self, dialect: FireboltDialect): connection_url = ( - "test_engine://test_user@email:test_password@test_db_name/test_engine_name" + "test_engine://test_user@email:test_password@test_db_name/test_engine_name?" ) u = url.make_url(connection_url) with mock.patch.dict(os.environ, {"FIREBOLT_BASE_URL": "test_url"}):