From 05defb6e84bfa08d91049badf8ba61697df4b2db Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 10 Mar 2023 13:19:50 +0200 Subject: [PATCH 01/28] cleanup and rename auth classes --- src/firebolt/async_db/connection.py | 193 ++++++++---------- src/firebolt/client/auth/__init__.py | 5 +- ...rvice_account.py => client_credentials.py} | 10 +- src/firebolt/client/auth/token.py | 41 ---- src/firebolt/client/auth/username_password.py | 85 -------- src/firebolt/client/auth/utils.py | 37 ---- src/firebolt/common/settings.py | 20 +- src/firebolt/service/manager.py | 20 +- 8 files changed, 105 insertions(+), 306 deletions(-) rename src/firebolt/client/auth/{service_account.py => client_credentials.py} (90%) delete mode 100644 src/firebolt/client/auth/token.py delete mode 100644 src/firebolt/client/auth/username_password.py delete mode 100644 src/firebolt/client/auth/utils.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index c7d2dc21fa2..28dae7d6127 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -12,14 +12,13 @@ from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Auth, _get_auth +from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, ) -from firebolt.common.util import validate_engine_name_and_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -34,18 +33,6 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema -AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials - directly to the `connect` function is deprecated. - Pass the `Auth` object instead. - Examples: - >>> from firebolt.client.auth import UsernamePassword - >>> ... - >>> connect(auth=UsernamePassword(username, password), ...) - or - >>> from firebolt.client.auth import Token - >>> ... - >>> connect(auth=Token(access_token), ...)""" - logger = logging.getLogger(__name__) @@ -123,6 +110,95 @@ async def _get_database_default_engine_url( ) as e: raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") +def _validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + if engine_name and engine_url: + raise ConfigurationError( + "Both engine_name and engine_url are provided. Provide only one to connect." + ) + + +async def connect( + auth: Auth, + database: str = None, + engine_name: Optional[str] = None, + engine_url: Optional[str] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + use_token_cache: bool = True, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. + + Args: + `auth` (Auth) Authentication object. + `database` (str): Name of the database to connect + `engine_name` (Optional[str]): Name of the engine to connect to + `engine_url` (Optional[str]): The engine endpoint to use + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `use_token_cache` (bool): Cached authentication token in filesystem + Default: True + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + + Note: + Providing both `engine_name` and `engine_url` will result in an error + + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + if not database: + raise ConfigurationError("database name is required to connect.") + + if not auth: + raise ConfigurationError("auth is required to connect.") + + _validate_engine_name_and_url(engine_name, engine_url) + + api_endpoint = fix_url_schema(api_endpoint) + + # Mypy checks, this should never happen + assert database is not None + + if not engine_name and not engine_url: + engine_url = await _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: + engine_url = await _resolve_engine_url( + engine_name=engine_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + await client.account_id + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, database, auth, api_endpoint, additional_parameters + ) + class OverriddenHttpBackend(AutoBackend): """ @@ -260,92 +336,3 @@ async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: await self.aclose() - - -async def connect( - database: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, - account_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, - additional_parameters: Dict[str, Any] = {}, -) -> Connection: - """Connect to Firebolt database. - - Args: - `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. - `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use - `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used - `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True - `additional_parameters` (Optional[Dict]): Dictionary of less widely-used - arguments for connection - - Note: - Providing both `engine_name` and `engine_url` will result in an error - - """ - # These parameters are optional in function signature - # but are required to connect. - # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") - - validate_engine_name_and_url(engine_name, engine_url) - - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") - api_endpoint = fix_url_schema(api_endpoint) - - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = await _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - - elif engine_name: - engine_url = await _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - await client.account_id - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return Connection(engine_url, database, auth, api_endpoint, additional_parameters) diff --git a/src/firebolt/client/auth/__init__.py b/src/firebolt/client/auth/__init__.py index 90e4d4149a5..51c9bcdc95d 100644 --- a/src/firebolt/client/auth/__init__.py +++ b/src/firebolt/client/auth/__init__.py @@ -1,5 +1,2 @@ from firebolt.client.auth.base import Auth -from firebolt.client.auth.service_account import ServiceAccount -from firebolt.client.auth.token import Token -from firebolt.client.auth.username_password import UsernamePassword -from firebolt.client.auth.utils import _get_auth +from firebolt.client.auth.client_credentials import ClientCredentials diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/client_credentials.py similarity index 90% rename from src/firebolt/client/auth/service_account.py rename to src/firebolt/client/auth/client_credentials.py index a9181b82e72..4c7bd413de7 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -7,7 +7,7 @@ from firebolt.utils.util import cached_property -class ServiceAccount(_RequestBasedAuth): +class ClientCredentials(_RequestBasedAuth): """Service Account authentication class for Firebolt Database. Gets authentication token using @@ -45,13 +45,15 @@ def __init__( self.client_secret = client_secret super().__init__(use_token_cache) - def copy(self) -> "ServiceAccount": + def copy(self) -> "ClientCredentials": """Make another auth object with same credentials. Returns: - ServiceAccount: Auth object + ClientCredentials: Auth object """ - return ServiceAccount(self.client_id, self.client_secret, self._use_token_cache) + return ClientCredentials( + self.client_id, self.client_secret, self._use_token_cache + ) @cached_property def _token_storage(self) -> Optional[TokenSecureStorage]: diff --git a/src/firebolt/client/auth/token.py b/src/firebolt/client/auth/token.py deleted file mode 100644 index 01864a1cc09..00000000000 --- a/src/firebolt/client/auth/token.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Generator - -from httpx import Request, Response - -from firebolt.client.auth import Auth -from firebolt.utils.exception import AuthorizationError - - -class Token(Auth): - """Token authentication class for Firebolt Database. - - Uses provided token for authentication. Doesn't cache token and doesn't - refresh it on expiration. - - Args: - token (str): Authorization token - - Attributes: - token (str): - """ - - def __init__(self, token: str): - super().__init__(use_token_cache=False) - self._token = token - - def copy(self) -> "Token": - """Make another auth object with same credentials. - - Returns: - Token: Auth object - """ - assert self.token - return Token(self.token) - - def get_new_token_generator(self) -> Generator[Request, Response, None]: - """Raise authorization error since token is invalid or expired. - - Raises: - AuthorizationError: Token is invalid or expired - """ - raise AuthorizationError("Provided token in not valid anymore.") diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py deleted file mode 100644 index 21fbd31936d..00000000000 --- a/src/firebolt/client/auth/username_password.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional - -from firebolt.client.auth.base import AuthRequest -from firebolt.client.auth.request_auth_base import _RequestBasedAuth -from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import AUTH_URL -from firebolt.utils.util import cached_property - - -class UsernamePassword(_RequestBasedAuth): - """Username/Password authentication class for Firebolt Database. - - Gets authentication token using - provided credentials and updates it when it expires. - - Args: - username (str): Username - password (str): Password - use_token_cache (bool): True if token should be cached in filesystem; - False otherwise - - Attributes: - username (str): Username - password (str): Password - """ - - __slots__ = ( - "username", - "password", - "_token", - "_expires", - "_use_token_cache", - "_user_agent", - ) - - requires_response_body = True - - def __init__( - self, - username: str, - password: str, - use_token_cache: bool = True, - ): - self.username = username - self.password = password - super().__init__(use_token_cache) - - def copy(self) -> "UsernamePassword": - """Make another auth object with same credentials. - - Returns: - UsernamePassword: Auth object - """ - return UsernamePassword(self.username, self.password, self._use_token_cache) - - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled - - Returns: - TokenSecureStorage: Token filesystem cache storage - """ - return TokenSecureStorage(username=self.username, password=self.password) - - def _make_auth_request(self) -> AuthRequest: - """Get new token using username and password. - - Yields: - Request: An http request to get token. Expects Response to be sent back - - Raises: - AuthenticationError: Error while authenticating with provided credentials - """ - response = self.request_class( - "POST", - AUTH_URL, - headers={ - "Content-Type": "application/json;charset=UTF-8", - "User-Agent": self._user_agent, - }, - json={"username": self.username, "password": self.password}, - ) - return response diff --git a/src/firebolt/client/auth/utils.py b/src/firebolt/client/auth/utils.py deleted file mode 100644 index f6bb6b09374..00000000000 --- a/src/firebolt/client/auth/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -from firebolt.client.auth import Auth, Token, UsernamePassword -from firebolt.utils.exception import ConfigurationError - - -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - if not access_token: - if not username or not password: - raise ConfigurationError( - "Neither username/password nor access_token are provided. Provide one" - " to authenticate." - ) - return UsernamePassword(username, password, use_token_cache) - if username or password: - raise ConfigurationError( - "Username/password and access_token are both provided. Provide only one" - " to authenticate." - ) - return Token(access_token) diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index 663b17a7fc4..570311c5468 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Optional -from firebolt.client.auth import Auth, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials logger = logging.getLogger(__name__) @@ -23,9 +23,8 @@ >>> ... >>> settings = Settings(auth=Token(access_token), ...)""" -USERNAME_ENV = "FIREBOLT_USER" -PASSWORD_ENV = "FIREBOLT_PASSWORD" -AUTH_TOKEN_ENV = "FIREBOLT_AUTH_TOKEN" +CLIENT_ID_ENV = "FIREBOLT_CLIENT_ID" +CLIENT_SECRET_ENV = "FIREBOLT_CLIENT_SECRET" ACCOUNT_ENV = "FIREBOLT_ACCOUNT" SERVER_ENV = "FIREBOLT_SERVER" DEFAULT_REGION_ENV = "FIREBOLT_DEFAULT_REGION" @@ -39,10 +38,10 @@ def inner() -> Any: def auth_from_env() -> Optional[Auth]: - username = os.environ.get(USERNAME_ENV, None) - password = os.environ.get(PASSWORD_ENV, None) - if username and password: - return UsernamePassword(username, password) + client_id = os.environ.get(CLIENT_ID_ENV, None) + client_secret = os.environ.get(CLIENT_SECRET_ENV, None) + if client_id and client_secret: + return ClientCredentials(client_id, client_secret) return None @@ -63,11 +62,6 @@ class Settings: """ auth: Optional[Auth] = field(default_factory=auth_from_env) - # Authorization - user: Optional[str] = field(default=None) - password: Optional[str] = field(default=None) - # Or - access_token: Optional[str] = field(default_factory=from_env(AUTH_TOKEN_ENV)) account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV)) server: str = field(default_factory=from_env(SERVER_ENV)) diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index 70657301dfb..c8d646bfb78 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -3,7 +3,6 @@ from httpx import Timeout from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx -from firebolt.client.auth import Token, UsernamePassword from firebolt.common import Settings from firebolt.service.provider import get_provider_id from firebolt.utils.util import fix_url_schema @@ -28,25 +27,8 @@ class ResourceManager: def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() - - auth = self.settings.auth - - # Deprecated: we shouldn't support passing credentials after 1.0 release - if auth is None: - if self.settings.access_token: - auth = Token(self.settings.access_token) - else: - # mypy checks - assert self.settings.user - assert self.settings.password - auth = UsernamePassword( - self.settings.user, - self.settings.password, - self.settings.use_token_cache, - ) - self.client = Client( - auth=auth, + auth=self.settings.auth, base_url=fix_url_schema(self.settings.server), account_name=self.settings.account_name, api_endpoint=self.settings.server, From 935721a7fbfddbfde5c7f48a5ee4081d1f22d4e7 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 10 Mar 2023 14:01:06 +0200 Subject: [PATCH 02/28] fix async_db tests --- src/firebolt/async_db/connection.py | 6 +- tests/unit/async_db/conftest.py | 6 +- tests/unit/async_db/test_connection.py | 104 ++++++-------------- tests/unit/conftest.py | 64 +++++------- tests/unit/service/test_resource_manager.py | 40 +------- 5 files changed, 62 insertions(+), 158 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 28dae7d6127..e8ac0cd8ab9 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -118,15 +118,13 @@ def _validate_engine_name_and_url( "Both engine_name and engine_url are provided. Provide only one to connect." ) - async def connect( - auth: Auth, database: str = None, + auth: Auth = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, account_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: """Connect to Firebolt database. @@ -139,8 +137,6 @@ async def connect( `account_name` (Optional[str]): For customers with multiple accounts; if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index fdca43ab438..f6b6abea1f3 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -5,18 +5,18 @@ from pytest_asyncio import fixture as asyncio_fixture from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect +from firebolt.client.auth import Auth from firebolt.common.settings import Settings from tests.unit.db_conftest import * # noqa @asyncio_fixture -async def connection(settings: Settings, db_name: str) -> Connection: +async def connection(settings: Settings, auth: Auth, db_name: str) -> Connection: async with ( await connect( engine_url=settings.server, database=db_name, - username="u", - password="p", + auth=auth, account_name=settings.account_name, api_endpoint=settings.server, ) diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 6da6cf46445..aef2e7f3ba2 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -8,7 +8,7 @@ from pytest_httpx import HTTPXMock from firebolt.async_db.connection import Connection, connect -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.utils.exception import ( @@ -69,8 +69,7 @@ async def test_cursor_initialized( await connect( engine_url=url, database=db_name, - username="u", - password="p", + auth=ClientCredentials("cid", "cs"), api_endpoint=settings.server, ) ) as connection: @@ -96,44 +95,6 @@ async def test_connect_empty_parameters(): pass -async def test_connect_access_token( - settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - check_token_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], - access_token: str, -): - httpx_mock.add_callback(check_token_callback, url=query_url) - async with ( - await connect( - engine_url=settings.server, - database=db_name, - access_token=access_token, - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert await cursor.execute("select*") == -1 - - with raises(ConfigurationError): - async with await connect(engine_url="engine_url", database="database"): - pass - - with raises(ConfigurationError): - async with await connect( - engine_url="engine_url", - database="database", - username="username", - password="password", - access_token="access_token", - ): - pass - - async def test_connect_engine_name( settings: Settings, db_name: str, @@ -157,8 +118,7 @@ async def test_connect_engine_name( engine_url="engine_url", engine_name="engine_name", database="db", - username="username", - password="password", + auth=ClientCredentials("cid", "cs"), account_name="account_name", ): pass @@ -181,8 +141,7 @@ async def test_connect_engine_name( with raises(FireboltEngineError): async with await connect( database="db", - username="username", - password="password", + auth=ClientCredentials("cid", "cs"), engine_name=engine_name, account_name=settings.account_name, api_endpoint=settings.server, @@ -201,8 +160,7 @@ async def test_connect_engine_name( async with await connect( engine_name=engine_name, database=db_name, - username="u", - password="p", + auth=ClientCredentials("cid", "cs"), account_name=settings.account_name, api_endpoint=settings.server, ) as connection: @@ -236,8 +194,7 @@ async def test_connect_default_engine( ) async with await connect( database=db_name, - username="u", - password="p", + auth=ClientCredentials("cid", "cs"), account_name=settings.account_name, api_endpoint=settings.server, ) as connection: @@ -264,6 +221,8 @@ async def test_connection_token_caching( query_url: str, python_query_data: List[List[ColType]], access_token: str, + client_id: str, + client_secret: str, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) @@ -271,34 +230,30 @@ async def test_connection_token_caching( with Patcher(): async with await connect( database=db_name, - username=settings.user, - password=settings.password, + auth=ClientCredentials(client_id, client_secret, use_token_cache=True), engine_url=settings.server, account_name=settings.account_name, api_endpoint=settings.server, - use_token_cache=True, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): async with await connect( database=db_name, - username=settings.user, - password=settings.password, + auth=ClientCredentials(client_id, client_secret, use_token_cache=False), engine_url=settings.server, account_name=settings.account_name, api_endpoint=settings.server, - use_token_cache=False, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" @@ -313,26 +268,19 @@ async def test_connect_with_auth( query_callback: Callable, query_url: str, access_token: str, + auth: Auth, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) - for auth in ( - UsernamePassword( - settings.user, - settings.password, - use_token_cache=False, - ), - Token(access_token), - ): - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ) as connection: - await connection.cursor().execute("select*") + async with await connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + ) as connection: + await connection.cursor().execute("select*") async def test_connect_account_name( @@ -374,10 +322,13 @@ async def test_connect_with_user_agent( db_name: str, query_callback: Callable, query_url: str, + auth_callback: Callable, + auth_url: str, access_token: str, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" + httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( query_callback, url=query_url, @@ -385,7 +336,7 @@ async def test_connect_with_user_agent( ) async with await connect( - auth=Token(access_token), + auth=ClientCredentials("cid", "cs"), database=db_name, engine_url=settings.server, account_name=settings.account_name, @@ -405,16 +356,19 @@ async def test_connect_no_user_agent( db_name: str, query_callback: Callable, query_url: str, + auth_callback: Callable, + auth_url: str, access_token: str, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" + httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} ) async with await connect( - auth=Token(access_token), + auth=ClientCredentials("cid", "cs"), database=db_name, engine_url=settings.server, account_name=settings.account_name, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1b9e17fde8b..1c0e9c7e570 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,3 @@ -from json import loads from re import Pattern, compile from typing import Callable, List @@ -7,7 +6,7 @@ from pyfakefs.fake_filesystem_unittest import Patcher from pytest import fixture -from firebolt.client.auth import Auth, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.settings import Settings from firebolt.model.provider import Provider from firebolt.model.region import Region, RegionKey @@ -30,7 +29,7 @@ ACCOUNT_ENGINE_URL, ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, ACCOUNT_URL, - AUTH_URL, + AUTH_SERVICE_ACCOUNT_URL, DATABASES_URL, ENGINES_URL, ) @@ -52,13 +51,13 @@ def global_fake_fs(request) -> None: @fixture -def username() -> str: - return "email@domain.com" +def client_id() -> str: + return "client_id" @fixture -def password() -> str: - return "*****" +def client_secret() -> str: + return "client_secret" @fixture @@ -117,21 +116,20 @@ def mock_regions(region_1, region_2) -> List[Region]: @fixture -def settings(server: str, region_1: str, username: str, password: str) -> Settings: +def auth(client_id: str, client_secret: str) -> Auth: + return ClientCredentials(client_id, client_secret) + + +@fixture +def settings(server: str, region_1: str, auth: Auth) -> Settings: return Settings( server=server, - user=username, - password=password, + auth=auth, default_region=region_1.name, account_name=None, ) -@fixture -def auth(username: str, password: str) -> Auth: - return UsernamePassword(username, password) - - @fixture def auth_callback(auth_url: str) -> Callable: def do_mock( @@ -149,7 +147,7 @@ def do_mock( @fixture def auth_url(settings: Settings) -> str: - return f"https://{settings.server}{AUTH_URL}" + return f"https://{settings.server}{AUTH_SERVICE_ACCOUNT_URL}" @fixture @@ -327,33 +325,21 @@ def db_api_exceptions(): @fixture -def check_token_callback(access_token: str) -> Callable: - def check_token(request: Request = None, **kwargs) -> Response: - prefix = "Bearer " - assert request, "empty request" - assert "authorization" in request.headers, "missing authorization header" - auth = request.headers["authorization"] - assert auth.startswith(prefix), "invalid authorization header format" - token = auth[len(prefix) :] - assert token == access_token, "invalid authorization token" - - return Response(status_code=httpx.codes.OK, headers={"content-length": "0"}) - - return check_token - - -@fixture -def check_credentials_callback(settings: Settings, access_token: str) -> Callable: +def check_credentials_callback( + client_id: str, client_secret: str, access_token: str +) -> Callable: def check_credentials( - request: Request = None, + request: httpx.Request = None, **kwargs, ) -> Response: assert request, "empty request" - body = loads(request.read()) - assert "username" in body, "Missing username" - assert body["username"] == settings.user, "Invalid username" - assert "password" in body, "Missing password" - assert body["password"] == settings.password, "Invalid password" + body = request.read().decode("utf-8") + assert "client_id" in body, "Missing id" + assert f"client_id={client_id}" in body, "Invalid id" + assert "client_secret" in body, "Missing secret" + assert f"client_secret={client_secret}" in body, "Invalid secret" + assert "grant_type" in body, "Missing grant_type" + assert "grant_type=client_credentials" in body, "Invalid grant_type" return Response( status_code=httpx.codes.OK, diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 42060b2785e..36d36b3acc8 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -5,7 +5,7 @@ from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth, Token, UsernamePassword +from firebolt.client.auth import Auth from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError @@ -14,7 +14,6 @@ def test_rm_credentials( httpx_mock: HTTPXMock, - check_token_callback: Callable, check_credentials_callback: Callable, settings: Settings, auth_url: str, @@ -28,47 +27,18 @@ def test_rm_credentials( url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) rm = ResourceManager(settings) rm.client.get(url) - token_settings = Settings( - access_token=access_token, - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(token_settings) - rm.client.get(url) - - auth_username_password_settings = Settings( - auth=UsernamePassword(settings.user, settings.password), - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(auth_username_password_settings) - rm.client.get(url) - - auth_token_settings = Settings( - auth=Token(access_token), - server=settings.server, - default_region=settings.default_region, - ) - - rm = ResourceManager(auth_token_settings) - rm.client.get(url) - @mark.nofakefs def test_rm_token_cache( httpx_mock: HTTPXMock, check_token_callback: Callable, check_credentials_callback: Callable, - settings: Settings, auth_url: str, account_id_url: Pattern, account_id_callback: Callable, @@ -76,7 +46,7 @@ def test_rm_token_cache( provider_url: str, access_token: str, ) -> None: - """Credentials, that are passed to rm are processed properly.""" + """Credentials, that are passed to rm are cached properly.""" url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) @@ -86,8 +56,7 @@ def test_rm_token_cache( with Patcher(): local_settings = Settings( - user=settings.user, - password=settings.password, + auth=settings.auth, server=settings.server, default_region=settings.default_region, use_token_cache=True, @@ -101,8 +70,7 @@ def test_rm_token_cache( # Do the same, but with use_token_cache=False with Patcher(): local_settings = Settings( - user=settings.user, - password=settings.password, + auth=settings.auth, server=settings.server, default_region=settings.default_region, use_token_cache=False, From bf509d000dcba7081b3fd9ed1f4ea9718ebc6561 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 10 Mar 2023 14:10:35 +0200 Subject: [PATCH 03/28] fix db unit tests --- tests/unit/db/conftest.py | 6 +- tests/unit/db/test_connection.py | 99 ++++++++------------------------ 2 files changed, 28 insertions(+), 77 deletions(-) diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index 1c06ed3a253..eaba5fe0849 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -1,16 +1,16 @@ from pytest import fixture +from firebolt.client.auth import Auth from firebolt.common.settings import Settings from firebolt.db import Connection, Cursor, connect @fixture -def connection(settings: Settings, db_name: str) -> Connection: +def connection(settings: Settings, db_name: str, auth: Auth) -> Connection: with connect( engine_url=settings.server, database=db_name, - username="u", - password="p", + auth=auth, api_endpoint=settings.server, ) as connection: yield connection diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 571c8d7f06d..df9d7ca0e75 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -8,8 +8,8 @@ from pytest import mark, raises, warns from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth, Token, UsernamePassword -from firebolt.common._types import ColType +from firebolt.async_db._types import ColType +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( @@ -59,6 +59,7 @@ def test_cursor_initialized( query_callback: Callable, query_url: str, python_query_data: List[List[ColType]], + auth: Auth, ) -> None: """Connection initialized its cursors properly.""" httpx_mock.add_callback(auth_callback, url=auth_url) @@ -68,9 +69,8 @@ def test_cursor_initialized( with connect( engine_url=url, database=db_name, - username="u", - password="p", api_endpoint=settings.server, + auth=auth, ) as connection: cursor = connection.cursor() @@ -95,44 +95,6 @@ def test_connect_empty_parameters(): pass -def test_connect_access_token( - settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - check_token_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], - access_token: str, -): - httpx_mock.add_callback(check_token_callback, url=query_url) - with ( - connect( - engine_url=settings.server, - database=db_name, - access_token=access_token, - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert cursor.execute("select*") == -1 - - with raises(ConfigurationError): - with connect(engine_url="engine_url", database="database"): - pass - - with raises(ConfigurationError): - with connect( - engine_url="engine_url", - database="database", - username="username", - password="password", - access_token="access_token", - ): - pass - - def test_connect_engine_name( settings: Settings, db_name: str, @@ -148,6 +110,7 @@ def test_connect_engine_name( get_engine_url_by_id_callback: Callable, python_query_data: List[List[ColType]], account_id: str, + auth: Auth, ): """connect properly handles engine_name""" @@ -156,8 +119,7 @@ def test_connect_engine_name( engine_url="engine_url", engine_name="engine_name", database="db", - username="username", - password="password", + auth=auth, ) httpx_mock.add_callback(auth_callback, url=auth_url) @@ -178,8 +140,7 @@ def test_connect_engine_name( with raises(FireboltEngineError): connect( database="db", - username="username", - password="password", + auth=auth, engine_name=engine_name, account_name=settings.account_name, api_endpoint=settings.server, @@ -197,8 +158,7 @@ def test_connect_engine_name( with connect( engine_name=engine_name, database=db_name, - username="u", - password="p", + auth=auth, account_name=settings.account_name, api_endpoint=settings.server, ) as connection: @@ -219,6 +179,7 @@ def test_connect_default_engine( engine_by_db_url: str, python_query_data: List[List[ColType]], account_id: str, + auth: Auth, ): httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) @@ -234,8 +195,7 @@ def test_connect_default_engine( ) with connect( database=db_name, - username="u", - password="p", + auth=auth, account_name=settings.account_name, api_endpoint=settings.server, ) as connection: @@ -282,6 +242,8 @@ def test_connection_token_caching( query_url: str, python_query_data: List[List[ColType]], access_token: str, + client_id: str, + client_secret: str, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) @@ -289,30 +251,26 @@ def test_connection_token_caching( with Patcher(): with connect( database=db_name, - username=settings.user, - password=settings.password, + auth=ClientCredentials(client_id, client_secret, use_token_cache=True), engine_url=settings.server, account_name=settings.account_name, api_endpoint=settings.server, - use_token_cache=True, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): with connect( database=db_name, - username=settings.user, - password=settings.password, + auth=ClientCredentials(client_id, client_secret, use_token_cache=False), engine_url=settings.server, account_name=settings.account_name, api_endpoint=settings.server, - use_token_cache=False, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=settings.user, password=settings.password) + ts = TokenSecureStorage(username=client_id, password=client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" @@ -327,26 +285,19 @@ def test_connect_with_auth( query_callback: Callable, query_url: str, access_token: str, + auth: Auth, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) - for auth in ( - UsernamePassword( - settings.user, - settings.password, - use_token_cache=False, - ), - Token(access_token), - ): - with connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ) as connection: - connection.cursor().execute("select*") + with connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + ) as connection: + connection.cursor().execute("select*") def test_connect_account_name( From a5bb82c7b80fb8f99357e74cfe53c00132b51fe8 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 10 Mar 2023 14:55:40 +0200 Subject: [PATCH 04/28] fix remaining tests --- src/firebolt/common/settings.py | 1 - src/firebolt/utils/urls.py | 1 - tests/unit/client/auth/test_auth.py | 46 ++++++------ tests/unit/client/auth/test_auth_async.py | 46 ++++++------ tests/unit/client/auth/test_request_auth.py | 36 ++------- tests/unit/client/auth/test_token.py | 40 ---------- tests/unit/client/test_client.py | 81 ++++++++++----------- tests/unit/client/test_client_async.py | 35 +++------ tests/unit/common/test_settings.py | 36 ++------- tests/unit/conftest.py | 21 ++++++ tests/unit/service/test_resource_manager.py | 23 ++++-- 11 files changed, 147 insertions(+), 219 deletions(-) delete mode 100644 tests/unit/client/auth/test_token.py diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index 570311c5468..557b0ddcc2d 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -66,7 +66,6 @@ class Settings: account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV)) server: str = field(default_factory=from_env(SERVER_ENV)) default_region: str = field(default_factory=from_env(DEFAULT_REGION_ENV)) - use_token_cache: bool = field(default=True) def __post_init__(self) -> None: """Validate that either creds or token is provided. diff --git a/src/firebolt/utils/urls.py b/src/firebolt/utils/urls.py index 0149576e47d..f94d72a0461 100644 --- a/src/firebolt/utils/urls.py +++ b/src/firebolt/utils/urls.py @@ -1,4 +1,3 @@ -AUTH_URL = "/auth/v1/login" AUTH_SERVICE_ACCOUNT_URL = "/auth/v1/token" DATABASES_URL = "/core/v1/account/databases" diff --git a/tests/unit/client/auth/test_auth.py b/tests/unit/client/auth/test_auth.py index 8827546d170..94203b0875a 100644 --- a/tests/unit/client/auth/test_auth.py +++ b/tests/unit/client/auth/test_auth.py @@ -12,7 +12,7 @@ def test_auth_refresh_on_expiration( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth refreshes the token on expiration.""" url = "https://host" @@ -29,19 +29,19 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token2, "Expired access token was not updated." + assert auth.token == access_token_2, "Expired access token was not updated." def test_auth_uses_same_token_if_valid( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth reuses the token until it's expired.""" url = "https://host" @@ -58,37 +58,37 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert not auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - assert auth.token == test_token, "Should not update token until it expires." + assert auth.token == access_token, "Should not update token until it expires." -def test_auth_adds_header(test_token: str) -> None: +def test_auth_adds_header(access_token: str) -> None: """Auth adds required authentication headers to httpx.Request.""" auth = Auth(use_token_cache=False) - auth._token = test_token + auth._token = access_token auth._expires = 2**32 flow = auth.auth_flow(Request("get", "")) request = next(flow) assert "authorization" in request.headers, "missing authorization header" assert ( - request.headers["authorization"] == f"Bearer {test_token}" + request.headers["authorization"] == f"Bearer {access_token}" ), "missing authorization header" @mark.nofakefs def test_auth_token_storage( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token, + client_id: str, + client_secret: str, + access_token: str, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -104,26 +104,26 @@ def inner(self): with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=True) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(test_username, test_password) - assert st.get_cached_token() == test_token, "Invalid token value cached" + st = TokenSecureStorage(client_id, client_secret) + assert st.get_cached_token() == access_token, "Invalid token value cached" with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=False) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(test_username, test_password) + st = TokenSecureStorage(client_id, client_secret) assert ( st.get_cached_token() is None ), "Token cached even though caching is disabled" diff --git a/tests/unit/client/auth/test_auth_async.py b/tests/unit/client/auth/test_auth_async.py index 93990283b9e..abd0fd2cc36 100644 --- a/tests/unit/client/auth/test_auth_async.py +++ b/tests/unit/client/auth/test_auth_async.py @@ -12,7 +12,7 @@ async def test_auth_refresh_on_expiration( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth refreshes the token on expiration.""" url = "https://host" @@ -29,19 +29,19 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests(auth.async_auth_flow(Request("GET", url))) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth.expired # Refresh token - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) await async_execute_generator_requests(auth.async_auth_flow(Request("GET", url))) - assert auth.token == test_token2, "expired access token was not updated" + assert auth.token == access_token_2, "expired access token was not updated" async def test_auth_uses_same_token_if_valid( - httpx_mock: HTTPXMock, test_token: str, test_token2: str + httpx_mock: HTTPXMock, access_token: str, access_token_2: str ) -> None: """Auth reuses the token until it's expired.""" url = "https://host" @@ -58,40 +58,40 @@ def inner(self): auth = Auth(use_token_cache=False) # Get token for the first time - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", "https://host")) ) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert not auth.expired - auth.get_new_token_generator = MethodType(set_token(test_token2), auth) + auth.get_new_token_generator = MethodType(set_token(access_token_2), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", "https://host")) ) - assert auth.token == test_token, "shoud not update token until it expires" + assert auth.token == access_token, "shoud not update token until it expires" -async def test_auth_adds_header(test_token: str) -> None: +async def test_auth_adds_header(access_token: str) -> None: """Auth adds required authentication headers to httpx.Request.""" auth = Auth(use_token_cache=False) - auth._token = test_token + auth._token = access_token auth._expires = 2**32 flow = auth.async_auth_flow(Request("get", "")) request = await flow.__anext__() assert "authorization" in request.headers, "missing authorization header" assert ( - request.headers["authorization"] == f"Bearer {test_token}" + request.headers["authorization"] == f"Bearer {access_token}" ), "missing authorization header" @mark.nofakefs async def test_auth_token_storage( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token, + client_id: str, + client_secret: str, + access_token, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -107,30 +107,30 @@ def inner(self): with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=True) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(test_username, test_password) - assert st.get_cached_token() == test_token, "Invalid token value cached" + st = TokenSecureStorage(client_id, client_secret) + assert st.get_cached_token() == access_token, "Invalid token value cached" with Patcher(), patch( "firebolt.client.auth.base.Auth._token_storage", new_callable=PropertyMock, - return_value=TokenSecureStorage(test_username, test_password), + return_value=TokenSecureStorage(client_id, client_secret), ): auth = Auth(use_token_cache=False) # Get token - auth.get_new_token_generator = MethodType(set_token(test_token), auth) + auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(test_username, test_password) + st = TokenSecureStorage(client_id, client_secret) assert ( st.get_cached_token() is None ), "Token cached even though caching is disabled" diff --git a/tests/unit/client/auth/test_request_auth.py b/tests/unit/client/auth/test_request_auth.py index d0e4ef87e54..bcb028d500a 100644 --- a/tests/unit/client/auth/test_request_auth.py +++ b/tests/unit/client/auth/test_request_auth.py @@ -2,56 +2,36 @@ import pytest from httpx import StreamError, codes -from pytest import mark from pytest_httpx import HTTPXMock from pytest_mock import MockerFixture -from firebolt.client.auth import Auth, ServiceAccount, UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.utils.exception import AuthenticationError from tests.unit.util import execute_generator_requests def test_auth_service_account( - httpx_mock: HTTPXMock, - mocker: MockerFixture, - check_service_credentials_callback: typing.Callable, - mock_service_id: str, - mock_service_secret: str, - test_token: str, -): - """Auth can retrieve token and expiration values.""" - httpx_mock.add_callback(check_service_credentials_callback) - - mocker.patch("firebolt.client.auth.request_auth_base.time", return_value=0) - auth = ServiceAccount(mock_service_id, mock_service_secret) - execute_generator_requests(auth.get_new_token_generator()) - assert auth.token == test_token, "invalid access token" - assert auth._expires == 2**32, "invalid expiration value" - - -def test_auth_username_password( httpx_mock: HTTPXMock, mocker: MockerFixture, check_credentials_callback: typing.Callable, - test_username, - test_password, - test_token, + client_id: str, + client_secret: str, + access_token: str, ): """Auth can retrieve token and expiration values.""" httpx_mock.add_callback(check_credentials_callback) mocker.patch("firebolt.client.auth.request_auth_base.time", return_value=0) - auth = UsernamePassword(test_username, test_password) + auth = ClientCredentials(client_id, client_secret) execute_generator_requests(auth.get_new_token_generator()) - assert auth.token == test_token, "invalid access token" + assert auth.token == access_token, "invalid access token" assert auth._expires == 2**32, "invalid expiration value" -@mark.parametrize("auth_class", [UsernamePassword, ServiceAccount]) -def test_auth_error_handling(httpx_mock: HTTPXMock, auth_class: Auth): +def test_auth_error_handling(httpx_mock: HTTPXMock, client_id: str, client_secret: str): """Auth handles various errors properly.""" for api_endpoint in ("https://host", "host"): - auth = auth_class("user", "password", use_token_cache=False) + auth = ClientCredentials(client_id, client_secret, use_token_cache=False) # Internal httpx error def http_error(*args, **kwargs): diff --git a/tests/unit/client/auth/test_token.py b/tests/unit/client/auth/test_token.py deleted file mode 100644 index f7cc2df07da..00000000000 --- a/tests/unit/client/auth/test_token.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Callable - -from httpx import Request, Response, codes -from pytest import raises -from pytest_httpx import HTTPXMock - -from firebolt.client.auth import Token -from firebolt.utils.exception import AuthorizationError -from tests.unit.util import execute_generator_requests - - -def test_token_happy_path( - httpx_mock: HTTPXMock, test_token: str, check_token_callback: Callable -) -> None: - """Validate that provided token is added to request.""" - httpx_mock.add_callback(check_token_callback) - - auth = Token(test_token) - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) - - -def test_token_invalid(httpx_mock: HTTPXMock) -> None: - """Authorization error raised when token is invalid.""" - - def authorization_error(*args, **kwargs) -> Response: - return Response(status_code=codes.UNAUTHORIZED) - - httpx_mock.add_callback(authorization_error) - - auth = Token("token") - with raises(AuthorizationError): - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) - - -def test_token_expired() -> None: - """Authorization error is raised when token expires.""" - auth = Token("token") - auth._expires = 0 - with raises(AuthorizationError): - execute_generator_requests(auth.auth_flow(Request("GET", "https://host"))) diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index aba06631178..2b8b4a6005e 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -7,60 +7,57 @@ from pytest_httpx import HTTPXMock from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx from firebolt.common import Settings from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import AUTH_URL +from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema def test_client_retry( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - client = Client(auth=UsernamePassword(test_username, test_password)) + with Client(auth=auth) as client: - # auth get token - httpx_mock.add_response( - status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, - ) + # auth get token + httpx_mock.add_response( + status_code=codes.OK, + json={"expires_in": 2**30, "access_token": access_token}, + ) - # client request failed authorization - httpx_mock.add_response( - status_code=codes.UNAUTHORIZED, - ) + # client request failed authorization + httpx_mock.add_response( + status_code=codes.UNAUTHORIZED, + ) - # auth get another token - httpx_mock.add_response( - status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, - ) + # auth get another token + httpx_mock.add_response( + status_code=codes.OK, + json={"expires_in": 2**30, "access_token": access_token}, + ) - # client request success this time - httpx_mock.add_response( - status_code=codes.OK, - ) + # client request success this time + httpx_mock.add_response( + status_code=codes.OK, + ) - assert ( - client.get("https://url").status_code == codes.OK - ), "request failed with firebolt client" + assert ( + client.get("https://url").status_code == codes.OK + ), "request failed with firebolt client" def test_client_different_auths( httpx_mock: HTTPXMock, check_credentials_callback: Callable, check_token_callback: Callable, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, ): """ Client properly handles such auth types: @@ -72,13 +69,12 @@ def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_URL}", + url=f"https://{DEFAULT_API_URL}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") - Client(auth=UsernamePassword(test_username, test_password)).get("https://url") - Client(auth=Token(test_token)).get("https://url") + Client(auth=auth).get("https://url") # client accepts None auth, but authorization fails with raises(AssertionError) as excinfo: @@ -94,8 +90,7 @@ def test_client_different_auths( def test_client_account_id( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, + auth: Auth, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -107,7 +102,7 @@ def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) with Client( - auth=UsernamePassword(test_username, test_password), + auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, ) as c: @@ -118,19 +113,19 @@ def test_client_account_id( def test_refresh_with_hooks( fs: FakeFilesystem, httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + client_id: str, + client_secret: str, + access_token: str, ) -> None: """ When hooks are used, the invalid token, fetched from cache, is refreshed """ - tss = TokenSecureStorage(test_username, test_password) - tss.cache_token(test_token, 2**32) + tss = TokenSecureStorage(client_id, client_secret) + tss.cache_token(access_token, 2**32) client = Client( - auth=UsernamePassword(test_username, test_password), + auth=ClientCredentials(client_id, client_secret), event_hooks={ "response": [raise_on_4xx_5xx], }, @@ -144,7 +139,7 @@ def test_refresh_with_hooks( # auth get another token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request success this time diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 81d513727fc..844ba3ef4a3 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -6,30 +6,26 @@ from pytest_httpx import HTTPXMock from firebolt.client import DEFAULT_API_URL, AsyncClient -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client.auth import Auth from firebolt.common import Settings -from firebolt.utils.urls import AUTH_URL +from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema async def test_client_retry( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, + access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - async with AsyncClient( - auth=UsernamePassword(test_username, test_password) - ) as client: - + async with AsyncClient(auth=auth) as client: # auth get token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request failed authorization @@ -40,7 +36,7 @@ async def test_client_retry( # auth get another token httpx_mock.add_response( status_code=codes.OK, - json={"expires_in": 2**30, "access_token": test_token}, + json={"expires_in": 2**30, "access_token": access_token}, ) # client request success this time @@ -57,9 +53,7 @@ async def test_client_different_auths( httpx_mock: HTTPXMock, check_credentials_callback: Callable, check_token_callback: Callable, - test_username: str, - test_password: str, - test_token: str, + auth: Auth, ): """ Client properly handles such auth types: @@ -71,16 +65,12 @@ async def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_URL}", + url=f"https://{DEFAULT_API_URL}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") - async with AsyncClient( - auth=UsernamePassword(test_username, test_password) - ) as client: - await client.get("https://url") - async with AsyncClient(auth=Token(test_token)) as client: + async with AsyncClient(auth=auth) as client: await client.get("https://url") # client accepts None auth, but authorization fails @@ -99,8 +89,7 @@ async def test_client_different_auths( async def test_client_account_id( httpx_mock: HTTPXMock, - test_username: str, - test_password: str, + auth: Auth, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -112,7 +101,7 @@ async def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) async with AsyncClient( - auth=UsernamePassword(test_username, test_password), + auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, ) as c: diff --git a/tests/unit/common/test_settings.py b/tests/unit/common/test_settings.py index cd4c23c4c14..508afda9de7 100644 --- a/tests/unit/common/test_settings.py +++ b/tests/unit/common/test_settings.py @@ -8,15 +8,8 @@ from firebolt.common.settings import Settings -@mark.parametrize( - "fields", - ( - ("user", "password", "account_name", "server", "default_region"), - ("access_token", "account_name", "server", "default_region"), - ("auth", "account_name", "server", "default_region"), - ), -) -def test_settings_happy_path(fields: Tuple[str]) -> None: +def test_settings_happy_path() -> None: + fields = ("auth", "account_name", "server", "default_region") kwargs = {f: (f if f != "auth" else Auth()) for f in fields} s = Settings(**kwargs) @@ -27,30 +20,13 @@ def test_settings_happy_path(fields: Tuple[str]) -> None: ), f"Invalid settings value {f}" -creds_fields = ("access_token", "user", "password") -other_fields = ("server", "default_region") - - -@mark.parametrize( - "kwargs", - ( - {f: f for f in other_fields}, - {f: f for f in creds_fields + other_fields}, - {"auth": Auth(), "access_token": "123", **{f: f for f in other_fields}}, - ), -) -def test_settings_auth_credentials(kwargs) -> None: - with raises(ValueError) as exc_info: - Settings(**kwargs) - - @patch("firebolt.common.settings.logger") def test_no_deprecation_warning_with_env(logger_mock: Mock): with patch.dict( os.environ, { - "FIREBOLT_USER": "user", - "FIREBOLT_PASSWORD": "password", + "FIREBOLT_CLIENT_ID": "client_id", + "FIREBOLT_CLIENT_SECRET": "client_secret", "FIREBOLT_SERVER": "dummy.firebolt.io", }, clear=True, @@ -59,5 +35,5 @@ def test_no_deprecation_warning_with_env(logger_mock: Mock): logger_mock.warning.assert_not_called() assert s.server == "dummy.firebolt.io" assert s.auth is not None, "Settings.auth wasn't populated from env variables" - assert s.auth.username == "user", "Invalid username in Settings.auth" - assert s.auth.password == "password", "Invalid password in Settings.auth" + assert s.auth.client_id == "client_id", "Invalid username in Settings.auth" + assert s.auth.client_secret == "client_secret", "Invalid password in Settings.auth" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1c0e9c7e570..b47042f2f23 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -75,6 +75,11 @@ def access_token() -> str: return "mock_access_token" +@fixture +def access_token_2() -> str: + return "mock_access_token_2" + + @fixture def provider() -> Provider: return Provider( @@ -324,6 +329,22 @@ def db_api_exceptions(): return exceptions +@fixture +def check_token_callback(access_token: str) -> Callable: + def check_token(request: Request = None, **kwargs) -> Response: + prefix = "Bearer " + assert request, "empty request" + assert "authorization" in request.headers, "missing authorization header" + auth = request.headers["authorization"] + assert auth.startswith(prefix), "invalid authorization header format" + token = auth[len(prefix) :] + assert token == access_token, "invalid authorization token" + + return Response(status_code=httpx.codes.OK, headers={"content-length": "0"}) + + return check_token + + @fixture def check_credentials_callback( client_id: str, client_secret: str, access_token: str diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 36d36b3acc8..f3a66d17487 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -5,7 +5,7 @@ from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError @@ -14,6 +14,7 @@ def test_rm_credentials( httpx_mock: HTTPXMock, + check_token_callback: Callable, check_credentials_callback: Callable, settings: Settings, auth_url: str, @@ -27,6 +28,7 @@ def test_rm_credentials( url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) @@ -39,6 +41,7 @@ def test_rm_token_cache( httpx_mock: HTTPXMock, check_token_callback: Callable, check_credentials_callback: Callable, + settings: Settings, auth_url: str, account_id_url: Pattern, account_id_callback: Callable, @@ -56,29 +59,35 @@ def test_rm_token_cache( with Patcher(): local_settings = Settings( - auth=settings.auth, + auth=ClientCredentials( + settings.auth.client_id, + settings.auth.client_secret, + use_token_cache=True, + ), server=settings.server, default_region=settings.default_region, - use_token_cache=True, ) rm = ResourceManager(local_settings) rm.client.get(url) - ts = TokenSecureStorage(settings.user, settings.password) + ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): local_settings = Settings( - auth=settings.auth, + auth=ClientCredentials( + settings.auth.client_id, + settings.auth.client_secret, + use_token_cache=False, + ), server=settings.server, default_region=settings.default_region, - use_token_cache=False, ) rm = ResourceManager(local_settings) rm.client.get(url) - ts = TokenSecureStorage(settings.user, settings.password) + ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" From 82aa9768890303bb88c963223a9351e9707141f6 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 14 Mar 2023 11:51:19 +0200 Subject: [PATCH 05/28] update auth flow logic --- src/firebolt/client/auth/client_credentials.py | 3 +++ src/firebolt/client/client.py | 12 ++++++++---- src/firebolt/utils/urls.py | 2 +- src/firebolt/utils/util.py | 14 ++++++++++++++ tests/unit/client/test_client.py | 8 +++++--- tests/unit/client/test_client_async.py | 12 +++++++----- tests/unit/conftest.py | 9 +++++++-- 7 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 4c7bd413de7..36ba2a16791 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -31,6 +31,7 @@ class ClientCredentials(_RequestBasedAuth): "_expires", "_use_token_cache", "_user_agent", + "_audience", ) requires_response_body = True @@ -43,6 +44,7 @@ def __init__( ): self.client_id = client_id self.client_secret = client_secret + self._audience = "" super().__init__(use_token_cache) def copy(self) -> "ClientCredentials": @@ -86,6 +88,7 @@ def _make_auth_request(self) -> AuthRequest: "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": "client_credentials", + "audience": self._audience, }, ) return response diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index f6e378a595f..fcc87be7f7c 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -16,6 +16,7 @@ from firebolt.utils.util import ( cached_property, fix_url_schema, + get_auth_endpoint, merge_urls, mixin_for, ) @@ -45,6 +46,7 @@ def __init__( ): self.account_name = account_name self._api_endpoint = URL(fix_url_schema(api_endpoint)) + self._auth_endpoint = get_auth_endpoint(self._api_endpoint) super().__init__(*args, auth=auth, **kwargs) def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: @@ -61,13 +63,15 @@ def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: Raises: TypeError: Auth argument has unsupported type """ - if auth is None or isinstance(auth, Auth): - return auth - raise TypeError(f'Invalid "auth" argument: {auth!r}') + if not (auth is None or isinstance(auth, Auth)): + raise TypeError(f'Invalid "auth" argument: {auth!r}') + if hasattr(auth, "_audience"): + auth._audience = self._api_endpoint # type: ignore + return auth def _merge_auth_request(self, request: Request) -> Request: if isinstance(request, AuthRequest): - request.url = merge_urls(self._api_endpoint, request.url) + request.url = merge_urls(self._auth_endpoint, request.url) request._prepare(dict(request.headers)) return request diff --git a/src/firebolt/utils/urls.py b/src/firebolt/utils/urls.py index f94d72a0461..ca355c26cf5 100644 --- a/src/firebolt/utils/urls.py +++ b/src/firebolt/utils/urls.py @@ -1,4 +1,4 @@ -AUTH_SERVICE_ACCOUNT_URL = "/auth/v1/token" +AUTH_SERVICE_ACCOUNT_URL = "/oauth/token" DATABASES_URL = "/core/v1/account/databases" diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 96ac7a1ae64..9076f1bbc26 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -72,6 +72,20 @@ def fix_url_schema(url: str) -> str: return url if url.startswith("http") else f"https://{url}" +def get_auth_endpoint(api_endpoint: URL) -> URL: + """Create auth endpoint from api endpoint. + + Args: + api_endpoint (URL): provided API endpoint + + Returns: + URL: authentication endpoint + """ + return api_endpoint.copy_with( + host=".".join(["id"] + api_endpoint.host.split(".")[1:]) + ) + + def async_to_sync(f: Callable) -> Callable: """Convert async function to sync. diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index 2b8b4a6005e..f9374efab7f 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -6,7 +6,7 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.client import DEFAULT_API_URL, Client +from firebolt.client import Client from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx from firebolt.common import Settings @@ -58,6 +58,8 @@ def test_client_different_auths( check_credentials_callback: Callable, check_token_callback: Callable, auth: Auth, + auth_server: str, + server: str, ): """ Client properly handles such auth types: @@ -69,12 +71,12 @@ def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_SERVICE_ACCOUNT_URL}", + url=f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") - Client(auth=auth).get("https://url") + Client(auth=auth, api_endpoint=server).get("https://url") # client accepts None auth, but authorization fails with raises(AssertionError) as excinfo: diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 844ba3ef4a3..9792c97351e 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -5,7 +5,7 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.client import DEFAULT_API_URL, AsyncClient +from firebolt.client import AsyncClient from firebolt.client.auth import Auth from firebolt.common import Settings from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL @@ -54,6 +54,8 @@ async def test_client_different_auths( check_credentials_callback: Callable, check_token_callback: Callable, auth: Auth, + auth_server: str, + server: str, ): """ Client properly handles such auth types: @@ -65,21 +67,21 @@ async def test_client_different_auths( httpx_mock.add_callback( check_credentials_callback, - url=f"https://{DEFAULT_API_URL}{AUTH_SERVICE_ACCOUNT_URL}", + url=f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}", ) httpx_mock.add_callback(check_token_callback, url="https://url") - async with AsyncClient(auth=auth) as client: + async with AsyncClient(auth=auth, api_endpoint=server) as client: await client.get("https://url") # client accepts None auth, but authorization fails with raises(AssertionError) as excinfo: - async with AsyncClient(auth=None) as client: + async with AsyncClient(auth=None, api_endpoint=server) as client: await client.get("https://url") with raises(TypeError) as excinfo: - async with AsyncClient(auth=lambda r: r): + async with AsyncClient(auth=lambda r: r, api_endpoint=server): await client.get("https://url") assert str(excinfo.value).startswith( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b47042f2f23..233c3701ba8 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -65,6 +65,11 @@ def server() -> str: return "api-dev.mock.firebolt.io" +@fixture +def auth_server() -> str: + return "id.mock.firebolt.io" + + @fixture def account_id() -> str: return "mock_account_id" @@ -151,8 +156,8 @@ def do_mock( @fixture -def auth_url(settings: Settings) -> str: - return f"https://{settings.server}{AUTH_SERVICE_ACCOUNT_URL}" +def auth_url(auth_server: str) -> str: + return f"https://{auth_server}{AUTH_SERVICE_ACCOUNT_URL}" @fixture From 51703b6c5e01a4a06534030dc4a2bea70098f26d Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 20 Mar 2023 11:58:50 +0200 Subject: [PATCH 06/28] update connection logic --- src/firebolt/async_db/connection.py | 205 ++++++++++++---------------- src/firebolt/async_db/cursor.py | 4 +- src/firebolt/utils/urls.py | 2 + 3 files changed, 95 insertions(+), 116 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index e8ac0cd8ab9..a7ad60acf23 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -2,13 +2,12 @@ import logging import socket -from json import JSONDecodeError from types import TracebackType -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream -from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout +from httpx import AsyncHTTPTransport, Timeout, codes from firebolt.async_db.cursor import Cursor from firebolt.client import DEFAULT_API_URL, AsyncClient @@ -22,25 +21,19 @@ from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, - FireboltEngineError, InterfaceError, ) -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, -) +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema logger = logging.getLogger(__name__) -async def _resolve_engine_url( - engine_name: str, +async def _get_system_engine_url( auth: Auth, + account_name: str, api_endpoint: str, - account_name: Optional[str] = None, ) -> str: async with AsyncClient( auth=auth, @@ -49,74 +42,60 @@ async def _resolve_engine_url( api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), ) as client: - account_id = await client.account_id - url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - try: - response = await client.get( - url=url, - params={"engine_name": engine_name}, - ) - response.raise_for_status() - engine_id = response.json()["engine_id"]["engine_id"] - url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) - response = await client.get(url=url) - response.raise_for_status() - return response.json()["engine"]["endpoint"] - except HTTPStatusError as e: - # Engine error would be 404. - if e.response.status_code != 404: - raise InterfaceError( - f"Error {e.__class__.__name__}: Unable to retrieve engine " - f"endpoint {url}." - ) - # Once this is point is reached we've already authenticated with - # the backend so it's safe to assume the cause of the error is - # missing engine. - raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError) as e: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = await client.get(url=url) + if response.status != codes.OK: raise InterfaceError( - f"Error {e.__class__.__name__}: " - f"Unable to retrieve engine endpoint {url}." + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status} {response.content}" ) + return response.json()["gatewayHost"] async def _get_database_default_engine_url( - database: str, - auth: Auth, - api_endpoint: str, - account_name: Optional[str] = None, + system_engine: Connection, database_name: str ) -> str: - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - try: - account_id = await client.account_id - response = await client.get( - url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), - params={"database_name": database}, - ) - response.raise_for_status() - return response.json()["engine_url"] - except ( - JSONDecodeError, - RequestError, - RuntimeError, - HTTPStatusError, - KeyError, - ) as e: - raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") - -def _validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) + cursor = system_engine.cursor() + await cursor.execute( + """ + SELECT engs.engine_url, engs.status + FROM information_schema.databases AS dbs + INNER JOIN information_schema.engines AS engs + ON engs.attached_to = dbs.database_name + AND engs.engine_name = NULLIF(SPLIT_PART(ARRAY_FIRST( + eng_name -> eng_name LIKE '%(default)', + SPLIT(',', attached_engines) + ), ' ', 1), '') + WHERE database_name = ?; + """, + [database_name], + ) + row = await cursor.fetchone() + if row is None: + raise InterfaceError(f"Database {database_name} doesn't have a default engine") + engine_url, status = row + if status != "Running": + raise InterfaceError(f"A default engine for {database_name} is not running") + return str(engine_url) # Mypy check + +async def _get_engine_url_and_db( + system_engine: Connection, engine_name: str +) -> Tuple[str, str]: + cursor = system_engine.cursor() + await cursor.execute( + """ + SELECT engine_url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], + ) + row = await cursor.fetchone() + if row is None: + raise InterfaceError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") + return str(engine_url), str(database) # Mypy check async def connect( database: str = None, @@ -144,56 +123,53 @@ async def connect( Providing both `engine_name` and `engine_url` will result in an error """ + # These parameters are optional in function signature # but are required to connect. # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") + for name, value in (("auth", auth), (account_name, "account_name")): + if not value: + raise ConfigurationError(f"{name} is required to connect.") - if not auth: - raise ConfigurationError("auth is required to connect.") + system_engine_url = fix_url_schema( + await _get_system_engine_url(auth, account_name, api_endpoint) + ) - _validate_engine_name_and_url(engine_name, engine_url) + if not engine_name and not database: + # Return system engine connection + return connection_class( + system_engine_url, None, auth, api_endpoint, additional_parameters + ) - api_endpoint = fix_url_schema(api_endpoint) + else: + async with Connection( + system_engine_url, None, auth, api_endpoint, additional_parameters + ) as system_engine_connection: + if engine_name: + engine_url, attached_db = await _get_engine_url_and_db( + system_engine_connection, engine_name + ) - # Mypy checks, this should never happen - assert database is not None + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + elif database: + # Get database default engine + engine_url = await _get_database_default_engine_url( + system_engine_connection, database + ) - if not engine_name and not engine_url: - engine_url = await _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) + assert engine_url is not None - elif engine_name: - engine_url = await _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, database, auth, api_endpoint, additional_parameters ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - await client.account_id - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, database, auth, api_endpoint, additional_parameters - ) class OverriddenHttpBackend(AutoBackend): @@ -237,7 +213,6 @@ async def connect_tcp( # type: ignore [override] ) return stream - class Connection(BaseConnection): """ Firebolt asynchronous database connection class. Implements `PEP 249`_. diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 890e71674a7..f4189a3d0c1 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -119,7 +119,9 @@ async def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: - if not await is_db_available(self.connection, self.connection.database): + if self.connection.database and not await is_db_available( + self.connection, self.connection.database + ): raise FireboltDatabaseError( f"Database {self.connection.database} does not exist" ) diff --git a/src/firebolt/utils/urls.py b/src/firebolt/utils/urls.py index ca355c26cf5..4987b240576 100644 --- a/src/firebolt/utils/urls.py +++ b/src/firebolt/utils/urls.py @@ -28,3 +28,5 @@ PROVIDERS_URL = "/compute/v1/providers" REGIONS_URL = "/compute/v1/regions" + +GATEWAY_HOST_BY_ACCOUNT_NAME = "/v3/getGatewayHostByAccountName/{account_name}" From b628aad08dac376fbb27863f44801c02a8daa9bb Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 20 Mar 2023 12:01:23 +0200 Subject: [PATCH 07/28] make database optional in connection --- src/firebolt/async_db/cursor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index f4189a3d0c1..d424374e2b9 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -159,15 +159,15 @@ async def _api_request( set parameters are sent. Setting this to False will allow self._set_parameters to be ignored. """ + parameters = {} if use_set_parameters: parameters = {**(self._set_parameters or {}), **(parameters or {})} + if self.connection.database: + parameters["database"] = self.connection.database return await self._client.request( url=f"/{path}", method="POST", - params={ - "database": self.connection.database, - **(parameters or dict()), - }, + params=parameters, content=query, ) From b5e69ec69659ad7362728b9ed0af1686490ad351 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 21 Mar 2023 14:27:36 +0200 Subject: [PATCH 08/28] fix sync connection tests --- setup.cfg | 2 +- src/firebolt/async_db/connection.py | 24 ++- src/firebolt/async_db/cursor.py | 19 +- tests/unit/conftest.py | 56 +----- tests/unit/db/conftest.py | 18 +- tests/unit/db/test_connection.py | 257 ++++++++++------------------ tests/unit/db_conftest.py | 142 +++++++++++++++ 7 files changed, 270 insertions(+), 248 deletions(-) diff --git a/setup.cfg b/setup.cfg index cc209eb844a..7a3d4e111bf 100755 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,8 @@ install_requires = httpx[http2]==0.24.0 pydantic[dotenv]>=1.8.2 python-dateutil>=2.8.2 - readerwriterlock==1.0.9 sqlparse>=0.4.2 + tricycle>=0.2.2 trio<0.22.0 python_requires = >=3.7 include_package_data = True diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index a7ad60acf23..34ff5b34c10 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -44,7 +44,7 @@ async def _get_system_engine_url( ) as client: url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) response = await client.get(url=url) - if response.status != codes.OK: + if response.status_code != codes.OK: raise InterfaceError( f"Unable to retrieve system engine endpoint {url}: " f"{response.status} {response.content}" @@ -98,11 +98,10 @@ async def _get_engine_url_and_db( return str(engine_url), str(database) # Mypy check async def connect( - database: str = None, - auth: Auth = None, - engine_name: Optional[str] = None, - engine_url: Optional[str] = None, + auth: Optional[Auth] = None, account_name: Optional[str] = None, + database: Optional[str] = None, + engine_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ) -> Connection: @@ -112,28 +111,25 @@ async def connect( `auth` (Auth) Authentication object. `database` (str): Name of the database to connect `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use `account_name` (Optional[str]): For customers with multiple accounts; if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection - Note: - Providing both `engine_name` and `engine_url` will result in an error - """ - # These parameters are optional in function signature # but are required to connect. # PEP 249 recommends making them kwargs. - for name, value in (("auth", auth), (account_name, "account_name")): + for name, value in (("auth", auth), ("account_name", account_name)): if not value: raise ConfigurationError(f"{name} is required to connect.") - system_engine_url = fix_url_schema( - await _get_system_engine_url(auth, account_name, api_endpoint) - ) + # Type checks + assert auth is not None + assert account_name is not None + + api_endpoint = fix_url_schema(api_endpoint) if not engine_name and not database: # Return system engine connection diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index d424374e2b9..f334c9d35c9 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -17,8 +17,8 @@ Union, ) -from aiorwlock import RWLock from httpx import Response, codes +from tricycle import RWLock from firebolt.async_db.util import is_db_available, is_engine_running from firebolt.common._types import ( @@ -139,10 +139,10 @@ async def _raise_if_error(self, resp: Response) -> None: async def _api_request( self, - query: Optional[str] = "", - parameters: Optional[dict[str, Any]] = {}, - path: Optional[str] = "", - use_set_parameters: Optional[bool] = True, + query: str = "", + parameters: dict[str, Any] = {}, + path: str = "", + use_set_parameters: bool = True, ) -> Response: """ Query API, return Response object. @@ -159,7 +159,6 @@ async def _api_request( set parameters are sent. Setting this to False will allow self._set_parameters to be ignored. """ - parameters = {} if use_set_parameters: parameters = {**(self._set_parameters or {}), **(parameters or {})} if self.connection.database: @@ -431,13 +430,13 @@ async def cancel(self, query_id: str) -> None: @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """Fetch the next row of a query result set.""" return super().fetchone() @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """ Fetch the next set of rows of a query result; size is cursor.arraysize by default. @@ -446,11 +445,11 @@ async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): """Fetch all remaining rows of a query result.""" return super().fetchall() @wraps(BaseCursor.nextset) async def nextset(self) -> None: - async with self._async_query_lock.reader: + async with self._async_query_lock.read_locked(): return super().nextset() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 233c3701ba8..89c1eb9fcff 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -27,7 +27,6 @@ ACCOUNT_BY_NAME_URL, ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, ACCOUNT_URL, AUTH_SERVICE_ACCOUNT_URL, DATABASES_URL, @@ -75,6 +74,11 @@ def account_id() -> str: return "mock_account_id" +@fixture +def account_name() -> str: + return "mock_account_name" + + @fixture def access_token() -> str: return "mock_access_token" @@ -224,48 +228,6 @@ def get_engine_name_by_id_url( ) -@fixture -def get_engine_url_by_id_url( - settings: Settings, account_id: str, engine_id: str -) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( - account_id=account_id, engine_id=engine_id - ) - - -@fixture -def get_engine_url_by_id_callback( - get_engine_url_by_id_url: str, engine_id: str, settings: Settings -) -> Callable: - def do_mock( - request: Request = None, - **kwargs, - ) -> Response: - assert request.url == get_engine_url_by_id_url - return Response( - status_code=httpx.codes.OK, - json={ - "engine": { - "name": "name", - "compute_region_id": { - "provider_id": "provider", - "region_id": "region", - }, - "settings": { - "preset": "", - "auto_stop_delay_duration": "1s", - "minimum_logging_level": "", - "is_read_only": False, - "warm_up": "", - }, - "endpoint": f"https://{settings.server}", - } - }, - ) - - return do_mock - - @fixture def get_engines_url(settings: Settings) -> str: return f"https://{settings.server}{ENGINES_URL}" @@ -309,14 +271,6 @@ def do_mock( return do_mock -@fixture -def engine_by_db_url(settings: Settings, account_id: str) -> str: - return ( - f"https://{settings.server}" - f"{ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id)}" - ) - - @fixture def db_api_exceptions(): exceptions = { diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index eaba5fe0849..85b3a1d6082 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -1,17 +1,27 @@ +from typing import Callable + from pytest import fixture from firebolt.client.auth import Auth -from firebolt.common.settings import Settings from firebolt.db import Connection, Cursor, connect @fixture -def connection(settings: Settings, db_name: str, auth: Auth) -> Connection: +def connection( + server: str, + db_name: str, + auth: Auth, + engine_name: str, + account_name: str, + mock_connection_flow: Callable, +) -> Connection: + mock_connection_flow() with connect( - engine_url=settings.server, + engine_name=engine_name, database=db_name, auth=auth, - api_endpoint=settings.server, + account_name=account_name, + api_endpoint=server, ) as connection: yield connection diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index df9d7ca0e75..626499bdef5 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,9 +1,7 @@ import gc import warnings -from re import Pattern from typing import Callable, List -from httpx import codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns from pytest_httpx import HTTPXMock @@ -13,13 +11,11 @@ from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( - AccountNotFoundError, ConfigurationError, ConnectionClosedError, - FireboltEngineError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL def test_closed_connection(connection: Connection) -> None: @@ -53,45 +49,40 @@ def test_cursors_closed_on_close(connection: Connection) -> None: def test_cursor_initialized( settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_connection_flow: Callable, + mock_query: Callable, + account_name: str, python_query_data: List[List[ColType]], auth: Auth, + engine_name: str, ) -> None: """Connection initialized its cursors properly.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_connection_flow() + mock_query() - for url in (settings.server, f"https://{settings.server}"): - with connect( - engine_url=url, - database=db_name, - api_endpoint=settings.server, - auth=auth, - ) as connection: + with connect( + engine_name=engine_name, + database=db_name, + api_endpoint=settings.server, + auth=auth, + account_name=account_name, + ) as connection: - cursor = connection.cursor() - assert ( - cursor.connection == connection - ), "Invalid cursor connection attribute" - assert ( - cursor._client == connection._client - ), "Invalid cursor _client attribute" + cursor = connection.cursor() + assert cursor.connection == connection, "Invalid cursor connection attribute" + assert cursor._client == connection._client, "Invalid cursor _client attribute" - assert cursor.execute("select*") == len(python_query_data) + assert cursor.execute("select*") == len(python_query_data) - cursor.close() - assert ( - cursor not in connection._cursors - ), "Cursor wasn't removed from connection after close" + cursor.close() + assert ( + cursor not in connection._cursors + ), "Cursor wasn't removed from connection after close" def test_connect_empty_parameters(): with raises(ConfigurationError): - with connect(engine_url="engine_url"): + with connect(engine_name="engine_name"): pass @@ -99,67 +90,48 @@ def test_connect_engine_name( settings: Settings, db_name: str, httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - engine_id: str, - get_engine_url_by_id_url: str, - get_engine_url_by_id_callback: Callable, + mock_connection_flow: Callable, + mock_query: Callable, + account_name: str, + engine_name: str, python_query_data: List[List[ColType]], - account_id: str, auth: Auth, + system_engine_query_url: str, + get_engine_url_not_running_callback: Callable, + get_engine_url_invalid_db_callback: Callable, + auth_url: str, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + get_engine_url_callback: Callable, ): """connect properly handles engine_name""" + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) - with raises(ConfigurationError): - connect( - engine_url="engine_url", - engine_name="engine_name", - database="db", - auth=auth, - ) - - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(get_engine_url_by_id_callback, url=get_engine_url_by_id_url) - - engine_name = settings.server.split(".")[0] - - # Mock engine id lookup error - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.NOT_FOUND, - ) + mock_query() - with raises(FireboltEngineError): - connect( - database="db", - auth=auth, - engine_name=engine_name, - account_name=settings.account_name, - api_endpoint=settings.server, - ) - - # Mock engine id lookup by name - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.OK, - json={"engine_id": {"engine_id": engine_id}}, - ) + for callback in ( + get_engine_url_invalid_db_callback, + get_engine_url_not_running_callback, + ): + httpx_mock.add_callback(callback, url=system_engine_query_url) + with raises(InterfaceError): + connect( + database="db", + auth=auth, + engine_name=engine_name, + account_name=account_name, + api_endpoint=settings.server, + ) + + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) with connect( engine_name=engine_name, database=db_name, auth=auth, - account_name=settings.account_name, + account_name=account_name, api_endpoint=settings.server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) @@ -168,35 +140,42 @@ def test_connect_engine_name( def test_connect_default_engine( settings: Settings, db_name: str, + mock_query: Callable, httpx_mock: HTTPXMock, - auth_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, + check_credentials_callback: Callable, database_id: str, - engine_by_db_url: str, python_query_data: List[List[ColType]], account_id: str, auth: Auth, + account_name: str, + system_engine_query_url: str, + get_default_db_engine_callback: Callable, + get_default_db_engine_not_running_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" - - httpx_mock.add_response( - url=engine_by_db_url, - status_code=codes.OK, - json={ - "engine_url": settings.server, - }, + mock_query() + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback( + get_default_db_engine_not_running_callback, url=system_engine_query_url ) + with raises(InterfaceError): + with connect( + database=db_name, + auth=auth, + account_name=account_name, + api_endpoint=settings.server, + ) as connection: + connection.cursor().execute("select*") + + httpx_mock.add_callback(get_default_db_engine_callback, url=system_engine_query_url) + with connect( database=db_name, auth=auth, - account_name=settings.account_name, + account_name=account_name, api_endpoint=settings.server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) @@ -235,25 +214,24 @@ def test_connection_commit(connection: Connection): def test_connection_token_caching( settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_connection_flow: Callable, + mock_query: Callable, python_query_data: List[List[ColType]], access_token: str, client_id: str, client_secret: str, + engine_name: str, + account_name: str, ) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_connection_flow() + mock_query() with Patcher(): with connect( database=db_name, auth=ClientCredentials(client_id, client_secret, use_token_cache=True), - engine_url=settings.server, - account_name=settings.account_name, + engine_name=engine_name, + account_name=account_name, api_endpoint=settings.server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) @@ -265,8 +243,8 @@ def test_connection_token_caching( with connect( database=db_name, auth=ClientCredentials(client_id, client_secret, use_token_cache=False), - engine_url=settings.server, - account_name=settings.account_name, + engine_name=engine_name, + account_name=account_name, api_endpoint=settings.server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) @@ -274,60 +252,3 @@ def test_connection_token_caching( assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" - - -def test_connect_with_auth( - httpx_mock: HTTPXMock, - settings: Settings, - db_name: str, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - access_token: str, - auth: Auth, -) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - - with connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ) as connection: - connection.cursor().execute("select*") - - -def test_connect_account_name( - httpx_mock: HTTPXMock, - auth: Auth, - settings: Settings, - db_name: str, - auth_url: str, - check_credentials_callback: Callable, - account_id_url: Pattern, - account_id_callback: Callable, -): - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - - with raises(AccountNotFoundError): - with connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name="invalid", - api_endpoint=settings.server, - ): - pass - - with connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index b9dccba3cf8..09e14ba6452 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -5,10 +5,12 @@ from httpx import URL, Request, Response, codes from pytest import fixture +from pytest_httpx import HTTPXMock from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column from firebolt.common.settings import Settings from firebolt.db import ARRAY, DECIMAL +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME QUERY_ROW_COUNT: int = 10 @@ -346,3 +348,143 @@ def set_query_url(settings: Settings, db_name: str) -> str: def query_with_params_url(query_url: str, set_params: str) -> str: params_encoded = "&".join([f"{k}={encode_param(v)}" for k, v in set_params.items()]) query_url = f"{query_url}&{params_encoded}" + + +def _get_engine_url_callback(server: str, db_name: str, status="Running") -> Callable: + def do_query(request: Request, **kwargs) -> Response: + set_parameters = request.url.params + assert len(set_parameters) == 1 and "output_format" in set_parameters + data = [[server, db_name, status]] + query_response = { + "meta": [{"name": "name", "type": "Text"} for _ in range(len(data[0]))], + "data": data, + "rows": len(data), + # Real example of statistics field value, not used by our code + "statistics": { + "elapsed": 0.002983335, + "time_before_execution": 0.002729331, + "time_to_execute": 0.000215215, + "rows_read": 1, + "bytes_read": 1, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } + return Response(status_code=codes.OK, json=query_response) + + return do_query + + +@fixture +def get_engine_url_callback(server: str, db_name: str, status="Running") -> Callable: + return _get_engine_url_callback(server, db_name) + + +@fixture +def get_engine_url_not_running_callback(engine_name, db_name) -> Callable: + return _get_engine_url_callback(engine_name, db_name, "Stopped") + + +@fixture +def get_engine_url_invalid_db_callback(engine_name, db_name) -> Callable: + return _get_engine_url_callback(engine_name, "not_" + db_name) + + +def _get_default_db_engine_callback(server: str, status="Running") -> Callable: + def do_query(request: Request, **kwargs) -> Response: + set_parameters = request.url.params + assert len(set_parameters) == 1 and "output_format" in set_parameters + data = [[server, status]] + query_response = { + "meta": [{"name": "name", "type": "Text"} for _ in range(len(data[0]))], + "data": data, + "rows": len(data), + # Real example of statistics field value, not used by our code + "statistics": { + "elapsed": 0.002983335, + "time_before_execution": 0.002729331, + "time_to_execute": 0.000215215, + "rows_read": 1, + "bytes_read": 1, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } + return Response(status_code=codes.OK, json=query_response) + + return do_query + + +@fixture +def get_default_db_engine_callback(server: str) -> Callable: + return _get_default_db_engine_callback(server) + + +@fixture +def get_default_db_engine_not_running_callback(server: str) -> Callable: + return _get_default_db_engine_callback(server, "Failed") + + +@fixture +def system_engine_url() -> str: + return "https://bravo.a.eu-west-1.aws.mock.firebolt.io" + + +@fixture +def system_engine_query_url(system_engine_url: str) -> str: + return f"{system_engine_url}/?output_format=JSON_Compact" + + +@fixture +def get_system_engine_url(server: str, account_name: str) -> str: + return URL( + f"https://{server}" + f"{GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name)}" + ) + + +@fixture +def get_system_engine_callback(system_engine_url: str) -> Callable: + def inner( + request: Request = None, + **kwargs, + ) -> Response: + assert request, "empty request" + assert request.method == "GET", "invalid request method" + + return Response( + status_code=codes.OK, + json={"gatewayHost": system_engine_url}, + ) + + return inner + + +@fixture +def mock_connection_flow( + httpx_mock: HTTPXMock, + auth_url: str, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + system_engine_query_url: str, + get_engine_url_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) + + return inner + + +@fixture +def mock_query( + httpx_mock: HTTPXMock, + query_url: str, + query_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(query_callback, url=query_url) + + return inner From 83a5b440cbfb0a922414dee69081181dba97cad2 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 21 Mar 2023 14:44:44 +0200 Subject: [PATCH 09/28] fix sync cursor tests --- tests/unit/db/test_cursor.py | 119 ++++++++--------------------------- tests/unit/db_conftest.py | 12 ++++ 2 files changed, 39 insertions(+), 92 deletions(-) diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 8b8e18be77c..8b4959cfc55 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -18,15 +18,12 @@ def test_cursor_state( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, + mock_query: Callable, query_url: str, cursor: Cursor, ): """Cursor state changes depending on the operations performed with it.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert cursor._state == CursorState.NONE @@ -89,10 +86,7 @@ def test_closed_cursor(cursor: Cursor): def test_cursor_no_query( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_query: Callable, cursor: Cursor, ): """Some of cursor methods are unavailable until a query is run.""" @@ -103,8 +97,7 @@ def test_cursor_no_query( "nextset", ) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() for method in methods: with raises(QueryNotRunError): @@ -134,12 +127,8 @@ def test_cursor_no_query( def test_cursor_execute( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], @@ -157,8 +146,7 @@ def test_cursor_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert query() == len( python_query_data ), f"Invalid row count returned for {message}." @@ -181,11 +169,8 @@ def test_cursor_execute( cursor.fetchone() is None ), f"Non-empty fetchone after all data received for {message}." - httpx_mock.reset(True) - # Query with empty output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() assert query() == -1, f"Invalid row count for insert using {message}." assert ( cursor.rowcount == -1 @@ -197,8 +182,6 @@ def test_cursor_execute( def test_cursor_execute_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_url: str, cursor: Cursor, ): @@ -213,8 +196,6 @@ def test_cursor_execute_error( "server-side synchronous executemany()", ), ): - httpx_mock.add_callback(auth_callback, url=auth_url) - # Internal httpx error def http_error(*args, **kwargs): raise StreamError("httpx error") @@ -254,17 +235,12 @@ def http_error(*args, **kwargs): def test_cursor_fetchone( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchone fetches single row in correct order; if no rows returns None.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -279,27 +255,22 @@ def test_cursor_fetchone( cursor.fetchone() is None ), "fetchone should return None when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchone() def test_cursor_fetchmany( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """ Cursor's fetchmany fetches the provided amount of rows, or arraysize by default. If not enough rows left, returns less or None if there are no rows. """ - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -340,24 +311,19 @@ def test_cursor_fetchmany( len(cursor.fetchmany()) == 0 ), "fetchmany should return empty result set when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchmany() def test_cursor_fetchall( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchall fetches all rows that left after last query.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() cursor.execute("sql") @@ -374,28 +340,23 @@ def test_cursor_fetchall( len(cursor.fetchall()) == 0 ), "fetchmany should return empty result set when no rows left to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() cursor.execute("sql") with raises(DataError): cursor.fetchall() def test_cursor_multi_statement( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], ): """executemany with multiple parameter sets is not supported.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() + mock_insert_query() + mock_query() rc = cursor.execute("select * from t; insert into t values (1, 2); select * from t") assert rc == len(python_query_data), "Invalid row count returned" @@ -432,8 +393,6 @@ def test_cursor_multi_statement( def test_cursor_set_statements( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_callback: Callable, select_one_query_callback: Callable, set_query_url: str, @@ -442,7 +401,6 @@ def test_cursor_set_statements( python_query_data: List[List[ColType]], ): """cursor correctly parses and processes set statements.""" - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") assert len(cursor._set_parameters) == 0 @@ -504,8 +462,6 @@ def test_cursor_set_statements( def test_cursor_set_parameters_sent( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, set_query_url: str, query_url: str, query_with_params_callback: Callable, @@ -514,7 +470,6 @@ def test_cursor_set_parameters_sent( set_params: Dict, ): """Cursor passes provided set parameters to engine.""" - httpx_mock.add_callback(auth_callback, url=auth_url) params = "" @@ -531,16 +486,11 @@ def test_cursor_set_parameters_sent( def test_cursor_skip_parse( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_url: str, - query_callback: Callable, + mock_query: Callable, cursor: Cursor, ): """Cursor doesn't process a query if skip_parsing is provided.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() with patch("firebolt.db.cursor.split_format_sql") as split_format_sql_mock: cursor.execute("non-an-actual-sql") @@ -553,8 +503,6 @@ def test_cursor_skip_parse( def test_cursor_server_side_async_execute( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_id_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -576,7 +524,6 @@ def test_cursor_server_side_async_execute( ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_id_callback, url=query_with_params_url ) @@ -594,8 +541,6 @@ def test_cursor_server_side_async_execute( def test_cursor_server_side_async_cancel( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_cancel_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -607,7 +552,6 @@ def test_cursor_server_side_async_cancel( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_cancel_callback, url=query_with_params_url ) @@ -616,8 +560,6 @@ def test_cursor_server_side_async_cancel( def test_cursor_server_side_async_get_status_completed( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -629,7 +571,6 @@ def test_cursor_server_side_async_get_status_completed( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_callback, url=query_with_params_url ) @@ -639,8 +580,6 @@ def test_cursor_server_side_async_get_status_completed( def test_cursor_server_side_async_get_status_not_yet_available( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_not_yet_availabe_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -652,7 +591,6 @@ def test_cursor_server_side_async_get_status_not_yet_available( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_not_yet_availabe_callback, url=query_with_params_url, @@ -663,15 +601,12 @@ def test_cursor_server_side_async_get_status_not_yet_available( def test_cursor_server_side_async_get_status_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_error: Callable, server_side_async_id: Callable, query_with_params_url: str, cursor: Cursor, ): """ """ - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_error, url=query_with_params_url ) diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 09e14ba6452..5e32ec0bd22 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -488,3 +488,15 @@ def inner() -> None: httpx_mock.add_callback(query_callback, url=query_url) return inner + + +@fixture +def mock_insert_query( + httpx_mock: HTTPXMock, + query_url: str, + insert_query_callback: Callable, +) -> Callable: + def inner() -> None: + httpx_mock.add_callback(insert_query_callback, url=query_url) + + return inner From eb66f78020f0be9d1ba82f32f3a56cc79be9458a Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 21 Mar 2023 16:05:45 +0200 Subject: [PATCH 10/28] fixes for async cursor tests --- setup.cfg | 4 +- tests/unit/async_db/conftest.py | 21 +- tests/unit/async_db/test_connection.py | 320 +++++++++---------------- tests/unit/async_db/test_cursor.py | 130 +++------- tests/unit/db/test_connection.py | 141 +++++++---- tests/unit/db/test_cursor.py | 4 - 6 files changed, 255 insertions(+), 365 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7a3d4e111bf..1f372f432bd 100755 --- a/setup.cfg +++ b/setup.cfg @@ -51,11 +51,11 @@ dev = pre-commit==2.15.0 pyfakefs>=4.5.3 pytest==6.2.5 - pytest-asyncio==0.19.0 pytest-cov==3.0.0 pytest-httpx==0.22.0 pytest-mock==3.6.1 pytest-timeout==2.1.0 + pytest-trio==0.8.0 pytest-xdist==2.5.0 trio-typing[mypy]==0.6.* types-cryptography==3.3.18 @@ -90,4 +90,4 @@ docstring-convention = google inline-quotes = " [tool:pytest] -asyncio_mode = auto +trio_mode = true diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index f6b6abea1f3..8be96d4ccf5 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -2,7 +2,6 @@ from typing import Dict from pytest import fixture -from pytest_asyncio import fixture as asyncio_fixture from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect from firebolt.client.auth import Auth @@ -10,21 +9,29 @@ from tests.unit.db_conftest import * # noqa -@asyncio_fixture -async def connection(settings: Settings, auth: Auth, db_name: str) -> Connection: +@fixture +async def connection( + server: str, + db_name: str, + auth: Auth, + engine_name: str, + account_name: str, + mock_connection_flow: Callable, +) -> Connection: + mock_connection_flow() async with ( await connect( - engine_url=settings.server, + engine_name=engine_name, database=db_name, auth=auth, - account_name=settings.account_name, - api_endpoint=settings.server, + account_name=account_name, + api_endpoint=server, ) ) as connection: yield connection -@asyncio_fixture +@fixture async def cursor(connection: Connection, settings: Settings) -> Cursor: return connection.cursor() diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index aef2e7f3ba2..b6414f1434e 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,8 +1,6 @@ -from re import Pattern from typing import Callable, List from unittest.mock import patch -from httpx import codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises from pytest_httpx import HTTPXMock @@ -12,13 +10,11 @@ from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.utils.exception import ( - AccountNotFoundError, ConfigurationError, ConnectionClosedError, - FireboltEngineError, + InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL async def test_closed_connection(connection: Connection) -> None: @@ -52,151 +48,121 @@ async def test_cursors_closed_on_close(connection: Connection) -> None: async def test_cursor_initialized( settings: Settings, - db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, + mock_query: Callable, + connection: Connection, python_query_data: List[List[ColType]], ) -> None: """Connection initialized its cursors properly.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() - for url in (settings.server, f"https://{settings.server}"): - async with ( - await connect( - engine_url=url, - database=db_name, - auth=ClientCredentials("cid", "cs"), - api_endpoint=settings.server, - ) - ) as connection: - cursor = connection.cursor() - assert ( - cursor.connection == connection - ), "Invalid cursor connection attribute." - assert ( - cursor._client == connection._client - ), "Invalid cursor _client attribute" + cursor = connection.cursor() + assert cursor.connection == connection, "Invalid cursor connection attribute." + assert cursor._client == connection._client, "Invalid cursor _client attribute" - assert await cursor.execute("select*") == len(python_query_data) + assert await cursor.execute("select*") == len(python_query_data) - cursor.close() - assert ( - cursor not in connection._cursors - ), "Cursor wasn't removed from connection after close." + cursor.close() + assert ( + cursor not in connection._cursors + ), "Cursor wasn't removed from connection after close." async def test_connect_empty_parameters(): with raises(ConfigurationError): - async with await connect(engine_url="engine_url"): + async with await connect(engine_name="engine_name"): pass async def test_connect_engine_name( - settings: Settings, db_name: str, + account_name: str, + engine_name: str, + auth: Auth, + server: str, + python_query_data: List[List[ColType]], httpx_mock: HTTPXMock, - auth_callback: Callable, + mock_query: Callable, + system_engine_query_url: str, + get_engine_url_not_running_callback: Callable, + get_engine_url_invalid_db_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - engine_id: str, - get_engine_url_by_id_url: str, - get_engine_url_by_id_callback: Callable, - python_query_data: List[List[ColType]], - account_id: str, + check_credentials_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, + get_engine_url_callback: Callable, ): """connect properly handles engine_name""" - with raises(ConfigurationError): - async with await connect( - engine_url="engine_url", - engine_name="engine_name", - database="db", - auth=ClientCredentials("cid", "cs"), - account_name="account_name", - ): - pass - - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(get_engine_url_by_id_callback, url=get_engine_url_by_id_url) - - engine_name = settings.server.split(".")[0] + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) - # Mock engine id lookup error - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.NOT_FOUND, - ) + mock_query() - with raises(FireboltEngineError): - async with await connect( - database="db", - auth=ClientCredentials("cid", "cs"), - engine_name=engine_name, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass + for callback in ( + get_engine_url_invalid_db_callback, + get_engine_url_not_running_callback, + ): + httpx_mock.add_callback(callback, url=system_engine_query_url) + with raises(InterfaceError): + async with await connect( + database=db_name, + auth=auth, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + ): + pass - # Mock engine id lookup by name - httpx_mock.add_response( - url=f"https://{settings.server}" - + ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - + f"?engine_name={engine_name}", - status_code=codes.OK, - json={"engine_id": {"engine_id": engine_id}}, - ) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) async with await connect( engine_name=engine_name, database=db_name, - auth=ClientCredentials("cid", "cs"), - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len(python_query_data) async def test_connect_default_engine( - settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - auth_callback: Callable, auth_url: str, - query_callback: Callable, - query_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - engine_by_db_url: str, + server: str, + auth: Auth, + account_name: str, python_query_data: List[List[ColType]], + mock_query: Callable, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, + system_engine_query_url: str, + get_default_db_engine_callback: Callable, + get_default_db_engine_not_running_callback: Callable, + get_system_engine_url: str, + get_system_engine_callback: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" - - httpx_mock.add_response( - url=engine_by_db_url, - status_code=codes.OK, - json={ - "engine_url": settings.server, - }, + mock_query() + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback( + get_default_db_engine_not_running_callback, url=system_engine_query_url ) + with raises(InterfaceError): + async with await connect( + database=db_name, + auth=auth, + account_name=account_name, + api_endpoint=server, + ) as connection: + await connection.cursor().execute("select*") + + httpx_mock.add_callback(get_default_db_engine_callback, url=system_engine_query_url) + async with await connect( database=db_name, - auth=ClientCredentials("cid", "cs"), - account_name=settings.account_name, - api_endpoint=settings.server, + auth=auth, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len(python_query_data) @@ -212,28 +178,27 @@ async def test_connection_commit(connection: Connection): @mark.nofakefs async def test_connection_token_caching( - settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - python_query_data: List[List[ColType]], + server: str, access_token: str, client_id: str, client_secret: str, + engine_name: str, + account_name: str, + python_query_data: List[List[ColType]], + mock_connection_flow: Callable, + mock_query: Callable, ) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_connection_flow() + mock_query() with Patcher(): async with await connect( database=db_name, auth=ClientCredentials(client_id, client_secret, use_token_cache=True), - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data @@ -246,9 +211,9 @@ async def test_connection_token_caching( async with await connect( database=db_name, auth=ClientCredentials(client_id, client_secret, use_token_cache=False), - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: assert await connection.cursor().execute("select*") == len( python_query_data @@ -259,76 +224,21 @@ async def test_connection_token_caching( ), "Token is cached even though caching is disabled" -async def test_connect_with_auth( - httpx_mock: HTTPXMock, - settings: Settings, +async def test_connect_with_user_agent( + engine_name: str, + account_name: str, + server: str, db_name: str, - check_credentials_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - access_token: str, - auth: Auth, -) -> None: - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ) as connection: - await connection.cursor().execute("select*") - - -async def test_connect_account_name( - httpx_mock: HTTPXMock, auth: Auth, - settings: Settings, - db_name: str, - auth_url: str, - check_credentials_callback: Callable, - account_id_url: Pattern, - account_id_callback: Callable, -): - httpx_mock.add_callback(check_credentials_callback, url=auth_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - - with raises(AccountNotFoundError): - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name="invalid", - api_endpoint=settings.server, - ): - pass - - async with await connect( - auth=auth, - database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, - ): - pass - - -async def test_connect_with_user_agent( + access_token: str, httpx_mock: HTTPXMock, - settings: Settings, - db_name: str, query_callback: Callable, query_url: str, - auth_callback: Callable, - auth_url: str, - access_token: str, + mock_connection_flow: Callable, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_connection_flow() httpx_mock.add_callback( query_callback, url=query_url, @@ -336,43 +246,45 @@ async def test_connect_with_user_agent( ) async with await connect( - auth=ClientCredentials("cid", "cs"), + auth=auth, database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, additional_parameters={ "user_clients": [("MyConnector", "1.0")], "user_drivers": [("DriverA", "1.1")], }, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) async def test_connect_no_user_agent( - httpx_mock: HTTPXMock, - settings: Settings, + engine_name: str, + account_name: str, + server: str, db_name: str, + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, query_callback: Callable, query_url: str, - auth_callback: Callable, - auth_url: str, - access_token: str, + mock_connection_flow: Callable, ) -> None: with patch("firebolt.async_db.connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_connection_flow() httpx_mock.add_callback( query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} ) async with await connect( - auth=ClientCredentials("cid", "cs"), + auth=auth, database=db_name, - engine_url=settings.server, - account_name=settings.account_name, - api_endpoint=settings.server, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, ) as connection: await connection.cursor().execute("select*") - ut.assert_called_once_with([], []) + ut.assert_called_with([], []) diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 64f13d4318f..aa6194eb695 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -22,15 +22,12 @@ async def test_cursor_state( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, + mock_query: Callable, query_url: str, cursor: Cursor, ): """Cursor state changes depend on the operations performed with it.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert cursor._state == CursorState.NONE @@ -95,11 +92,7 @@ async def test_closed_cursor(cursor: Cursor): async def test_cursor_no_query( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - query_url: str, - server_side_async_id: str, + mock_query: Callable, cursor: Cursor, ): """Some cursor methods are unavailable until a query is run.""" @@ -109,8 +102,7 @@ async def test_cursor_no_query( "fetchall", ) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() for amethod in async_methods: with raises(QueryNotRunError): @@ -145,12 +137,8 @@ async def test_cursor_no_query( async def test_cursor_execute( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], @@ -167,8 +155,7 @@ async def test_cursor_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() assert await query() == len( python_query_data ), f"Invalid row count returned for {message}." @@ -189,11 +176,8 @@ async def test_cursor_execute( await cursor.fetchone() is None ), f"Non-empty fetchone after all data received {message}." - httpx_mock.reset(True) - # Query with empty output - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() assert await query() == -1, f"Invalid row count for insert using {message}." assert ( cursor.rowcount == -1 @@ -205,11 +189,7 @@ async def test_cursor_execute( async def test_cursor_execute_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_url: str, - get_engines_url: str, - get_databases_url: str, cursor: Cursor, ): """Cursor handles all types of errors properly.""" @@ -223,8 +203,6 @@ async def test_cursor_execute_error( "server-side synchronous executemany()", ), ): - httpx_mock.add_callback(auth_callback, url=auth_url) - # Internal httpx error def http_error(*args, **kwargs): raise StreamError("httpx error") @@ -297,8 +275,6 @@ def http_error(*args, **kwargs): async def test_cursor_server_side_async_execute_errors( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, query_with_params_url: str, server_side_async_missing_id_callback: Callable, insert_query_callback: str, @@ -318,8 +294,6 @@ async def test_cursor_server_side_async_execute_errors( "server-side asynchronous executemany()", ), ): - # Empty server-side asynchronous execution return. - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(insert_query_callback, url=query_with_params_url) with raises(OperationalError) as excinfo: await query("SELECT * FROM t") @@ -328,7 +302,6 @@ async def test_cursor_server_side_async_execute_errors( assert str(excinfo.value) == ("No response to asynchronous query.") # Missing query_id from server-side asynchronous execution. - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_missing_id_callback, url=query_with_params_url ) @@ -341,7 +314,6 @@ async def test_cursor_server_side_async_execute_errors( ) # Multi-statement queries are not possible with async_execution error. - httpx_mock.add_callback(auth_callback, url=auth_url) with raises(AsyncExecutionUnavailableError) as excinfo: await query("SELECT * FROM t; SELECT * FROM s") @@ -374,23 +346,17 @@ async def test_cursor_server_side_async_execute_errors( ), f"use_standard_sql=0 was allowed for server-side asynchronous queries on {message}." # Have to reauth or next execute fails. Not sure why. - httpx_mock.add_callback(auth_callback, url=auth_url) await cursor.execute("set use_standard_sql=1") httpx_mock.reset(True) async def test_cursor_fetchone( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchone fetches single row in correct order. If no rows, returns None.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -405,27 +371,22 @@ async def test_cursor_fetchone( await cursor.fetchone() is None ), "fetchone should return None when no rows left to fetch." - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchone() async def test_cursor_fetchmany( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """ Cursor's fetchmany fetches the provided amount of rows, or arraysize by default. If not enough rows left, returns less or None if there are no rows. """ - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -466,24 +427,19 @@ async def test_cursor_fetchmany( len(await cursor.fetchmany()) == 0 ), "fetchmany should return empty result set when no rows remain to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchmany() async def test_cursor_fetchall( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, ): """cursor fetchall fetches all rows remaining after last query.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() await cursor.execute("sql") @@ -500,28 +456,23 @@ async def test_cursor_fetchall( len(await cursor.fetchall()) == 0 ), "fetchmany should return empty result set when no rows remain to fetch" - httpx_mock.add_callback(insert_query_callback, url=query_url) + mock_insert_query() await cursor.execute("sql") with raises(DataError): await cursor.fetchall() async def test_cursor_multi_statement( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_callback: Callable, - insert_query_callback: Callable, - query_url: str, + mock_query: Callable, + mock_insert_query: Callable, cursor: Cursor, python_query_description: List[Column], python_query_data: List[List[ColType]], ): """executemany with multiple parameter sets is not supported.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) - httpx_mock.add_callback(insert_query_callback, url=query_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() + mock_insert_query() + mock_query() rc = await cursor.execute( "select * from t; insert into t values (1, 2); select * from t" @@ -570,14 +521,11 @@ async def test_cursor_multi_statement( async def test_cursor_set_statements( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, select_one_query_callback: Callable, set_query_url: str, cursor: Cursor, ): """cursor correctly parses and processes set statements.""" - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") assert len(cursor._set_parameters) == 0 @@ -639,8 +587,6 @@ async def test_cursor_set_statements( async def test_cursor_set_parameters_sent( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, set_query_url: str, query_url: str, query_with_params_callback: Callable, @@ -649,8 +595,6 @@ async def test_cursor_set_parameters_sent( set_params: Dict, ): """Cursor passes provided set parameters to engine.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - params = "" for p, v in set_params.items(): @@ -666,16 +610,11 @@ async def test_cursor_set_parameters_sent( async def test_cursor_skip_parse( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - query_url: str, - query_callback: Callable, + mock_query: Callable, cursor: Cursor, ): """Cursor doesn't process a query if skip_parsing is provided.""" - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(query_callback, url=query_url) + mock_query() with patch("firebolt.async_db.cursor.split_format_sql") as split_format_sql_mock: await cursor.execute("non-an-actual-sql") @@ -688,8 +627,6 @@ async def test_cursor_skip_parse( async def test_cursor_server_side_async_execute( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_id_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -712,7 +649,6 @@ async def test_cursor_server_side_async_execute( ), ): # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_id_callback, url=query_with_params_url ) @@ -730,8 +666,6 @@ async def test_cursor_server_side_async_execute( async def test_cursor_server_side_async_cancel( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_cancel_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -743,7 +677,6 @@ async def test_cursor_server_side_async_cancel( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_cancel_callback, url=query_with_params_url ) @@ -752,8 +685,6 @@ async def test_cursor_server_side_async_cancel( async def test_cursor_server_side_async_get_status_completed( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -765,7 +696,6 @@ async def test_cursor_server_side_async_get_status_completed( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_callback, url=query_with_params_url ) @@ -775,8 +705,6 @@ async def test_cursor_server_side_async_get_status_completed( async def test_cursor_server_side_async_get_status_not_yet_available( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_not_yet_availabe_callback: Callable, server_side_async_id: Callable, query_with_params_url: str, @@ -788,7 +716,6 @@ async def test_cursor_server_side_async_get_status_not_yet_available( """ # Query with json output - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_not_yet_availabe_callback, url=query_with_params_url, @@ -799,15 +726,12 @@ async def test_cursor_server_side_async_get_status_not_yet_available( async def test_cursor_server_side_async_get_status_error( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, server_side_async_get_status_error: Callable, server_side_async_id: Callable, query_with_params_url: str, cursor: Cursor, ): """ """ - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( server_side_async_get_status_error, url=query_with_params_url ) diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 626499bdef5..59368424391 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,6 +1,7 @@ import gc import warnings from typing import Callable, List +from unittest.mock import patch from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark, raises, warns @@ -48,36 +49,23 @@ def test_cursors_closed_on_close(connection: Connection) -> None: def test_cursor_initialized( settings: Settings, - db_name: str, - mock_connection_flow: Callable, mock_query: Callable, - account_name: str, + connection: Connection, python_query_data: List[List[ColType]], - auth: Auth, - engine_name: str, ) -> None: """Connection initialized its cursors properly.""" - mock_connection_flow() mock_query() - with connect( - engine_name=engine_name, - database=db_name, - api_endpoint=settings.server, - auth=auth, - account_name=account_name, - ) as connection: - - cursor = connection.cursor() - assert cursor.connection == connection, "Invalid cursor connection attribute" - assert cursor._client == connection._client, "Invalid cursor _client attribute" + cursor = connection.cursor() + assert cursor.connection == connection, "Invalid cursor connection attribute" + assert cursor._client == connection._client, "Invalid cursor _client attribute" - assert cursor.execute("select*") == len(python_query_data) + assert cursor.execute("select*") == len(python_query_data) - cursor.close() - assert ( - cursor not in connection._cursors - ), "Cursor wasn't removed from connection after close" + cursor.close() + assert ( + cursor not in connection._cursors + ), "Cursor wasn't removed from connection after close" def test_connect_empty_parameters(): @@ -87,15 +75,14 @@ def test_connect_empty_parameters(): def test_connect_engine_name( - settings: Settings, db_name: str, - httpx_mock: HTTPXMock, - mock_connection_flow: Callable, - mock_query: Callable, account_name: str, engine_name: str, - python_query_data: List[List[ColType]], auth: Auth, + server: str, + python_query_data: List[List[ColType]], + mock_query: Callable, + httpx_mock: HTTPXMock, system_engine_query_url: str, get_engine_url_not_running_callback: Callable, get_engine_url_invalid_db_callback: Callable, @@ -118,11 +105,11 @@ def test_connect_engine_name( httpx_mock.add_callback(callback, url=system_engine_query_url) with raises(InterfaceError): connect( - database="db", + database=db_name, auth=auth, engine_name=engine_name, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) @@ -132,23 +119,21 @@ def test_connect_engine_name( database=db_name, auth=auth, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) def test_connect_default_engine( - settings: Settings, db_name: str, - mock_query: Callable, - httpx_mock: HTTPXMock, auth_url: str, - check_credentials_callback: Callable, - database_id: str, - python_query_data: List[List[ColType]], - account_id: str, + server: str, auth: Auth, account_name: str, + python_query_data: List[List[ColType]], + mock_query: Callable, + httpx_mock: HTTPXMock, + check_credentials_callback: Callable, system_engine_query_url: str, get_default_db_engine_callback: Callable, get_default_db_engine_not_running_callback: Callable, @@ -166,7 +151,7 @@ def test_connect_default_engine( database=db_name, auth=auth, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) as connection: connection.cursor().execute("select*") @@ -176,7 +161,7 @@ def test_connect_default_engine( database=db_name, auth=auth, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) @@ -212,16 +197,16 @@ def test_connection_commit(connection: Connection): @mark.nofakefs def test_connection_token_caching( - settings: Settings, db_name: str, - mock_connection_flow: Callable, - mock_query: Callable, - python_query_data: List[List[ColType]], + server: str, access_token: str, client_id: str, client_secret: str, engine_name: str, account_name: str, + python_query_data: List[List[ColType]], + mock_connection_flow: Callable, + mock_query: Callable, ) -> None: mock_connection_flow() mock_query() @@ -232,7 +217,7 @@ def test_connection_token_caching( auth=ClientCredentials(client_id, client_secret, use_token_cache=True), engine_name=engine_name, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) ts = TokenSecureStorage(username=client_id, password=client_secret) @@ -245,10 +230,76 @@ def test_connection_token_caching( auth=ClientCredentials(client_id, client_secret, use_token_cache=False), engine_name=engine_name, account_name=account_name, - api_endpoint=settings.server, + api_endpoint=server, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) ts = TokenSecureStorage(username=client_id, password=client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" + + +def test_connect_with_user_agent( + engine_name: str, + account_name: str, + server: str, + db_name: str, + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, +) -> None: + with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + ut.return_value = "MyConnector/1.0 DriverA/1.1" + mock_connection_flow() + httpx_mock.add_callback( + query_callback, + url=query_url, + match_headers={"User-Agent": "MyConnector/1.0 DriverA/1.1"}, + ) + + with connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + additional_parameters={ + "user_clients": [("MyConnector", "1.0")], + "user_drivers": [("DriverA", "1.1")], + }, + ) as connection: + connection.cursor().execute("select*") + ut.assert_called_with([("DriverA", "1.1")], [("MyConnector", "1.0")]) + + +def test_connect_no_user_agent( + engine_name: str, + account_name: str, + server: str, + db_name: str, + auth: Auth, + access_token: str, + httpx_mock: HTTPXMock, + query_callback: Callable, + query_url: str, + mock_connection_flow: Callable, +) -> None: + with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + ut.return_value = "Python/3.0" + mock_connection_flow() + httpx_mock.add_callback( + query_callback, url=query_url, match_headers={"User-Agent": "Python/3.0"} + ) + + with connect( + auth=auth, + database=db_name, + engine_name=engine_name, + account_name=account_name, + api_endpoint=server, + ) as connection: + connection.cursor().execute("select*") + ut.assert_called_with([], []) diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index 8b4959cfc55..04f9082012b 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -393,12 +393,9 @@ def test_cursor_multi_statement( def test_cursor_set_statements( httpx_mock: HTTPXMock, - query_callback: Callable, select_one_query_callback: Callable, set_query_url: str, cursor: Cursor, - python_query_description: List[Column], - python_query_data: List[List[ColType]], ): """cursor correctly parses and processes set statements.""" httpx_mock.add_callback(select_one_query_callback, url=f"{set_query_url}&a=b") @@ -470,7 +467,6 @@ def test_cursor_set_parameters_sent( set_params: Dict, ): """Cursor passes provided set parameters to engine.""" - params = "" for p, v in set_params.items(): From f2cc825eb8f1eba7c475290cd5c77dee31633522 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 21 Mar 2023 16:13:16 +0200 Subject: [PATCH 11/28] fix engine get connection function --- src/firebolt/model/engine.py | 4 ++-- tests/unit/conftest.py | 4 ++-- tests/unit/service/conftest.py | 4 ++-- tests/unit/service/test_engine.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 8eab5e8e35b..a484ed9d919 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -194,8 +194,8 @@ def get_connection(self) -> Connection: """ return connect( database=self.database.name, # type: ignore # already checked by decorator - auth=self._service.client.auth, # type: ignore - engine_url=self.endpoint, + auth=self._service.client.auth, + engine_name=self.name, account_name=self._service.settings.account_name, api_endpoint=self._service.settings.server, ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 89c1eb9fcff..3657daf1115 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -135,12 +135,12 @@ def auth(client_id: str, client_secret: str) -> Auth: @fixture -def settings(server: str, region_1: str, auth: Auth) -> Settings: +def settings(server: str, region_1: str, auth: Auth, account_name: str) -> Settings: return Settings( server=server, auth=auth, default_region=region_1.name, - account_name=None, + account_name=account_name, ) diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 2043f8a4dcf..4cb7a7789a2 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -290,9 +290,9 @@ def account_engine_url(settings: Settings, account_id, mock_engine) -> str: @fixture -def mock_database(region_1, account_id) -> Database: +def mock_database(region_1: str, account_id: str) -> Database: return Database( - name="mock_db_name", + name="database", description="mock_db_description", compute_region_key=region_1.key, database_key=DatabaseKey( diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 18b93af5c44..9686ecbf7ff 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -288,19 +288,19 @@ def test_get_connection( database_url: str, bindings_callback: Callable, bindings_url: str, + mock_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") httpx_mock.add_callback(bindings_callback, url=bindings_url) httpx_mock.add_callback(database_callback, url=database_url) + mock_connection_flow() manager = ResourceManager(settings=settings) engine = manager.engines.create(name=engine_name) From 21b2b89735c35d33a71c2ad3eb9fe0ab88372eba Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 21 Mar 2023 16:40:23 +0200 Subject: [PATCH 12/28] fix nested loops in trio --- src/firebolt/utils/util.py | 47 +++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 9076f1bbc26..4d3457bb48d 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,5 +1,6 @@ +from contextlib import contextmanager from functools import lru_cache, partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generator, Type, TypeVar import trio from httpx import URL @@ -86,6 +87,38 @@ def get_auth_endpoint(api_endpoint: URL) -> URL: ) +@contextmanager +def nested_loop() -> Generator: + from trio._core._run import GLOBAL_RUN_CONTEXT # type: ignore + + s = object() + task, runner, _dict = s, s, s + if hasattr(GLOBAL_RUN_CONTEXT, "__dict__"): + _dict = GLOBAL_RUN_CONTEXT.__dict__ + if hasattr(GLOBAL_RUN_CONTEXT, "task"): + task = GLOBAL_RUN_CONTEXT.task + del GLOBAL_RUN_CONTEXT.task + if hasattr(GLOBAL_RUN_CONTEXT, "runner"): + runner = GLOBAL_RUN_CONTEXT.runner + del GLOBAL_RUN_CONTEXT.runner + + try: + yield + finally: + if task is not s: + GLOBAL_RUN_CONTEXT.task = task + elif hasattr(GLOBAL_RUN_CONTEXT, "task"): + del GLOBAL_RUN_CONTEXT.task + + if runner is not s: + GLOBAL_RUN_CONTEXT.runner = runner + elif hasattr(GLOBAL_RUN_CONTEXT, "runner"): + del GLOBAL_RUN_CONTEXT.runner + + if _dict is not s: + GLOBAL_RUN_CONTEXT.__dict__.update(_dict) + + def async_to_sync(f: Callable) -> Callable: """Convert async function to sync. @@ -98,7 +131,8 @@ def async_to_sync(f: Callable) -> Callable: @wraps(f) def sync(*args: Any, **kwargs: Any) -> Any: - return trio.run(partial(f, *args, **kwargs)) + with nested_loop(): + return trio.run(partial(f, *args, **kwargs)) return sync @@ -119,12 +153,3 @@ def merge_urls(base: URL, merge: URL) -> URL: merge_raw_path = base.raw_path + merge.raw_path.lstrip(b"/") return base.copy_with(raw_path=merge_raw_path) return merge - - -def validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - if engine_name and engine_url: - raise ConfigurationError( - "Both engine_name and engine_url are provided. Provide only one to connect." - ) From af858a42a947a536cad33f97644e0850b3a46429 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 22 Mar 2023 11:40:27 +0200 Subject: [PATCH 13/28] update connection logic, fix engine running check --- src/firebolt/async_db/connection.py | 149 +++++++++++----------------- src/firebolt/async_db/cursor.py | 10 +- src/firebolt/async_db/util.py | 89 ++++++++++------- 3 files changed, 114 insertions(+), 134 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 34ff5b34c10..680dadf3d0f 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -3,13 +3,18 @@ import logging import socket from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Type from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream -from httpx import AsyncHTTPTransport, Timeout, codes +from httpx import AsyncHTTPTransport, Timeout from firebolt.async_db.cursor import Cursor +from firebolt.async_db.util import ( + DEFAULT_TIMEOUT_SECONDS, + _get_engine_url_status_db, + _get_system_engine_url, +) from firebolt.client import DEFAULT_API_URL, AsyncClient from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection @@ -23,80 +28,13 @@ ConnectionClosedError, InterfaceError, ) -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema + logger = logging.getLogger(__name__) -async def _get_system_engine_url( - auth: Auth, - account_name: str, - api_endpoint: str, -) -> str: - async with AsyncClient( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) - response = await client.get(url=url) - if response.status_code != codes.OK: - raise InterfaceError( - f"Unable to retrieve system engine endpoint {url}: " - f"{response.status} {response.content}" - ) - return response.json()["gatewayHost"] - - -async def _get_database_default_engine_url( - system_engine: Connection, database_name: str -) -> str: - cursor = system_engine.cursor() - await cursor.execute( - """ - SELECT engs.engine_url, engs.status - FROM information_schema.databases AS dbs - INNER JOIN information_schema.engines AS engs - ON engs.attached_to = dbs.database_name - AND engs.engine_name = NULLIF(SPLIT_PART(ARRAY_FIRST( - eng_name -> eng_name LIKE '%(default)', - SPLIT(',', attached_engines) - ), ' ', 1), '') - WHERE database_name = ?; - """, - [database_name], - ) - row = await cursor.fetchone() - if row is None: - raise InterfaceError(f"Database {database_name} doesn't have a default engine") - engine_url, status = row - if status != "Running": - raise InterfaceError(f"A default engine for {database_name} is not running") - return str(engine_url) # Mypy check - -async def _get_engine_url_and_db( - system_engine: Connection, engine_name: str -) -> Tuple[str, str]: - cursor = system_engine.cursor() - await cursor.execute( - """ - SELECT engine_url, attached_to, status FROM information_schema.engines - WHERE engine_name=? - """, - [engine_name], - ) - row = await cursor.fetchone() - if row is None: - raise InterfaceError(f"Engine with name {engine_name} doesn't exist") - engine_url, database, status = row - if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") - return str(engine_url), str(database) # Mypy check - async def connect( auth: Optional[Auth] = None, account_name: Optional[str] = None, @@ -116,7 +54,6 @@ async def connect( `api_endpoint` (str): Firebolt API endpoint. Used for authentication `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection - """ # These parameters are optional in function signature # but are required to connect. @@ -137,35 +74,55 @@ async def connect( system_engine_url, None, auth, api_endpoint, additional_parameters ) + if not engine_name: + # Return system engine connection + return connection_class( + system_engine_url, + database, + auth, + api_endpoint, + None, + additional_parameters, + ) + else: async with Connection( - system_engine_url, None, auth, api_endpoint, additional_parameters + system_engine_url, + database, + auth, + api_endpoint, + None, + additional_parameters, ) as system_engine_connection: if engine_name: - engine_url, attached_db = await _get_engine_url_and_db( + engine_url, status, attached_db = await _get_engine_url_status_db( system_engine_connection, engine_name ) + elif database is None: + database = attached_db - if database is not None and database != attached_db: - raise InterfaceError( - f"Engine {engine_name} is not attached to {database}, " - f"but to {attached_db}" - ) - elif database is None: - database = attached_db - - elif database: - # Get database default engine - engine_url = await _get_database_default_engine_url( - system_engine_connection, database - ) - - assert engine_url is not None + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, database, auth, api_endpoint, additional_parameters - ) + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + api_endpoint, + system_engine_connection, + additional_parameters, + ) class OverriddenHttpBackend(AutoBackend): @@ -240,6 +197,7 @@ class Connection(BaseConnection): "engine_url", "api_endpoint", "_is_closed", + "_system_engine_connection", ) def __init__( @@ -248,6 +206,7 @@ def __init__( database: str, auth: Auth, api_endpoint: str = DEFAULT_API_URL, + system_engine_connection: Optional["Connection"], additional_parameters: Dict[str, Any] = {}, ): self.api_endpoint = api_endpoint @@ -267,8 +226,14 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) + self._system_engine_connection = system_engine_connection super().__init__() + @property + def _is_system(self) -> bool: + """`True` if connection is a system engine connection; `False` otherwise.""" + return self._system_engine_connection is not None + def cursor(self, **kwargs: Any) -> Cursor: if self.closed: raise ConnectionClosedError("Unable to create cursor: connection closed.") diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index f334c9d35c9..7a24422cb90 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -29,6 +29,9 @@ SetParameter, split_format_sql, ) + +from firebolt.async_db.util import is_engine_running +from firebolt.client import AsyncClient from firebolt.common.base_cursor import ( BaseCursor, CursorState, @@ -39,7 +42,6 @@ AsyncExecutionUnavailableError, CursorClosedError, EngineNotRunningError, - FireboltDatabaseError, OperationalError, ProgrammingError, QueryNotRunError, @@ -119,12 +121,6 @@ async def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: - if self.connection.database and not await is_db_available( - self.connection, self.connection.database - ): - raise FireboltDatabaseError( - f"Database {self.connection.database} does not exist" - ) raise ProgrammingError(resp.read().decode("utf-8")) if ( resp.status_code == codes.SERVICE_UNAVAILABLE diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index ab560c75379..d268a8f2753 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -1,26 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple -from httpx import URL, Response +from httpx import URL, Timeout, codes -from firebolt.utils.urls import DATABASES_URL, ENGINES_URL +from firebolt.client import AsyncClient +from firebolt.client.auth import Auth +from firebolt.utils.exception import InterfaceError +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.async_db.connection import Connection - -async def is_db_available(connection: Connection, database_name: str) -> bool: - """ - Verify that the database exists. - - Args: - connection (firebolt.async_db.connection.Connection) - """ - resp = await _filter_request( - connection, DATABASES_URL, {"filter.name_contains": database_name} - ) - return len(resp.json()["edges"]) > 0 +DEFAULT_TIMEOUT_SECONDS = 60 +ENGINE_STATUS_RUNNING = "Running" async def is_engine_running(connection: Connection, engine_url: str) -> bool: @@ -29,29 +22,55 @@ async def is_engine_running(connection: Connection, engine_url: str) -> bool: Args: connection (firebolt.async_db.connection.Connection): connection. + engine_url (str): URL of the engine """ - # Url is not guaranteed to be of this structure, - # but for the sake of error checking this is sufficient. + if connection._is_system: + # System engine is always running + return True + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") - resp = await _filter_request( - connection, - ENGINES_URL, - { - "filter.name_contains": engine_name, - "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", - }, + assert connection._system_engine_connection is not None # Type check + _, status, _ = await _get_engine_url_status_db( + connection._system_engine_connection, engine_name ) - return len(resp.json()["edges"]) > 0 + return status == ENGINE_STATUS_RUNNING + + +async def _get_system_engine_url( + auth: Auth, + account_name: str, + api_endpoint: str, +) -> str: + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = await client.get(url=url) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status} {response.content}" + ) + return response.json()["gatewayHost"] -async def _filter_request( - connection: Connection, endpoint: str, filters: dict -) -> Response: - resp = await connection._client.request( - # Full url overrides the client url, which contains engine as a prefix. - url=connection.api_endpoint + endpoint, - method="GET", - params=filters, +async def _get_engine_url_status_db( + system_engine: Connection, engine_name: str +) -> Tuple[str, str, str]: + cursor = system_engine.cursor() + await cursor.execute( + """ + SELECT engine_url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], ) - resp.raise_for_status() - return resp + row = await cursor.fetchone() + if row is None: + raise InterfaceError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + return str(engine_url), str(status), str(database) # Mypy check From 08f9d9169910a767249dcb869e269bd0fa00991c Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 22 Mar 2023 13:09:18 +0200 Subject: [PATCH 14/28] unit tests fixes --- src/firebolt/async_db/connection.py | 36 ++++++++++++-------------- src/firebolt/async_db/util.py | 1 + tests/unit/async_db/test_connection.py | 31 ++++++++++------------ tests/unit/async_db/test_cursor.py | 26 ++++--------------- tests/unit/db/test_connection.py | 35 ++++++++++++------------- tests/unit/db_conftest.py | 13 ++++++++-- 6 files changed, 64 insertions(+), 78 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 680dadf3d0f..8aab8727730 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -86,31 +86,29 @@ async def connect( ) else: - async with Connection( + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( system_engine_url, database, auth, api_endpoint, None, additional_parameters, - ) as system_engine_connection: - if engine_name: - engine_url, status, attached_db = await _get_engine_url_status_db( - system_engine_connection, engine_name - ) - elif database is None: - database = attached_db - - if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") - - if database is not None and database != attached_db: - raise InterfaceError( - f"Engine {engine_name} is not attached to {database}, " - f"but to {attached_db}" - ) - elif database is None: - database = attached_db + ) + engine_url, status, attached_db = await _get_engine_url_status_db( + system_engine_connection, engine_name + ) + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") + + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db assert engine_url is not None diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index d268a8f2753..a8e3703b347 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -24,6 +24,7 @@ async def is_engine_running(connection: Connection, engine_url: str) -> bool: connection (firebolt.async_db.connection.Connection): connection. engine_url (str): URL of the engine """ + if connection._is_system: # System engine is always running return True diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index b6414f1434e..8d771c1e371 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -125,38 +125,35 @@ async def test_connect_engine_name( assert await connection.cursor().execute("select*") == len(python_query_data) -async def test_connect_default_engine( +async def test_connect_database( db_name: str, auth_url: str, server: str, auth: Auth, account_name: str, python_query_data: List[List[ColType]], - mock_query: Callable, httpx_mock: HTTPXMock, + query_callback: str, check_credentials_callback: Callable, system_engine_query_url: str, - get_default_db_engine_callback: Callable, - get_default_db_engine_not_running_callback: Callable, + system_engine_no_db_query_url: str, get_system_engine_url: str, get_system_engine_callback: Callable, ): - mock_query() httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) - httpx_mock.add_callback( - get_default_db_engine_not_running_callback, url=system_engine_query_url - ) - with raises(InterfaceError): - async with await connect( - database=db_name, - auth=auth, - account_name=account_name, - api_endpoint=server, - ) as connection: - await connection.cursor().execute("select*") + httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) + async with await connect( + database=None, + auth=auth, + account_name=account_name, + api_endpoint=server, + ) as connection: + await connection.cursor().execute("select*") - httpx_mock.add_callback(get_default_db_engine_callback, url=system_engine_query_url) + httpx_mock.reset(True) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_query_url) async with await connect( database=db_name, diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index aa6194eb695..c9e02e69a64 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -13,7 +13,6 @@ CursorClosedError, DataError, EngineNotRunningError, - FireboltDatabaseError, OperationalError, QueryNotRunError, ) @@ -191,6 +190,8 @@ async def test_cursor_execute_error( httpx_mock: HTTPXMock, query_url: str, cursor: Cursor, + get_engine_url_not_running_callback: Callable, + system_engine_query_url: str, ): """Cursor handles all types of errors properly.""" for query, message in ( @@ -239,32 +240,15 @@ def http_error(*args, **kwargs): str(excinfo.value) == "Error executing query:\nQuery error message" ), f"Invalid authentication error message for {message}." - # Database does not exist error - httpx_mock.add_response( - status_code=codes.FORBIDDEN, - content="Query error message", - url=query_url, - ) - httpx_mock.add_response( - json={"edges": []}, - url=get_databases_url + "?filter.name_contains=database", - ) - with raises(FireboltDatabaseError) as excinfo: - await query() - assert cursor._state == CursorState.ERROR - # Engine is not running error httpx_mock.add_response( status_code=codes.SERVICE_UNAVAILABLE, content="Query error message", url=query_url, ) - httpx_mock.add_response( - json={"edges": []}, - url=( - get_engines_url + "?filter.name_contains=api_dev" - "&filter.current_status_eq=ENGINE_STATUS_RUNNING_REVISION_SERVING" - ), + httpx_mock.add_callback( + get_engine_url_not_running_callback, + url=system_engine_query_url, ) with raises(EngineNotRunningError) as excinfo: await query() diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 59368424391..cca847a42e6 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -124,38 +124,35 @@ def test_connect_engine_name( assert connection.cursor().execute("select*") == len(python_query_data) -def test_connect_default_engine( +def test_connect_database( db_name: str, auth_url: str, server: str, auth: Auth, account_name: str, python_query_data: List[List[ColType]], - mock_query: Callable, httpx_mock: HTTPXMock, + query_callback: str, check_credentials_callback: Callable, system_engine_query_url: str, - get_default_db_engine_callback: Callable, - get_default_db_engine_not_running_callback: Callable, + system_engine_no_db_query_url: str, get_system_engine_url: str, get_system_engine_callback: Callable, ): - mock_query() httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) - httpx_mock.add_callback( - get_default_db_engine_not_running_callback, url=system_engine_query_url - ) - with raises(InterfaceError): - with connect( - database=db_name, - auth=auth, - account_name=account_name, - api_endpoint=server, - ) as connection: - connection.cursor().execute("select*") + httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) + with connect( + database=None, + auth=auth, + account_name=account_name, + api_endpoint=server, + ) as connection: + connection.cursor().execute("select*") - httpx_mock.add_callback(get_default_db_engine_callback, url=system_engine_query_url) + httpx_mock.reset(True) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(query_callback, url=system_engine_query_url) with connect( database=db_name, @@ -167,7 +164,7 @@ def test_connect_default_engine( def test_connection_unclosed_warnings(): - c = Connection("", "", None, "") + c = Connection("", "", None, "", None) with warns(UserWarning) as winfo: del c gc.collect() @@ -178,7 +175,7 @@ def test_connection_unclosed_warnings(): def test_connection_no_warnings(): - c = Connection("", "", None, "") + c = Connection("", "", None, "", None) c.close() with warnings.catch_warnings(): warnings.simplefilter("error") diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 5e32ec0bd22..b27eca33a58 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -353,7 +353,11 @@ def query_with_params_url(query_url: str, set_params: str) -> str: def _get_engine_url_callback(server: str, db_name: str, status="Running") -> Callable: def do_query(request: Request, **kwargs) -> Response: set_parameters = request.url.params - assert len(set_parameters) == 1 and "output_format" in set_parameters + assert ( + len(set_parameters) == 2 + and "output_format" in set_parameters + and "database" in set_parameters + ) data = [[server, db_name, status]] query_response = { "meta": [{"name": "name", "type": "Text"} for _ in range(len(data[0]))], @@ -431,7 +435,12 @@ def system_engine_url() -> str: @fixture -def system_engine_query_url(system_engine_url: str) -> str: +def system_engine_query_url(system_engine_url: str, db_name: str) -> str: + return f"{system_engine_url}/?output_format=JSON_Compact&database={db_name}" + + +@fixture +def system_engine_no_db_query_url(system_engine_url: str) -> str: return f"{system_engine_url}/?output_format=JSON_Compact" From 30bf6a697b89f1be2f03eb4f397852854ecaa654 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Fri, 14 Apr 2023 11:20:48 +0300 Subject: [PATCH 15/28] integration testing WIP --- src/firebolt/async_db/connection.py | 22 +++--- src/firebolt/async_db/util.py | 1 + .../client/auth/client_credentials.py | 3 +- tests/integration/conftest.py | 45 ++---------- tests/integration/dbapi/async/conftest.py | 50 ++++++------- .../dbapi/async/test_errors_async.py | 71 +++++++------------ tests/integration/dbapi/sync/test_errors.py | 22 +++--- 7 files changed, 79 insertions(+), 135 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 8aab8727730..b4a67c78cbb 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -110,17 +110,17 @@ async def connect( elif database is None: database = attached_db - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, - database, - auth, - api_endpoint, - system_engine_connection, - additional_parameters, - ) + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + api_endpoint, + system_engine_connection, + additional_parameters, + ) class OverriddenHttpBackend(AutoBackend): diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index a8e3703b347..3b337e9b217 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -49,6 +49,7 @@ async def _get_system_engine_url( api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), ) as client: + return "https://api.us-east-1.dev.firebolt.io/dynamic" url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) response = await client.get(url=url) if response.status_code != codes.OK: diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 36ba2a16791..0655a3a025c 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -88,7 +88,8 @@ def _make_auth_request(self) -> AuthRequest: "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": "client_credentials", - "audience": self._audience, + "audience": "https://dev-firebolt-v2.us.auth0.com/api/v2/" + # "audience": self._audience, }, ) return response diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 810125cf4a9..7798bad930e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,20 +2,15 @@ from os import environ from pytest import fixture -from pytest_asyncio import fixture as async_fixture -from firebolt.client.auth import ServiceAccount, UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.service.manager import Settings LOGGER = getLogger(__name__) -ENGINE_URL_ENV = "ENGINE_URL" ENGINE_NAME_ENV = "ENGINE_NAME" -STOPPED_ENGINE_URL_ENV = "STOPPED_ENGINE_URL" STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME" DATABASE_NAME_ENV = "DATABASE_NAME" -USER_NAME_ENV = "USER_NAME" -PASSWORD_ENV = "PASSWORD" ACCOUNT_NAME_ENV = "ACCOUNT_NAME" API_ENDPOINT_ENV = "API_ENDPOINT" SERVICE_ID_ENV = "SERVICE_ID" @@ -28,26 +23,15 @@ def must_env(var_name: str) -> str: return environ[var_name] -@async_fixture(scope="session") -def rm_settings(api_endpoint, username, password) -> Settings: +@fixture(scope="session") +def rm_settings(api_endpoint, auth) -> Settings: return Settings( server=api_endpoint, - user=username, - password=password, + auth=auth, default_region="us-east-1", ) -@fixture(scope="session") -def engine_url() -> str: - return must_env(ENGINE_URL_ENV) - - -@fixture(scope="session") -def stopped_engine_url() -> str: - return must_env(STOPPED_ENGINE_URL_ENV) - - @fixture(scope="session") def engine_name() -> str: return must_env(ENGINE_NAME_ENV) @@ -55,7 +39,7 @@ def engine_name() -> str: @fixture(scope="session") def stopped_engine_name() -> str: - return must_env(STOPPED_ENGINE_URL_ENV) + return must_env(STOPPED_ENGINE_NAME_ENV) @fixture(scope="session") @@ -63,16 +47,6 @@ def database_name() -> str: return must_env(DATABASE_NAME_ENV) -@fixture(scope="session") -def username() -> str: - return must_env(USER_NAME_ENV) - - -@fixture(scope="session") -def password() -> str: - return must_env(PASSWORD_ENV) - - @fixture(scope="session") def account_name() -> str: return must_env(ACCOUNT_NAME_ENV) @@ -94,10 +68,5 @@ def service_secret() -> str: @fixture -def service_auth(service_id, service_secret) -> ServiceAccount: - return ServiceAccount(service_id, service_secret) - - -@fixture(scope="session") -def password_auth(username, password) -> UsernamePassword: - return UsernamePassword(username, password) +def auth(service_id, service_secret) -> ClientCredentials: + return ClientCredentials(service_id, service_secret) diff --git a/tests/integration/dbapi/async/conftest.py b/tests/integration/dbapi/async/conftest.py index 76bae447748..3bfbbc7ff4d 100644 --- a/tests/integration/dbapi/async/conftest.py +++ b/tests/integration/dbapi/async/conftest.py @@ -1,75 +1,67 @@ -from pytest_asyncio import fixture as async_fixture +from pytest import fixture from firebolt.async_db import Connection, connect from firebolt.client.auth.base import Auth -@async_fixture -async def username_password_connection( - engine_url: str, +@fixture +async def connection( + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: async with await connect( - engine_url=engine_url, + engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection( - engine_url: str, - database_name: str, - password_auth: Auth, +@fixture +async def connection_no_db( + engine_name: str, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: async with await connect( - engine_url=engine_url, - database=database_name, - auth=password_auth, + engine_name=engine_name, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection_engine_name( - engine_name: str, +@fixture +async def connection_system_engine( database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - async with await connect( - engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: yield connection -@async_fixture -async def connection_no_engine( - database_name: str, - password_auth: Auth, +@fixture +async def connection_system_engine_no_db( + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - async with await connect( - database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: diff --git a/tests/integration/dbapi/async/test_errors_async.py b/tests/integration/dbapi/async/test_errors_async.py index be19f63a9b7..67d64d70e00 100644 --- a/tests/integration/dbapi/async/test_errors_async.py +++ b/tests/integration/dbapi/async/test_errors_async.py @@ -1,13 +1,11 @@ -from httpx import ConnectError from pytest import mark, raises from firebolt.async_db import Connection, connect -from firebolt.client.auth import UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.utils.exception import ( AccountNotFoundError, - EngineNotRunningError, - FireboltDatabaseError, FireboltEngineError, + InterfaceError, OperationalError, ) @@ -15,7 +13,7 @@ async def test_invalid_account( database_name: str, engine_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid account error.""" @@ -23,8 +21,8 @@ async def test_invalid_account( with raises(AccountNotFoundError) as exc_info: async with await connect( database=database_name, - engine_name=engine_name, # Omit engine_url to force account_id lookup. - auth=password_auth, + engine_name=engine_name, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -35,29 +33,10 @@ async def test_invalid_account( ), "Invalid account error message." -async def test_engine_url_not_exists( - engine_url: str, - database_name: str, - password_auth: UsernamePassword, - account_name: str, - api_endpoint: str, -) -> None: - """Connection properly reacts to invalid engine url error.""" - async with await connect( - engine_url=engine_url + "_", - database=database_name, - auth=password_auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) as connection: - with raises(ConnectError): - await connection.cursor().execute("show tables") - - async def test_engine_name_not_exists( engine_name: str, database_name: str, - password_auth: UsernamePassword, + password_auth: ClientCredentials, account_name: str, api_endpoint: str, ) -> None: @@ -66,7 +45,7 @@ async def test_engine_name_not_exists( async with await connect( engine_name=engine_name + "_________", database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -74,18 +53,18 @@ async def test_engine_name_not_exists( async def test_engine_stopped( - stopped_engine_url: str, + stopped_engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" - with raises(EngineNotRunningError): + with raises(InterfaceError): async with await connect( - engine_url=stopped_engine_url, + engine_name=stopped_engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -94,25 +73,27 @@ async def test_engine_stopped( @mark.skip(reason="Behaviour is different in prod vs dev") async def test_database_not_exists( - engine_url: str, + engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, + account_name: str, ) -> None: """Connection properly reacts to invalid database error.""" new_db_name = database_name + "_" - async with await connect( - engine_url=engine_url, - database=new_db_name, - auth=password_auth, - api_endpoint=api_endpoint, - ) as connection: - with raises(FireboltDatabaseError) as exc_info: + with raises(InterfaceError) as exc_info: + async with await connect( + engine_name=engine_name, + database=new_db_name, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: await connection.cursor().execute("show tables") - assert ( - str(exc_info.value) == f"Database {new_db_name} does not exist" - ), "Invalid database name error message." + assert ( + str(exc_info.value) == f"Database {new_db_name} does not exist" + ), "Invalid database name error message." async def test_sql_error(connection: Connection) -> None: diff --git a/tests/integration/dbapi/sync/test_errors.py b/tests/integration/dbapi/sync/test_errors.py index dd5017c0202..193dbc617b1 100644 --- a/tests/integration/dbapi/sync/test_errors.py +++ b/tests/integration/dbapi/sync/test_errors.py @@ -1,7 +1,7 @@ from httpx import ConnectError from pytest import mark, raises -from firebolt.client.auth import UsernamePassword +from firebolt.client.auth import ClientCredentials from firebolt.db import Connection, connect from firebolt.utils.exception import ( AccountNotFoundError, @@ -15,7 +15,7 @@ def test_invalid_account( database_name: str, engine_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid account error.""" @@ -24,7 +24,7 @@ def test_invalid_account( with connect( database=database_name, engine_name=engine_name, # Omit engine_url to force account_id lookup. - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, ) as connection: @@ -38,14 +38,14 @@ def test_invalid_account( def test_engine_url_not_exists( engine_url: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine url error.""" with connect( engine_url=engine_url + "_", database=database_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: with raises(ConnectError): @@ -55,7 +55,7 @@ def test_engine_url_not_exists( def test_engine_name_not_exists( engine_name: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" @@ -63,7 +63,7 @@ def test_engine_name_not_exists( with connect( engine_name=engine_name + "_________", database=database_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("show tables") @@ -72,7 +72,7 @@ def test_engine_name_not_exists( def test_engine_stopped( stopped_engine_url: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to engine not running error.""" @@ -80,7 +80,7 @@ def test_engine_stopped( with connect( engine_url=stopped_engine_url, database=database_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: connection.cursor().execute("show tables") @@ -90,7 +90,7 @@ def test_engine_stopped( def test_database_not_exists( engine_url: str, database_name: str, - password_auth: UsernamePassword, + auth: ClientCredentials, api_endpoint: str, ) -> None: """Connection properly reacts to invalid database error.""" @@ -98,7 +98,7 @@ def test_database_not_exists( with connect( engine_url=engine_url, database=new_db_name, - auth=password_auth, + auth=auth, api_endpoint=api_endpoint, ) as connection: with raises(FireboltDatabaseError) as exc_info: From e303f8c65453fca412e84ba18c8003d85858ad43 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 10 May 2023 16:14:33 +0300 Subject: [PATCH 16/28] WIP update audience --- src/firebolt/client/auth/client_credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 0655a3a025c..027d3d11942 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -88,7 +88,7 @@ def _make_auth_request(self) -> AuthRequest: "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": "client_credentials", - "audience": "https://dev-firebolt-v2.us.auth0.com/api/v2/" + "audience": "https://api.firebolt.io" # "audience": self._audience, }, ) From e97b9ea43094e834ea74d033c4415157b97dc7d3 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 25 May 2023 14:54:07 +0300 Subject: [PATCH 17/28] WIP main rebase --- src/firebolt/async_db/connection.py | 186 ++++++++++++++-------------- src/firebolt/async_db/cursor.py | 5 +- src/firebolt/async_db/util.py | 2 +- src/firebolt/common/settings.py | 44 +------ src/firebolt/db/connection.py | 32 ++++- src/firebolt/model/engine.py | 3 +- tests/unit/db/test_connection.py | 2 +- 7 files changed, 128 insertions(+), 146 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index b4a67c78cbb..a7249be909a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -3,7 +3,7 @@ import logging import socket from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from httpcore.backends.auto import AutoBackend from httpcore.backends.base import AsyncNetworkStream @@ -11,7 +11,6 @@ from firebolt.async_db.cursor import Cursor from firebolt.async_db.util import ( - DEFAULT_TIMEOUT_SECONDS, _get_engine_url_status_db, _get_system_engine_url, ) @@ -31,98 +30,9 @@ from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema - logger = logging.getLogger(__name__) -async def connect( - auth: Optional[Auth] = None, - account_name: Optional[str] = None, - database: Optional[str] = None, - engine_name: Optional[str] = None, - api_endpoint: str = DEFAULT_API_URL, - additional_parameters: Dict[str, Any] = {}, -) -> Connection: - """Connect to Firebolt database. - - Args: - `auth` (Auth) Authentication object. - `database` (str): Name of the database to connect - `engine_name` (Optional[str]): Name of the engine to connect to - `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used - `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `additional_parameters` (Optional[Dict]): Dictionary of less widely-used - arguments for connection - """ - # These parameters are optional in function signature - # but are required to connect. - # PEP 249 recommends making them kwargs. - for name, value in (("auth", auth), ("account_name", account_name)): - if not value: - raise ConfigurationError(f"{name} is required to connect.") - - # Type checks - assert auth is not None - assert account_name is not None - - api_endpoint = fix_url_schema(api_endpoint) - - if not engine_name and not database: - # Return system engine connection - return connection_class( - system_engine_url, None, auth, api_endpoint, additional_parameters - ) - - if not engine_name: - # Return system engine connection - return connection_class( - system_engine_url, - database, - auth, - api_endpoint, - None, - additional_parameters, - ) - - else: - # Don't use context manager since this will be stored - # and used in a resulting connection - system_engine_connection = Connection( - system_engine_url, - database, - auth, - api_endpoint, - None, - additional_parameters, - ) - engine_url, status, attached_db = await _get_engine_url_status_db( - system_engine_connection, engine_name - ) - if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") - - if database is not None and database != attached_db: - raise InterfaceError( - f"Engine {engine_name} is not attached to {database}, " - f"but to {attached_db}" - ) - elif database is None: - database = attached_db - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, - database, - auth, - api_endpoint, - system_engine_connection, - additional_parameters, - ) - - class OverriddenHttpBackend(AutoBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -164,6 +74,7 @@ async def connect_tcp( # type: ignore [override] ) return stream + class Connection(BaseConnection): """ Firebolt asynchronous database connection class. Implements `PEP 249`_. @@ -201,9 +112,9 @@ class Connection(BaseConnection): def __init__( self, engine_url: str, - database: str, + database: Optional[str], auth: Auth, - api_endpoint: str = DEFAULT_API_URL, + api_endpoint: str, system_engine_connection: Optional["Connection"], additional_parameters: Dict[str, Any] = {}, ): @@ -230,7 +141,7 @@ def __init__( @property def _is_system(self) -> bool: """`True` if connection is a system engine connection; `False` otherwise.""" - return self._system_engine_connection is not None + return self._system_engine_connection is not None def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -266,3 +177,90 @@ async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: await self.aclose() + + +async def connect( + auth: Optional[Auth] = None, + account_name: Optional[str] = None, + database: Optional[str] = None, + engine_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + additional_parameters: Dict[str, Any] = {}, +) -> Connection: + """Connect to Firebolt database. + + Args: + `auth` (Auth) Authentication object. + `database` (str): Name of the database to connect + `engine_name` (Optional[str]): Name of the engine to connect to + `account_name` (Optional[str]): For customers with multiple accounts; + if none, default is used + `api_endpoint` (str): Firebolt API endpoint. Used for authentication + `additional_parameters` (Optional[Dict]): Dictionary of less widely-used + arguments for connection + """ + # These parameters are optional in function signature + # but are required to connect. + # PEP 249 recommends making them kwargs. + for name, value in (("auth", auth), ("account_name", account_name)): + if not value: + raise ConfigurationError(f"{name} is required to connect.") + + # Type checks + assert auth is not None + assert account_name is not None + + api_endpoint = fix_url_schema(api_endpoint) + + system_engine_url = fix_url_schema( + await _get_system_engine_url(auth, account_name, api_endpoint) + ) + + if not engine_name: + # Return system engine connection + return Connection( + system_engine_url, + database, + auth, + api_endpoint, + None, + additional_parameters, + ) + + else: + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( + system_engine_url, + None, + auth, + api_endpoint, + None, + additional_parameters, + ) + engine_url, status, attached_db = await _get_engine_url_status_db( + system_engine_connection, engine_name + ) + + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") + + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + api_endpoint, + system_engine_connection, + additional_parameters, + ) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 7a24422cb90..1b53c428c7b 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -20,7 +20,7 @@ from httpx import Response, codes from tricycle import RWLock -from firebolt.async_db.util import is_db_available, is_engine_running +from firebolt.async_db.util import is_engine_running from firebolt.common._types import ( ColType, Column, @@ -29,9 +29,6 @@ SetParameter, split_format_sql, ) - -from firebolt.async_db.util import is_engine_running -from firebolt.client import AsyncClient from firebolt.common.base_cursor import ( BaseCursor, CursorState, diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 3b337e9b217..b7b273987f5 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -6,13 +6,13 @@ from firebolt.client import AsyncClient from firebolt.client.auth import Auth +from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import InterfaceError from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.async_db.connection import Connection -DEFAULT_TIMEOUT_SECONDS = 60 ENGINE_STATUS_RUNNING = "Running" diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index 557b0ddcc2d..c271d352f84 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -12,17 +12,6 @@ KEEPIDLE_RATE: int = 60 # seconds DEFAULT_TIMEOUT_SECONDS: int = 60 -AUTH_CREDENTIALS_DEPRECATION_MESSAGE = """ Passing connection credentials directly in Settings is deprecated. - Use Auth object instead. - Examples: - >>> from firebolt.client.auth import UsernamePassword - >>> ... - >>> settings = Settings(auth=UsernamePassword(username, password), ...) - or - >>> from firebolt.client.auth import Token - >>> ... - >>> settings = Settings(auth=Token(access_token), ...)""" - CLIENT_ID_ENV = "FIREBOLT_CLIENT_ID" CLIENT_SECRET_ENV = "FIREBOLT_CLIENT_SECRET" ACCOUNT_ENV = "FIREBOLT_ACCOUNT" @@ -37,12 +26,12 @@ def inner() -> Any: return inner -def auth_from_env() -> Optional[Auth]: +def auth_from_env() -> Auth: client_id = os.environ.get(CLIENT_ID_ENV, None) client_secret = os.environ.get(CLIENT_SECRET_ENV, None) if client_id and client_secret: return ClientCredentials(client_id, client_secret) - return None + raise ValueError("Auth not provided") @dataclass @@ -61,35 +50,8 @@ class Settings: default_region (str): Default region for provisioning """ - auth: Optional[Auth] = field(default_factory=auth_from_env) + auth: Auth = field(default_factory=auth_from_env) account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV)) server: str = field(default_factory=from_env(SERVER_ENV)) default_region: str = field(default_factory=from_env(DEFAULT_REGION_ENV)) - - def __post_init__(self) -> None: - """Validate that either creds or token is provided. - - Args: - values (dict): settings initial values - - Returns: - dict: Validated settings values - - Raises: - ValueError: Either both or none of credentials and token are provided - """ - - params_present = ( - self.user is not None or self.password is not None, - self.access_token is not None, - self.auth is not None, - ) - if sum(params_present) == 0: - raise ValueError( - "Provide at least one of auth, user/password or access_token." - ) - if sum(params_present) > 1: - raise ValueError("Provide only one of auth, user/password or access_token") - if any((self.user, self.password, self.access_token)): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 3a29d7852e5..e1f59678b09 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -13,15 +13,13 @@ from readerwriterlock.rwlock import RWLockWrite from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Auth, _get_auth +from firebolt.client.auth import Auth, ClientCredentials from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( - AUTH_CREDENTIALS_DEPRECATION_MESSAGE, DEFAULT_TIMEOUT_SECONDS, KEEPALIVE_FLAG, KEEPIDLE_RATE, ) -from firebolt.common.util import validate_engine_name_and_url from firebolt.db.cursor import Cursor from firebolt.utils.exception import ( ConfigurationError, @@ -40,6 +38,33 @@ logger = logging.getLogger(__name__) +def _get_auth( + username: Optional[str], + password: Optional[str], + access_token: Optional[str], + use_token_cache: bool, +) -> Auth: + """Create `Auth` class based on provided credentials. + + If `access_token` is provided, it's used for `Auth` creation. + Otherwise, username/password are used. + + Returns: + Auth: `auth object` + + Raises: + `ConfigurationError`: Invalid combination of credentials provided + + """ + return ClientCredentials(username, password) # type: ignore + + +def validate_engine_name_and_url( + engine_name: Optional[str], engine_url: Optional[str] +) -> None: + pass + + def _resolve_engine_url( engine_name: str, auth: Auth, @@ -309,7 +334,6 @@ def connect( if not auth: if any([username, password, access_token, api_endpoint, use_token_cache]): - logger.warning(AUTH_CREDENTIALS_DEPRECATION_MESSAGE) auth = _get_auth(username, password, access_token, use_token_cache) else: raise ConfigurationError("No authentication provided.") diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index a484ed9d919..04c9ef378d9 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -194,7 +194,8 @@ def get_connection(self) -> Connection: """ return connect( database=self.database.name, # type: ignore # already checked by decorator - auth=self._service.client.auth, + # we always have firebolt Auth as a client auth + auth=self._service.client.auth, # type: ignore engine_name=self.name, account_name=self._service.settings.account_name, api_endpoint=self._service.settings.server, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index cca847a42e6..0bdfe6a8360 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -7,8 +7,8 @@ from pytest import mark, raises, warns from pytest_httpx import HTTPXMock -from firebolt.async_db._types import ColType from firebolt.client.auth import Auth, ClientCredentials +from firebolt.common._types import ColType from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( From 91cc609b79cd15a5d02bb976514774c58448635d Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 25 May 2023 15:23:24 +0300 Subject: [PATCH 18/28] fix async unit tests --- src/firebolt/async_db/connection.py | 4 ++-- src/firebolt/async_db/util.py | 2 +- src/firebolt/client/auth/client_credentials.py | 5 +---- src/firebolt/client/client.py | 2 -- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index a7249be909a..1c433fc9074 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -141,7 +141,7 @@ def __init__( @property def _is_system(self) -> bool: """`True` if connection is a system engine connection; `False` otherwise.""" - return self._system_engine_connection is not None + return self._system_engine_connection is None def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -232,7 +232,7 @@ async def connect( # and used in a resulting connection system_engine_connection = Connection( system_engine_url, - None, + database, auth, api_endpoint, None, diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index b7b273987f5..eefc72bf45a 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -49,7 +49,7 @@ async def _get_system_engine_url( api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), ) as client: - return "https://api.us-east-1.dev.firebolt.io/dynamic" + # return "https://api.us-east-1.dev.firebolt.io/dynamic" url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) response = await client.get(url=url) if response.status_code != codes.OK: diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index 027d3d11942..6370a2b49f2 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -31,7 +31,6 @@ class ClientCredentials(_RequestBasedAuth): "_expires", "_use_token_cache", "_user_agent", - "_audience", ) requires_response_body = True @@ -44,7 +43,6 @@ def __init__( ): self.client_id = client_id self.client_secret = client_secret - self._audience = "" super().__init__(use_token_cache) def copy(self) -> "ClientCredentials": @@ -88,8 +86,7 @@ def _make_auth_request(self) -> AuthRequest: "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": "client_credentials", - "audience": "https://api.firebolt.io" - # "audience": self._audience, + "audience": "https://api.firebolt.io", }, ) return response diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index fcc87be7f7c..2c52c155842 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -65,8 +65,6 @@ def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: """ if not (auth is None or isinstance(auth, Auth)): raise TypeError(f'Invalid "auth" argument: {auth!r}') - if hasattr(auth, "_audience"): - auth._audience = self._api_endpoint # type: ignore return auth def _merge_auth_request(self, request: Request) -> Request: From a551d2c36c463c952bab7fcfb4a4ec2d4595c772 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 25 May 2023 15:27:54 +0300 Subject: [PATCH 19/28] fix merge unit test issues --- tests/unit/client/conftest.py | 86 -------------------------------- tests/unit/client/test_client.py | 1 - 2 files changed, 87 deletions(-) delete mode 100644 tests/unit/client/conftest.py diff --git a/tests/unit/client/conftest.py b/tests/unit/client/conftest.py deleted file mode 100644 index 8fe56a0ec33..00000000000 --- a/tests/unit/client/conftest.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import typing - -import httpx -from httpx import Response -from pytest import fixture - -from firebolt.common.settings import Settings - - -@fixture -def test_token(access_token: str) -> str: - return access_token - - -@fixture -def test_token2() -> str: - return "test_token2" - - -@fixture -def test_username(settings: Settings) -> str: - return settings.user - - -@fixture -def test_password(settings: Settings) -> str: - return settings.password - - -@fixture -def mock_service_id() -> str: - return "mock_service_id" - - -@fixture -def mock_service_secret() -> str: - return "my_secret" - - -@fixture -def check_credentials_callback( - test_username: str, test_password: str, test_token: str -) -> typing.Callable: - def check_credentials( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request, "empty request" - body = json.loads(request.read()) - assert "username" in body, "Missing username" - assert body["username"] == test_username, "Invalid username" - assert "password" in body, "Missing password" - assert body["password"] == test_password, "Invalid password" - - return Response( - status_code=httpx.codes.OK, - json={"expires_in": 2**32, "access_token": test_token}, - ) - - return check_credentials - - -@fixture -def check_service_credentials_callback( - mock_service_id: str, mock_service_secret: str, test_token: str -) -> typing.Callable: - def check_credentials( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request, "empty request" - body = request.read().decode("utf-8") - assert "client_id" in body, "Missing id" - assert f"client_id={mock_service_id}" in body, "Invalid id" - assert "client_secret" in body, "Missing secret" - assert f"client_secret={mock_service_secret}" in body, "Invalid secret" - assert "grant_type" in body, "Missing grant_type" - assert "grant_type=client_credentials" in body, "Invalid grant_type" - - return Response( - status_code=httpx.codes.OK, - json={"expires_in": 2**32, "access_token": test_token}, - ) - - return check_credentials diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index f9374efab7f..f365ad4b781 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -63,7 +63,6 @@ def test_client_different_auths( ): """ Client properly handles such auth types: - - tuple(username, password) - Auth - None All other types should raise TypeError. From 358014363bc06cae4327311d2fe83d05167d8a67 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 30 May 2023 12:44:49 +0300 Subject: [PATCH 20/28] add account_id to system engine, minor fixes --- src/firebolt/async_db/connection.py | 34 ++++++++---------- src/firebolt/async_db/cursor.py | 8 +++-- src/firebolt/async_db/util.py | 8 ++--- src/firebolt/client/client.py | 53 +++++++++++++++-------------- src/firebolt/common/settings.py | 4 +-- src/firebolt/db/connection.py | 14 +++++--- src/firebolt/utils/urls.py | 5 +-- tests/unit/db_conftest.py | 2 +- 8 files changed, 67 insertions(+), 61 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 1c433fc9074..762ea63a10c 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -115,6 +115,7 @@ def __init__( database: Optional[str], auth: Auth, api_endpoint: str, + account_name: str, system_engine_connection: Optional["Connection"], additional_parameters: Dict[str, Any] = {}, ): @@ -128,6 +129,7 @@ def __init__( user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) self._client = AsyncClient( + account_name=account_name, auth=auth, base_url=engine_url, api_endpoint=api_endpoint, @@ -215,29 +217,22 @@ async def connect( system_engine_url = fix_url_schema( await _get_system_engine_url(auth, account_name, api_endpoint) ) + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( + system_engine_url, + database, + auth, + api_endpoint, + account_name, + None, + additional_parameters, + ) if not engine_name: - # Return system engine connection - return Connection( - system_engine_url, - database, - auth, - api_endpoint, - None, - additional_parameters, - ) + return system_engine_connection else: - # Don't use context manager since this will be stored - # and used in a resulting connection - system_engine_connection = Connection( - system_engine_url, - database, - auth, - api_endpoint, - None, - additional_parameters, - ) engine_url, status, attached_db = await _get_engine_url_status_db( system_engine_connection, engine_name ) @@ -261,6 +256,7 @@ async def connect( database, auth, api_endpoint, + account_name, system_engine_connection, additional_parameters, ) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 1b53c428c7b..4c598cc06f0 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -21,6 +21,7 @@ from tricycle import RWLock from firebolt.async_db.util import is_engine_running +from firebolt.client import AsyncClient from firebolt.common._types import ( ColType, Column, @@ -47,7 +48,6 @@ if TYPE_CHECKING: from firebolt.async_db.connection import Connection -from httpx import AsyncClient as AsyncHttpxClient logger = logging.getLogger(__name__) @@ -102,7 +102,7 @@ class Cursor(BaseCursor): def __init__( self, *args: Any, - client: AsyncHttpxClient, + client: AsyncClient, connection: Connection, **kwargs: Any, ) -> None: @@ -156,8 +156,10 @@ async def _api_request( parameters = {**(self._set_parameters or {}), **(parameters or {})} if self.connection.database: parameters["database"] = self.connection.database + if self.connection._is_system: + parameters["account_id"] = await self._client.account_id return await self._client.request( - url=f"/{path}", + url=f"/{path}" if path else "", method="POST", params=parameters, content=query, diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index eefc72bf45a..76a97201bab 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -8,7 +8,7 @@ from firebolt.client.auth import Auth from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS from firebolt.utils.exception import InterfaceError -from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME +from firebolt.utils.urls import DYNAMIC_QUERY, GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.async_db.connection import Connection @@ -55,9 +55,9 @@ async def _get_system_engine_url( if response.status_code != codes.OK: raise InterfaceError( f"Unable to retrieve system engine endpoint {url}: " - f"{response.status} {response.content}" + f"{response.status_code} {response.content}" ) - return response.json()["gatewayHost"] + return response.json()["engineUrl"] + DYNAMIC_QUERY async def _get_engine_url_status_db( @@ -66,7 +66,7 @@ async def _get_engine_url_status_db( cursor = system_engine.cursor() await cursor.execute( """ - SELECT engine_url, attached_to, status FROM information_schema.engines + SELECT url, attached_to, status FROM information_schema.engines WHERE engine_name=? """, [engine_name], diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index 2c52c155842..7e4d0d71316 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -12,7 +12,7 @@ from firebolt.client.auth.base import AuthRequest from firebolt.client.constants import DEFAULT_API_URL from firebolt.utils.exception import AccountNotFoundError -from firebolt.utils.urls import ACCOUNT_BY_NAME_URL, ACCOUNT_URL +from firebolt.utils.urls import ACCOUNT_BY_NAME_URL from firebolt.utils.util import ( cached_property, fix_url_schema, @@ -39,9 +39,9 @@ class FireboltClientMixin(FireboltClientMixinBase): def __init__( self, *args: Any, - account_name: Optional[str] = None, + account_name: str, + auth: Auth, api_endpoint: str = DEFAULT_API_URL, - auth: Optional[Auth] = None, **kwargs: Any, ): self.account_name = account_name @@ -49,7 +49,7 @@ def __init__( self._auth_endpoint = get_auth_endpoint(self._api_endpoint) super().__init__(*args, auth=auth, **kwargs) - def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: + def _build_auth(self, auth: Optional[AuthTypes]) -> Auth: """Create Auth object based on auth provided. Overrides ``httpx.Client._build_auth`` @@ -65,6 +65,7 @@ def _build_auth(self, auth: Optional[AuthTypes]) -> Optional[Auth]: """ if not (auth is None or isinstance(auth, Auth)): raise TypeError(f'Invalid "auth" argument: {auth!r}') + assert auth is not None # type check return auth def _merge_auth_request(self, request: Request) -> Request: @@ -73,6 +74,10 @@ def _merge_auth_request(self, request: Request) -> Request: request._prepare(dict(request.headers)) return request + def _enforce_trailing_slash(self, url: URL) -> URL: + """Don't automatically append trailing slach to a base url""" + return url + class Client(FireboltClientMixin, HttpxClient): """An HTTP client, based on httpx.Client. @@ -95,18 +100,16 @@ def account_id(self) -> str: Raises: AccountNotFoundError: No account found with provided name """ - if self.account_name: - response = self.get( - url=ACCOUNT_BY_NAME_URL, params={"account_name": self.account_name} + response = self.get( + url=self._api_endpoint.copy_with( + path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) ) - if response.status_code == HttpxCodes.NOT_FOUND: - raise AccountNotFoundError(self.account_name) - # process all other status codes - response.raise_for_status() - return response.json()["account_id"] - - # account_name isn't set, use the default account. - return self.get(url=ACCOUNT_URL).json()["account"]["id"] + ) + if response.status_code == HttpxCodes.NOT_FOUND: + raise AccountNotFoundError(self.account_name) + # process all other status codes + response.raise_for_status() + return response.json()["id"] def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any @@ -137,18 +140,16 @@ async def account_id(self) -> str: Raises: AccountNotFoundError: No account found with provided name """ - if self.account_name: - response = await self.get( - url=ACCOUNT_BY_NAME_URL, params={"account_name": self.account_name} + response = await self.get( + url=self._api_endpoint.copy_with( + path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) ) - if response.status_code == HttpxCodes.NOT_FOUND: - raise AccountNotFoundError(self.account_name) - # process all other status codes - response.raise_for_status() - return response.json()["account_id"] - - # account_name isn't set; use the default account. - return (await self.get(url=ACCOUNT_URL)).json()["account"]["id"] + ) + if response.status_code == HttpxCodes.NOT_FOUND: + raise AccountNotFoundError(self.account_name) + # process all other status codes + response.raise_for_status() + return response.json()["id"] async def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any diff --git a/src/firebolt/common/settings.py b/src/firebolt/common/settings.py index c271d352f84..2c0c3ead742 100644 --- a/src/firebolt/common/settings.py +++ b/src/firebolt/common/settings.py @@ -1,7 +1,7 @@ import logging import os from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any, Callable from firebolt.client.auth import Auth, ClientCredentials @@ -52,6 +52,6 @@ class Settings: auth: Auth = field(default_factory=auth_from_env) - account_name: Optional[str] = field(default_factory=from_env(ACCOUNT_ENV)) + account_name: str = field(default_factory=from_env(ACCOUNT_ENV)) server: str = field(default_factory=from_env(SERVER_ENV)) default_region: str = field(default_factory=from_env(DEFAULT_REGION_ENV)) diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index e1f59678b09..7c16331c4e2 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -69,7 +69,7 @@ def _resolve_engine_url( engine_name: str, auth: Auth, api_endpoint: str, - account_name: Optional[str] = None, + account_name: str, ) -> str: with Client( auth=auth, @@ -113,7 +113,7 @@ def _get_database_default_engine_url( database: str, auth: Auth, api_endpoint: str, - account_name: Optional[str] = None, + account_name: str, ) -> str: with Client( auth=auth, @@ -215,6 +215,7 @@ def __init__( engine_url: str, database: str, auth: Auth, + account_name: str, api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): @@ -228,6 +229,7 @@ def __init__( user_drivers = additional_parameters.get("user_drivers", []) user_clients = additional_parameters.get("user_clients", []) self._client = Client( + account_name=account_name, auth=auth, base_url=engine_url, api_endpoint=api_endpoint, @@ -290,13 +292,13 @@ def __del__(self) -> None: def connect( database: str = None, + account_name: str = None, username: Optional[str] = None, password: Optional[str] = None, access_token: Optional[str] = None, auth: Auth = None, engine_name: Optional[str] = None, engine_url: Optional[str] = None, - account_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, use_token_cache: bool = True, additional_parameters: Dict[str, Any] = {}, @@ -329,6 +331,8 @@ def connect( # PEP 249 recommends making them kwargs. if not database: raise ConfigurationError("database name is required to connect.") + if not account_name: + raise ConfigurationError("account_name is required to connect.") validate_engine_name_and_url(engine_name, engine_url) @@ -373,4 +377,6 @@ def connect( assert engine_url is not None engine_url = fix_url_schema(engine_url) - return Connection(engine_url, database, auth, api_endpoint, additional_parameters) + return Connection( + engine_url, database, auth, account_name, api_endpoint, additional_parameters + ) diff --git a/src/firebolt/utils/urls.py b/src/firebolt/utils/urls.py index 4987b240576..f68f1943946 100644 --- a/src/firebolt/utils/urls.py +++ b/src/firebolt/utils/urls.py @@ -6,7 +6,7 @@ ENGINES_BY_IDS_URL = "/core/v1/engines:getByIds" ACCOUNT_URL = "/iam/v2/account" -ACCOUNT_BY_NAME_URL = "/iam/v2/accounts:getIdByName" +ACCOUNT_BY_NAME_URL = "/web/v3/account/{account_name}/resolve" ACCOUNT_ENGINE_URL = "/core/v1/accounts/{account_id}/engines/{engine_id}" ACCOUNT_ENGINE_START_URL = ACCOUNT_ENGINE_URL + ":start" @@ -29,4 +29,5 @@ PROVIDERS_URL = "/compute/v1/providers" REGIONS_URL = "/compute/v1/regions" -GATEWAY_HOST_BY_ACCOUNT_NAME = "/v3/getGatewayHostByAccountName/{account_name}" +GATEWAY_HOST_BY_ACCOUNT_NAME = "/web/v3/account/{account_name}/engineUrl" +DYNAMIC_QUERY = "/dynamic/query" diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index b27eca33a58..6ed5103592e 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -463,7 +463,7 @@ def inner( return Response( status_code=codes.OK, - json={"gatewayHost": system_engine_url}, + json={"engineUrl": system_engine_url}, ) return inner From 92f0b30e93b3213e49886971ca2781eeb479fd6b Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 31 May 2023 11:34:19 +0300 Subject: [PATCH 21/28] fix async tests --- setup.cfg | 1 + src/firebolt/async_db/util.py | 10 +- src/firebolt/client/client.py | 3 +- .../dbapi/async/test_errors_async.py | 2 +- .../dbapi/async/test_queries_async.py | 91 +++++++++++++------ tests/integration/dbapi/conftest.py | 43 +++++++-- tests/unit/async_db/test_connection.py | 7 ++ tests/unit/client/test_client.py | 16 ++-- tests/unit/client/test_client_async.py | 19 ++-- tests/unit/conftest.py | 22 ++--- tests/unit/db_conftest.py | 16 +++- tests/unit/service/test_resource_manager.py | 3 + tests/unit/util.py | 11 ++- 13 files changed, 167 insertions(+), 77 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1f372f432bd..17fdbd0a149 100755 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ install_requires = aiorwlock==1.1.0 appdirs>=1.4.4 appdirs-stubs>=0.1.0 + async-property>=0.2.1 cryptography>=3.4.0 httpx[http2]==0.24.0 pydantic[dotenv]>=1.8.2 diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 76a97201bab..7cfa07b73c4 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -7,7 +7,11 @@ from firebolt.client import AsyncClient from firebolt.client.auth import Auth from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS -from firebolt.utils.exception import InterfaceError +from firebolt.utils.exception import ( + AccountNotFoundError, + FireboltEngineError, + InterfaceError, +) from firebolt.utils.urls import DYNAMIC_QUERY, GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: @@ -52,6 +56,8 @@ async def _get_system_engine_url( # return "https://api.us-east-1.dev.firebolt.io/dynamic" url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) response = await client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundError(account_name) if response.status_code != codes.OK: raise InterfaceError( f"Unable to retrieve system engine endpoint {url}: " @@ -73,6 +79,6 @@ async def _get_engine_url_status_db( ) row = await cursor.fetchone() if row is None: - raise InterfaceError(f"Engine with name {engine_name} doesn't exist") + raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") engine_url, database, status = row return str(engine_url), str(status), str(database) # Mypy check diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index 7e4d0d71316..c341e82cc01 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -1,6 +1,7 @@ from typing import Any, Optional from anyio._core._eventloop import get_asynclib +from async_property import async_cached_property # type: ignore from httpx import URL from httpx import AsyncClient as HttpxAsyncClient from httpx import Client as HttpxClient @@ -127,7 +128,7 @@ class AsyncClient(FireboltClientMixin, HttpxAsyncClient): FireboltAuth instance. """ - @cached_property + @async_cached_property async def account_id(self) -> str: """User account id. diff --git a/tests/integration/dbapi/async/test_errors_async.py b/tests/integration/dbapi/async/test_errors_async.py index 67d64d70e00..bd5e6f40842 100644 --- a/tests/integration/dbapi/async/test_errors_async.py +++ b/tests/integration/dbapi/async/test_errors_async.py @@ -36,7 +36,7 @@ async def test_invalid_account( async def test_engine_name_not_exists( engine_name: str, database_name: str, - password_auth: ClientCredentials, + auth: ClientCredentials, account_name: str, api_endpoint: str, ) -> None: diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index d671ae6f029..0eb5cc26512 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -98,8 +98,8 @@ async def status_loop( ), f"Failed {query}. Got {status} rather than {final_status}." -async def test_connect_engine_name( - connection_engine_name: Connection, +async def test_connect_no_db( + connection_no_db: Connection, all_types_query: str, all_types_query_description: List[Column], all_types_query_response: List[ColType], @@ -107,24 +107,7 @@ async def test_connect_engine_name( ) -> None: """Connecting with engine name is handled properly.""" await test_select( - connection_engine_name, - all_types_query, - all_types_query_description, - all_types_query_response, - timezone_name, - ) - - -async def test_connect_no_engine( - connection_no_engine: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - await test_select( - connection_no_engine, + connection_no_db, all_types_query, all_types_query_description, all_types_query_response, @@ -198,20 +181,18 @@ async def test_long_query( async def test_drop_create(connection: Connection) -> None: """Create and drop table/index queries are handled properly.""" - async def test_query(c: Cursor, query: str) -> None: + async def test_query(c: Cursor, query: str, empty_response=True) -> None: await c.execute(query) assert c.description == None - assert c.rowcount == -1 + assert c.rowcount == (-1 if empty_response else 0) """Create table query is handled properly""" with connection.cursor() as c: # Cleanup - await c.execute("DROP JOIN INDEX IF EXISTS test_drop_create_async_db_join_idx") - await c.execute( - "DROP AGGREGATING INDEX IF EXISTS test_drop_create_async_db_agg_idx" - ) - await c.execute("DROP TABLE IF EXISTS test_drop_create_async_tb") - await c.execute("DROP TABLE IF EXISTS test_drop_create_async_tb_dim") + await c.execute("DROP JOIN INDEX IF EXISTS test_db_join_idx") + await c.execute("DROP AGGREGATING INDEX IF EXISTS test_db_agg_idx") + await c.execute("DROP TABLE IF EXISTS test_drop_create_async") + await c.execute("DROP TABLE IF EXISTS test_drop_create_async_dim") # Fact table await test_query( @@ -239,6 +220,7 @@ async def test_query(c: Cursor, query: str) -> None: c, "CREATE AGGREGATING INDEX test_db_agg_idx ON " "test_drop_create_async(id, sum(f), count(dt))", + empty_response=False, ) # Drop join index @@ -506,3 +488,56 @@ async def test_bytea_roundtrip( assert ( bytes_data.decode("utf-8") == data ), "Invalid bytea data returned after roundtrip" + + +async def test_system_engine( + connection_system_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + with connection_system_engine.cursor() as c: + assert await c.execute(all_types_query) == 1, "Invalid row count returned" + assert c.rowcount == 1, "Invalid rowcount value" + data = await c.fetchall() + assert len(data) == c.rowcount, "Invalid data length" + assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") + assert c.description == all_types_query_description, "Invalid description value" + assert len(data[0]) == len(c.description), "Invalid description length" + assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" + + # Different fetch types + await c.execute(all_types_query) + assert ( + await c.fetchone() == all_types_query_system_engine_response[0] + ), "Invalid fetchone data" + assert await c.fetchone() is None, "Redundant data returned by fetchone" + + await c.execute(all_types_query) + assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" + data = await c.fetchmany() + assert len(data) == 1, "Invalid data size returned by fetchmany" + assert_deep_eq( + data, + all_types_query_system_engine_response, + "Invalid data returned by fetchmany", + ) + + +async def test_system_engine_no_db( + connection_system_engine_no_db: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + await test_system_engine( + connection_system_engine_no_db, + all_types_query, + all_types_query_description, + all_types_query_system_engine_response, + timezone_name, + ) diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index b61bedf4077..7495af88b94 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -66,16 +66,14 @@ def all_types_query() -> str: "1.2345678901234 as float64, " "'text' as \"string\", " "CAST('2021-03-28' AS DATE) as \"date\", " - "CAST('1860-03-04' AS DATE_EXT) as \"date32\"," "pgdate '0001-01-01' as \"pgdate\", " "CAST('2019-07-31 01:01:01' AS DATETIME) as \"datetime\", " - "CAST('2019-07-31 01:01:01.1234' AS TIMESTAMP_EXT(4)) as \"datetime64\", " "CAST('1111-01-05 17:04:42.123456' as timestampntz) as \"timestampntz\", " "'1111-01-05 17:04:42.123456'::timestamptz as \"timestamptz\", " 'true as "boolean", ' "[1,2,3,4] as \"array\", cast('1231232.123459999990457054844258706536' as " 'decimal(38,30)) as "decimal", ' - 'cast(null as int) as "nullable", ' + 'null as "nullable", ' "'abc123'::bytea as \"bytea\"" ) @@ -95,16 +93,14 @@ def all_types_query_description() -> List[Column]: Column("float64", float, None, None, None, None, None), Column("string", str, None, None, None, None, None), Column("date", date, None, None, None, None, None), - Column("date32", date, None, None, None, None, None), Column("pgdate", date, None, None, None, None, None), Column("datetime", datetime, None, None, None, None, None), - Column("datetime64", datetime, None, None, None, None, None), Column("timestampntz", datetime, None, None, None, None, None), Column("timestamptz", datetime, None, None, None, None, None), Column("boolean", bool, None, None, None, None, None), Column("array", ARRAY(int), None, None, None, None, None), Column("decimal", DECIMAL(38, 30), None, None, None, None, None), - Column("nullable", int, None, None, None, None, None), + Column("nullable", str, None, None, None, None, None), Column("bytea", bytes, None, None, None, None, None), ] @@ -122,13 +118,11 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]: 30000000000, -30000000000, 1.23, - 1.23456789012, + 1.2345678901234, "text", date(2021, 3, 28), - date(1860, 3, 4), date(1, 1, 1), datetime(2019, 7, 31, 1, 1, 1), - datetime(2019, 7, 31, 1, 1, 1, 123400), datetime(1111, 1, 5, 17, 4, 42, 123456), datetime( 1111, @@ -149,6 +143,37 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]: ] +@fixture +def all_types_query_system_engine_response( + timezone_offset_seconds: int, +) -> List[ColType]: + return [ + [ + 1, + -1, + 257, + -257, + 80000, + -80000, + 30000000000, + -30000000000, + 1.23, + 1.23456789012, + "text", + date(2021, 3, 28), + date(1, 1, 1), + datetime(2019, 7, 31, 1, 1, 1), + datetime(1111, 1, 5, 17, 4, 42, 123456), + datetime(1111, 1, 5, 17, 4, 42, 123456, tzinfo=timezone.utc), + True, + [1, 2, 3, 4], + Decimal("1231232.123459999990457054844258706536"), + None, + b"abc123", + ] + ] + + @fixture def timezone_name() -> str: return "Asia/Calcutta" diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 8d771c1e371..9b047305e80 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -90,11 +90,14 @@ async def test_connect_engine_name( get_system_engine_url: str, get_system_engine_callback: Callable, get_engine_url_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): """connect properly handles engine_name""" httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) mock_query() @@ -139,10 +142,13 @@ async def test_connect_database( system_engine_no_db_query_url: str, get_system_engine_url: str, get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) async with await connect( database=None, auth=auth, @@ -154,6 +160,7 @@ async def test_connect_database( httpx_mock.reset(True) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(query_callback, url=system_engine_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) async with await connect( database=db_name, diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index f365ad4b781..0db63da2dd2 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -18,13 +18,14 @@ def test_client_retry( httpx_mock: HTTPXMock, auth: Auth, + account_name: str, access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - with Client(auth=auth) as client: + with Client(account_name=account_name, auth=auth) as client: # auth get token httpx_mock.add_response( @@ -58,6 +59,7 @@ def test_client_different_auths( check_credentials_callback: Callable, check_token_callback: Callable, auth: Auth, + account_name: str, auth_server: str, server: str, ): @@ -75,14 +77,10 @@ def test_client_different_auths( httpx_mock.add_callback(check_token_callback, url="https://url") - Client(auth=auth, api_endpoint=server).get("https://url") - - # client accepts None auth, but authorization fails - with raises(AssertionError) as excinfo: - Client(auth=None).get("https://url") + Client(account_name=account_name, auth=auth, api_endpoint=server).get("https://url") with raises(TypeError) as excinfo: - Client(auth=lambda r: r).get("https://url") + Client(account_name=account_name, auth=lambda r: r).get("https://url") assert str(excinfo.value).startswith( 'Invalid "auth" argument' @@ -92,6 +90,7 @@ def test_client_different_auths( def test_client_account_id( httpx_mock: HTTPXMock, auth: Auth, + account_name: str, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -103,6 +102,7 @@ def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) with Client( + account_name=account_name, auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, @@ -114,6 +114,7 @@ def test_client_account_id( def test_refresh_with_hooks( fs: FakeFilesystem, httpx_mock: HTTPXMock, + account_name: str, client_id: str, client_secret: str, access_token: str, @@ -126,6 +127,7 @@ def test_refresh_with_hooks( tss.cache_token(access_token, 2**32) client = Client( + account_name=account_name, auth=ClientCredentials(client_id, client_secret), event_hooks={ "response": [raise_on_4xx_5xx], diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 9792c97351e..acad0c4a286 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -15,13 +15,14 @@ async def test_client_retry( httpx_mock: HTTPXMock, auth: Auth, + account_name: str, access_token: str, ): """ Client retries with new auth token if first attempt fails with unauthorized error. """ - async with AsyncClient(auth=auth) as client: + async with AsyncClient(account_name=account_name, auth=auth) as client: # auth get token httpx_mock.add_response( status_code=codes.OK, @@ -54,6 +55,7 @@ async def test_client_different_auths( check_credentials_callback: Callable, check_token_callback: Callable, auth: Auth, + account_name: str, auth_server: str, server: str, ): @@ -72,16 +74,15 @@ async def test_client_different_auths( httpx_mock.add_callback(check_token_callback, url="https://url") - async with AsyncClient(auth=auth, api_endpoint=server) as client: + async with AsyncClient( + account_name=account_name, auth=auth, api_endpoint=server + ) as client: await client.get("https://url") - # client accepts None auth, but authorization fails - with raises(AssertionError) as excinfo: - async with AsyncClient(auth=None, api_endpoint=server) as client: - await client.get("https://url") - with raises(TypeError) as excinfo: - async with AsyncClient(auth=lambda r: r, api_endpoint=server): + async with AsyncClient( + account_name=account_name, auth=lambda r: r, api_endpoint=server + ): await client.get("https://url") assert str(excinfo.value).startswith( @@ -92,6 +93,7 @@ async def test_client_different_auths( async def test_client_account_id( httpx_mock: HTTPXMock, auth: Auth, + account_name: str, account_id: str, account_id_url: Pattern, account_id_callback: Callable, @@ -103,6 +105,7 @@ async def test_client_account_id( httpx_mock.add_callback(auth_callback, url=auth_url) async with AsyncClient( + account_name=account_name, auth=auth, base_url=fix_url_schema(settings.server), api_endpoint=settings.server, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3657daf1115..9910471dc29 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -27,7 +27,6 @@ ACCOUNT_BY_NAME_URL, ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_ENGINE_URL, - ACCOUNT_URL, AUTH_SERVICE_ACCOUNT_URL, DATABASES_URL, ENGINES_URL, @@ -175,12 +174,12 @@ def db_description() -> str: @fixture -def account_id_url(settings: Settings) -> Pattern: - base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}?account_name=" - default_base = f"https://{settings.server}{ACCOUNT_URL}" +def account_id_url(settings: Settings, account_name: str) -> Pattern: + account_name_re = r"[^\\\\]*" + base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}" base = base.replace("/", "\\/").replace("?", "\\?") - default_base = default_base.replace("/", "\\/").replace("?", "\\?") - return compile(f"(?:{base}.*|{default_base})") + base = base.format(account_name=account_name_re) + return compile(base) @fixture @@ -192,14 +191,9 @@ def do_mock( request: Request, **kwargs, ) -> Response: - if "account_name" not in request.url.params: - return Response( - status_code=httpx.codes.OK, json={"account": {"id": account_id}} - ) - # In this case, an account_name *should* be specified. - if request.url.params["account_name"] != settings.account_name: - raise AccountNotFoundError(request.url.params["account_name"]) - return Response(status_code=httpx.codes.OK, json={"account_id": account_id}) + if request.url.path.split("/")[-2] != settings.account_name: + raise AccountNotFoundError(request.url.path.split("/")[-2]) + return Response(status_code=httpx.codes.OK, json={"id": account_id}) return do_mock diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 6ed5103592e..1ad3064d55e 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -354,9 +354,10 @@ def _get_engine_url_callback(server: str, db_name: str, status="Running") -> Cal def do_query(request: Request, **kwargs) -> Response: set_parameters = request.url.params assert ( - len(set_parameters) == 2 + len(set_parameters) == 3 and "output_format" in set_parameters and "database" in set_parameters + and "account_id" in set_parameters ) data = [[server, db_name, status]] query_response = { @@ -435,13 +436,15 @@ def system_engine_url() -> str: @fixture -def system_engine_query_url(system_engine_url: str, db_name: str) -> str: - return f"{system_engine_url}/?output_format=JSON_Compact&database={db_name}" +def system_engine_query_url( + system_engine_url: str, db_name: str, account_id: str +) -> str: + return f"{system_engine_url}/dynamic/query?output_format=JSON_Compact&database={db_name}&account_id={account_id}" @fixture -def system_engine_no_db_query_url(system_engine_url: str) -> str: - return f"{system_engine_url}/?output_format=JSON_Compact" +def system_engine_no_db_query_url(system_engine_url: str, account_id: str) -> str: + return f"{system_engine_url}/dynamic/query?output_format=JSON_Compact&account_id={account_id}" @fixture @@ -478,11 +481,14 @@ def mock_connection_flow( get_system_engine_callback: Callable, system_engine_query_url: str, get_engine_url_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ) -> Callable: def inner() -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) return inner diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index f3a66d17487..9ddfeacb468 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -43,6 +43,7 @@ def test_rm_token_cache( check_credentials_callback: Callable, settings: Settings, auth_url: str, + account_name: str, account_id_url: Pattern, account_id_callback: Callable, provider_callback: Callable, @@ -59,6 +60,7 @@ def test_rm_token_cache( with Patcher(): local_settings = Settings( + account_name=account_name, auth=ClientCredentials( settings.auth.client_id, settings.auth.client_secret, @@ -76,6 +78,7 @@ def test_rm_token_cache( # Do the same, but with use_token_cache=False with Patcher(): local_settings = Settings( + account_name=account_name, auth=ClientCredentials( settings.auth.client_id, settings.auth.client_secret, diff --git a/tests/unit/util.py b/tests/unit/util.py index 2f5528a9e6a..98b72304fc1 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -3,6 +3,7 @@ from httpx import Request, Response from firebolt.client import AsyncClient, Client +from firebolt.client.auth import Auth from firebolt.model import FireboltBaseModel @@ -15,7 +16,10 @@ def execute_generator_requests( ) -> None: request = next(requests) - with Client(api_endpoint=api_endpoint) as client: + with Client( + account_name="account", auth=Auth(), api_endpoint=api_endpoint + ) as client: + client._auth = None while True: response = client.send(request) try: @@ -30,7 +34,10 @@ async def async_execute_generator_requests( ) -> None: request = await requests.__anext__() - async with AsyncClient(api_endpoint=api_endpoint) as client: + async with AsyncClient( + account_name="account", auth=Auth(), api_endpoint=api_endpoint + ) as client: + client._auth = None while True: response = await client.send(request) try: From bf8fd5eab7000a2613724fa63b68776cc8da3f65 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 31 May 2023 15:27:12 +0300 Subject: [PATCH 22/28] fix db unit tests --- src/firebolt/async_db/connection.py | 73 +++---- src/firebolt/async_db/cursor.py | 67 ++----- src/firebolt/async_db/util.py | 47 +++-- src/firebolt/common/base_connection.py | 8 +- src/firebolt/common/base_cursor.py | 14 ++ src/firebolt/db/connection.py | 255 +++++++------------------ src/firebolt/db/cursor.py | 50 ++--- src/firebolt/db/util.py | 106 +++++++--- tests/unit/db/test_connection.py | 24 ++- tests/unit/db_conftest.py | 4 +- 10 files changed, 292 insertions(+), 356 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 762ea63a10c..0be7f850e1b 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -98,7 +98,6 @@ class Connection(BaseConnection): """ - client_class: type __slots__ = ( "_client", "_cursors", @@ -114,11 +113,12 @@ def __init__( engine_url: str, database: Optional[str], auth: Auth, - api_endpoint: str, account_name: str, system_engine_connection: Optional["Connection"], + api_endpoint: str, additional_parameters: Dict[str, Any] = {}, ): + super().__init__() self.api_endpoint = api_endpoint self.engine_url = engine_url self.database = database @@ -138,12 +138,6 @@ def __init__( headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) self._system_engine_connection = system_engine_connection - super().__init__() - - @property - def _is_system(self) -> bool: - """`True` if connection is a system engine connection; `False` otherwise.""" - return self._system_engine_connection is None def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -175,6 +169,9 @@ async def aclose(self) -> None: await self._client.aclose() self._is_closed = True + if self._system_engine_connection: + await self._system_engine_connection.aclose() + async def __aexit__( self, exc_type: type, exc_val: Exception, exc_tb: TracebackType ) -> None: @@ -189,10 +186,10 @@ async def connect( api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ) -> Connection: - """Connect to Firebolt database. + """Connect to Firebolt. Args: - `auth` (Auth) Authentication object. + `auth` (Auth) Authentication object `database` (str): Name of the database to connect `engine_name` (Optional[str]): Name of the engine to connect to `account_name` (Optional[str]): For customers with multiple accounts; @@ -223,9 +220,9 @@ async def connect( system_engine_url, database, auth, - api_endpoint, account_name, None, + api_endpoint, additional_parameters, ) @@ -233,30 +230,34 @@ async def connect( return system_engine_connection else: - engine_url, status, attached_db = await _get_engine_url_status_db( - system_engine_connection, engine_name - ) - - if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") + try: + engine_url, status, attached_db = await _get_engine_url_status_db( + system_engine_connection, engine_name + ) - if database is not None and database != attached_db: - raise InterfaceError( - f"Engine {engine_name} is not attached to {database}, " - f"but to {attached_db}" + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") + + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + account_name, + system_engine_connection, + api_endpoint, + additional_parameters, ) - elif database is None: - database = attached_db - - assert engine_url is not None - - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, - database, - auth, - api_endpoint, - account_name, - system_engine_connection, - additional_parameters, - ) + except: # noqa + await system_engine_connection.aclose() + raise diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 4c598cc06f0..850039ac451 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -8,7 +8,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Iterator, List, Optional, @@ -20,7 +19,7 @@ from httpx import Response, codes from tricycle import RWLock -from firebolt.async_db.util import is_engine_running +from firebolt.async_db.util import is_db_available, is_engine_running from firebolt.client import AsyncClient from firebolt.common._types import ( ColType, @@ -31,18 +30,20 @@ split_format_sql, ) from firebolt.common.base_cursor import ( + JSON_OUTPUT_FORMAT, BaseCursor, CursorState, QueryStatus, Statistics, + check_not_closed, + check_query_executed, ) from firebolt.utils.exception import ( AsyncExecutionUnavailableError, - CursorClosedError, EngineNotRunningError, + FireboltDatabaseError, OperationalError, ProgrammingError, - QueryNotRunError, ) if TYPE_CHECKING: @@ -52,36 +53,6 @@ logger = logging.getLogger(__name__) -JSON_OUTPUT_FORMAT = "JSON_Compact" - - -def check_not_closed(func: Callable) -> Callable: - """(Decorator) ensure cursor is not closed before calling method.""" - - @wraps(func) - def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: - if self.closed: - raise CursorClosedError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - -def check_query_executed(func: Callable) -> Callable: - """ - (Decorator) ensure that some query has been executed before - calling cursor method. - """ - - @wraps(func) - def inner(self: Cursor, *args: Any, **kwargs: Any) -> Any: - if self._state == CursorState.NONE: - raise QueryNotRunError(method_name=func.__name__) - return func(self, *args, **kwargs) - - return inner - - class Cursor(BaseCursor): """ Executes async queries to Firebolt Database. @@ -118,6 +89,12 @@ async def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: + if self.connection.database and not await is_db_available( + self.connection, self.connection.database + ): + raise FireboltDatabaseError( + f"Database {self.connection.database} does not exist" + ) raise ProgrammingError(resp.read().decode("utf-8")) if ( resp.status_code == codes.SERVICE_UNAVAILABLE @@ -346,24 +323,6 @@ async def executemany( def __aiter__(self) -> Cursor: return self - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - self.connection._remove_cursor(self) - - def __del__(self) -> None: - self.close() - - # Context manager support - @check_not_closed - def __enter__(self) -> Cursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - # TODO: figure out how to implement __aenter__ and __await__ @check_not_closed def __aenter__(self) -> Cursor: @@ -448,3 +407,7 @@ async def fetchall(self) -> List[List[ColType]]: async def nextset(self) -> None: async with self._async_query_lock.read_locked(): return super().nextset() + + @check_not_closed + def __enter__(self) -> Cursor: + return self diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 7cfa07b73c4..94cde980306 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -20,6 +20,27 @@ ENGINE_STATUS_RUNNING = "Running" +async def is_db_available(connection: Connection, database_name: str) -> bool: + """ + Verify that the database exists. + + Args: + connection (firebolt.async_db.connection.Connection) + database_name (str): Name of a database + """ + system_engine = connection._system_engine_connection or connection + with system_engine.cursor() as cursor: + return ( + await cursor.execute( + """ + SELECT 1 FROM information_schema.databases WHERE database_name=? + """, + [database_name], + ) + > 0 + ) + + async def is_engine_running(connection: Connection, engine_url: str) -> bool: """ Verify that the engine is running. @@ -69,16 +90,16 @@ async def _get_system_engine_url( async def _get_engine_url_status_db( system_engine: Connection, engine_name: str ) -> Tuple[str, str, str]: - cursor = system_engine.cursor() - await cursor.execute( - """ - SELECT url, attached_to, status FROM information_schema.engines - WHERE engine_name=? - """, - [engine_name], - ) - row = await cursor.fetchone() - if row is None: - raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") - engine_url, database, status = row - return str(engine_url), str(status), str(database) # Mypy check + with system_engine.cursor() as cursor: + await cursor.execute( + """ + SELECT url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], + ) + row = await cursor.fetchone() + if row is None: + raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + return str(engine_url), str(status), str(database) # Mypy check diff --git a/src/firebolt/common/base_connection.py b/src/firebolt/common/base_connection.py index 1d47374015c..f66993e52e1 100644 --- a/src/firebolt/common/base_connection.py +++ b/src/firebolt/common/base_connection.py @@ -1,5 +1,6 @@ from typing import Any, List +from firebolt.common.base_cursor import BaseCursor from firebolt.utils.exception import ConnectionClosedError @@ -8,13 +9,18 @@ def __init__(self) -> None: self._cursors: List[Any] = [] self._is_closed = False - def _remove_cursor(self, cursor: Any) -> None: + def _remove_cursor(self, cursor: BaseCursor) -> None: # This way it's atomic try: self._cursors.remove(cursor) except ValueError: pass + @property + def _is_system(self) -> bool: + """`True` if connection is a system engine connection; `False` otherwise.""" + return self._system_engine_connection is None # type: ignore + @property def closed(self) -> bool: """`True` if connection is closed; `False` otherwise.""" diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index eb288998116..b0925bfc001 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -3,6 +3,7 @@ import logging from enum import Enum from functools import wraps +from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from httpx import Response @@ -363,3 +364,16 @@ def setinputsizes(self, sizes: List[int]) -> None: @check_not_closed def setoutputsize(self, size: int, column: Optional[int] = None) -> None: """Set a column buffer size for fetches of large columns (does nothing).""" + + def close(self) -> None: + """Terminate an ongoing query (if any) and mark connection as closed.""" + self._state = CursorState.CLOSED + self.connection._remove_cursor(self) # type:ignore + + def __del__(self) -> None: + self.close() + + def __exit__( + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType + ) -> None: + self.close() diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 7c16331c4e2..6a62c255376 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -2,18 +2,17 @@ import logging import socket -from json import JSONDecodeError from types import TracebackType from typing import Any, Dict, List, Optional from warnings import warn from httpcore.backends.base import NetworkStream from httpcore.backends.sync import SyncBackend -from httpx import HTTPStatusError, HTTPTransport, RequestError, Timeout +from httpx import HTTPTransport, Timeout from readerwriterlock.rwlock import RWLockWrite from firebolt.client import DEFAULT_API_URL, Client -from firebolt.client.auth import Auth, ClientCredentials +from firebolt.client.auth import Auth from firebolt.common.base_connection import BaseConnection from firebolt.common.settings import ( DEFAULT_TIMEOUT_SECONDS, @@ -21,125 +20,18 @@ KEEPIDLE_RATE, ) from firebolt.db.cursor import Cursor +from firebolt.db.util import _get_engine_url_status_db, _get_system_engine_url from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, - FireboltEngineError, InterfaceError, ) -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, -) from firebolt.utils.usage_tracker import get_user_agent_header from firebolt.utils.util import fix_url_schema logger = logging.getLogger(__name__) -def _get_auth( - username: Optional[str], - password: Optional[str], - access_token: Optional[str], - use_token_cache: bool, -) -> Auth: - """Create `Auth` class based on provided credentials. - - If `access_token` is provided, it's used for `Auth` creation. - Otherwise, username/password are used. - - Returns: - Auth: `auth object` - - Raises: - `ConfigurationError`: Invalid combination of credentials provided - - """ - return ClientCredentials(username, password) # type: ignore - - -def validate_engine_name_and_url( - engine_name: Optional[str], engine_url: Optional[str] -) -> None: - pass - - -def _resolve_engine_url( - engine_name: str, - auth: Auth, - api_endpoint: str, - account_name: str, -) -> str: - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - account_id = client.account_id - url = ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=account_id) - try: - response = client.get( - url=url, - params={"engine_name": engine_name}, - ) - response.raise_for_status() - engine_id = response.json()["engine_id"]["engine_id"] - url = ACCOUNT_ENGINE_URL.format(account_id=account_id, engine_id=engine_id) - response = client.get(url=url) - response.raise_for_status() - return response.json()["engine"]["endpoint"] - except HTTPStatusError as e: - # Engine error would be 404. - if e.response.status_code != 404: - raise InterfaceError( - f"Error {e.__class__.__name__}: Unable to retrieve engine " - f"endpoint {url}." - ) - # Once this is point is reached we've already authenticated with - # the backend so it's safe to assume the cause of the error is - # missing engine. - raise FireboltEngineError(f"Firebolt engine {engine_name} does not exist.") - except (JSONDecodeError, RequestError, RuntimeError) as e: - raise InterfaceError( - f"Error {e.__class__.__name__}: " - f"Unable to retrieve engine endpoint {url}." - ) - - -def _get_database_default_engine_url( - database: str, - auth: Auth, - api_endpoint: str, - account_name: str, -) -> str: - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), - ) as client: - try: - account_id = client.account_id - response = client.get( - url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), - params={"database_name": database}, - ) - response.raise_for_status() - return response.json()["engine_url"] - except ( - JSONDecodeError, - RequestError, - RuntimeError, - HTTPStatusError, - KeyError, - ) as e: - raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") - - class OverriddenHttpBackend(SyncBackend): """ `OverriddenHttpBackend` is a short-term solution for the TCP @@ -199,7 +91,6 @@ class Connection(BaseConnection): are not implemented. """ - client_class: type __slots__ = ( "_client", "_cursors", @@ -208,17 +99,20 @@ class Connection(BaseConnection): "api_endpoint", "_is_closed", "_closing_lock", + "_system_engine_connection", ) def __init__( self, engine_url: str, - database: str, + database: Optional[str], auth: Auth, account_name: str, + system_engine_connection: Optional["Connection"], api_endpoint: str = DEFAULT_API_URL, additional_parameters: Dict[str, Any] = {}, ): + super().__init__() self.api_endpoint = api_endpoint self.engine_url = engine_url self.database = database @@ -237,10 +131,10 @@ def __init__( transport=transport, headers={"User-Agent": get_user_agent_header(user_drivers, user_clients)}, ) + self._system_engine_connection = system_engine_connection # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() - super().__init__() def cursor(self, **kwargs: Any) -> Cursor: if self.closed: @@ -251,13 +145,6 @@ def cursor(self, **kwargs: Any) -> Cursor: self._cursors.append(c) return c - def _remove_cursor(self, cursor: Cursor) -> None: - # This way it's atomic - try: - self._cursors.remove(cursor) - except ValueError: - pass - def close(self) -> None: if self.closed: return @@ -274,6 +161,9 @@ def close(self) -> None: self._client.close() self._is_closed = True + if self._system_engine_connection: + self._system_engine_connection.close() + # Context manager support def __enter__(self) -> Connection: if self.closed: @@ -287,96 +177,89 @@ def __exit__( def __del__(self) -> None: if not self.closed: - warn(f"Unclosed {self!r}", UserWarning) + warn(f"Unclosed {self!r} {id(self)}", UserWarning) def connect( - database: str = None, - account_name: str = None, - username: Optional[str] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - auth: Auth = None, + auth: Optional[Auth] = None, + account_name: Optional[str] = None, + database: Optional[str] = None, engine_name: Optional[str] = None, - engine_url: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, - use_token_cache: bool = True, additional_parameters: Dict[str, Any] = {}, ) -> Connection: - """Connect to Firebolt database. + """Connect to Firebolt. Args: + `auth` (Auth) Authentication object `database` (str): Name of the database to connect - `username` (Optional[str]): User name to use for authentication (Deprecated) - `password` (Optional[str]): Password to use for authentication (Deprecated) - `access_token` (Optional[str]): Authentication token to use instead of - credentials (Deprecated) - `auth` (Auth)L Authentication object. `engine_name` (Optional[str]): Name of the engine to connect to - `engine_url` (Optional[str]): The engine endpoint to use `account_name` (Optional[str]): For customers with multiple accounts; if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication - `use_token_cache` (bool): Cached authentication token in filesystem - Default: True `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection - Note: - Providing both `engine_name` and `engine_url` will result in an error - """ # These parameters are optional in function signature # but are required to connect. # PEP 249 recommends making them kwargs. - if not database: - raise ConfigurationError("database name is required to connect.") - if not account_name: - raise ConfigurationError("account_name is required to connect.") + for name, value in (("auth", auth), ("account_name", account_name)): + if not value: + raise ConfigurationError(f"{name} is required to connect.") - validate_engine_name_and_url(engine_name, engine_url) + # Type checks + assert auth is not None + assert account_name is not None - if not auth: - if any([username, password, access_token, api_endpoint, use_token_cache]): - auth = _get_auth(username, password, access_token, use_token_cache) - else: - raise ConfigurationError("No authentication provided.") api_endpoint = fix_url_schema(api_endpoint) - # Mypy checks, this should never happen - assert database is not None - - if not engine_name and not engine_url: - engine_url = _get_database_default_engine_url( - database=database, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) + system_engine_url = fix_url_schema( + _get_system_engine_url(auth, account_name, api_endpoint) + ) + # Don't use context manager since this will be stored + # and used in a resulting connection + system_engine_connection = Connection( + system_engine_url, + database, + auth, + account_name, + None, + api_endpoint, + additional_parameters, + ) + if not engine_name: + return system_engine_connection - elif engine_name: - engine_url = _resolve_engine_url( - engine_name=engine_name, - auth=auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - elif account_name: - # In above if branches account name is validated since it's used to - # resolve or get an engine url. - # We need to manually validate account_name if none of the above - # cases are triggered. - with Client( - auth=auth, - base_url=api_endpoint, - account_name=account_name, - api_endpoint=api_endpoint, - ) as client: - client.account_id + else: + try: + engine_url, status, attached_db = _get_engine_url_status_db( + system_engine_connection, engine_name + ) - assert engine_url is not None + if status != "Running": + raise InterfaceError(f"Engine {engine_name} is not running") - engine_url = fix_url_schema(engine_url) - return Connection( - engine_url, database, auth, account_name, api_endpoint, additional_parameters - ) + if database is not None and database != attached_db: + raise InterfaceError( + f"Engine {engine_name} is not attached to {database}, " + f"but to {attached_db}" + ) + elif database is None: + database = attached_db + + assert engine_url is not None + + engine_url = fix_url_schema(engine_url) + return Connection( + engine_url, + database, + auth, + account_name, + system_engine_connection, + api_endpoint, + additional_parameters, + ) + except: # noqa + system_engine_connection.close() + raise diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 69c84a81115..c516d1e8e25 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -4,7 +4,6 @@ import re import time from threading import Lock -from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -16,10 +15,10 @@ Union, ) -from httpx import Client as HttpxClient from httpx import Response, codes from readerwriterlock.rwlock import RWLockWrite +from firebolt.client import Client from firebolt.common._types import ( ColType, Column, @@ -72,7 +71,7 @@ class Cursor(BaseCursor): ) def __init__( - self, *args: Any, client: HttpxClient, connection: Connection, **kwargs: Any + self, *args: Any, client: Client, connection: Connection, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self._query_lock = RWLockWrite() @@ -87,7 +86,9 @@ def _raise_if_error(self, resp: Response) -> None: f"Error executing query:\n{resp.read().decode('utf-8')}" ) if resp.status_code == codes.FORBIDDEN: - if not is_db_available(self.connection, self.connection.database): + if self.connection.database and not is_db_available( + self.connection, self.connection.database + ): raise FireboltDatabaseError( f"Database {self.connection.database} does not exist" ) @@ -105,10 +106,10 @@ def _raise_if_error(self, resp: Response) -> None: def _api_request( self, - query: Optional[str] = "", - parameters: Optional[dict[str, Any]] = {}, - path: Optional[str] = "", - use_set_parameters: Optional[bool] = True, + query: str = "", + parameters: dict[str, Any] = {}, + path: str = "", + use_set_parameters: bool = True, ) -> Response: """ Query API, return Response object. @@ -127,13 +128,14 @@ def _api_request( """ if use_set_parameters: parameters = {**(self._set_parameters or {}), **(parameters or {})} + if self.connection.database: + parameters["database"] = self.connection.database + if self.connection._is_system: + parameters["account_id"] = self._client.account_id return self._client.request( - url=f"/{path}", + url=f"/{path}" if path else "", method="POST", - params={ - "database": self.connection.database, - **(parameters or dict()), - }, + params=parameters, content=query, ) @@ -341,24 +343,6 @@ def get_status(self, query_id: str) -> QueryStatus: return QueryStatus.NOT_READY return QueryStatus[resp_json["status"]] - def close(self) -> None: - """Terminate an ongoing query (if any) and mark connection as closed.""" - self._state = CursorState.CLOSED - self.connection._remove_cursor(self) - - def __del__(self) -> None: - self.close() - - # Context manager support - @check_not_closed - def __enter__(self) -> Cursor: - return self - - def __exit__( - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType - ) -> None: - self.close() - @check_not_closed def cancel(self, query_id: str) -> None: """Cancel a server-side async query.""" @@ -377,3 +361,7 @@ def __iter__(self) -> Generator[List[ColType], None, None]: if row is None: return yield row + + @check_not_closed + def __enter__(self) -> Cursor: + return self diff --git a/src/firebolt/db/util.py b/src/firebolt/db/util.py index 94f24c3be7c..44a3a71de47 100644 --- a/src/firebolt/db/util.py +++ b/src/firebolt/db/util.py @@ -1,26 +1,44 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple -from httpx import URL, Response +from httpx import URL, Timeout, codes -from firebolt.utils.urls import DATABASES_URL, ENGINES_URL +from firebolt.client import Client +from firebolt.client.auth import Auth +from firebolt.common.settings import DEFAULT_TIMEOUT_SECONDS +from firebolt.utils.exception import ( + AccountNotFoundError, + FireboltEngineError, + InterfaceError, +) +from firebolt.utils.urls import DYNAMIC_QUERY, GATEWAY_HOST_BY_ACCOUNT_NAME if TYPE_CHECKING: from firebolt.db.connection import Connection +ENGINE_STATUS_RUNNING = "Running" + def is_db_available(connection: Connection, database_name: str) -> bool: """ Verify that the database exists. Args: - connection (firebolt.async_db.connection.Connection) + connection (firebolt.db.connection.Connection) + database_name (str): Name of a database """ - resp = _filter_request( - connection, DATABASES_URL, {"filter.name_contains": database_name} - ) - return len(resp.json()["edges"]) > 0 + system_engine = connection._system_engine_connection or connection + with system_engine.cursor() as cursor: + return ( + cursor.execute( + """ + SELECT 1 FROM information_schema.databases WHERE database_name=? + """, + [database_name], + ) + > 0 + ) def is_engine_running(connection: Connection, engine_url: str) -> bool: @@ -28,28 +46,60 @@ def is_engine_running(connection: Connection, engine_url: str) -> bool: Verify that the engine is running. Args: - connection (firebolt.async_db.connection.Connection): connection. + connection (firebolt.db.connection.Connection): connection. + engine_url (str): URL of the engine """ - # Url is not guaranteed to be of this structure, - # but for the sake of error checking this is sufficient. + + if connection._is_system: + # System engine is always running + return True + engine_name = URL(engine_url).host.split(".")[0].replace("-", "_") - resp = _filter_request( - connection, - ENGINES_URL, - { - "filter.name_contains": engine_name, - "filter.current_status_eq": "ENGINE_STATUS_RUNNING_REVISION_SERVING", - }, + assert connection._system_engine_connection is not None # Type check + _, status, _ = _get_engine_url_status_db( + connection._system_engine_connection, engine_name ) - return len(resp.json()["edges"]) > 0 + return status == ENGINE_STATUS_RUNNING -def _filter_request(connection: Connection, endpoint: str, filters: dict) -> Response: - resp = connection._client.request( - # Full url overrides the client url, which contains engine as a prefix. - url=connection.api_endpoint + endpoint, - method="GET", - params=filters, - ) - resp.raise_for_status() - return resp +def _get_system_engine_url( + auth: Auth, + account_name: str, + api_endpoint: str, +) -> str: + with Client( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), + ) as client: + # return "https://api.us-east-1.dev.firebolt.io/dynamic" + url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) + response = client.get(url=url) + if response.status_code == codes.NOT_FOUND: + raise AccountNotFoundError(account_name) + if response.status_code != codes.OK: + raise InterfaceError( + f"Unable to retrieve system engine endpoint {url}: " + f"{response.status_code} {response.content}" + ) + return response.json()["engineUrl"] + DYNAMIC_QUERY + + +def _get_engine_url_status_db( + system_engine: Connection, engine_name: str +) -> Tuple[str, str, str]: + with system_engine.cursor() as cursor: + cursor.execute( + """ + SELECT url, attached_to, status FROM information_schema.engines + WHERE engine_name=? + """, + [engine_name], + ) + row = cursor.fetchone() + if row is None: + raise FireboltEngineError(f"Engine with name {engine_name} doesn't exist") + engine_url, database, status = row + return str(engine_url), str(status), str(database) # Mypy check diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 0bdfe6a8360..7cd350ae41a 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -91,10 +91,13 @@ def test_connect_engine_name( get_system_engine_url: str, get_system_engine_callback: Callable, get_engine_url_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): """connect properly handles engine_name""" httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) mock_query() @@ -104,13 +107,16 @@ def test_connect_engine_name( ): httpx_mock.add_callback(callback, url=system_engine_query_url) with raises(InterfaceError): - connect( + c = connect( database=db_name, auth=auth, engine_name=engine_name, account_name=account_name, api_endpoint=server, ) + print(type(c)) + with c: + pass httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) @@ -138,10 +144,13 @@ def test_connect_database( system_engine_no_db_query_url: str, get_system_engine_url: str, get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ): httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(query_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) with connect( database=None, auth=auth, @@ -153,6 +162,7 @@ def test_connect_database( httpx_mock.reset(True) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(query_callback, url=system_engine_query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) with connect( database=db_name, @@ -163,8 +173,8 @@ def test_connect_database( assert connection.cursor().execute("select*") == len(python_query_data) -def test_connection_unclosed_warnings(): - c = Connection("", "", None, "", None) +def test_connection_unclosed_warnings(auth: Auth): + c = Connection("", "", auth, "", None) with warns(UserWarning) as winfo: del c gc.collect() @@ -174,8 +184,8 @@ def test_connection_unclosed_warnings(): ), "Invalid unclosed connection warning" -def test_connection_no_warnings(): - c = Connection("", "", None, "", None) +def test_connection_no_warnings(auth: Auth): + c = Connection("", "", auth, "", None) c.close() with warnings.catch_warnings(): warnings.simplefilter("error") @@ -248,7 +258,7 @@ def test_connect_with_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.db.connection.get_user_agent_header") as ut: ut.return_value = "MyConnector/1.0 DriverA/1.1" mock_connection_flow() httpx_mock.add_callback( @@ -284,7 +294,7 @@ def test_connect_no_user_agent( query_url: str, mock_connection_flow: Callable, ) -> None: - with patch("firebolt.async_db.connection.get_user_agent_header") as ut: + with patch("firebolt.db.connection.get_user_agent_header") as ut: ut.return_value = "Python/3.0" mock_connection_flow() httpx_mock.add_callback( diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 1ad3064d55e..c5a740ac0d9 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -334,8 +334,8 @@ def set_params() -> Dict: @fixture def query_url(settings: Settings, db_name: str) -> str: return URL( - f"https://{settings.server}/?database={db_name}" - f"&output_format={JSON_OUTPUT_FORMAT}" + f"https://{settings.server}/", + params={"output_format": JSON_OUTPUT_FORMAT, "database": db_name}, ) From d08b08e334f00395d1cf378d635d4757e770d0b6 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 7 Jun 2023 13:05:12 +0300 Subject: [PATCH 23/28] fix integration tests --- src/firebolt/async_db/connection.py | 3 +- src/firebolt/db/connection.py | 3 +- src/firebolt/utils/exception.py | 3 + tests/integration/conftest.py | 5 +- .../dbapi/async/test_errors_async.py | 3 +- .../dbapi/async/test_queries_async.py | 111 +--------- .../dbapi/async/test_system_engine.py | 204 ++++++++++++++++++ tests/integration/dbapi/sync/conftest.py | 81 +++---- tests/integration/dbapi/sync/test_errors.py | 28 +-- tests/integration/dbapi/sync/test_queries.py | 47 ++-- .../dbapi/sync/test_system_engine.py | 81 +++++-- tests/integration/dbapi/utils.py | 9 + 12 files changed, 344 insertions(+), 234 deletions(-) create mode 100644 tests/integration/dbapi/async/test_system_engine.py create mode 100644 tests/integration/dbapi/utils.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 0be7f850e1b..4b56635b0ec 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -25,6 +25,7 @@ from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, + EngineNotRunningError, InterfaceError, ) from firebolt.utils.usage_tracker import get_user_agent_header @@ -236,7 +237,7 @@ async def connect( ) if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") + raise EngineNotRunningError(engine_name) if database is not None and database != attached_db: raise InterfaceError( diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 6a62c255376..adb8bf46e71 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -24,6 +24,7 @@ from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, + EngineNotRunningError, InterfaceError, ) from firebolt.utils.usage_tracker import get_user_agent_header @@ -238,7 +239,7 @@ def connect( ) if status != "Running": - raise InterfaceError(f"Engine {engine_name} is not running") + raise EngineNotRunningError(engine_name) if database is not None and database != attached_db: raise InterfaceError( diff --git a/src/firebolt/utils/exception.py b/src/firebolt/utils/exception.py index b5e39420be1..a5e87212ecb 100644 --- a/src/firebolt/utils/exception.py +++ b/src/firebolt/utils/exception.py @@ -9,6 +9,9 @@ class FireboltEngineError(FireboltError): class EngineNotRunningError(FireboltEngineError): """Engine that's being queried is not running.""" + def __init__(self, engine_name: str): + super().__init__(f"Engine {engine_name} is not running") + class NoAttachedDatabaseError(FireboltEngineError): """Engine that's being accessed is not running. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7798bad930e..05a86144f58 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -24,8 +24,9 @@ def must_env(var_name: str) -> str: @fixture(scope="session") -def rm_settings(api_endpoint, auth) -> Settings: +def rm_settings(api_endpoint, auth, account_name) -> Settings: return Settings( + account_name=account_name, server=api_endpoint, auth=auth, default_region="us-east-1", @@ -67,6 +68,6 @@ def service_secret() -> str: return must_env(SERVICE_SECRET_ENV) -@fixture +@fixture(scope="session") def auth(service_id, service_secret) -> ClientCredentials: return ClientCredentials(service_id, service_secret) diff --git a/tests/integration/dbapi/async/test_errors_async.py b/tests/integration/dbapi/async/test_errors_async.py index bd5e6f40842..6af46dc14df 100644 --- a/tests/integration/dbapi/async/test_errors_async.py +++ b/tests/integration/dbapi/async/test_errors_async.py @@ -4,6 +4,7 @@ from firebolt.client.auth import ClientCredentials from firebolt.utils.exception import ( AccountNotFoundError, + EngineNotRunningError, FireboltEngineError, InterfaceError, OperationalError, @@ -60,7 +61,7 @@ async def test_engine_stopped( api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" - with raises(InterfaceError): + with raises(EngineNotRunningError): async with await connect( engine_name=stopped_engine_name, database=database_name, diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 0eb5cc26512..4a961cddd2c 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -1,6 +1,6 @@ from datetime import date, datetime from decimal import Decimal -from typing import Any, List +from typing import List from pytest import mark, raises @@ -13,67 +13,13 @@ ) from firebolt.async_db.cursor import QueryStatus from firebolt.common._types import ColType, Column +from tests.integration.dbapi.utils import assert_deep_eq VALS_TO_INSERT_2 = ",".join( [f"({i}, {i-3}, '{val}')" for (i, val) in enumerate(range(4, 1000))] ) LONG_INSERT = f"INSERT INTO test_tbl VALUES {VALS_TO_INSERT_2}" -CREATE_EXTERNAL_TABLE = """CREATE EXTERNAL TABLE IF NOT EXISTS ex_lineitem ( - l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT -) -URL = 's3://firebolt-publishing-public/samples/tpc-h/parquet/lineitem/' -OBJECT_PATTERN = '*.parquet' -TYPE = (PARQUET);""" - -CREATE_FACT_TABLE = """CREATE FACT TABLE IF NOT EXISTS lineitem ( --- In this example, these fact table columns --- map directly to the external table columns. - l_orderkey LONG, - l_partkey LONG, - l_suppkey LONG, - l_linenumber INT, - l_quantity LONG, - l_extendedprice LONG, - l_discount LONG, - l_tax LONG, - l_returnflag TEXT, - l_linestatus TEXT, - l_shipdate TEXT, - l_commitdate TEXT, - l_receiptdate TEXT, - l_shipinstruct TEXT, - l_shipmode TEXT, - l_comment TEXT -) -PRIMARY INDEX - l_orderkey, - l_linenumber; -""" - - -def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: - if type(got) == list and type(expected) == list: - all([assert_deep_eq(f, s, msg) for f, s in zip(got, expected)]) - assert ( - type(got) == type(expected) and got == expected - ), f"{msg}: {got}(got) != {expected}(expected)" - async def status_loop( query_id: str, @@ -488,56 +434,3 @@ async def test_bytea_roundtrip( assert ( bytes_data.decode("utf-8") == data ), "Invalid bytea data returned after roundtrip" - - -async def test_system_engine( - connection_system_engine: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_system_engine_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - with connection_system_engine.cursor() as c: - assert await c.execute(all_types_query) == 1, "Invalid row count returned" - assert c.rowcount == 1, "Invalid rowcount value" - data = await c.fetchall() - assert len(data) == c.rowcount, "Invalid data length" - assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") - assert c.description == all_types_query_description, "Invalid description value" - assert len(data[0]) == len(c.description), "Invalid description length" - assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" - - # Different fetch types - await c.execute(all_types_query) - assert ( - await c.fetchone() == all_types_query_system_engine_response[0] - ), "Invalid fetchone data" - assert await c.fetchone() is None, "Redundant data returned by fetchone" - - await c.execute(all_types_query) - assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" - data = await c.fetchmany() - assert len(data) == 1, "Invalid data size returned by fetchmany" - assert_deep_eq( - data, - all_types_query_system_engine_response, - "Invalid data returned by fetchmany", - ) - - -async def test_system_engine_no_db( - connection_system_engine_no_db: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_system_engine_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - await test_system_engine( - connection_system_engine_no_db, - all_types_query, - all_types_query_description, - all_types_query_system_engine_response, - timezone_name, - ) diff --git a/tests/integration/dbapi/async/test_system_engine.py b/tests/integration/dbapi/async/test_system_engine.py new file mode 100644 index 00000000000..5be98a0ece5 --- /dev/null +++ b/tests/integration/dbapi/async/test_system_engine.py @@ -0,0 +1,204 @@ +from typing import List + +from pytest import fixture, mark, raises + +from firebolt.async_db import Connection +from firebolt.common._types import ColType, Column +from firebolt.utils.exception import OperationalError +from tests.integration.dbapi.utils import assert_deep_eq + + +@fixture +def db_name(database_name): + return database_name + "_system_test" + + +@fixture +def second_db_name(database_name): + return database_name + "_system_test_two" + + +@fixture +def region(): + return "us-east-1" + + +@fixture +def engine_name(engine_name): + return engine_name + "_system_test" + + +@fixture +async def setup_dbs( + connection_system_engine, db_name, second_db_name, engine_name, region +): + with connection_system_engine.cursor() as cursor: + + await cursor.execute(f"DROP DATABASE IF EXISTS {db_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") + + await cursor.execute(create_database(name=db_name)) + + await cursor.execute(create_engine(engine_name, engine_specs(region))) + + await cursor.execute( + create_database(name=second_db_name, specs=db_specs(region, engine_name)) + ) + + yield + + await cursor.execute(f"DROP ENGINE IF EXISTS {engine_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {db_name}") + await cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") + + +async def test_system_engine( + connection_system_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + with connection_system_engine.cursor() as c: + assert await c.execute(all_types_query) == 1, "Invalid row count returned" + assert c.rowcount == 1, "Invalid rowcount value" + data = await c.fetchall() + assert len(data) == c.rowcount, "Invalid data length" + assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") + assert c.description == all_types_query_description, "Invalid description value" + assert len(data[0]) == len(c.description), "Invalid description length" + assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" + + # Different fetch types + await c.execute(all_types_query) + assert ( + await c.fetchone() == all_types_query_system_engine_response[0] + ), "Invalid fetchone data" + assert await c.fetchone() is None, "Redundant data returned by fetchone" + + await c.execute(all_types_query) + assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" + data = await c.fetchmany() + assert len(data) == 1, "Invalid data size returned by fetchmany" + assert_deep_eq( + data, + all_types_query_system_engine_response, + "Invalid data returned by fetchmany", + ) + + +async def test_system_engine_no_db( + connection_system_engine_no_db: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + await test_system_engine( + connection_system_engine_no_db, + all_types_query, + all_types_query_description, + all_types_query_system_engine_response, + timezone_name, + ) + + +def engine_specs(region): + return f"REGION = '{region}' " "SPEC = 'B1' " "SCALE = 1" + + +def create_database(name, specs=None): + query = f"CREATE DATABASE {name}" + query += f" WITH {specs}" if specs else "" + return query + + +def create_engine(name, specs=None): + query = f"CREATE ENGINE {name}" + query += f" WITH {specs}" if specs else "" + return query + + +def db_specs(region, attached_engine): + return ( + f"REGION = '{region}' " + f"ATTACHED_ENGINES = ('{attached_engine}') " + "DESCRIPTION = 'Sample description'" + ) + + +@mark.parametrize( + "query", + ["CREATE DIMENSION TABLE dummy(id INT)"], +) +async def test_query_errors(connection_system_engine, query): + with connection_system_engine.cursor() as cursor: + with raises(OperationalError): + await cursor.execute(query) + + +@mark.xdist_group(name="system_engine") +async def test_show_databases(setup_dbs, connection_system_engine, db_name): + with connection_system_engine.cursor() as cursor: + + await cursor.execute("SHOW DATABASES") + + dbs = [row[0] for row in await cursor.fetchall()] + + assert db_name in dbs + assert f"{db_name}_two" in dbs + + +@mark.xdist_group(name="system_engine") +async def test_detach_engine( + setup_dbs, connection_system_engine, engine_name, second_db_name +): + async def check_engine_exists(cursor, engine_name, db_name): + await cursor.execute("SHOW ENGINES") + engines = await cursor.fetchall() + # Results have the following columns + # engine_name, region, spec, scale, status, attached_to, version + assert engine_name in [row[0] for row in engines] + assert (engine_name, db_name) in [(row[0], row[5]) for row in engines] + + with connection_system_engine.cursor() as cursor: + await check_engine_exists(cursor, engine_name, db_name=second_db_name) + await cursor.execute(f"DETACH ENGINE {engine_name} FROM {second_db_name}") + + # When engine not attached db is - + await check_engine_exists(cursor, engine_name, db_name="-") + + await cursor.execute(f"ATTACH ENGINE {engine_name} TO {second_db_name}") + await check_engine_exists(cursor, engine_name, db_name=second_db_name) + + +@mark.xdist_group(name="system_engine") +async def test_alter_engine(setup_dbs, connection_system_engine, engine_name): + with connection_system_engine.cursor() as cursor: + await cursor.execute(f"ALTER ENGINE {engine_name} SET AUTO_STOP = 60") + + await cursor.execute( + "SELECT engine_name, auto_stop FROM information_schema.engines" + ) + engines = await cursor.fetchall() + assert [engine_name, 3600] in engines + + +@mark.xdist_group(name="system_engine") +async def test_start_stop_engine(setup_dbs, connection_system_engine, engine_name): + async def check_engine_status(cursor, engine_name, status): + await cursor.execute("SHOW ENGINES") + engines = await cursor.fetchall() + # Results have the following columns + # engine_name, region, spec, scale, status, attached_to, version + assert engine_name in [row[0] for row in engines] + assert (engine_name, status) in [(row[0], row[4]) for row in engines] + + with connection_system_engine.cursor() as cursor: + await check_engine_status(cursor, engine_name, "Stopped") + await cursor.execute(f"START ENGINE {engine_name}") + await check_engine_status(cursor, engine_name, "Running") + await cursor.execute(f"STOP ENGINE {engine_name}") + await check_engine_status(cursor, engine_name, "Stopped") diff --git a/tests/integration/dbapi/sync/conftest.py b/tests/integration/dbapi/sync/conftest.py index e8c4447ea08..2b561c7fb29 100644 --- a/tests/integration/dbapi/sync/conftest.py +++ b/tests/integration/dbapi/sync/conftest.py @@ -5,91 +5,64 @@ @fixture -def username_password_connection( - engine_url: str, - database_name: str, - password_auth: Auth, - account_name: str, - api_endpoint: str, -) -> Connection: - connection = connect( - engine_url=engine_url, - database=database_name, - auth=password_auth, - account_name=account_name, - api_endpoint=api_endpoint, - ) - yield connection - connection.close() - - -@fixture -async def connection( - engine_url: str, +def connection( + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( - engine_url=engine_url, + with connect( + engine_name=engine_name, database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection @fixture -def connection_engine_name( +def connection_no_db( engine_name: str, - database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( + with connect( engine_name=engine_name, - database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection @fixture -def connection_no_engine( +def connection_system_engine( database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( + with connect( database=database_name, - auth=password_auth, + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection -@fixture(scope="session") -def connection_system_engine( - password_auth: Auth, +@fixture +def connection_system_engine_no_db( + auth: Auth, account_name: str, api_endpoint: str, ) -> Connection: - connection = connect( - database="dummy", - engine_name="system", - auth=password_auth, + with connect( + auth=auth, account_name=account_name, api_endpoint=api_endpoint, - ) - yield connection - connection.close() + ) as connection: + yield connection diff --git a/tests/integration/dbapi/sync/test_errors.py b/tests/integration/dbapi/sync/test_errors.py index 193dbc617b1..088041d9833 100644 --- a/tests/integration/dbapi/sync/test_errors.py +++ b/tests/integration/dbapi/sync/test_errors.py @@ -1,4 +1,3 @@ -from httpx import ConnectError from pytest import mark, raises from firebolt.client.auth import ClientCredentials @@ -35,32 +34,17 @@ def test_invalid_account( ), "Invalid account error message." -def test_engine_url_not_exists( - engine_url: str, - database_name: str, - auth: ClientCredentials, - api_endpoint: str, -) -> None: - """Connection properly reacts to invalid engine url error.""" - with connect( - engine_url=engine_url + "_", - database=database_name, - auth=auth, - api_endpoint=api_endpoint, - ) as connection: - with raises(ConnectError): - connection.cursor().execute("show tables") - - def test_engine_name_not_exists( engine_name: str, database_name: str, auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid engine name error.""" with raises(FireboltEngineError): with connect( + account_name=account_name, engine_name=engine_name + "_________", database=database_name, auth=auth, @@ -70,15 +54,17 @@ def test_engine_name_not_exists( def test_engine_stopped( - stopped_engine_url: str, + stopped_engine_name: str, database_name: str, auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to engine not running error.""" with raises(EngineNotRunningError): with connect( - engine_url=stopped_engine_url, + account_name=account_name, + engine_name=stopped_engine_name, database=database_name, auth=auth, api_endpoint=api_endpoint, @@ -91,11 +77,13 @@ def test_database_not_exists( engine_url: str, database_name: str, auth: ClientCredentials, + account_name: str, api_endpoint: str, ) -> None: """Connection properly reacts to invalid database error.""" new_db_name = database_name + "_" with connect( + account_name=account_name, engine_url=engine_url, database=new_db_name, auth=auth, diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 44eb217c8d0..32c7320fb16 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -16,6 +16,7 @@ OperationalError, connect, ) +from tests.integration.dbapi.utils import assert_deep_eq VALS_TO_INSERT = ",".join([f"({i},'{val}')" for (i, val) in enumerate(range(1, 360))]) LONG_INSERT = f"INSERT INTO test_tbl VALUES {VALS_TO_INSERT}" @@ -52,8 +53,8 @@ def status_loop( ), f"Failed {query}. Got {status} rather than {final_status}." -def test_connect_engine_name( - connection_engine_name: Connection, +def test_connect_no_db( + connection_no_db: Connection, all_types_query: str, all_types_query_description: List[Column], all_types_query_response: List[ColType], @@ -61,24 +62,7 @@ def test_connect_engine_name( ) -> None: """Connecting with engine name is handled properly.""" test_select( - connection_engine_name, - all_types_query, - all_types_query_description, - all_types_query_response, - timezone_name, - ) - - -def test_connect_no_engine( - connection_no_engine: Connection, - all_types_query: str, - all_types_query_description: List[Column], - all_types_query_response: List[ColType], - timezone_name: str, -) -> None: - """Connecting with engine name is handled properly.""" - test_select( - connection_no_engine, + connection_no_db, all_types_query, all_types_query_description, all_types_query_response, @@ -149,10 +133,10 @@ def test_long_query( def test_drop_create(connection: Connection) -> None: """Create and drop table/index queries are handled properly.""" - def test_query(c: Cursor, query: str) -> None: + def test_query(c: Cursor, query: str, empty_response=True) -> None: c.execute(query) assert c.description == None - assert c.rowcount == -1 + assert c.rowcount == (-1 if empty_response else 0) """Create table query is handled properly""" with connection.cursor() as c: @@ -188,6 +172,7 @@ def test_query(c: Cursor, query: str) -> None: c, "CREATE AGGREGATING INDEX test_drop_create_db_agg_idx ON " "test_drop_create_tb(id, sum(f), count(dt))", + empty_response=False, ) # Drop join index @@ -378,9 +363,9 @@ def test_set_invalid_parameter(connection: Connection): # Run test multiple times since the issue is flaky @mark.parametrize("_", range(5)) def test_anyio_backend_import_issue( - engine_url: str, + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, _: int, @@ -391,13 +376,13 @@ def test_anyio_backend_import_issue( exceptions = [] def run_query(idx: int): - nonlocal password_auth, database_name, engine_url, account_name, api_endpoint + nonlocal auth, database_name, engine_name, account_name, api_endpoint try: with connect( - auth=password_auth, + auth=auth, database=database_name, account_name=account_name, - engine_url=engine_url, + engine_name=engine_name, api_endpoint=api_endpoint, ) as c: cursor = c.cursor() @@ -468,9 +453,9 @@ async def test_server_side_async_execution_get_status( def test_multi_thread_connection_sharing( - engine_url: str, + engine_name: str, database_name: str, - password_auth: Auth, + auth: Auth, account_name: str, api_endpoint: str, ) -> None: @@ -484,10 +469,10 @@ def test_multi_thread_connection_sharing( exceptions = [] connection = connect( - auth=password_auth, + auth=auth, database=database_name, account_name=account_name, - engine_url=engine_url, + engine_name=engine_name, api_endpoint=api_endpoint, ) diff --git a/tests/integration/dbapi/sync/test_system_engine.py b/tests/integration/dbapi/sync/test_system_engine.py index 2474c323dfc..8deea187e35 100644 --- a/tests/integration/dbapi/sync/test_system_engine.py +++ b/tests/integration/dbapi/sync/test_system_engine.py @@ -1,29 +1,34 @@ +from typing import List + from pytest import fixture, mark, raises +from firebolt.common._types import ColType, Column +from firebolt.db import Connection from firebolt.utils.exception import OperationalError +from tests.integration.dbapi.utils import assert_deep_eq -@fixture(scope="module") +@fixture def db_name(database_name): return database_name + "_system_test" -@fixture(scope="module") +@fixture def second_db_name(database_name): return database_name + "_system_test_two" -@fixture(scope="module") +@fixture def region(): return "us-east-1" -@fixture(scope="module") +@fixture def engine_name(engine_name): return engine_name + "_system_test" -@fixture(scope="module") +@fixture def setup_dbs(connection_system_engine, db_name, second_db_name, engine_name, region): with connection_system_engine.cursor() as cursor: @@ -42,6 +47,59 @@ def setup_dbs(connection_system_engine, db_name, second_db_name, engine_name, re cursor.execute(f"DROP DATABASE IF EXISTS {second_db_name}") +def test_system_engine( + connection_system_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + with connection_system_engine.cursor() as c: + assert c.execute(all_types_query) == 1, "Invalid row count returned" + assert c.rowcount == 1, "Invalid rowcount value" + data = c.fetchall() + assert len(data) == c.rowcount, "Invalid data length" + assert_deep_eq(data, all_types_query_system_engine_response, "Invalid data") + assert c.description == all_types_query_description, "Invalid description value" + assert len(data[0]) == len(c.description), "Invalid description length" + assert len(c.fetchall()) == 0, "Redundant data returned by fetchall" + + # Different fetch types + c.execute(all_types_query) + assert ( + c.fetchone() == all_types_query_system_engine_response[0] + ), "Invalid fetchone data" + assert c.fetchone() is None, "Redundant data returned by fetchone" + + c.execute(all_types_query) + assert len(c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" + data = c.fetchmany() + assert len(data) == 1, "Invalid data size returned by fetchmany" + assert_deep_eq( + data, + all_types_query_system_engine_response, + "Invalid data returned by fetchmany", + ) + + +def test_system_engine_no_db( + connection_system_engine_no_db: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_system_engine_response: List[ColType], + timezone_name: str, +) -> None: + """Connecting with engine name is handled properly.""" + test_system_engine( + connection_system_engine_no_db, + all_types_query, + all_types_query_description, + all_types_query_system_engine_response, + timezone_name, + ) + + def engine_specs(region): return f"REGION = '{region}' " "SPEC = 'B1' " "SCALE = 1" @@ -114,11 +172,11 @@ def check_engine_exists(cursor, engine_name, db_name): @mark.xdist_group(name="system_engine") def test_alter_engine(setup_dbs, connection_system_engine, engine_name): with connection_system_engine.cursor() as cursor: - cursor.execute(f"ALTER ENGINE {engine_name} SET SPEC = B2") + cursor.execute(f"ALTER ENGINE {engine_name} SET AUTO_STOP = 60") - cursor.execute("SHOW ENGINES") + cursor.execute("SELECT engine_name, auto_stop FROM information_schema.engines") engines = cursor.fetchall() - assert (engine_name, "B2") in [(row[0], row[2]) for row in engines] + assert [engine_name, 3600] in engines @mark.xdist_group(name="system_engine") @@ -137,10 +195,3 @@ def check_engine_status(cursor, engine_name, status): check_engine_status(cursor, engine_name, "Running") cursor.execute(f"STOP ENGINE {engine_name}") check_engine_status(cursor, engine_name, "Stopped") - - -@mark.xdist_group(name="system_engine") -def test_select_one(connection_system_engine): - """SELECT statements are supported""" - with connection_system_engine.cursor() as cursor: - cursor.execute("SELECT 1") diff --git a/tests/integration/dbapi/utils.py b/tests/integration/dbapi/utils.py new file mode 100644 index 00000000000..603c58a7d7b --- /dev/null +++ b/tests/integration/dbapi/utils.py @@ -0,0 +1,9 @@ +from typing import Any + + +def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: + if type(got) == list and type(expected) == list: + all([assert_deep_eq(f, s, msg) for f, s in zip(got, expected)]) + assert ( + type(got) == type(expected) and got == expected + ), f"{msg}: {got}(got) != {expected}(expected)" From 8afbb389e443bb21483e7fa50475d90dce09be71 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 7 Jun 2023 14:37:46 +0300 Subject: [PATCH 24/28] update trio version --- setup.cfg | 2 +- tests/unit/async_db/test_connection.py | 9 +++++---- tests/unit/db/test_connection.py | 9 +++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 17fdbd0a149..dffdee12641 100755 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ install_requires = python-dateutil>=2.8.2 sqlparse>=0.4.2 tricycle>=0.2.2 - trio<0.22.0 + trio>=0.22.0 python_requires = >=3.7 include_package_data = True package_dir = diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 9b047305e80..ac63f1f4aa3 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -12,6 +12,7 @@ from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, + EngineNotRunningError, InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage @@ -101,12 +102,12 @@ async def test_connect_engine_name( mock_query() - for callback in ( - get_engine_url_invalid_db_callback, - get_engine_url_not_running_callback, + for callback, err_cls in ( + (get_engine_url_invalid_db_callback, InterfaceError), + (get_engine_url_not_running_callback, EngineNotRunningError), ): httpx_mock.add_callback(callback, url=system_engine_query_url) - with raises(InterfaceError): + with raises(err_cls): async with await connect( database=db_name, auth=auth, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 7cd350ae41a..89ef9e741cb 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -14,6 +14,7 @@ from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, + EngineNotRunningError, InterfaceError, ) from firebolt.utils.token_storage import TokenSecureStorage @@ -101,12 +102,12 @@ def test_connect_engine_name( mock_query() - for callback in ( - get_engine_url_invalid_db_callback, - get_engine_url_not_running_callback, + for callback, err_cls in ( + (get_engine_url_invalid_db_callback, InterfaceError), + (get_engine_url_not_running_callback, EngineNotRunningError), ): httpx_mock.add_callback(callback, url=system_engine_query_url) - with raises(InterfaceError): + with raises(err_cls): c = connect( database=db_name, auth=auth, From 19a225ffdaae88748ba0c6bcca8b6acaec1485b8 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 14 Jun 2023 15:13:54 +0300 Subject: [PATCH 25/28] address comments --- docsrc/firebolt.common.rst | 2 +- src/firebolt/async_db/connection.py | 1 - src/firebolt/async_db/util.py | 1 - src/firebolt/utils/util.py | 58 +--------------- .../dbapi/async/test_errors_async.py | 3 +- .../dbapi/async/test_system_engine.py | 9 +++ .../dbapi/sync/test_system_engine.py | 9 +++ tests/unit/common/test_util.py | 68 ------------------- 8 files changed, 22 insertions(+), 129 deletions(-) delete mode 100644 tests/unit/common/test_util.py diff --git a/docsrc/firebolt.common.rst b/docsrc/firebolt.common.rst index d2c022d7f6b..561ec83516d 100644 --- a/docsrc/firebolt.common.rst +++ b/docsrc/firebolt.common.rst @@ -42,7 +42,7 @@ Util --------------------------- .. automodule:: firebolt.common.util - :exclude-members: async_to_sync, cached_property, fix_url_schema, mixin_for, prune_dict + :exclude-members: cached_property, fix_url_schema, mixin_for, prune_dict :members: :undoc-members: :show-inheritance: diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 4b56635b0ec..f73fb28627a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -194,7 +194,6 @@ async def connect( `database` (str): Name of the database to connect `engine_name` (Optional[str]): Name of the engine to connect to `account_name` (Optional[str]): For customers with multiple accounts; - if none, default is used `api_endpoint` (str): Firebolt API endpoint. Used for authentication `additional_parameters` (Optional[Dict]): Dictionary of less widely-used arguments for connection diff --git a/src/firebolt/async_db/util.py b/src/firebolt/async_db/util.py index 94cde980306..180c3efa6a6 100644 --- a/src/firebolt/async_db/util.py +++ b/src/firebolt/async_db/util.py @@ -74,7 +74,6 @@ async def _get_system_engine_url( api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), ) as client: - # return "https://api.us-east-1.dev.firebolt.io/dynamic" url = GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name=account_name) response = await client.get(url=url) if response.status_code == codes.NOT_FOUND: diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 4d3457bb48d..c7b7533b847 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,12 +1,8 @@ -from contextlib import contextmanager -from functools import lru_cache, partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Generator, Type, TypeVar +from functools import lru_cache +from typing import TYPE_CHECKING, Callable, Type, TypeVar -import trio from httpx import URL -from firebolt.utils.exception import ConfigurationError - T = TypeVar("T") @@ -87,56 +83,6 @@ def get_auth_endpoint(api_endpoint: URL) -> URL: ) -@contextmanager -def nested_loop() -> Generator: - from trio._core._run import GLOBAL_RUN_CONTEXT # type: ignore - - s = object() - task, runner, _dict = s, s, s - if hasattr(GLOBAL_RUN_CONTEXT, "__dict__"): - _dict = GLOBAL_RUN_CONTEXT.__dict__ - if hasattr(GLOBAL_RUN_CONTEXT, "task"): - task = GLOBAL_RUN_CONTEXT.task - del GLOBAL_RUN_CONTEXT.task - if hasattr(GLOBAL_RUN_CONTEXT, "runner"): - runner = GLOBAL_RUN_CONTEXT.runner - del GLOBAL_RUN_CONTEXT.runner - - try: - yield - finally: - if task is not s: - GLOBAL_RUN_CONTEXT.task = task - elif hasattr(GLOBAL_RUN_CONTEXT, "task"): - del GLOBAL_RUN_CONTEXT.task - - if runner is not s: - GLOBAL_RUN_CONTEXT.runner = runner - elif hasattr(GLOBAL_RUN_CONTEXT, "runner"): - del GLOBAL_RUN_CONTEXT.runner - - if _dict is not s: - GLOBAL_RUN_CONTEXT.__dict__.update(_dict) - - -def async_to_sync(f: Callable) -> Callable: - """Convert async function to sync. - - Args: - f (Callable): function to convert - - Returns: - Callable: regular function, which can be executed synchronously - """ - - @wraps(f) - def sync(*args: Any, **kwargs: Any) -> Any: - with nested_loop(): - return trio.run(partial(f, *args, **kwargs)) - - return sync - - def merge_urls(base: URL, merge: URL) -> URL: """Merge a base and merge urls. diff --git a/tests/integration/dbapi/async/test_errors_async.py b/tests/integration/dbapi/async/test_errors_async.py index 6af46dc14df..127fd327532 100644 --- a/tests/integration/dbapi/async/test_errors_async.py +++ b/tests/integration/dbapi/async/test_errors_async.py @@ -1,4 +1,4 @@ -from pytest import mark, raises +from pytest import raises from firebolt.async_db import Connection, connect from firebolt.client.auth import ClientCredentials @@ -72,7 +72,6 @@ async def test_engine_stopped( await connection.cursor().execute("show tables") -@mark.skip(reason="Behaviour is different in prod vs dev") async def test_database_not_exists( engine_name: str, database_name: str, diff --git a/tests/integration/dbapi/async/test_system_engine.py b/tests/integration/dbapi/async/test_system_engine.py index 5be98a0ece5..f04f27d04c3 100644 --- a/tests/integration/dbapi/async/test_system_engine.py +++ b/tests/integration/dbapi/async/test_system_engine.py @@ -87,6 +87,15 @@ async def test_system_engine( "Invalid data returned by fetchmany", ) + if connection_system_engine.database: + await c.execute("show tables") + with raises(OperationalError): + await c.execute("create table test(id int) primary index id") + else: + await c.execute("show databases") + with raises(OperationalError): + await c.execute("show tables") + async def test_system_engine_no_db( connection_system_engine_no_db: Connection, diff --git a/tests/integration/dbapi/sync/test_system_engine.py b/tests/integration/dbapi/sync/test_system_engine.py index 8deea187e35..58b4ec10418 100644 --- a/tests/integration/dbapi/sync/test_system_engine.py +++ b/tests/integration/dbapi/sync/test_system_engine.py @@ -82,6 +82,15 @@ def test_system_engine( "Invalid data returned by fetchmany", ) + if connection_system_engine.database: + c.execute("show tables") + with raises(OperationalError): + c.execute("create table test(id int) primary index id") + else: + c.execute("show databases") + with raises(OperationalError): + c.execute("show tables") + def test_system_engine_no_db( connection_system_engine_no_db: Connection, diff --git a/tests/unit/common/test_util.py b/tests/unit/common/test_util.py deleted file mode 100644 index 1b02b7b3dc9..00000000000 --- a/tests/unit/common/test_util.py +++ /dev/null @@ -1,68 +0,0 @@ -from asyncio import run -from threading import Thread - -from pytest import raises - -from firebolt.utils.util import async_to_sync - - -def test_async_to_sync_happy_path(): - """async_to_sync properly converts coroutine to sync function.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - for i in range(3): - with raises(JobMarker): - async_to_sync(task)() - - -def test_async_to_sync_thread(): - """async_to_sync works correctly in threads.""" - - marks = [False] * 3 - - async def task(id: int): - marks[id] = True - - ts = [Thread(target=async_to_sync(task), args=[i]) for i in range(3)] - [t.start() for t in ts] - [t.join() for t in ts] - assert all(marks) - - -def test_async_to_sync_after_run(): - """async_to_sync runs correctly after asyncio.run.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - with raises(JobMarker): - run(task()) - - # Here local event loop is closed by run - - with raises(JobMarker): - async_to_sync(task)() - - -async def test_nested_loops() -> None: - """async_to_sync works correctly inside a running loop.""" - - class JobMarker(Exception): - pass - - async def task(): - raise JobMarker() - - with raises(JobMarker): - await task() - - with raises(JobMarker): - async_to_sync(task)() From 0352ae111001e6262848af5d6991bffb906cbddf Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 14 Jun 2023 15:26:18 +0300 Subject: [PATCH 26/28] update pytest --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index dffdee12641..f8697cccc23 100755 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ dev = mypy==0.910 pre-commit==2.15.0 pyfakefs>=4.5.3 - pytest==6.2.5 + pytest==7.2.0 pytest-cov==3.0.0 pytest-httpx==0.22.0 pytest-mock==3.6.1 From 6c57e8e48fcdfd310ad53c04b2a7817c10bbe2d2 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 14 Jun 2023 15:33:52 +0300 Subject: [PATCH 27/28] fix code checks --- tests/unit/common/test_settings.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unit/common/test_settings.py b/tests/unit/common/test_settings.py index 508afda9de7..084046ff2bf 100644 --- a/tests/unit/common/test_settings.py +++ b/tests/unit/common/test_settings.py @@ -1,9 +1,6 @@ import os -from typing import Tuple from unittest.mock import Mock, patch -from pytest import mark, raises - from firebolt.client.auth import Auth from firebolt.common.settings import Settings @@ -36,4 +33,6 @@ def test_no_deprecation_warning_with_env(logger_mock: Mock): assert s.server == "dummy.firebolt.io" assert s.auth is not None, "Settings.auth wasn't populated from env variables" assert s.auth.client_id == "client_id", "Invalid username in Settings.auth" - assert s.auth.client_secret == "client_secret", "Invalid password in Settings.auth" + assert ( + s.auth.client_secret == "client_secret" + ), "Invalid password in Settings.auth" From 0239c0afe06290a34bcabe7f9ba5421cee953738 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 14 Jun 2023 15:35:07 +0300 Subject: [PATCH 28/28] add missing requirement --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index f8697cccc23..fee2c304831 100755 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ install_requires = httpx[http2]==0.24.0 pydantic[dotenv]>=1.8.2 python-dateutil>=2.8.2 + readerwriterlock>=1.0.9 sqlparse>=0.4.2 tricycle>=0.2.2 trio>=0.22.0